add format interchange functions, add tests, fix stuff

This commit is contained in:
Mark Riedesel 2020-06-12 13:07:41 -05:00
parent 6af5067fca
commit 8f86f76167
7 changed files with 298 additions and 146 deletions

View file

@ -1,6 +1,4 @@
from .record import * from collections import Callable
from .reader import RecordReader
import collections
VERSION = (0, 2012, 0) VERSION = (0, 2012, 0)
@ -14,77 +12,55 @@ RECORD_TYPES = [
'OptionalTotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'StateTotalRecord',
'FinalRecord' 'FinalRecord'
] ]
def test():
from . import record, model
from .fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError as e:
print(e.msg, type(inst), inst.record_identifier)
continue
print(type(inst), inst.record_identifier, output_length)
def test_dump(): def get_record_types():
import record, io
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = io.BytesIO()
dump(records, out, None)
return out
def test_record_order():
from . import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
from . import record from . import record
types = {} types = {}
for r in RECORD_TYPES: for r in RECORD_TYPES:
klass = record.__dict__[r] klass = record.__dict__[r]
types[klass.record_identifier] = klass types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM # PARSE DATA INTO RECORDS AND YIELD THEM
while fp.tell() < fp.len: while True:
record_ident = fp.read(2) record_ident = fp.read(ident_length)
if record_ident in types: if not record_ident:
record = types[record_ident]() break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp) record.read(fp)
yield record yield record
def loads(s):
def loads(s, record_types=get_record_types()):
import io import io
fp = io.BytesIO(s) fp = io.BytesIO(s)
return load(fp) return load(fp, record_types)
def dump(records, fp, delim=None): def dump(fp, records, delim=None):
for r in records: for r in records:
fp.write(r.output()) fp.write(r.output())
if delim: if delim:
fp.write(delim) fp.write(delim)
def dumps(records, delim=None): def dumps(records, delim=None):
import io import io
fp = io.BytesIO() fp = io.BytesIO()
@ -92,15 +68,15 @@ def dumps(records, delim=None):
fp.seek(0) fp.seek(0)
return fp.read() return fp.read()
def json_dumps(records): def json_dumps(records):
import json import json
from . import model
import decimal import decimal
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), collections.Callable): if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
return o.toJSON() return o.toJSON()
if type(o) is bytes: if type(o) is bytes:
@ -111,37 +87,76 @@ def json_dumps(records):
return super(JSONEncoder, self).default(o) return super(JSONEncoder, self).default(o)
return json.dumps(records, cls=JSONEncoder, indent=2) return json.dumps(list(records), cls=JSONEncoder, indent=2)
def json_loads(s, record_classes): def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json import json
from . import fields from . import fields
import decimal import decimal
import re
if not isinstance(record_classes, dict): if not isinstance(record_types, dict):
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes]) record_types = dict([ (x.__name__, x) for x in record_types])
def object_hook(o): def object_hook(o):
if '__class__' in o: if '__class__' in o:
klass = o['__class__'] klass = o['__class__']
if klass in record_types:
if klass in record_classes: record = record_types[klass]()
return record_classes[klass]().fromJSON(o) record.fromJSON(o)
return record
elif hasattr(fields, klass): elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o) return getattr(fields, klass)().fromJSON(o)
return o return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook) return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE # THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER # REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE. # TO JUST KEEP IT IN HERE.
@ -162,6 +177,7 @@ def validate_required_records(records):
else: else:
req_types.remove(req) req_types.remove(req)
def validate_record_order(records): def validate_record_order(records):
from . import record from . import record
from .fields import ValidationError from .fields import ValidationError
@ -218,8 +234,3 @@ def test_unique_fields():
if r1.employee_first_name.value == r2.employee_first_name.value: if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records") raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -1,7 +1,10 @@
import decimal, datetime import decimal, datetime
import inspect import inspect
from six import string_types
from . import enums from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
class ValidationError(Exception): class ValidationError(Exception):
def __init__(self, msg, field=None): def __init__(self, msg, field=None):
@ -17,6 +20,7 @@ class ValidationError(Exception):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
is_read_only = False
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
self.name = name self.name = name
@ -29,10 +33,10 @@ class Field(object):
Field.creation_counter += 1 Field.creation_counter += 1
def validate(self): def validate(self):
raise NotImplemented raise NotImplementedError
def get_data(self): def get_data(self):
raise NotImplemented raise NotImplementedError
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -77,7 +81,7 @@ class Field(object):
required=o['required'], required=o['required'],
) )
if isinstance(o['value'], str) and re.match('^\d*\.\d*$', o['value']): if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value']) o['value'] = decimal.Decimal(o['value'])
self.value = o['value'] self.value = o['value']
@ -164,9 +168,10 @@ class StateField(TextField):
else: else:
self.value = s self.value = s
class EmailField(TextField): class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None): def __init__(self, name=None, required=True, max_length=None):
return super(EmailField, self).__init__(name=name, max_length=max_length, super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False) required=required, uppercase=False)
class IntegerField(TextField): class IntegerField(TextField):
@ -183,7 +188,10 @@ class IntegerField(TextField):
return value.zfill(self.max_length)[:self.max_length] return value.zfill(self.max_length)[:self.max_length]
def parse(self, s): def parse(self, s):
if not is_blank_space(s):
self.value = int(s) self.value = int(s)
else:
self.value = 0
class StaticField(TextField): class StaticField(TextField):
@ -197,8 +205,10 @@ class StaticField(TextField):
class BlankField(TextField): class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False): def __init__(self, name=None, max_length=0, required=False):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self): def get_data(self):
return b' ' * self.max_length return b' ' * self.max_length
@ -208,13 +218,17 @@ class BlankField(TextField):
class ZeroField(BlankField): class ZeroField(BlankField):
is_read_only = True
def get_data(self): def get_data(self):
return b'0' * self.max_length return b'0' * self.max_length
class CRLFField(TextField): class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False): def __init__(self, name=None, required=False):
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False) super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -262,12 +276,27 @@ class MoneyField(Field):
return formatted[:self.max_length] return formatted[:self.max_length]
def parse(self, s): def parse(self, s):
if not is_blank_space(s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01') self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField): class DateField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=8) super(DateField, self).__init__(name=name, required=required, max_length=8)
if value: if value:
self.value = value self.value = value
@ -298,7 +327,7 @@ class DateField(TextField):
class MonthYearField(TextField): class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=6) super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
if value: if value:
self.value = value self.value = value

