Compare commits

..

1 commit

14 changed files with 949 additions and 1140 deletions

View file

@ -1,9 +1,7 @@
try:
from collections import Callable
except:
from typing import Callable # Python 3.10+
from record import *
from reader import RecordReader
VERSION = (0, 2025, 0)
VERSION = (0, 2012, 0)
RECORD_TYPES = [
'SubmitterRecord',
@ -15,156 +13,129 @@ RECORD_TYPES = [
'OptionalTotalRecord',
'StateTotalRecord',
'FinalRecord'
]
]
def test():
import record, model
from fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError, e:
print e.msg, type(inst), inst.record_identifier
continue
print type(inst), inst.record_identifier, output_length
def get_record_types():
from . import record
def test_dump():
import record, StringIO
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = StringIO.StringIO()
dump(records, out)
return out
def test_record_order():
import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
import record
types = {}
for r in RECORD_TYPES:
klass = record.__dict__[r]
types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM
while True:
record_ident = fp.read(ident_length)
if not record_ident:
break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
while fp.tell() < fp.len:
record_ident = fp.read(2)
if record_ident in types:
record = types[record_ident]()
record.read(fp)
yield record
def loads(s, record_types=get_record_types()):
import io
fp = io.BytesIO(s)
return load(fp, record_types)
def loads(s):
import StringIO
fp = StringIO.StringIO(s)
return load(fp)
def dump(fp, records, delim=None):
if type(delim) is str:
delim = delim.encode('ascii')
def dump(records, fp):
for r in records:
fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records, delim=None, skip_validation=False):
import io
fp = io.BytesIO()
if not skip_validation:
for record in records:
record.validate()
dump(fp, records, delim=delim)
def dumps(records):
import StringIO
fp = StringIO.StringIO()
dump(records, fp)
fp.seek(0)
return fp.read()
def json_dumps(records):
import json
import model
import decimal
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
if hasattr(o, 'toJSON') and callable(getattr(o, 'toJSON')):
return o.toJSON()
if type(o) is bytes:
return o.decode('ascii')
elif isinstance(o, decimal.Decimal):
return str(o.quantize(decimal.Decimal('0.01')))
return super(JSONEncoder, self).default(o)
return json.dumps(list(records), cls=JSONEncoder, indent=2)
return json.dumps(records, cls=JSONEncoder, indent=2)
def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
def json_loads(s, record_classes):
import json
from . import fields
import fields
import decimal
import re
if not isinstance(record_types, dict):
record_types = dict([ (x.__name__, x) for x in record_types])
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
def object_hook(o):
if '__class__' in o:
klass = o['__class__']
if klass in record_types:
record = record_types[klass]()
record.fromJSON(o)
return record
if klass in record_classes:
return record_classes[klass]().fromJSON(o)
elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o)
return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE.
@ -180,15 +151,14 @@ def validate_required_records(records):
while req_types:
req = req_types[0]
if req not in types:
from .fields import ValidationError
from fields import ValidationError
raise ValidationError("Record set missing required record: %s" % req)
else:
req_types.remove(req)
def validate_record_order(records):
from . import record
from .fields import ValidationError
import record
from fields import ValidationError
# 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord):
@ -208,10 +178,10 @@ def validate_record_order(records):
if not isinstance(records[i+1], record.EmployeeWageRecord):
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)])
num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)])
num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)])
num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)])
num_ro_records = len(filter(lambda x:isinstance(x, record.OptionalEmployeeWageRecord), records))
num_ru_records = len(filter(lambda x:isinstance(x, record.OptionalTotalRecord), records))
num_employer_records = len(filter(lambda x:isinstance(x, record.EmployerRecord), records))
num_total_records = len(filter(lambda x: isinstance(x, record.TotalRecord), records))
# a TotalRecord is required for each instance of an EmployeeRecord
if num_total_records != num_employer_records:
@ -224,7 +194,7 @@ def validate_record_order(records):
num_ro_records, num_ru_records))
# FinalRecord - Must appear only once on each file.
if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1:
if len(filter(lambda x:isinstance(x, record.FinalRecord), records)) != 1:
raise ValidationError("Incorrect number of FinalRecords")
def validate_records(records):
@ -237,8 +207,13 @@ def test_unique_fields():
r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord()
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
print 'r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter
print 'r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter
if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -300,8 +300,7 @@ countries = (
('EH', 'Western Sahara'),
('YE', 'Yemen'),
('ZM', 'Zambia'),
('ZW', 'Zimbabwe'),
)
('ZW', 'Zimbabwe'))
employer_types = (
@ -323,7 +322,6 @@ employment_codes = (
)
tax_jurisdiction_codes = (
(' ', 'W-2'),
('V', 'Virgin Islands'),
('G', 'Guam'),
('S', 'American Samoa'),

View file

@ -1,10 +1,6 @@
import decimal, datetime
import inspect
from six import string_types
from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
import enums
class ValidationError(Exception):
def __init__(self, msg, field=None):
@ -20,26 +16,22 @@ class ValidationError(Exception):
class Field(object):
creation_counter = 0
is_read_only = False
_value = None
def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None):
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
self.name = name
self._value = None
self._orig_value = None
self.min_length = min_length
self.max_length = max_length
self.blank = blank
self.required = required
self.uppercase = uppercase
self.creation_counter = creation_counter or Field.creation_counter
Field.creation_counter += 1
def validate(self):
raise NotImplementedError
raise NotImplemented
def get_data(self):
raise NotImplementedError
raise NotImplemented
def __setvalue(self, value):
self._value = value
@ -84,7 +76,7 @@ class Field(object):
required=o['required'],
)
if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
if isinstance(o['value'], basestring) and re.match('^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value'])
self.value = o['value']
@ -98,10 +90,14 @@ class Field(object):
wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False)
wrapper.width = 100
value = wrapper.wrap(value)
value = list([(" " * 9) + ('"' + x + '"') for x in value])
value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10))
value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))])))
#value = textwrap.wrap(value, 100)
#print value
value = list(map(lambda x:(" " * 9) + ('"' + x + '"'), value))
#value[0] = '"' + value[0] + '"'
value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * (wrapper.width / 10))
value.append(" " * 10 + ''.join((map(lambda x:str(x) + (' ' * 9), range(wrapper.width / 10 )))))
#value.append((" " * 59) + map(lambda x:("%x" % x), range(16))
start = counter['c']
counter['c'] += len(self._orig_value or self.value)
@ -119,28 +115,22 @@ class Field(object):
class TextField(Field):
def validate(self):
if self.value is None and self.required:
if self.value == None and self.required:
raise ValidationError("value required", field=self)
data = self.get_data()
if len(data) > self.max_length:
if len(self.get_data()) > self.max_length:
raise ValidationError("value is too long", field=self)
stripped_data_length = len(data.strip())
if stripped_data_length < self.min_length:
raise ValidationError("value is too short", field=self)
if stripped_data_length == 0 and (not self.blank and self.required):
raise ValidationError("field cannot be blank", field=self)
def get_data(self):
value = str(self.value or '').encode('ascii') or b''
value = self.value or ""
if self.uppercase:
value = value.upper()
return value.ljust(self.max_length)[:self.max_length]
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
def __setvalue(self, value):
# NO NEWLINES
try:
value = value.replace('\n', '').replace('\r', '')
except AttributeError:
except AttributeError, e:
pass
self._value = value
@ -152,35 +142,31 @@ class TextField(Field):
class StateField(TextField):
def __init__(self, name=None, required=True, use_numeric=False, max_length=2):
super(StateField, self).__init__(name=name, max_length=max_length, required=required)
super(StateField, self).__init__(name=name, max_length=2, required=required)
self.use_numeric = use_numeric
def get_data(self):
value = str(self.value or 'XX')
value = self.value or ""
if value.strip() and self.use_numeric:
postcode = enums.state_postal_numeric[value.upper()]
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length)
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
else:
formatted = value.encode('ascii').ljust(self.max_length)
return formatted[:self.max_length]
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
def validate(self):
super(StateField, self).validate()
if self.value and self.value.upper() not in list(enums.state_postal_numeric.keys()):
if self.value and self.value.upper() not in enums.state_postal_numeric.keys():
raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self)
def parse(self, s):
if s.strip() and self.use_numeric:
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
states = dict( [(v,k) for (k,v) in enums.state_postal_numeric.items()] )
self.value = states[int(s)]
else:
self.value = s
class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None):
super(EmailField, self).__init__(name=name, max_length=max_length,
return super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False)
class IntegerField(TextField):
@ -192,58 +178,37 @@ class IntegerField(TextField):
except ValueError:
raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self):
value = str(self.value).encode('ascii') if self.value else b''
return value.zfill(self.max_length)[:self.max_length]
value = self.value or ""
return str(value).zfill(self.max_length)[:self.max_length]
def parse(self, s):
if not is_blank_space(s):
self.value = int(s)
else:
self.value = 0
class StaticField(TextField):
def __init__(self, name=None, required=True, value=None, uppercase=False):
super(StaticField, self).__init__(name=name,
required=required,
max_length=len(value),
uppercase=uppercase)
self._static_value = value
def __init__(self, name=None, required=True, value=None):
super(StaticField, self).__init__(name=name, required=required,
max_length=len(value))
self._value = value
def parse(self, s):
pass
class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False):
super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self):
return b' ' * self.max_length
return " " * self.max_length
def parse(self, s):
pass
def validate(self):
if len(self.get_data()) != self.max_length:
raise ValidationError("blank field did not match expected length", field=self)
class ZeroField(BlankField):
is_read_only = True
def get_data(self):
return b'0' * self.max_length
class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False):
super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value):
self._value = value
@ -254,12 +219,11 @@ class CRLFField(TextField):
value = property(__getvalue, __setvalue)
def get_data(self):
return b'\r\n'
return '\r\n'
def parse(self, s):
self.value = s
class BooleanField(Field):
def __init__(self, name=None, required=True, value=None):
super(BooleanField, self).__init__(name=name, required=required, max_length=1)
@ -269,7 +233,7 @@ class BooleanField(Field):
pass
def get_data(self):
return b'1' if self._value else b'0'
return '1' if self._value else '0'
def parse(self, s):
self.value = (s == '1')
@ -286,43 +250,26 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self)
def get_data(self):
cents = int((self.value or 0) * 100)
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
def parse(self, s):
if not is_blank_space(s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField):
def __init__(self, name=None, required=True, value=None):
super(DateField, self).__init__(name=name, required=required, max_length=8)
super(TextField, self).__init__(name=name, required=required, max_length=8)
if value:
self.value = value
def get_data(self):
if self._value:
return self._value.strftime('%m%d%Y').encode('ascii')
return b'0' * self.max_length
return self._value.strftime('%m%d%Y')
return '0' * self.max_length
def parse(self, s):
if int(s) > 0:
self.value = datetime.date(*[int(x) for x in (s[4:8], s[0:2], s[2:4])])
self.value = datetime.date(*[int(x) for x in s[4:8], s[0:2], s[2:4]])
else:
self.value = None
@ -330,7 +277,7 @@ class DateField(TextField):
if isinstance(value, datetime.date):
self._value = value
elif value:
self._value = datetime.date(*[int(x) for x in (value[4:8], value[0:2], value[2:4])])
self._value = datetime.date(*[int(x) for x in value[4:8], value[0:2], value[2:4]])
else:
self._value = None
@ -342,18 +289,19 @@ class DateField(TextField):
class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None):
super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
super(TextField, self).__init__(name=name, required=required, max_length=6)
if value:
self.value = value
def get_data(self):
if self._value:
return str(self._value.strftime('%m%Y').encode('ascii'))
return b'0' * self.max_length
return self._value.strftime("%m%Y")
return '0' * self.max_length
def parse(self, s):
if int(s) > 0:
self.value = datetime.date(*[int(x) for x in (s[2:6], s[0:2], 1)])
self.value = datetime.date(*[int(x) for x in s[2:6], s[0:2], 1])
else:
self.value = None
@ -361,7 +309,7 @@ class MonthYearField(TextField):
if isinstance(value, datetime.date):
self._value = value
elif value:
self._value = datetime.date(*[int(x) for x in (value[2:6], value[0:2], 1)])
self._value = datetime.date(*[int(x) for x in value[2:6], value[0:2], 1])
else:
self._value = None
@ -369,3 +317,4 @@ class MonthYearField(TextField):
return self._value
value = property(__getvalue, __setvalue)

