From 068f1bbae41b672b72ef651c7c4c42caaef72477 Mon Sep 17 00:00:00 2001 From: Binh Van Nguyen Date: Sat, 2 Apr 2011 15:28:38 -0500 Subject: [PATCH] Added load/dump methods which work similarly to those found in simplejson. Tests seem to work so far. Still need to figure out how to get data into the records in some easy way. --- __init__.py | 60 ++++++++++++++++++++++++++++++++++++++++++----------- fields.py | 47 ++++++++++++++++++++++++++++++++++++----- model.py | 16 +++++++------- record.py | 29 ++++++++++++++++++-------- 4 files changed, 118 insertions(+), 34 deletions(-) diff --git a/__init__.py b/__init__.py index 105db6c..5087ffd 100644 --- a/__init__.py +++ b/__init__.py @@ -1,13 +1,4 @@ -""" -from record import SubmitterRecord -from record import EmployerRecord -from record import EmployeeWageRecord -from record import OptionalEmployeeWageRecord -from record import TotalRecord -from record import OptionalTotalRecord -from record import StateTotalRecord -from record import FinalRecord -""" +from record import * RECORD_TYPES = [ 'SubmitterRecord', @@ -20,8 +11,6 @@ RECORD_TYPES = [ 'FinalRecord', ] - - def test(): import record, model for rname in RECORD_TYPES: @@ -29,6 +18,53 @@ def test(): print type(inst), len(inst.output()) +def test_dump(): + import record, StringIO + records = [ + record.SubmitterRecord(), + record.EmployerRecord(), + record.EmployeeWageRecord(), + ] + out = StringIO.StringIO() + dump(records, out) + return out +def test_load(fp): + return load(fp) + +def load(fp): + # BUILD LIST OF RECORD TYPES + import record + types = {} + for r in RECORD_TYPES: + klass = record.__dict__[r] + types[klass.record_identifier] = klass + + # PARSE DATA INTO RECORDS AND YIELD THEM + while fp.tell() < fp.len: + record_ident = fp.read(2) + if record_ident in types: + record = types[record_ident]() + record.read(fp) + yield record + + +def loads(s): + import StringIO + fp = StringIO.StringIO(s) + return load(fp) + + +def dump(records, fp): + for r in records: + fp.write(r.output()) + + +def dumps(records): + import StringIO + fp = StringIO.StringIO() + dump(records, fp) + fp.seek(0) + return fp.read() diff --git a/fields.py b/fields.py index 801edce..e557ee9 100644 --- a/fields.py +++ b/fields.py @@ -1,3 +1,4 @@ +import decimal class ValidationError(Exception): pass @@ -15,10 +16,10 @@ class Field(object): Field.creation_counter += 1 def validate(self): - raise NotImplemented() + raise NotImplemented def get_data(self): - raise NotImplemented() + raise NotImplemented def __setvalue(self, value): self._value = value @@ -28,8 +29,16 @@ class Field(object): value = property(__getvalue, __setvalue) - def __repr__(self): - return self.name + def read(self, fp): + if fp.tell() + self.max_length <= fp.len: + data = fp.read(self.max_length) + print self, self.max_length, data + return self.parse(data) + return None + + def parse(self, s): + self.value = s.strip() + class TextField(Field): def validate(self): @@ -44,14 +53,17 @@ class TextField(Field): value = value.upper() return value.ljust(self.max_length).encode('ascii') + class StateField(TextField): def __init__(self, name=None, required=True): return super(StateField, self).__init__(name=name, max_length=2, required=required) + class EmailField(TextField): def __init__(self, name=None, required=True, max_length=None): return super(EmailField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) + class NumericField(TextField): def validate(self): super(NumericField, self).validate() @@ -59,24 +71,46 @@ class NumericField(TextField): int(self.value) except ValueError: raise ValidationError("field contains non-numeric characters") - + + def get_data(self): + value = self.value or "" + return value.zfill(self.max_length) + + def parse(self, s): + self.value = int(s) + + class StaticField(TextField): def __init__(self, name=None, required=True, value=None): super(StaticField, self).__init__(name=name, required=required, max_length=len(value)) self._value = value + def parse(self, s): + print 'STATIC', self.max_length, s, len(s) + pass + class BlankField(TextField): def get_data(self): return " " * self.max_length + def parse(self, s): + pass + class BooleanField(Field): + def __init__(self, name=None, required=True, value=None): + super(BooleanField, self).__init__(name=name, required=required, max_length=1) + self._value = value + def validate(self): pass def get_data(self): return '1' if self._value else '0' + def parse(self, s): + self.value = (s == '1') + class MoneyField(Field): def validate(self): if self.value == None and self.required: @@ -87,3 +121,6 @@ class MoneyField(Field): def get_data(self): return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length) + def parse(self, s): + self.value = decimal.Decimal(s) * decimal.Decimal('0.01') + diff --git a/model.py b/model.py index c74953a..27686dc 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,8 @@ -from fields import Field, TextField, MoneyField, StateField +from fields import Field + +class Model(object): + record_identifier = ' ' -class Model(object): def __init__(self): for (key, value) in self.__class__.__dict__.items(): if isinstance(value, Field): @@ -32,12 +34,10 @@ class Model(object): f.validate() def output(self): - return ''.join([field.get_data() for field in self.get_sorted_fields()]) + return ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()]) + def read(self, fp): + for field in self.get_sorted_fields(): + field.read(fp) - -class TestModel(Model): - field_a = TextField(max_length=20) - field_b = MoneyField(max_length=10) - state = StateField() diff --git a/record.py b/record.py index 4a61a05..62a4e8c 100644 --- a/record.py +++ b/record.py @@ -1,8 +1,14 @@ -from fields import * import model +from fields import * + +__all__ = ['SubmitterRecord', 'EmployerRecord', + 'EmployeeWageRecord', 'OptionalEmployeeWageRecord', + 'TotalRecord', 'OptionalTotalRecord', + 'StateTotalRecord', 'FinalRecord',] class SubmitterRecord(model.Model): - record_identifier = StaticField(value='ra') + record_identifier = 'RA' + submitter_ein = NumericField(max_length=9) user_id = TextField(max_length=8) software_vendor = TextField(max_length=4) @@ -44,7 +50,8 @@ class SubmitterRecord(model.Model): blank6 = BlankField(max_length=12) class EmployerRecord(model.Model): - record_identifier = StaticField(value='re') + record_identifier = 'RE' + tax_year = NumericField(max_length=4) agent_indicator = NumericField(max_length=1) employer_ein = TextField(max_length=9) @@ -69,7 +76,8 @@ class EmployerRecord(model.Model): blank2 = BlankField(max_length=291) class EmployeeWageRecord(model.Model): - record_identifier = StaticField(value='rw') + record_identifier = 'RW' + ssn = NumericField(max_length=9, required=False) employee_first_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15) @@ -119,7 +127,8 @@ class EmployeeWageRecord(model.Model): class OptionalEmployeeWageRecord(model.Model): - record_identifier = StaticField(value='ro') + record_identifier = 'RO' + blank1 = BlankField(max_length=9) allocated_tips = NumericField(max_length=11) uncollected_tax_on_tips = NumericField(max_length=11) @@ -145,10 +154,11 @@ class OptionalEmployeeWageRecord(model.Model): class TotalRecord(model.Model): - record_identifier = StaticField(value='rt') + record_identifier = 'RT' class OptionalTotalRecord(model.Model): - record_identifier = StaticField(value='ru') + record_identifier = 'RU' + number_of_ro_records = NumericField(max_length=7) allocated_tips = NumericField(max_length=15) uncollected_tax_on_tips = NumericField(max_length=15) @@ -173,11 +183,12 @@ class OptionalTotalRecord(model.Model): class StateTotalRecord(model.Model): - record_identifier = StaticField(value='rv') + record_identifier = 'RV' supplemental_data = TextField(max_length=510) class FinalRecord(model.Model): - record_identifier = StaticField(value='rf') + record_identifier = 'RF' + blank1 = BlankField(max_length=5) number_of_rw_records = NumericField(max_length=9) blank2 = BlankField(max_length=496)