View file

@ -4,11 +4,15 @@ import collections
class Model(object): class Model(object):
record_length = -1
record_identifier = ' ' record_identifier = ' '
required = False required = False
target_size = 512 target_size = 512
def __init__(self): def __init__(self):
if self.record_length == -1:
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()): for (key, value) in list(self.__class__.__dict__.items()):
if isinstance(value, Field): if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION # GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
@ -19,15 +23,22 @@ class Model(object):
if not src_field.name: if not src_field.name:
setattr(src_field, 'name', key) setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__) setattr(src_field, 'parent_name', self.__class__.__name__)
self.__dict__[key] = copy.copy(src_field) new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = None
self.__dict__[key] = new_field_instance
def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field): if hasattr(self, key) and isinstance(getattr(self, key), Field):
getattr(self, key).value = value self.set_field_value(key, value)
else: else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR? # MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value self.__dict__[key] = value
def set_field_value(self, field_name, value):
print('setfieldval: ' + field_name + ' ' + value)
getattr(self, field_name).value = value
def get_fields(self): def get_fields(self):
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
identifier.value = self.record_identifier identifier.value = self.record_identifier
@ -55,18 +66,28 @@ class Model(object):
if isinstance(custom_validator, collections.Callable): if isinstance(custom_validator, collections.Callable):
custom_validator(f) custom_validator(f)
def output(self): def output(self, format='binary'):
if format == 'text':
return self.output_text()
return self.output_efile()
def output_efile(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()]) result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if self.record_length < 0 or len(result) != self.record_length:
if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
return result return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp): def read(self, fp):
# Skip the first record, since that's an identifier # Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]: for field in self.get_sorted_fields()[1:]:
field.read(fp) field.read(fp)
print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value)
def toJSON(self): def toJSON(self):
return { return {
@ -77,6 +98,9 @@ class Model(object):
def fromJSON(self, o): def fromJSON(self, o):
fields = o['fields'] fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields: for f in fields:
target = self.__dict__[f.name] target = self.__dict__[f.name]
@ -84,7 +108,7 @@ class Model(object):
or target.max_length != f.max_length): or target.max_length != f.max_length):
print("Warning: value mismatch on import") print("Warning: value mismatch on import")
target._value = f._value target.value = f.value
return self return self

View file

@ -2,7 +2,7 @@ import re
class ClassEntryCommentSequence(object): class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line): def __init__(self, classname, line):
self.classname = classname, self.classname = classname,
@ -72,7 +72,7 @@ class ModelDefParser(object):
classmatch = self.re_classdef.match(line) classmatch = self.re_classdef.match(line)
if classmatch: if classmatch:
classname, subclass = classmatch.groups() classname, _subclass = classmatch.groups()
self.beginclass(classname, self.line) self.beginclass(classname, self.line)
continue continue

View file

@ -109,7 +109,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken): class NumericToken(BaseToken):
regexp = re.compile('^(\d+)$') regexp = re.compile(r'^(\d+)$')
@property @property
def value(self): def value(self):

View file

@ -1,17 +1,15 @@
import unittest import unittest
import decimal
from pyaccuwage.fields import TextField from pyaccuwage.fields import TextField
from pyaccuwage.fields import IntegerField # from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import StateField # from pyaccuwage.fields import StateField
from pyaccuwage.fields import BlankField # from pyaccuwage.fields import BlankField
from pyaccuwage.fields import ZeroField # from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import MoneyField # from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model from pyaccuwage.model import Model
class TestTextField(unittest.TestCase): class TestTextField(unittest.TestCase):
def testStringShortOptional(self): def testStringShortOptional(self):
field = TextField(max_length=6, required=False) field = TextField(max_length=6, required=False)
field.validate() # optional field.validate() # optional
@ -30,43 +28,6 @@ class TestTextField(unittest.TestCase):
def testStringLongOptional(self): def testStringLongOptional(self):
field = TextField(max_length=6, required=False) field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long field.value = 'Hello, World!' # too long
self.assertEqual(len(field.get_data()), field.max_length) data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('1234.56')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'123456'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
self.assertEqual(model.output(), expected)

127
tests/test_records.py Normal file
View file

@ -0,0 +1,127 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
blank2 = BlankField(max_length=24)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)