View file

@ -1,19 +1,15 @@
from .fields import Field, TextField, ValidationError
from fields import Field, TextField, ValidationError
import copy
import collections
import pdb
class Model(object):
record_length = -1
record_identifier = ' '
required = False
target_size = 512
def __init__(self):
if self.record_length == -1:
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()):
for (key, value) in self.__class__.__dict__.items():
if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
# AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE,
@ -23,31 +19,21 @@ class Model(object):
if not src_field.name:
setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__)
new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = new_field_instance.value
self.__dict__[key] = new_field_instance
self.__dict__[key] = copy.copy(src_field)
def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field):
self.set_field_value(key, value)
getattr(self, key).value = value
else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value
def set_field_value(self, field_name, value):
getattr(self, field_name).value = value
def get_fields(self):
identifier = TextField(
"record_identifier",
max_length = len(self.record_identifier),
blank = len(self.record_identifier) == 0,
creation_counter=-1)
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
identifier.value = self.record_identifier
fields = [identifier]
for key in list(self.__class__.__dict__.keys()):
for key in self.__class__.__dict__.keys():
attr = getattr(self, key)
if isinstance(attr, Field):
fields.append(attr)
@ -55,7 +41,7 @@ class Model(object):
def get_sorted_fields(self):
fields = self.get_fields()
fields.sort(key=lambda x: x.creation_counter)
fields.sort(key=lambda x:x.creation_counter)
return fields
def validate(self):
@ -64,33 +50,27 @@ class Model(object):
try:
custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError:
except AttributeError, e:
continue
if isinstance(custom_validator, collections.Callable):
if callable(custom_validator):
custom_validator(f)
def output(self, format='binary'):
if format == 'text':
return self.output_text()
return self.output_efile()
def output(self):
result = ''.join([field.get_data() for field in self.get_sorted_fields()])
def output_efile(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if self.record_length < 0 or len(result) != self.record_length:
if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp):
# Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]:
field.read(fp)
def toJSON(self):
return {
'__class__': self.__class__.__name__,
@ -100,17 +80,19 @@ class Model(object):
def fromJSON(self, o):
fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields:
target = self.__dict__[f.name]
if (target.required != f.required
or target.max_length != f.max_length):
print("Warning: value mismatch on import")
if (target.required != f.required or
target.max_length != f.max_length):
print "Warning: value mismatch on import"
target.value = f.value
target._value = f._value
#print (self.__dict__[f.name].name == f.name)
#self.__dict__[f.name].name == f.name
#self.__dict__[f.name].max_length == f.max_length
return self

View file

@ -1,8 +1,8 @@
#!/usr/bin/env python
import re
class ClassEntryCommentSequence(object):
re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line):
self.classname = classname,
@ -22,15 +22,13 @@ class ClassEntryCommentSequence(object):
if (i + 1) != a:
line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
line_number, line.split(' ')[0].strip(), i+1, a)))
print("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a))
i = int(b) if b else a
class ModelDefParser(object):
re_triplequote = re.compile('"""')
re_whitespace = re.compile(r"^(\s*)[^\s]+")
re_whitespace = re.compile("^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass):
@ -72,7 +70,7 @@ class ModelDefParser(object):
classmatch = self.re_classdef.match(line)
if classmatch:
classname, _subclass = classmatch.groups()
classname, subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue
@ -84,3 +82,5 @@ class ModelDefParser(object):
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,3 +1,5 @@
#!/usr/bin/python
# coding=UTF-8
"""
Parser utility to read data from Publication 1220 and
convert it into python classes.
@ -5,7 +7,6 @@ convert it into python classes.
"""
import re
import hashlib
from functools import reduce
class SimpleDefParser(object):
def __init__(self):
@ -33,7 +34,7 @@ class SimpleDefParser(object):
item = item.upper()
if '-' in item:
parts = [self._intify(x) for x in item.split('-')]
parts = map(lambda x:self._intify(x), item.split('-'))
item = reduce(lambda x,y: y-x, parts)
else:
item = self._intify(item)
@ -55,7 +56,7 @@ class LengthExpression(object):
self.exp_cache = {}
def __call__(self, value, exps):
return len(exps) == sum([self.check(value, x) for x in exps])
return len(exps) == sum(map(lambda x: self.check(value, x), exps))
def compile_exp(self, exp):
op, val = self.REG.match(exp).groups()
@ -97,7 +98,7 @@ class RangeToken(BaseToken):
def value(self):
if '-' not in self._value:
return 1
return reduce(lambda x,y: y-x, list(map(int, self._value.split('-'))))+1
return reduce(lambda x,y: y-x, map(int, self._value.split('-')))+1
@property
def end_position(self):
@ -109,7 +110,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken):
regexp = re.compile(r'^(\d+)$')
regexp = re.compile('^(\d+)$')
@property
def value(self):
@ -117,7 +118,7 @@ class NumericToken(BaseToken):
class RecordBuilder(object):
from . import fields
import fields
entry_max_length = 4
@ -144,7 +145,8 @@ class RecordBuilder(object):
(re.compile(r'zero\-filled', re.IGNORECASE), +1),
(re.compile(r'leading zeroes', re.IGNORECASE), +1),
(re.compile(r'left\-justif', re.IGNORECASE), -1),
(re.compile(r'left-\justif', re.IGNORECASE), -1),
],
},
}),
@ -199,15 +201,15 @@ class RecordBuilder(object):
try:
f_length = int(f_length)
except ValueError as e:
except ValueError, e:
# bad result, skip
continue
try:
assert f_length == RangeToken(f_range).value
except AssertionError as e:
except AssertionError, e:
continue
except ValueError as e:
except ValueError, e:
# bad result, skip
continue
@ -221,7 +223,7 @@ class RecordBuilder(object):
else:
required = None
f_name = '_'.join([x.lower() for x in name_parts])
f_name = u'_'.join(map(lambda x:x.lower(), name_parts))
f_name = f_name.replace('&', 'and')
f_name = re.sub(r'[^\w]','', f_name)
@ -238,7 +240,7 @@ class RecordBuilder(object):
lengthexp = LengthExpression()
for entry in entries:
matches = dict([(x[0],0) for x in self.FIELD_TYPES])
matches = dict(map(lambda x:(x[0],0), self.FIELD_TYPES))
for (classtype, criteria) in self.FIELD_TYPES:
if 'length' in criteria:
@ -246,7 +248,7 @@ class RecordBuilder(object):
continue
if 'regexp' in criteria:
for crit_key, crit_values in list(criteria['regexp'].items()):
for crit_key, crit_values in criteria['regexp'].items():
for (crit_re, score) in crit_values:
matches[classtype] += score if crit_re.search(entry[crit_key]) else 0
@ -254,7 +256,7 @@ class RecordBuilder(object):
matches = list(matches.items())
matches.sort(key=lambda x:x[1])
matches_found = True if sum([x[1] for x in matches]) > 0 else False
matches_found = True if sum(map(lambda x:x[1], matches)) > 0 else False
entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField
yield entry
@ -269,7 +271,7 @@ class RecordBuilder(object):
if entry['name'] == 'blank':
blank_id = hashlib.new('md5')
blank_id.update(entry['range'].encode())
add( ('blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
add( (u'blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
else:
add(entry['name'].ljust(40))
@ -384,7 +386,7 @@ class PastedDefParser(RecordBuilder):
for g in groups:
assert g['byterange'].value == g['length'].value
desc = ' '.join([str(x.value) for x in g['desc']])
desc = u' '.join(map(lambda x:unicode(x.value), g['desc']))
if g['name'][-1].value.lower() == '(optional)':
g['name'] = g['name'][0:-1]
@ -394,7 +396,7 @@ class PastedDefParser(RecordBuilder):
else:
required = None
name = '_'.join([x.value.lower() for x in g['name']])
name = u'_'.join(map(lambda x:x.value.lower(), g['name']))
name = re.sub(r'[^\w]','', name)
yield({

View file

@ -3,102 +3,314 @@
import subprocess
import re
import itertools
import fitz
import pdb
""" pdftotext -layout -nopgbrk p1220.pdf - """
def strip_values(items):
expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE)
return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x]
class PDFRecordFinder(object):
field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$')
def __init__(self, src, heading_exp=None):
if not heading_exp:
heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout')
def __init__(self, src):
self.document = fitz.open(src)
field_heading_exp = re.compile('^Field.*Field.*Length.*Description')
def find_record_table_ranges(self):
matches = []
for (page_number, page) in enumerate(self.document):
header_rects = page.search_for("Record Name:")
for header_match_rect in header_rects:
header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: "
header_match_rect.x1 = page.bound().x1 # Extend to right side of page
header_text = page.get_textbox(header_match_rect)
record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip()
matches.append((record_name, {
'page': page_number,
'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably
}))
return matches
def find_records(self):
record_ranges = self.find_record_table_ranges()
for record_index, (record_name, record_details) in enumerate(record_ranges):
current_rows = []
next_index = record_index+1
(_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1})
for page_number in range(record_details['page'], next_record_details['page']):
page = self.document[page_number]
table_search_rect = page.bound()
if page_number == record_details['page']:
table_search_rect.y0 = record_details['y']
tables = page.find_tables(
clip = table_search_rect,
min_words_horizontal = 1,
min_words_vertical = 1,
horizontal_strategy = "lines_strict",
intersection_tolerance = 1,
)
for table in tables:
if table.col_count == 4:
table = table.extract()
# Parse field position (sometimes a cell has multiple
# values because IRS employees apparently smoke crack
for row in table:
first_column_lines = row[0].strip().split('\n')
if len(first_column_lines) > 1:
for sub_row in self.split_row(row):
current_rows.append(strip_values(sub_row))
else:
current_rows.append(strip_values(row))
consecutive_rows = self.filter_nonconsecutive_rows(current_rows)
yield(record_name, consecutive_rows)
def split_row(self, row):
if not row[1]:
return []
split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None))
description = strip_values([row[3]])[0]
rows = []
for row in split_rows:
if len(row) < 3 or not row[2]:
row = self.infer_field_length(row)
rows.append([*row, description])
return rows
def infer_field_length(self, row):
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
return row
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
length = str(end-start+1) if end and start else '1'
return (*row[:2], length)
def filter_nonconsecutive_rows(self, rows):
consecutive_rows = []
last_position = 0
for row in rows:
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
continue
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
if start != last_position + 1:
continue
last_position = end if end else start
consecutive_rows.append(row)
return consecutive_rows
opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-']
pdftext = subprocess.check_output(opts)
self.textrows = pdftext.split('\n')
self.heading_exp = heading_exp
self.field_heading_exp = field_heading_exp
def records(self):
return self.find_records()
headings = self.locate_heading_rows_by_field()
#for x in headings:
# print x
for (start, end, name) in headings:
name = name.decode('ascii', 'ignore')
yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end))
def locate_heading_rows_by_field(self):
results = []
record_break = []
line_is_whitespace_exp = re.compile('^(\s*)$')
record_begin_exp = self.heading_exp #re.compile('Record\ Name')
for (i, row) in enumerate(self.textrows):
match = self.field_heading_exp.match(row)
if match:
# work backwards until we think the header is fully copied
space_count_exp = re.compile('^(\s*)')
position = i - 1
spaces = 0
#last_spaces = 10000
complete = False
header = None
while not complete:
line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False
is_record_begin = record_begin_exp.search(self.textrows[position])
if is_record_begin or line_is_whitespace:
header = self.textrows[position-1:i]
complete = True
position -= 1
name = ''.join(header).strip().decode('ascii','ignore')
print (name, position)
results.append((i, name, position))
else:
# See if this row forces us to break from field reading.
if re.search('Record\ Layout', row):
record_break.append(i)
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]):
end_pos = None
#print a[0], record_break[0], b[0]-1
while record_break and record_break[0] < a[0]:
record_break = record_break[1:]
if record_break[0] < b[0]-1:
end_pos = record_break[0]
record_break = record_break[1:]
else:
end_pos = b[0]-1
merged.append( (a[0], end_pos-1, a[1]) )
return merged
"""
def locate_heading_rows(self):
results = []
for (i, row) in enumerate(self.textrows):
match = self.heading_exp.match(row)
if match:
results.append((i, ''.join(match.groups())))
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]):
merged.append( (a[0], b[0]-1, a[1]) )
return merged
def locate_layout_block_rows(self):
# Search for rows that contain "Record Layout", as these are not fields
# we are interested in because they contain the crazy blocks of field definitions
# and not the nice 4-column ones that we're looking for.
results = []
for (i, row) in enumerate(self.textrows):
match = re.match("Record Layout", row)
"""
def find_fields(self, row_iter):
cc = ColumnCollector()
blank_row_counter = 0
for r in row_iter:
row = r.decode('UTF-8')
#print row
row_columns = self.extract_columns_from_row(row)
if not row_columns:
if cc.data and len(cc.data.keys()) > 1 and len(row.strip()) > cc.data.keys()[-1]:
yield cc
cc = ColumnCollector()
else:
cc.empty_row()
continue
try:
cc.add(row_columns)
except IsNextField, e:
yield cc
cc = ColumnCollector()
cc.add(row_columns)
except UnknownColumn, e:
raise StopIteration
yield cc
def extract_columns_from_row(self, row):
re_multiwhite = re.compile(r'\s{2,}')
# IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE
if not re_multiwhite.search(row):
return None
white_ranges = [0,]
pos = 0
while pos < len(row):
match = re_multiwhite.search(row[pos:])
if match:
white_ranges.append(pos + match.start())
white_ranges.append(pos + match.end())
pos += match.end()
else:
white_ranges.append(len(row))
pos = len(row)
row_result = []
white_iter = iter(white_ranges)
while white_iter:
try:
start = white_iter.next()
end = white_iter.next()
if start != end:
row_result.append(
(start, row[start:end].encode('ascii','ignore'))
)
except StopIteration:
white_iter = None
#print row_result
return row_result
class UnknownColumn(Exception):
pass
class IsNextField(Exception):
pass
class ColumnCollector(object):
def __init__(self, initial=None):
self.data = None
self.column_widths = None
self.max_data_length = 0
self.adjust_pad = 3
self.empty_rows = 0
pass
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
map(lambda x:x if len(x) < 25 else x[:25] + '..',
self.data.values() if self.data else ''))
def add(self, data):
#if self.empty_rows > 2:
# raise IsNextField()
if not self.data:
self.data = dict(data)
else:
data = self.adjust_columns(data)
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
self.update_column_widths(data)
def empty_row(self):
self.empty_rows += 1
def update_column_widths(self, data):
self.last_data_length = len(data)
self.max_data_length = max(self.max_data_length, len(data))
if not self.column_widths:
self.column_widths = dict(map(lambda (column, value): [column, column + len(value)], data))
else:
for col_id, value in data:
try:
self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip()))
except KeyError:
pass
def add_old(self, data):
if not self.data:
self.data = dict(data)
else:
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
def adjust_columns(self, data):
adjusted_data = {}
for col_id, value in data:
if col_id in self.data.keys():
adjusted_data[col_id] = value.strip()
else:
for col_start, col_end in self.column_widths.items():
if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id:
if col_start in adjusted_data:
adjusted_data[col_start] += ' ' + value.strip()
else:
adjusted_data[col_start] = value.strip()
return adjusted_data.items()
def merge_column(self, col_id, value):
if col_id in self.data.keys():
self.data[col_id] += ' ' + value.strip()
else:
# try adding a wiggle room value?
# FIXME:
# Sometimes description columns contain column-like
# layouts, and this causes the ColumnCollector to become
# confused. Perhaps we could check to see if a column occurs
# after the maximum column, and assume it's part of the
# max column?
"""
for col_start, col_end in self.column_widths.items():
if col_start <= col_id and (col_end) >= col_id:
self.data[col_start] += ' ' + value.strip()
return
"""
raise UnknownColumn
def is_next_field(self, data):
"""
If the first key value contains a string
and we already have some data in the record,
then this row is probably the beginning of
the next field. Raise an exception and continue
on with a fresh ColumnCollector.
"""
""" If the length of the value in column_id is less than the position of the next column_id,
then this is probably a continuation.
"""
if self.data and data:
keys = dict(self.column_widths).keys()
keys.sort()
keys += [None]
if self.last_data_length < len(data):
return True
first_key, first_value = dict(data).items()[0]
if self.data.keys()[0] == first_key:
position = keys.index(first_key)
max_length = keys[position + 1]
if max_length:
return len(first_value) > max_length or len(data) == self.max_data_length
return False
@property
def tuple(self):
#try:
if self.data:
return tuple(map(lambda k:self.data[k], sorted(self.data.keys())))
return ()
#except:
# import pdb
# pdb.set_trace()

