Merge branch 'conversion-support'
This commit is contained in:
commit
1f1d3dd9bb
8 changed files with 373 additions and 146 deletions
|
@ -1,6 +1,4 @@
|
||||||
from .record import *
|
from collections import Callable
|
||||||
from .reader import RecordReader
|
|
||||||
import collections
|
|
||||||
|
|
||||||
VERSION = (0, 2012, 0)
|
VERSION = (0, 2012, 0)
|
||||||
|
|
||||||
|
@ -14,77 +12,55 @@ RECORD_TYPES = [
|
||||||
'OptionalTotalRecord',
|
'OptionalTotalRecord',
|
||||||
'StateTotalRecord',
|
'StateTotalRecord',
|
||||||
'FinalRecord'
|
'FinalRecord'
|
||||||
]
|
]
|
||||||
|
|
||||||
def test():
|
|
||||||
from . import record, model
|
|
||||||
from .fields import ValidationError
|
|
||||||
for rname in RECORD_TYPES:
|
|
||||||
inst = record.__dict__[rname]()
|
|
||||||
try:
|
|
||||||
output_length = len(inst.output())
|
|
||||||
except ValidationError as e:
|
|
||||||
print(e.msg, type(inst), inst.record_identifier)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(type(inst), inst.record_identifier, output_length)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dump():
|
def get_record_types():
|
||||||
import record, io
|
|
||||||
records = [
|
|
||||||
record.SubmitterRecord(),
|
|
||||||
record.EmployerRecord(),
|
|
||||||
record.EmployeeWageRecord(),
|
|
||||||
]
|
|
||||||
out = io.BytesIO()
|
|
||||||
dump(records, out, None)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_order():
|
|
||||||
from . import record
|
|
||||||
records = [
|
|
||||||
record.SubmitterRecord(),
|
|
||||||
record.EmployerRecord(),
|
|
||||||
record.EmployeeWageRecord(),
|
|
||||||
record.TotalRecord(),
|
|
||||||
record.FinalRecord(),
|
|
||||||
]
|
|
||||||
validate_record_order(records)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load(fp):
|
|
||||||
return load(fp)
|
|
||||||
|
|
||||||
def load(fp):
|
|
||||||
# BUILD LIST OF RECORD TYPES
|
|
||||||
from . import record
|
from . import record
|
||||||
types = {}
|
types = {}
|
||||||
for r in RECORD_TYPES:
|
for r in RECORD_TYPES:
|
||||||
klass = record.__dict__[r]
|
klass = record.__dict__[r]
|
||||||
types[klass.record_identifier] = klass
|
types[klass.record_identifier] = klass
|
||||||
|
return types
|
||||||
|
|
||||||
|
|
||||||
|
def load(fp, record_types):
|
||||||
|
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
|
||||||
|
assert(len(distinct_identifier_lengths) == 1)
|
||||||
|
ident_length = list(distinct_identifier_lengths)[0]
|
||||||
|
|
||||||
|
# Add aliases for the record types based on their record_identifier since that's all
|
||||||
|
# we have to work with with the e1099 data.
|
||||||
|
record_types_by_ident = {}
|
||||||
|
for k in record_types:
|
||||||
|
record_type = record_types[k]
|
||||||
|
record_identifier = record_type.record_identifier
|
||||||
|
record_types_by_ident[record_identifier] = record_type
|
||||||
|
|
||||||
# PARSE DATA INTO RECORDS AND YIELD THEM
|
# PARSE DATA INTO RECORDS AND YIELD THEM
|
||||||
while fp.tell() < fp.len:
|
while True:
|
||||||
record_ident = fp.read(2)
|
record_ident = fp.read(ident_length)
|
||||||
if record_ident in types:
|
if not record_ident:
|
||||||
record = types[record_ident]()
|
break
|
||||||
|
if record_ident in record_types_by_ident:
|
||||||
|
record = record_types_by_ident[record_ident]()
|
||||||
record.read(fp)
|
record.read(fp)
|
||||||
yield record
|
yield record
|
||||||
|
|
||||||
def loads(s):
|
|
||||||
|
def loads(s, record_types=get_record_types()):
|
||||||
import io
|
import io
|
||||||
fp = io.BytesIO(s)
|
fp = io.BytesIO(s)
|
||||||
return load(fp)
|
return load(fp, record_types)
|
||||||
|
|
||||||
|
|
||||||
def dump(records, fp, delim=None):
|
def dump(fp, records, delim=None):
|
||||||
for r in records:
|
for r in records:
|
||||||
fp.write(r.output())
|
fp.write(r.output())
|
||||||
if delim:
|
if delim:
|
||||||
fp.write(delim)
|
fp.write(delim)
|
||||||
|
|
||||||
|
|
||||||
def dumps(records, delim=None):
|
def dumps(records, delim=None):
|
||||||
import io
|
import io
|
||||||
fp = io.BytesIO()
|
fp = io.BytesIO()
|
||||||
|
@ -92,15 +68,15 @@ def dumps(records, delim=None):
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
return fp.read()
|
return fp.read()
|
||||||
|
|
||||||
|
|
||||||
def json_dumps(records):
|
def json_dumps(records):
|
||||||
import json
|
import json
|
||||||
from . import model
|
|
||||||
import decimal
|
import decimal
|
||||||
|
|
||||||
class JSONEncoder(json.JSONEncoder):
|
class JSONEncoder(json.JSONEncoder):
|
||||||
|
|
||||||
def default(self, o):
|
def default(self, o):
|
||||||
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), collections.Callable):
|
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
|
||||||
return o.toJSON()
|
return o.toJSON()
|
||||||
|
|
||||||
if type(o) is bytes:
|
if type(o) is bytes:
|
||||||
|
@ -111,37 +87,76 @@ def json_dumps(records):
|
||||||
|
|
||||||
return super(JSONEncoder, self).default(o)
|
return super(JSONEncoder, self).default(o)
|
||||||
|
|
||||||
return json.dumps(records, cls=JSONEncoder, indent=2)
|
return json.dumps(list(records), cls=JSONEncoder, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def json_loads(s, record_classes):
|
def json_dump(fp, records):
|
||||||
|
fp.write(json_dumps(records))
|
||||||
|
|
||||||
|
|
||||||
|
def json_loads(s, record_types):
|
||||||
import json
|
import json
|
||||||
from . import fields
|
from . import fields
|
||||||
import decimal
|
import decimal
|
||||||
import re
|
|
||||||
|
|
||||||
if not isinstance(record_classes, dict):
|
if not isinstance(record_types, dict):
|
||||||
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
|
record_types = dict([ (x.__name__, x) for x in record_types])
|
||||||
|
|
||||||
def object_hook(o):
|
def object_hook(o):
|
||||||
if '__class__' in o:
|
if '__class__' in o:
|
||||||
klass = o['__class__']
|
klass = o['__class__']
|
||||||
|
if klass in record_types:
|
||||||
if klass in record_classes:
|
record = record_types[klass]()
|
||||||
return record_classes[klass]().fromJSON(o)
|
record.fromJSON(o)
|
||||||
|
return record
|
||||||
elif hasattr(fields, klass):
|
elif hasattr(fields, klass):
|
||||||
return getattr(fields, klass)().fromJSON(o)
|
return getattr(fields, klass)().fromJSON(o)
|
||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
#print "OBJECTHOOK", str(o)
|
|
||||||
#return {'object_hook':str(o)}
|
|
||||||
#def default(self, o):
|
|
||||||
# return super(JSONDecoder, self).default(o)
|
|
||||||
|
|
||||||
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
|
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
|
||||||
|
|
||||||
|
def json_load(fp, record_types):
|
||||||
|
return json_loads(fp.read(), record_types)
|
||||||
|
|
||||||
|
def text_dump(fp, records):
|
||||||
|
for r in records:
|
||||||
|
fp.write(r.output(format='text').encode('ascii'))
|
||||||
|
|
||||||
|
|
||||||
|
def text_dumps(records):
|
||||||
|
import io
|
||||||
|
fp = io.BytesIO()
|
||||||
|
text_dump(fp, records)
|
||||||
|
fp.seek(0)
|
||||||
|
return fp.read()
|
||||||
|
|
||||||
|
|
||||||
|
def text_load(fp, record_classes):
|
||||||
|
records = []
|
||||||
|
current_record = None
|
||||||
|
|
||||||
|
if not isinstance(record_classes, dict):
|
||||||
|
record_classes = dict([ (x.__name__, x) for x in record_classes])
|
||||||
|
|
||||||
|
while True: #fp.readable():
|
||||||
|
line = fp.readline().decode('ascii')
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
if line.startswith('---'):
|
||||||
|
record_name = line.strip('---').strip()
|
||||||
|
current_record = record_classes[record_name]()
|
||||||
|
records.append(current_record)
|
||||||
|
elif ':' in line:
|
||||||
|
field, value = [x.strip() for x in line.split(':')]
|
||||||
|
current_record.set_field_value(field, value)
|
||||||
|
return records
|
||||||
|
|
||||||
|
def text_loads(s, record_classes):
|
||||||
|
import io
|
||||||
|
fp = io.BytesIO(s)
|
||||||
|
return text_load(fp, record_classes)
|
||||||
|
|
||||||
|
|
||||||
# THIS WAS IN CONTROLLER, BUT UNLESS WE
|
# THIS WAS IN CONTROLLER, BUT UNLESS WE
|
||||||
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
|
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
|
||||||
# TO JUST KEEP IT IN HERE.
|
# TO JUST KEEP IT IN HERE.
|
||||||
|
@ -153,7 +168,7 @@ def validate_required_records(records):
|
||||||
klass = record.__dict__[r]
|
klass = record.__dict__[r]
|
||||||
if klass.required:
|
if klass.required:
|
||||||
req_types.append(klass.__name__)
|
req_types.append(klass.__name__)
|
||||||
|
|
||||||
while req_types:
|
while req_types:
|
||||||
req = req_types[0]
|
req = req_types[0]
|
||||||
if req not in types:
|
if req not in types:
|
||||||
|
@ -162,10 +177,11 @@ def validate_required_records(records):
|
||||||
else:
|
else:
|
||||||
req_types.remove(req)
|
req_types.remove(req)
|
||||||
|
|
||||||
|
|
||||||
def validate_record_order(records):
|
def validate_record_order(records):
|
||||||
from . import record
|
from . import record
|
||||||
from .fields import ValidationError
|
from .fields import ValidationError
|
||||||
|
|
||||||
# 1st record must be SubmitterRecord
|
# 1st record must be SubmitterRecord
|
||||||
if not isinstance(records[0], record.SubmitterRecord):
|
if not isinstance(records[0], record.SubmitterRecord):
|
||||||
raise ValidationError("First record must be SubmitterRecord")
|
raise ValidationError("First record must be SubmitterRecord")
|
||||||
|
@ -211,15 +227,10 @@ def test_unique_fields():
|
||||||
r1 = EmployeeWageRecord()
|
r1 = EmployeeWageRecord()
|
||||||
|
|
||||||
r1.employee_first_name.value = "John Johnson"
|
r1.employee_first_name.value = "John Johnson"
|
||||||
|
|
||||||
r2 = EmployeeWageRecord()
|
r2 = EmployeeWageRecord()
|
||||||
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
|
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
|
||||||
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
|
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
|
||||||
|
|
||||||
if r1.employee_first_name.value == r2.employee_first_name.value:
|
if r1.employee_first_name.value == r2.employee_first_name.value:
|
||||||
raise ValidationError("Horrible problem involving shared values across records")
|
raise ValidationError("Horrible problem involving shared values across records")
|
||||||
|
|
||||||
#def state_postal_code(state_abbr):
|
|
||||||
# import enums
|
|
||||||
# return enums.state_postal_numeric[ state_abbr.upper() ]
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import decimal, datetime
|
import decimal, datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
from six import string_types
|
||||||
from . import enums
|
from . import enums
|
||||||
|
|
||||||
|
def is_blank_space(val):
|
||||||
|
return len(val.strip()) == 0
|
||||||
|
|
||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
def __init__(self, msg, field=None):
|
def __init__(self, msg, field=None):
|
||||||
|
@ -17,6 +20,7 @@ class ValidationError(Exception):
|
||||||
|
|
||||||
class Field(object):
|
class Field(object):
|
||||||
creation_counter = 0
|
creation_counter = 0
|
||||||
|
is_read_only = False
|
||||||
|
|
||||||
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
|
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -29,10 +33,10 @@ class Field(object):
|
||||||
Field.creation_counter += 1
|
Field.creation_counter += 1
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
raise NotImplemented
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
raise NotImplemented
|
raise NotImplementedError
|
||||||
|
|
||||||
def __setvalue(self, value):
|
def __setvalue(self, value):
|
||||||
self._value = value
|
self._value = value
|
||||||
|
@ -77,7 +81,7 @@ class Field(object):
|
||||||
required=o['required'],
|
required=o['required'],
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(o['value'], str) and re.match('^\d*\.\d*$', o['value']):
|
if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
|
||||||
o['value'] = decimal.Decimal(o['value'])
|
o['value'] = decimal.Decimal(o['value'])
|
||||||
|
|
||||||
self.value = o['value']
|
self.value = o['value']
|
||||||
|
@ -164,9 +168,10 @@ class StateField(TextField):
|
||||||
else:
|
else:
|
||||||
self.value = s
|
self.value = s
|
||||||
|
|
||||||
|
|
||||||
class EmailField(TextField):
|
class EmailField(TextField):
|
||||||
def __init__(self, name=None, required=True, max_length=None):
|
def __init__(self, name=None, required=True, max_length=None):
|
||||||
return super(EmailField, self).__init__(name=name, max_length=max_length,
|
super(EmailField, self).__init__(name=name, max_length=max_length,
|
||||||
required=required, uppercase=False)
|
required=required, uppercase=False)
|
||||||
|
|
||||||
class IntegerField(TextField):
|
class IntegerField(TextField):
|
||||||
|
@ -183,7 +188,10 @@ class IntegerField(TextField):
|
||||||
return value.zfill(self.max_length)[:self.max_length]
|
return value.zfill(self.max_length)[:self.max_length]
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
self.value = int(s)
|
if not is_blank_space(s):
|
||||||
|
self.value = int(s)
|
||||||
|
else:
|
||||||
|
self.value = 0
|
||||||
|
|
||||||
|
|
||||||
class StaticField(TextField):
|
class StaticField(TextField):
|
||||||
|
@ -197,8 +205,10 @@ class StaticField(TextField):
|
||||||
|
|
||||||
|
|
||||||
class BlankField(TextField):
|
class BlankField(TextField):
|
||||||
|
is_read_only = True
|
||||||
|
|
||||||
def __init__(self, name=None, max_length=0, required=False):
|
def __init__(self, name=None, max_length=0, required=False):
|
||||||
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
|
super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return b' ' * self.max_length
|
return b' ' * self.max_length
|
||||||
|
@ -208,13 +218,17 @@ class BlankField(TextField):
|
||||||
|
|
||||||
|
|
||||||
class ZeroField(BlankField):
|
class ZeroField(BlankField):
|
||||||
|
is_read_only = True
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return b'0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
|
|
||||||
class CRLFField(TextField):
|
class CRLFField(TextField):
|
||||||
|
is_read_only = True
|
||||||
|
|
||||||
def __init__(self, name=None, required=False):
|
def __init__(self, name=None, required=False):
|
||||||
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
|
super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
|
||||||
|
|
||||||
def __setvalue(self, value):
|
def __setvalue(self, value):
|
||||||
self._value = value
|
self._value = value
|
||||||
|
@ -262,12 +276,27 @@ class MoneyField(Field):
|
||||||
return formatted[:self.max_length]
|
return formatted[:self.max_length]
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
if not is_blank_space(s):
|
||||||
|
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
||||||
|
else:
|
||||||
|
self.value = decimal.Decimal(0.0)
|
||||||
|
|
||||||
|
def __setvalue(self, value):
|
||||||
|
new_value = value
|
||||||
|
if isinstance(new_value, string_types):
|
||||||
|
new_value = decimal.Decimal(new_value or '0')
|
||||||
|
if '.' not in value: # must be cents?
|
||||||
|
new_value *= decimal.Decimal('100.')
|
||||||
|
self._value = new_value
|
||||||
|
|
||||||
|
def __getvalue(self):
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
value = property(__getvalue, __setvalue)
|
||||||
|
|
||||||
class DateField(TextField):
|
class DateField(TextField):
|
||||||
def __init__(self, name=None, required=True, value=None):
|
def __init__(self, name=None, required=True, value=None):
|
||||||
super(TextField, self).__init__(name=name, required=required, max_length=8)
|
super(DateField, self).__init__(name=name, required=required, max_length=8)
|
||||||
if value:
|
if value:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
@ -298,7 +327,7 @@ class DateField(TextField):
|
||||||
|
|
||||||
class MonthYearField(TextField):
|
class MonthYearField(TextField):
|
||||||
def __init__(self, name=None, required=True, value=None):
|
def __init__(self, name=None, required=True, value=None):
|
||||||
super(TextField, self).__init__(name=name, required=required, max_length=6)
|
super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
|
||||||
if value:
|
if value:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,15 @@ import collections
|
||||||
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
|
record_length = -1
|
||||||
record_identifier = ' '
|
record_identifier = ' '
|
||||||
required = False
|
required = False
|
||||||
target_size = 512
|
target_size = 512
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
if self.record_length == -1:
|
||||||
|
raise ValueError(self.record_length)
|
||||||
|
|
||||||
for (key, value) in list(self.__class__.__dict__.items()):
|
for (key, value) in list(self.__class__.__dict__.items()):
|
||||||
if isinstance(value, Field):
|
if isinstance(value, Field):
|
||||||
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
|
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
|
||||||
|
@ -19,15 +23,22 @@ class Model(object):
|
||||||
if not src_field.name:
|
if not src_field.name:
|
||||||
setattr(src_field, 'name', key)
|
setattr(src_field, 'name', key)
|
||||||
setattr(src_field, 'parent_name', self.__class__.__name__)
|
setattr(src_field, 'parent_name', self.__class__.__name__)
|
||||||
self.__dict__[key] = copy.copy(src_field)
|
new_field_instance = copy.copy(src_field)
|
||||||
|
new_field_instance._orig_value = None
|
||||||
|
new_field_instance._value = None
|
||||||
|
self.__dict__[key] = new_field_instance
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if hasattr(self, key) and isinstance(getattr(self, key), Field):
|
if hasattr(self, key) and isinstance(getattr(self, key), Field):
|
||||||
getattr(self, key).value = value
|
self.set_field_value(key, value)
|
||||||
else:
|
else:
|
||||||
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
|
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
def set_field_value(self, field_name, value):
|
||||||
|
print('setfieldval: ' + field_name + ' ' + value)
|
||||||
|
getattr(self, field_name).value = value
|
||||||
|
|
||||||
def get_fields(self):
|
def get_fields(self):
|
||||||
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
|
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
|
||||||
identifier.value = self.record_identifier
|
identifier.value = self.record_identifier
|
||||||
|
@ -55,18 +66,28 @@ class Model(object):
|
||||||
if isinstance(custom_validator, collections.Callable):
|
if isinstance(custom_validator, collections.Callable):
|
||||||
custom_validator(f)
|
custom_validator(f)
|
||||||
|
|
||||||
def output(self):
|
def output(self, format='binary'):
|
||||||
|
if format == 'text':
|
||||||
|
return self.output_text()
|
||||||
|
return self.output_efile()
|
||||||
|
|
||||||
|
def output_efile(self):
|
||||||
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
|
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
|
||||||
|
if self.record_length < 0 or len(result) != self.record_length:
|
||||||
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
|
||||||
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def output_text(self):
|
||||||
|
fields = self.get_sorted_fields()[1:] # skip record identifier
|
||||||
|
fields = [field for field in fields if not field.is_read_only]
|
||||||
|
header = ''.join(['---', self.__class__.__name__, '\n'])
|
||||||
|
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
|
||||||
|
|
||||||
def read(self, fp):
|
def read(self, fp):
|
||||||
# Skip the first record, since that's an identifier
|
# Skip the first record, since that's an identifier
|
||||||
for field in self.get_sorted_fields()[1:]:
|
for field in self.get_sorted_fields()[1:]:
|
||||||
field.read(fp)
|
field.read(fp)
|
||||||
|
print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value)
|
||||||
|
|
||||||
def toJSON(self):
|
def toJSON(self):
|
||||||
return {
|
return {
|
||||||
|
@ -77,6 +98,9 @@ class Model(object):
|
||||||
def fromJSON(self, o):
|
def fromJSON(self, o):
|
||||||
fields = o['fields']
|
fields = o['fields']
|
||||||
|
|
||||||
|
identifier, fields = fields[0], fields[1:]
|
||||||
|
assert(identifier.value == self.record_identifier)
|
||||||
|
|
||||||
for f in fields:
|
for f in fields:
|
||||||
target = self.__dict__[f.name]
|
target = self.__dict__[f.name]
|
||||||
|
|
||||||
|
@ -84,7 +108,7 @@ class Model(object):
|
||||||
or target.max_length != f.max_length):
|
or target.max_length != f.max_length):
|
||||||
print("Warning: value mismatch on import")
|
print("Warning: value mismatch on import")
|
||||||
|
|
||||||
target._value = f._value
|
target.value = f.value
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import re
|
||||||
|
|
||||||
|
|
||||||
class ClassEntryCommentSequence(object):
|
class ClassEntryCommentSequence(object):
|
||||||
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
|
re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
|
||||||
|
|
||||||
def __init__(self, classname, line):
|
def __init__(self, classname, line):
|
||||||
self.classname = classname,
|
self.classname = classname,
|
||||||
|
@ -72,7 +72,7 @@ class ModelDefParser(object):
|
||||||
|
|
||||||
classmatch = self.re_classdef.match(line)
|
classmatch = self.re_classdef.match(line)
|
||||||
if classmatch:
|
if classmatch:
|
||||||
classname, subclass = classmatch.groups()
|
classname, _subclass = classmatch.groups()
|
||||||
self.beginclass(classname, self.line)
|
self.beginclass(classname, self.line)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ class RangeToken(BaseToken):
|
||||||
|
|
||||||
|
|
||||||
class NumericToken(BaseToken):
|
class NumericToken(BaseToken):
|
||||||
regexp = re.compile('^(\d+)$')
|
regexp = re.compile(r'^(\d+)$')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
|
|
75
scripts/pyaccuwage-convert
Executable file
75
scripts/pyaccuwage-convert
Executable file
|
@ -0,0 +1,75 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import pyaccuwage
|
||||||
|
import argparse
|
||||||
|
import os, os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
"""
|
||||||
|
Command line tool for converting IRS e-file fixed field records
|
||||||
|
to/from JSON or a simple text format.
|
||||||
|
|
||||||
|
Attempts to load record types from a python module in the current working
|
||||||
|
directory named record_types.py
|
||||||
|
|
||||||
|
The module must export a RECORD_TYPES list with the names of the classes to
|
||||||
|
import as valid record types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_record_types():
|
||||||
|
try:
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
import record_types
|
||||||
|
r = {}
|
||||||
|
for record_type in record_types.RECORD_TYPES:
|
||||||
|
r[record_type] = getattr(record_types, record_type)
|
||||||
|
return r
|
||||||
|
except ImportError:
|
||||||
|
print('warning: using default record types (failed to import record_types.py)')
|
||||||
|
return pyaccuwage.get_record_types()
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(fd, filename, record_types):
|
||||||
|
filename, extension = os.path.splitext(filename)
|
||||||
|
if extension == '.json':
|
||||||
|
return pyaccuwage.json_load(fd, record_types)
|
||||||
|
elif extension == '.txt':
|
||||||
|
return pyaccuwage.text_load(fd, record_types)
|
||||||
|
else:
|
||||||
|
return pyaccuwage.load(fd, record_types)
|
||||||
|
|
||||||
|
|
||||||
|
def write_file(outfile, filename, records):
|
||||||
|
filename, extension = os.path.splitext(filename)
|
||||||
|
if extension == '.json':
|
||||||
|
pyaccuwage.json_dump(outfile, records)
|
||||||
|
elif extension == '.txt':
|
||||||
|
pyaccuwage.text_dump(outfile, records)
|
||||||
|
else:
|
||||||
|
pyaccuwage.dump(outfile, records)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert accuwage efile data between different formats."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-i", '--input',
|
||||||
|
nargs=1,
|
||||||
|
required=True,
|
||||||
|
metavar="file",
|
||||||
|
type=argparse.FileType('r'),
|
||||||
|
help="Source file to convert")
|
||||||
|
|
||||||
|
parser.add_argument("-o", "--output",
|
||||||
|
nargs=1,
|
||||||
|
required=True,
|
||||||
|
metavar="file",
|
||||||
|
type=argparse.FileType('w'),
|
||||||
|
help="Destination file to output")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
in_file = args.input[0]
|
||||||
|
out_file = args.output[0]
|
||||||
|
|
||||||
|
records = list(read_file(in_file, in_file.name, get_record_types()))
|
||||||
|
write_file(out_file, out_file.name, records)
|
|
@ -1,17 +1,15 @@
|
||||||
import unittest
|
import unittest
|
||||||
import decimal
|
|
||||||
from pyaccuwage.fields import TextField
|
from pyaccuwage.fields import TextField
|
||||||
from pyaccuwage.fields import IntegerField
|
# from pyaccuwage.fields import IntegerField
|
||||||
from pyaccuwage.fields import StateField
|
# from pyaccuwage.fields import StateField
|
||||||
from pyaccuwage.fields import BlankField
|
# from pyaccuwage.fields import BlankField
|
||||||
from pyaccuwage.fields import ZeroField
|
# from pyaccuwage.fields import ZeroField
|
||||||
from pyaccuwage.fields import MoneyField
|
# from pyaccuwage.fields import MoneyField
|
||||||
from pyaccuwage.fields import ValidationError
|
from pyaccuwage.fields import ValidationError
|
||||||
from pyaccuwage.model import Model
|
from pyaccuwage.model import Model
|
||||||
|
|
||||||
|
|
||||||
class TestTextField(unittest.TestCase):
|
class TestTextField(unittest.TestCase):
|
||||||
|
|
||||||
def testStringShortOptional(self):
|
def testStringShortOptional(self):
|
||||||
field = TextField(max_length=6, required=False)
|
field = TextField(max_length=6, required=False)
|
||||||
field.validate() # optional
|
field.validate() # optional
|
||||||
|
@ -30,43 +28,6 @@ class TestTextField(unittest.TestCase):
|
||||||
def testStringLongOptional(self):
|
def testStringLongOptional(self):
|
||||||
field = TextField(max_length=6, required=False)
|
field = TextField(max_length=6, required=False)
|
||||||
field.value = 'Hello, World!' # too long
|
field.value = 'Hello, World!' # too long
|
||||||
self.assertEqual(len(field.get_data()), field.max_length)
|
data = field.get_data()
|
||||||
|
self.assertEqual(len(data), field.max_length)
|
||||||
|
self.assertEqual(data, b'HELLO,')
|
||||||
class TestModelOutput(unittest.TestCase):
|
|
||||||
class TestModel(Model):
|
|
||||||
record_length = 128
|
|
||||||
record_identifier = 'TEST' # 4 bytes
|
|
||||||
field1 = TextField(max_length=16)
|
|
||||||
field2 = IntegerField(max_length=16)
|
|
||||||
blank1 = BlankField(max_length=16)
|
|
||||||
zero1 = ZeroField(max_length=16)
|
|
||||||
money = MoneyField(max_length=32)
|
|
||||||
state_txt = StateField()
|
|
||||||
state_num = StateField(use_numeric=True)
|
|
||||||
blank2 = BlankField(max_length=24)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model = TestModelOutput.TestModel()
|
|
||||||
|
|
||||||
def testModelOutput(self):
|
|
||||||
model = self.model
|
|
||||||
model.field1.value = 'Hello, sir!'
|
|
||||||
model.field2.value = 12345
|
|
||||||
model.money.value = decimal.Decimal('1234.56')
|
|
||||||
model.state_txt.value = 'IA'
|
|
||||||
model.state_num.value = 'IA'
|
|
||||||
|
|
||||||
expected = b''.join([
|
|
||||||
b'TEST',
|
|
||||||
b'HELLO, SIR!'.ljust(16),
|
|
||||||
b'12345'.zfill(16),
|
|
||||||
b' ' * 16,
|
|
||||||
b'0' * 16,
|
|
||||||
b'123456'.zfill(32),
|
|
||||||
b'IA',
|
|
||||||
b'19',
|
|
||||||
b' ' * 24,
|
|
||||||
])
|
|
||||||
|
|
||||||
self.assertEqual(model.output(), expected)
|
|
||||||
|
|
127
tests/test_records.py
Normal file
127
tests/test_records.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import unittest
|
||||||
|
import decimal
|
||||||
|
import pyaccuwage
|
||||||
|
from pyaccuwage.fields import BlankField
|
||||||
|
from pyaccuwage.fields import IntegerField
|
||||||
|
from pyaccuwage.fields import MoneyField
|
||||||
|
from pyaccuwage.fields import StateField
|
||||||
|
from pyaccuwage.fields import TextField
|
||||||
|
from pyaccuwage.fields import ZeroField
|
||||||
|
from pyaccuwage.model import Model
|
||||||
|
|
||||||
|
class TestModelOutput(unittest.TestCase):
|
||||||
|
class TestModel(Model):
|
||||||
|
record_length = 128
|
||||||
|
record_identifier = 'TEST' # 4 bytes
|
||||||
|
field1 = TextField(max_length=16)
|
||||||
|
field2 = IntegerField(max_length=16)
|
||||||
|
blank1 = BlankField(max_length=16)
|
||||||
|
zero1 = ZeroField(max_length=16)
|
||||||
|
money = MoneyField(max_length=32)
|
||||||
|
state_txt = StateField()
|
||||||
|
state_num = StateField(use_numeric=True)
|
||||||
|
blank2 = BlankField(max_length=24)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model = TestModelOutput.TestModel()
|
||||||
|
|
||||||
|
def testModelBinaryOutput(self):
|
||||||
|
model = self.model
|
||||||
|
model.field1.value = 'Hello, sir!'
|
||||||
|
model.field2.value = 12345
|
||||||
|
model.money.value = decimal.Decimal('3133.77')
|
||||||
|
model.state_txt.value = 'IA'
|
||||||
|
model.state_num.value = 'IA'
|
||||||
|
|
||||||
|
expected = b''.join([
|
||||||
|
b'TEST',
|
||||||
|
b'HELLO, SIR!'.ljust(16),
|
||||||
|
b'12345'.zfill(16),
|
||||||
|
b' ' * 16,
|
||||||
|
b'0' * 16,
|
||||||
|
b'313377'.zfill(32),
|
||||||
|
b'IA',
|
||||||
|
b'19',
|
||||||
|
b' ' * 24,
|
||||||
|
])
|
||||||
|
|
||||||
|
output = model.output()
|
||||||
|
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
|
||||||
|
self.assertEqual(output, expected)
|
||||||
|
|
||||||
|
def testModelTextOutput(self):
|
||||||
|
model = self.model
|
||||||
|
model.field1.value = 'Hello, sir!'
|
||||||
|
model.field2.value = 12345
|
||||||
|
model.money.value = decimal.Decimal('3133.77')
|
||||||
|
model.state_txt.value = 'IA'
|
||||||
|
model.state_num.value = 'IA'
|
||||||
|
output = model.output(format='text')
|
||||||
|
|
||||||
|
self.assertEqual(output, '''---TestModel
|
||||||
|
field1: Hello, sir!
|
||||||
|
field2: 12345
|
||||||
|
money: 3133.77
|
||||||
|
state_txt: IA
|
||||||
|
state_num: IA
|
||||||
|
|
||||||
|
''')
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileFormats(unittest.TestCase):
|
||||||
|
class TestModelA(pyaccuwage.model.Model):
|
||||||
|
record_length = 128
|
||||||
|
record_identifier = 'A' # 1 byte
|
||||||
|
field1 = TextField(max_length=16)
|
||||||
|
field2 = IntegerField(max_length=16)
|
||||||
|
blank1 = BlankField(max_length=16)
|
||||||
|
zero1 = ZeroField(max_length=16)
|
||||||
|
money = MoneyField(max_length=32)
|
||||||
|
state_txt = StateField()
|
||||||
|
state_num = StateField(use_numeric=True)
|
||||||
|
blank2 = BlankField(max_length=27)
|
||||||
|
|
||||||
|
class TestModelB(pyaccuwage.model.Model):
|
||||||
|
record_length = 128
|
||||||
|
record_identifier = 'B' # 1 byte
|
||||||
|
zero1 = ZeroField(max_length=32)
|
||||||
|
text1 = TextField(max_length=71)
|
||||||
|
blank2 = BlankField(max_length=24)
|
||||||
|
|
||||||
|
record_types = [TestModelA, TestModelB]
|
||||||
|
|
||||||
|
def createExampleRecords(self):
|
||||||
|
model_a = TestFileFormats.TestModelA()
|
||||||
|
model_a.field1.value = 'I am model a'
|
||||||
|
model_a.field2.value = 5522
|
||||||
|
model_a.money.value = decimal.Decimal('23.00')
|
||||||
|
model_a.state_txt.value = 'IA'
|
||||||
|
model_a.state_num.value = 'IA'
|
||||||
|
|
||||||
|
model_b = TestFileFormats.TestModelB()
|
||||||
|
model_b.text1.value = 'hey I am model b and I have a big text field'
|
||||||
|
|
||||||
|
return [
|
||||||
|
model_a,
|
||||||
|
model_b,
|
||||||
|
]
|
||||||
|
|
||||||
|
def testJSONSerialization(self):
|
||||||
|
records = self.createExampleRecords()
|
||||||
|
record_types = self.record_types
|
||||||
|
json_data = pyaccuwage.json_dumps(records)
|
||||||
|
records_loaded = pyaccuwage.json_loads(json_data, record_types)
|
||||||
|
|
||||||
|
original_bytes = pyaccuwage.dumps(records)
|
||||||
|
reloaded_bytes = pyaccuwage.dumps(records_loaded)
|
||||||
|
self.assertEqual(original_bytes, reloaded_bytes)
|
||||||
|
|
||||||
|
def testTxtSerialization(self):
|
||||||
|
records = self.createExampleRecords()
|
||||||
|
record_types = self.record_types
|
||||||
|
text_data = pyaccuwage.text_dumps(records)
|
||||||
|
records_loaded = pyaccuwage.text_loads(text_data, record_types)
|
||||||
|
|
||||||
|
original_bytes = pyaccuwage.dumps(records)
|
||||||
|
reloaded_bytes = pyaccuwage.dumps(records_loaded)
|
||||||
|
self.assertEqual(original_bytes, reloaded_bytes)
|
Loading…
Add table
Add a link
Reference in a new issue