diff --git a/pyaccuwage/__init__.py b/pyaccuwage/__init__.py index abb380f..513b8c6 100644 --- a/pyaccuwage/__init__.py +++ b/pyaccuwage/__init__.py @@ -1,6 +1,4 @@ -from .record import * -from .reader import RecordReader -import collections +from collections import Callable VERSION = (0, 2012, 0) @@ -14,77 +12,55 @@ RECORD_TYPES = [ 'OptionalTotalRecord', 'StateTotalRecord', 'FinalRecord' - ] - -def test(): - from . import record, model - from .fields import ValidationError - for rname in RECORD_TYPES: - inst = record.__dict__[rname]() - try: - output_length = len(inst.output()) - except ValidationError as e: - print(e.msg, type(inst), inst.record_identifier) - continue - - print(type(inst), inst.record_identifier, output_length) +] -def test_dump(): - import record, io - records = [ - record.SubmitterRecord(), - record.EmployerRecord(), - record.EmployeeWageRecord(), - ] - out = io.BytesIO() - dump(records, out, None) - return out - - -def test_record_order(): - from . import record - records = [ - record.SubmitterRecord(), - record.EmployerRecord(), - record.EmployeeWageRecord(), - record.TotalRecord(), - record.FinalRecord(), - ] - validate_record_order(records) - - -def test_load(fp): - return load(fp) - -def load(fp): - # BUILD LIST OF RECORD TYPES +def get_record_types(): from . import record types = {} for r in RECORD_TYPES: klass = record.__dict__[r] types[klass.record_identifier] = klass + return types + + +def load(fp, record_types): + distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types]) + assert(len(distinct_identifier_lengths) == 1) + ident_length = list(distinct_identifier_lengths)[0] + + # Add aliases for the record types based on their record_identifier since that's all + # we have to work with with the e1099 data. + record_types_by_ident = {} + for k in record_types: + record_type = record_types[k] + record_identifier = record_type.record_identifier + record_types_by_ident[record_identifier] = record_type # PARSE DATA INTO RECORDS AND YIELD THEM - while fp.tell() < fp.len: - record_ident = fp.read(2) - if record_ident in types: - record = types[record_ident]() + while True: + record_ident = fp.read(ident_length) + if not record_ident: + break + if record_ident in record_types_by_ident: + record = record_types_by_ident[record_ident]() record.read(fp) yield record -def loads(s): + +def loads(s, record_types=get_record_types()): import io fp = io.BytesIO(s) - return load(fp) + return load(fp, record_types) -def dump(records, fp, delim=None): +def dump(fp, records, delim=None): for r in records: fp.write(r.output()) if delim: fp.write(delim) + def dumps(records, delim=None): import io fp = io.BytesIO() @@ -92,15 +68,15 @@ def dumps(records, delim=None): fp.seek(0) return fp.read() + def json_dumps(records): import json - from . import model import decimal class JSONEncoder(json.JSONEncoder): def default(self, o): - if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), collections.Callable): + if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable): return o.toJSON() if type(o) is bytes: @@ -111,37 +87,76 @@ def json_dumps(records): return super(JSONEncoder, self).default(o) - return json.dumps(records, cls=JSONEncoder, indent=2) + return json.dumps(list(records), cls=JSONEncoder, indent=2) -def json_loads(s, record_classes): +def json_dump(fp, records): + fp.write(json_dumps(records)) + + +def json_loads(s, record_types): import json from . import fields import decimal - import re - if not isinstance(record_classes, dict): - record_classes = dict([ (x.__class__.__name__, x) for x in record_classes]) + if not isinstance(record_types, dict): + record_types = dict([ (x.__name__, x) for x in record_types]) def object_hook(o): if '__class__' in o: klass = o['__class__'] - - if klass in record_classes: - return record_classes[klass]().fromJSON(o) - + if klass in record_types: + record = record_types[klass]() + record.fromJSON(o) + return record elif hasattr(fields, klass): return getattr(fields, klass)().fromJSON(o) - return o - #print "OBJECTHOOK", str(o) - #return {'object_hook':str(o)} - #def default(self, o): - # return super(JSONDecoder, self).default(o) - return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook) +def json_load(fp, record_types): + return json_loads(fp.read(), record_types) + +def text_dump(fp, records): + for r in records: + fp.write(r.output(format='text').encode('ascii')) + + +def text_dumps(records): + import io + fp = io.BytesIO() + text_dump(fp, records) + fp.seek(0) + return fp.read() + + +def text_load(fp, record_classes): + records = [] + current_record = None + + if not isinstance(record_classes, dict): + record_classes = dict([ (x.__name__, x) for x in record_classes]) + + while True: #fp.readable(): + line = fp.readline().decode('ascii') + if not line: + break + if line.startswith('---'): + record_name = line.strip('---').strip() + current_record = record_classes[record_name]() + records.append(current_record) + elif ':' in line: + field, value = [x.strip() for x in line.split(':')] + current_record.set_field_value(field, value) + return records + +def text_loads(s, record_classes): + import io + fp = io.BytesIO(s) + return text_load(fp, record_classes) + + # THIS WAS IN CONTROLLER, BUT UNLESS WE # REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER # TO JUST KEEP IT IN HERE. @@ -153,7 +168,7 @@ def validate_required_records(records): klass = record.__dict__[r] if klass.required: req_types.append(klass.__name__) - + while req_types: req = req_types[0] if req not in types: @@ -162,10 +177,11 @@ def validate_required_records(records): else: req_types.remove(req) + def validate_record_order(records): from . import record from .fields import ValidationError - + # 1st record must be SubmitterRecord if not isinstance(records[0], record.SubmitterRecord): raise ValidationError("First record must be SubmitterRecord") @@ -211,15 +227,10 @@ def test_unique_fields(): r1 = EmployeeWageRecord() r1.employee_first_name.value = "John Johnson" - + r2 = EmployeeWageRecord() print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter) print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter) - + if r1.employee_first_name.value == r2.employee_first_name.value: raise ValidationError("Horrible problem involving shared values across records") - -#def state_postal_code(state_abbr): -# import enums -# return enums.state_postal_numeric[ state_abbr.upper() ] - diff --git a/pyaccuwage/fields.py b/pyaccuwage/fields.py index ef60b92..0d6fcd3 100644 --- a/pyaccuwage/fields.py +++ b/pyaccuwage/fields.py @@ -1,7 +1,10 @@ import decimal, datetime import inspect +from six import string_types from . import enums +def is_blank_space(val): + return len(val.strip()) == 0 class ValidationError(Exception): def __init__(self, msg, field=None): @@ -17,6 +20,7 @@ class ValidationError(Exception): class Field(object): creation_counter = 0 + is_read_only = False def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): self.name = name @@ -29,10 +33,10 @@ class Field(object): Field.creation_counter += 1 def validate(self): - raise NotImplemented + raise NotImplementedError def get_data(self): - raise NotImplemented + raise NotImplementedError def __setvalue(self, value): self._value = value @@ -77,7 +81,7 @@ class Field(object): required=o['required'], ) - if isinstance(o['value'], str) and re.match('^\d*\.\d*$', o['value']): + if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']): o['value'] = decimal.Decimal(o['value']) self.value = o['value'] @@ -164,9 +168,10 @@ class StateField(TextField): else: self.value = s + class EmailField(TextField): def __init__(self, name=None, required=True, max_length=None): - return super(EmailField, self).__init__(name=name, max_length=max_length, + super(EmailField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) class IntegerField(TextField): @@ -183,7 +188,10 @@ class IntegerField(TextField): return value.zfill(self.max_length)[:self.max_length] def parse(self, s): - self.value = int(s) + if not is_blank_space(s): + self.value = int(s) + else: + self.value = 0 class StaticField(TextField): @@ -197,8 +205,10 @@ class StaticField(TextField): class BlankField(TextField): + is_read_only = True + def __init__(self, name=None, max_length=0, required=False): - super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) + super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) def get_data(self): return b' ' * self.max_length @@ -208,13 +218,17 @@ class BlankField(TextField): class ZeroField(BlankField): + is_read_only = True + def get_data(self): return b'0' * self.max_length class CRLFField(TextField): + is_read_only = True + def __init__(self, name=None, required=False): - super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False) + super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False) def __setvalue(self, value): self._value = value @@ -262,12 +276,27 @@ class MoneyField(Field): return formatted[:self.max_length] def parse(self, s): - self.value = decimal.Decimal(s) * decimal.Decimal('0.01') + if not is_blank_space(s): + self.value = decimal.Decimal(s) * decimal.Decimal('0.01') + else: + self.value = decimal.Decimal(0.0) + def __setvalue(self, value): + new_value = value + if isinstance(new_value, string_types): + new_value = decimal.Decimal(new_value or '0') + if '.' not in value: # must be cents? + new_value *= decimal.Decimal('100.') + self._value = new_value + + def __getvalue(self): + return self._value + + value = property(__getvalue, __setvalue) class DateField(TextField): def __init__(self, name=None, required=True, value=None): - super(TextField, self).__init__(name=name, required=required, max_length=8) + super(DateField, self).__init__(name=name, required=required, max_length=8) if value: self.value = value @@ -298,7 +327,7 @@ class DateField(TextField): class MonthYearField(TextField): def __init__(self, name=None, required=True, value=None): - super(TextField, self).__init__(name=name, required=required, max_length=6) + super(MonthYearField, self).__init__(name=name, required=required, max_length=6) if value: self.value = value diff --git a/pyaccuwage/model.py b/pyaccuwage/model.py index b71c26f..becd3ce 100644 --- a/pyaccuwage/model.py +++ b/pyaccuwage/model.py @@ -4,11 +4,15 @@ import collections class Model(object): + record_length = -1 record_identifier = ' ' required = False target_size = 512 def __init__(self): + if self.record_length == -1: + raise ValueError(self.record_length) + for (key, value) in list(self.__class__.__dict__.items()): if isinstance(value, Field): # GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION @@ -19,15 +23,22 @@ class Model(object): if not src_field.name: setattr(src_field, 'name', key) setattr(src_field, 'parent_name', self.__class__.__name__) - self.__dict__[key] = copy.copy(src_field) + new_field_instance = copy.copy(src_field) + new_field_instance._orig_value = None + new_field_instance._value = None + self.__dict__[key] = new_field_instance def __setattr__(self, key, value): if hasattr(self, key) and isinstance(getattr(self, key), Field): - getattr(self, key).value = value + self.set_field_value(key, value) else: # MAYBE THIS SHOULD RAISE A PROPERTY ERROR? self.__dict__[key] = value + def set_field_value(self, field_name, value): + print('setfieldval: ' + field_name + ' ' + value) + getattr(self, field_name).value = value + def get_fields(self): identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) identifier.value = self.record_identifier @@ -55,18 +66,28 @@ class Model(object): if isinstance(custom_validator, collections.Callable): custom_validator(f) - def output(self): + def output(self, format='binary'): + if format == 'text': + return self.output_text() + return self.output_efile() + + def output_efile(self): result = b''.join([field.get_data() for field in self.get_sorted_fields()]) - - if hasattr(self, 'record_length') and len(result) != self.record_length: + if self.record_length < 0 or len(result) != self.record_length: raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) - return result + def output_text(self): + fields = self.get_sorted_fields()[1:] # skip record identifier + fields = [field for field in fields if not field.is_read_only] + header = ''.join(['---', self.__class__.__name__, '\n']) + return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n' + def read(self, fp): # Skip the first record, since that's an identifier for field in self.get_sorted_fields()[1:]: field.read(fp) + print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value) def toJSON(self): return { @@ -77,6 +98,9 @@ class Model(object): def fromJSON(self, o): fields = o['fields'] + identifier, fields = fields[0], fields[1:] + assert(identifier.value == self.record_identifier) + for f in fields: target = self.__dict__[f.name] @@ -84,7 +108,7 @@ class Model(object): or target.max_length != f.max_length): print("Warning: value mismatch on import") - target._value = f._value + target.value = f.value return self diff --git a/pyaccuwage/modeldef.py b/pyaccuwage/modeldef.py index c6c9110..6d4ce35 100644 --- a/pyaccuwage/modeldef.py +++ b/pyaccuwage/modeldef.py @@ -2,7 +2,7 @@ import re class ClassEntryCommentSequence(object): - re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') + re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$') def __init__(self, classname, line): self.classname = classname, @@ -72,7 +72,7 @@ class ModelDefParser(object): classmatch = self.re_classdef.match(line) if classmatch: - classname, subclass = classmatch.groups() + classname, _subclass = classmatch.groups() self.beginclass(classname, self.line) continue diff --git a/pyaccuwage/parser.py b/pyaccuwage/parser.py index 250d122..c0fe399 100644 --- a/pyaccuwage/parser.py +++ b/pyaccuwage/parser.py @@ -109,7 +109,7 @@ class RangeToken(BaseToken): class NumericToken(BaseToken): - regexp = re.compile('^(\d+)$') + regexp = re.compile(r'^(\d+)$') @property def value(self): diff --git a/tests/test_fields.py b/tests/test_fields.py index 9293acd..3d8fd3e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,17 +1,15 @@ import unittest -import decimal from pyaccuwage.fields import TextField -from pyaccuwage.fields import IntegerField -from pyaccuwage.fields import StateField -from pyaccuwage.fields import BlankField -from pyaccuwage.fields import ZeroField -from pyaccuwage.fields import MoneyField +# from pyaccuwage.fields import IntegerField +# from pyaccuwage.fields import StateField +# from pyaccuwage.fields import BlankField +# from pyaccuwage.fields import ZeroField +# from pyaccuwage.fields import MoneyField from pyaccuwage.fields import ValidationError from pyaccuwage.model import Model class TestTextField(unittest.TestCase): - def testStringShortOptional(self): field = TextField(max_length=6, required=False) field.validate() # optional @@ -30,43 +28,6 @@ class TestTextField(unittest.TestCase): def testStringLongOptional(self): field = TextField(max_length=6, required=False) field.value = 'Hello, World!' # too long - self.assertEqual(len(field.get_data()), field.max_length) - - -class TestModelOutput(unittest.TestCase): - class TestModel(Model): - record_length = 128 - record_identifier = 'TEST' # 4 bytes - field1 = TextField(max_length=16) - field2 = IntegerField(max_length=16) - blank1 = BlankField(max_length=16) - zero1 = ZeroField(max_length=16) - money = MoneyField(max_length=32) - state_txt = StateField() - state_num = StateField(use_numeric=True) - blank2 = BlankField(max_length=24) - - def setUp(self): - self.model = TestModelOutput.TestModel() - - def testModelOutput(self): - model = self.model - model.field1.value = 'Hello, sir!' - model.field2.value = 12345 - model.money.value = decimal.Decimal('1234.56') - model.state_txt.value = 'IA' - model.state_num.value = 'IA' - - expected = b''.join([ - b'TEST', - b'HELLO, SIR!'.ljust(16), - b'12345'.zfill(16), - b' ' * 16, - b'0' * 16, - b'123456'.zfill(32), - b'IA', - b'19', - b' ' * 24, - ]) - - self.assertEqual(model.output(), expected) + data = field.get_data() + self.assertEqual(len(data), field.max_length) + self.assertEqual(data, b'HELLO,') diff --git a/tests/test_records.py b/tests/test_records.py new file mode 100644 index 0000000..a6485ac --- /dev/null +++ b/tests/test_records.py @@ -0,0 +1,127 @@ +import unittest +import decimal +import pyaccuwage +from pyaccuwage.fields import BlankField +from pyaccuwage.fields import IntegerField +from pyaccuwage.fields import MoneyField +from pyaccuwage.fields import StateField +from pyaccuwage.fields import TextField +from pyaccuwage.fields import ZeroField +from pyaccuwage.model import Model + +class TestModelOutput(unittest.TestCase): + class TestModel(Model): + record_length = 128 + record_identifier = 'TEST' # 4 bytes + field1 = TextField(max_length=16) + field2 = IntegerField(max_length=16) + blank1 = BlankField(max_length=16) + zero1 = ZeroField(max_length=16) + money = MoneyField(max_length=32) + state_txt = StateField() + state_num = StateField(use_numeric=True) + blank2 = BlankField(max_length=24) + + def setUp(self): + self.model = TestModelOutput.TestModel() + + def testModelBinaryOutput(self): + model = self.model + model.field1.value = 'Hello, sir!' + model.field2.value = 12345 + model.money.value = decimal.Decimal('3133.77') + model.state_txt.value = 'IA' + model.state_num.value = 'IA' + + expected = b''.join([ + b'TEST', + b'HELLO, SIR!'.ljust(16), + b'12345'.zfill(16), + b' ' * 16, + b'0' * 16, + b'313377'.zfill(32), + b'IA', + b'19', + b' ' * 24, + ]) + + output = model.output() + self.assertEqual(len(output), TestModelOutput.TestModel.record_length) + self.assertEqual(output, expected) + + def testModelTextOutput(self): + model = self.model + model.field1.value = 'Hello, sir!' + model.field2.value = 12345 + model.money.value = decimal.Decimal('3133.77') + model.state_txt.value = 'IA' + model.state_num.value = 'IA' + output = model.output(format='text') + + self.assertEqual(output, '''---TestModel +field1: Hello, sir! +field2: 12345 +money: 3133.77 +state_txt: IA +state_num: IA + +''') + + +class TestFileFormats(unittest.TestCase): + class TestModelA(pyaccuwage.model.Model): + record_length = 128 + record_identifier = 'A' # 1 byte + field1 = TextField(max_length=16) + field2 = IntegerField(max_length=16) + blank1 = BlankField(max_length=16) + zero1 = ZeroField(max_length=16) + money = MoneyField(max_length=32) + state_txt = StateField() + state_num = StateField(use_numeric=True) + blank2 = BlankField(max_length=27) + + class TestModelB(pyaccuwage.model.Model): + record_length = 128 + record_identifier = 'B' # 1 byte + zero1 = ZeroField(max_length=32) + text1 = TextField(max_length=71) + blank2 = BlankField(max_length=24) + + record_types = [TestModelA, TestModelB] + + def createExampleRecords(self): + model_a = TestFileFormats.TestModelA() + model_a.field1.value = 'I am model a' + model_a.field2.value = 5522 + model_a.money.value = decimal.Decimal('23.00') + model_a.state_txt.value = 'IA' + model_a.state_num.value = 'IA' + + model_b = TestFileFormats.TestModelB() + model_b.text1.value = 'hey I am model b and I have a big text field' + + return [ + model_a, + model_b, + ] + + def testJSONSerialization(self): + records = self.createExampleRecords() + record_types = self.record_types + json_data = pyaccuwage.json_dumps(records) + records_loaded = pyaccuwage.json_loads(json_data, record_types) + + original_bytes = pyaccuwage.dumps(records) + reloaded_bytes = pyaccuwage.dumps(records_loaded) + self.assertEqual(original_bytes, reloaded_bytes) + + def testTxtSerialization(self): + records = self.createExampleRecords() + record_types = self.record_types + text_data = pyaccuwage.text_dumps(records) + records_loaded = pyaccuwage.text_loads(text_data, record_types) + + original_bytes = pyaccuwage.dumps(records) + reloaded_bytes = pyaccuwage.dumps(records_loaded) + self.assertEqual(original_bytes, reloaded_bytes)