View file

@ -1,13 +1,11 @@
from . import model
from .fields import *
from . import enums
import model
from fields import *
import enums
__all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord',
'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
'TotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'FinalRecord', 'StateWageRecord',
'StateTotalRecordIA',
]
'StateTotalRecord', 'FinalRecord', 'StateWageRecord']
class EFW2Record(model.Model):
record_length = 512
@ -105,8 +103,8 @@ class EmployerRecord(EFW2Record):
zipcode_ext = TextField(max_length=4, required=False)
kind_of_employer = TextField(max_length=1)
blank1 = BlankField(max_length=4)
foreign_state_province = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False)
foreign_state_province = TextField(max_length=23)
foreign_postal_code = TextField(max_length=15)
country_code = TextField(max_length=2, required=False)
employment_code = TextField(max_length=1)
tax_jurisdiction_code = TextField(max_length=1, required=False)
@ -150,7 +148,7 @@ class EmployeeWageRecord(EFW2Record):
ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False)
employee_middle_name = TextField(max_length=15)
employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22)
@ -163,7 +161,7 @@ class EmployeeWageRecord(EFW2Record):
blank1 = BlankField(max_length=5)
foreign_state = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False)
country = TextField(max_length=2, required=True, blank=True)
country = TextField(max_length=2)
wages_tips = MoneyField(max_length=11)
federal_income_tax_withheld = MoneyField(max_length=11)
social_security_wages = MoneyField(max_length=11)
@ -199,10 +197,8 @@ class EmployeeWageRecord(EFW2Record):
blank6 = BlankField(max_length=23)
def validate_ssn(self, f):
if str(f.value).startswith('666'):
raise ValidationError("ssn cannot start with 666", field=f)
if str(f.value).startswith('9'):
raise ValidationError("ssn cannot start with 9", field=f)
if str(f.value).startswith('666','9'):
raise ValidationError("ssn cannot start with 666 or 9", field=f)
@ -245,7 +241,7 @@ class StateWageRecord(EFW2Record):
taxing_entity_code = TextField(max_length=5, required=False)
ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False)
employee_middle_name = TextField(max_length=15)
employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22)
@ -259,20 +255,20 @@ class StateWageRecord(EFW2Record):
foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False)
optional_code = TextField(max_length=2, required=False)
reporting_period = MonthYearField(required=False)
reporting_period = MonthYearField()
quarterly_unemp_ins_wages = MoneyField(max_length=11)
quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11)
number_of_weeks_worked = IntegerField(max_length=2, required=False)
number_of_weeks_worked = IntegerField(max_length=2)
date_first_employed = DateField(required=False)
date_of_separation = DateField(required=False)
blank2 = BlankField(max_length=5)
state_employer_account_num = IntegerField(max_length=20, required=False)
state_employer_account_num = TextField(max_length=20)
blank3 = BlankField(max_length=6)
state_code_2 = StateField(use_numeric=True)
state_taxable_wages = MoneyField(max_length=11)
state_income_tax_wh = MoneyField(max_length=11)
other_state_data = TextField(max_length=10, required=False)
tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F
tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F
local_taxable_wages = MoneyField(max_length=11)
local_income_tax_wh = MoneyField(max_length=11)
state_control_number = IntegerField(max_length=7, required=False)
@ -282,8 +278,7 @@ class StateWageRecord(EFW2Record):
def validate_tax_type_code(self, field):
choices = [x for x,y in enums.tax_type_codes]
value = field.value
if value and value.upper() not in choices:
if field.value.upper() not in choices:
raise ValidationError("%s not one of %s" % (field.value,choices), field=f)
@ -359,17 +354,6 @@ class StateTotalRecord(EFW2Record):
supplemental_data = TextField(max_length=510)
class StateTotalRecordIA(EFW2Record):
#year=2018
record_identifier = 'RV'
number_of_rs_records = IntegerField(max_length=7) # num records since last 'RE' record
wages_tips = MoneyField(max_length=15)
state_income_tax_wh = MoneyField(max_length=15)
employer_ben = TextField(max_length=8)
iowa_confirmation_number = ZeroField(max_length=10)
blank1 = BlankField(max_length=455)
class FinalRecord(EFW2Record):
#year=2012
record_identifier = 'RF'

