diff --git a/__init__.py b/__init__.py index 105db6c..5087ffd 100644 --- a/__init__.py +++ b/__init__.py @@ -1,13 +1,4 @@ -""" -from record import SubmitterRecord -from record import EmployerRecord -from record import EmployeeWageRecord -from record import OptionalEmployeeWageRecord -from record import TotalRecord -from record import OptionalTotalRecord -from record import StateTotalRecord -from record import FinalRecord -""" +from record import * RECORD_TYPES = [ 'SubmitterRecord', @@ -20,8 +11,6 @@ RECORD_TYPES = [ 'FinalRecord', ] - - def test(): import record, model for rname in RECORD_TYPES: @@ -29,6 +18,53 @@ def test(): print type(inst), len(inst.output()) +def test_dump(): + import record, StringIO + records = [ + record.SubmitterRecord(), + record.EmployerRecord(), + record.EmployeeWageRecord(), + ] + out = StringIO.StringIO() + dump(records, out) + return out +def test_load(fp): + return load(fp) + +def load(fp): + # BUILD LIST OF RECORD TYPES + import record + types = {} + for r in RECORD_TYPES: + klass = record.__dict__[r] + types[klass.record_identifier] = klass + + # PARSE DATA INTO RECORDS AND YIELD THEM + while fp.tell() < fp.len: + record_ident = fp.read(2) + if record_ident in types: + record = types[record_ident]() + record.read(fp) + yield record + + +def loads(s): + import StringIO + fp = StringIO.StringIO(s) + return load(fp) + + +def dump(records, fp): + for r in records: + fp.write(r.output()) + + +def dumps(records): + import StringIO + fp = StringIO.StringIO() + dump(records, fp) + fp.seek(0) + return fp.read() diff --git a/fields.py b/fields.py index 801edce..e557ee9 100644 --- a/fields.py +++ b/fields.py @@ -1,3 +1,4 @@ +import decimal class ValidationError(Exception): pass @@ -15,10 +16,10 @@ class Field(object): Field.creation_counter += 1 def validate(self): - raise NotImplemented() + raise NotImplemented def get_data(self): - raise NotImplemented() + raise NotImplemented def __setvalue(self, value): self._value = value @@ -28,8 +29,16 @@ class Field(object): value = property(__getvalue, __setvalue) - def __repr__(self): - return self.name + def read(self, fp): + if fp.tell() + self.max_length <= fp.len: + data = fp.read(self.max_length) + print self, self.max_length, data + return self.parse(data) + return None + + def parse(self, s): + self.value = s.strip() + class TextField(Field): def validate(self): @@ -44,14 +53,17 @@ class TextField(Field): value = value.upper() return value.ljust(self.max_length).encode('ascii') + class StateField(TextField): def __init__(self, name=None, required=True): return super(StateField, self).__init__(name=name, max_length=2, required=required) + class EmailField(TextField): def __init__(self, name=None, required=True, max_length=None): return super(EmailField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) + class NumericField(TextField): def validate(self): super(NumericField, self).validate() @@ -59,24 +71,46 @@ class NumericField(TextField): int(self.value) except ValueError: raise ValidationError("field contains non-numeric characters") - + + def get_data(self): + value = self.value or "" + return value.zfill(self.max_length) + + def parse(self, s): + self.value = int(s) + + class StaticField(TextField): def __init__(self, name=None, required=True, value=None): super(StaticField, self).__init__(name=name, required=required, max_length=len(value)) self._value = value + def parse(self, s): + print 'STATIC', self.max_length, s, len(s) + pass + class BlankField(TextField): def get_data(self): return " " * self.max_length + def parse(self, s): + pass + class BooleanField(Field): + def __init__(self, name=None, required=True, value=None): + super(BooleanField, self).__init__(name=name, required=required, max_length=1) + self._value = value + def validate(self): pass def get_data(self): return '1' if self._value else '0' + def parse(self, s): + self.value = (s == '1') + class MoneyField(Field): def validate(self): if self.value == None and self.required: @@ -87,3 +121,6 @@ class MoneyField(Field): def get_data(self): return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length) + def parse(self, s): + self.value = decimal.Decimal(s) * decimal.Decimal('0.01') + diff --git a/model.py b/model.py index c74953a..27686dc 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,8 @@ -from fields import Field, TextField, MoneyField, StateField +from fields import Field + +class Model(object): + record_identifier = ' ' -class Model(object): def __init__(self): for (key, value) in self.__class__.__dict__.items(): if isinstance(value, Field): @@ -32,12 +34,10 @@ class Model(object): f.validate() def output(self): - return ''.join([field.get_data() for field in self.get_sorted_fields()]) + return ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()]) + def read(self, fp): + for field in self.get_sorted_fields(): + field.read(fp) - -class TestModel(Model): - field_a = TextField(max_length=20) - field_b = MoneyField(max_length=10) - state = StateField() diff --git a/record.py b/record.py index 4a61a05..62a4e8c 100644 --- a/record.py +++ b/record.py @@ -1,8 +1,14 @@ -from fields import * import model +from fields import * + +__all__ = ['SubmitterRecord', 'EmployerRecord', + 'EmployeeWageRecord', 'OptionalEmployeeWageRecord', + 'TotalRecord', 'OptionalTotalRecord', + 'StateTotalRecord', 'FinalRecord',] class SubmitterRecord(model.Model): - record_identifier = StaticField(value='ra') + record_identifier = 'RA' + submitter_ein = NumericField(max_length=9) user_id = TextField(max_length=8) software_vendor = TextField(max_length=4) @@ -44,7 +50,8 @@ class SubmitterRecord(model.Model): blank6 = BlankField(max_length=12) class EmployerRecord(model.Model): - record_identifier = StaticField(value='re') + record_identifier = 'RE' + tax_year = NumericField(max_length=4) agent_indicator = NumericField(max_length=1) employer_ein = TextField(max_length=9) @@ -69,7 +76,8 @@ class EmployerRecord(model.Model): blank2 = BlankField(max_length=291) class EmployeeWageRecord(model.Model): - record_identifier = StaticField(value='rw') + record_identifier = 'RW' + ssn = NumericField(max_length=9, required=False) employee_first_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15) @@ -119,7 +127,8 @@ class EmployeeWageRecord(model.Model): class OptionalEmployeeWageRecord(model.Model): - record_identifier = StaticField(value='ro') + record_identifier = 'RO' + blank1 = BlankField(max_length=9) allocated_tips = NumericField(max_length=11) uncollected_tax_on_tips = NumericField(max_length=11) @@ -145,10 +154,11 @@ class OptionalEmployeeWageRecord(model.Model): class TotalRecord(model.Model): - record_identifier = StaticField(value='rt') + record_identifier = 'RT' class OptionalTotalRecord(model.Model): - record_identifier = StaticField(value='ru') + record_identifier = 'RU' + number_of_ro_records = NumericField(max_length=7) allocated_tips = NumericField(max_length=15) uncollected_tax_on_tips = NumericField(max_length=15) @@ -173,11 +183,12 @@ class OptionalTotalRecord(model.Model): class StateTotalRecord(model.Model): - record_identifier = StaticField(value='rv') + record_identifier = 'RV' supplemental_data = TextField(max_length=510) class FinalRecord(model.Model): - record_identifier = StaticField(value='rf') + record_identifier = 'RF' + blank1 = BlankField(max_length=5) number_of_rw_records = NumericField(max_length=9) blank2 = BlankField(max_length=496)