View file

@ -1 +0,0 @@
PyMuPDF==1.24.0

View file

@ -1,76 +0,0 @@
#!/usr/bin/env python
import pyaccuwage
import argparse
import os, os.path
import sys
"""
Command line tool for converting IRS e-file fixed field records
to/from JSON or a simple text format.
Attempts to load record types from a python module in the current working
directory named record_types.py
The module must export a RECORD_TYPES list with the names of the classes to
import as valid record types.
"""
def get_record_types():
try:
sys.path.append(os.getcwd())
import record_types
r = {}
for record_type in record_types.RECORD_TYPES:
r[record_type] = getattr(record_types, record_type)
return r
except ImportError:
print('warning: using default record types (failed to import record_types.py)')
return pyaccuwage.get_record_types()
def read_file(fd, filename, record_types):
filename, extension = os.path.splitext(filename)
if extension == '.json':
return pyaccuwage.json_load(fd, record_types)
elif extension == '.txt':
return pyaccuwage.text_load(fd, record_types)
else:
return pyaccuwage.load(fd, record_types)
def write_file(outfile, filename, records):
filename, extension = os.path.splitext(filename)
if extension == '.json':
pyaccuwage.json_dump(outfile, records)
elif extension == '.txt':
pyaccuwage.text_dump(outfile, records)
else:
pyaccuwage.dump(outfile, records)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Convert accuwage efile data between different formats."
)
parser.add_argument("-i", '--input',
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('r'),
help="Source file to convert")
parser.add_argument("-o", "--output",
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('w'),
help="Destination file to output")
args = parser.parse_args()
in_file = args.input[0]
out_file = args.output[0]
records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records)
print("wrote {} records to {}".format(len(records), out_file.name))

View file

@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/python
from pyaccuwage.parser import RecordBuilder
from pyaccuwage.pdfextract import PDFRecordFinder
import argparse
@ -29,9 +29,48 @@ doc = PDFRecordFinder(source_file)
records = doc.records()
builder = RecordBuilder()
for (name, fields) in records:
name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1])
name = re.sub(r'[^\w]*', '', name)
def record_begins_at(field):
return int(fields[0].data.values()[0].split('-')[0], 10)
def record_ends_at(fields):
return int(fields[-1].data.values()[0].split('-')[-1], 10)
last_record_begins_at = -1
last_record_ends_at = -1
for rec in records:
#if not rec[1]:
# continue # no actual fields detected
fields = rec[1]
# strip out fields that are not 4 items long
fields = filter(lambda x:len(x.tuple) == 4, fields)
# strip fields that don't begin at position 0
fields = filter(lambda x: 0 in x.data, fields)
# strip fields that don't have a length-range type item in position 0
fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields)
if not fields:
continue
begins_at = record_begins_at(fields)
ends_at = record_ends_at(fields)
# FIXME record_ends_at is randomly exploding due to record data being
# a lump of text and not necessarily a field entry. I assume
# this is cleaned out by the record builder class.
#print last_record_ends_at + 1, begins_at
if last_record_ends_at + 1 != begins_at:
name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1])
name = re.sub('[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x: x, fields[0:])):
for field in builder.load(map(lambda x:x.tuple, rec[1][0:])):
sys.stdout.write('\t' + field + '\n')
#print field
last_record_ends_at = ends_at

View file

@ -1,21 +1,12 @@
from setuptools import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
from distutils.core import setup
setup(name='pyaccuwage',
version='0.2025.0',
version='0.2012.1',
packages=['pyaccuwage'],
scripts=[
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-convert',
'scripts/pyaccuwage-genfieldfill',
'scripts/pyaccuwage-parse',
'scripts/pyaccuwage-pdfparse',
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-genfieldfill'
],
zip_safe=True,
test_suite='setup.pyaccuwage_tests',
)

View file

@ -1,67 +0,0 @@
import unittest
from pyaccuwage.fields import TextField
from pyaccuwage.fields import StaticField
# from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField
# from pyaccuwage.fields import ZeroField
# from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')
def testStringUnsetOptional(self):
field = TextField(max_length=6, required=False)
field.validate()
self.assertEqual(field.get_data(), b' ' * 6)
def testStringRequiredUnassigned(self):
field = TextField(max_length=6)
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredNonBlank(self):
field = TextField(max_length=6)
field.value = ''
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredBlank(self):
field = TextField(max_length=6, blank=True)
field.value = ''
field.validate()
self.assertEqual(len(field.get_data()), 6)
def testStringMinimumLength(self):
field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect
field.value = '' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '12345' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '123456' # one character too short
class TestStaticField(unittest.TestCase):
def test_static_field(self):
field = StaticField(value='TEST')
self.assertEqual(field.get_data(), b'TEST')

View file

@ -1,179 +0,0 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import StaticField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=12)
static1 = StaticField(value='hey mister!!')
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 12,
b'hey mister!!',
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
static1: hey mister!!
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
text2 = TextField(max_length=20, required=False)
blank2 = BlankField(max_length=4)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
class TestRequiredFields(unittest.TestCase):
def createTestRecord(self, required=False, blank=False):
class Record(pyaccuwage.model.Model):
record_length = 16
record_identifier = ''
test_field = TextField(max_length=16, required=required, blank=blank)
record = Record()
def dump():
return pyaccuwage.dumps([record])
return (record, dump)
def testRequiredBlankField(self):
(record, dump) = self.createTestRecord(required=True, blank=True)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value may be empty string
dump()
def testRequiredNonblankField(self):
(record, dump) = self.createTestRecord(required=True, blank=False)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value must not be empty string
self.assertRaises(ValidationError, dump)
record.test_field.value = 'hello'
dump()
def testOptionalBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=True)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()
def testOptionalNonBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=False)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()