Added load/dump methods which work similarly to those found in

simplejson. Tests seem to work so far. Still need to figure out
how to get data into the records in some easy way.
This commit is contained in:
Binh 2011-04-02 15:28:38 -05:00
parent a32feb79ed
commit 068f1bbae4
4 changed files with 118 additions and 34 deletions

View file

@ -1,13 +1,4 @@
"""
from record import SubmitterRecord
from record import EmployerRecord
from record import EmployeeWageRecord
from record import OptionalEmployeeWageRecord
from record import TotalRecord
from record import OptionalTotalRecord
from record import StateTotalRecord
from record import FinalRecord
"""
from record import *
RECORD_TYPES = [
'SubmitterRecord',
@ -20,8 +11,6 @@ RECORD_TYPES = [
'FinalRecord',
]
def test():
import record, model
for rname in RECORD_TYPES:
@ -29,6 +18,53 @@ def test():
print type(inst), len(inst.output())
def test_dump():
import record, StringIO
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = StringIO.StringIO()
dump(records, out)
return out
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
import record
types = {}
for r in RECORD_TYPES:
klass = record.__dict__[r]
types[klass.record_identifier] = klass
# PARSE DATA INTO RECORDS AND YIELD THEM
while fp.tell() < fp.len:
record_ident = fp.read(2)
if record_ident in types:
record = types[record_ident]()
record.read(fp)
yield record
def loads(s):
import StringIO
fp = StringIO.StringIO(s)
return load(fp)
def dump(records, fp):
for r in records:
fp.write(r.output())
def dumps(records):
import StringIO
fp = StringIO.StringIO()
dump(records, fp)
fp.seek(0)
return fp.read()

View file

@ -1,3 +1,4 @@
import decimal
class ValidationError(Exception):
pass
@ -15,10 +16,10 @@ class Field(object):
Field.creation_counter += 1
def validate(self):
raise NotImplemented()
raise NotImplemented
def get_data(self):
raise NotImplemented()
raise NotImplemented
def __setvalue(self, value):
self._value = value
@ -28,8 +29,16 @@ class Field(object):
value = property(__getvalue, __setvalue)
def __repr__(self):
return self.name
def read(self, fp):
if fp.tell() + self.max_length <= fp.len:
data = fp.read(self.max_length)
print self, self.max_length, data
return self.parse(data)
return None
def parse(self, s):
self.value = s.strip()
class TextField(Field):
def validate(self):
@ -44,14 +53,17 @@ class TextField(Field):
value = value.upper()
return value.ljust(self.max_length).encode('ascii')
class StateField(TextField):
def __init__(self, name=None, required=True):
return super(StateField, self).__init__(name=name, max_length=2, required=required)
class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None):
return super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False)
class NumericField(TextField):
def validate(self):
super(NumericField, self).validate()
@ -60,23 +72,45 @@ class NumericField(TextField):
except ValueError:
raise ValidationError("field contains non-numeric characters")
def get_data(self):
value = self.value or ""
return value.zfill(self.max_length)
def parse(self, s):
self.value = int(s)
class StaticField(TextField):
def __init__(self, name=None, required=True, value=None):
super(StaticField, self).__init__(name=name, required=required,
max_length=len(value))
self._value = value
def parse(self, s):
print 'STATIC', self.max_length, s, len(s)
pass
class BlankField(TextField):
def get_data(self):
return " " * self.max_length
def parse(self, s):
pass
class BooleanField(Field):
def __init__(self, name=None, required=True, value=None):
super(BooleanField, self).__init__(name=name, required=required, max_length=1)
self._value = value
def validate(self):
pass
def get_data(self):
return '1' if self._value else '0'
def parse(self, s):
self.value = (s == '1')
class MoneyField(Field):
def validate(self):
if self.value == None and self.required:
@ -87,3 +121,6 @@ class MoneyField(Field):
def get_data(self):
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)
def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')

View file

@ -1,6 +1,8 @@
from fields import Field, TextField, MoneyField, StateField
from fields import Field
class Model(object):
record_identifier = ' '
def __init__(self):
for (key, value) in self.__class__.__dict__.items():
if isinstance(value, Field):
@ -32,12 +34,10 @@ class Model(object):
f.validate()
def output(self):
return ''.join([field.get_data() for field in self.get_sorted_fields()])
return ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
def read(self, fp):
for field in self.get_sorted_fields():
field.read(fp)
class TestModel(Model):
field_a = TextField(max_length=20)
field_b = MoneyField(max_length=10)
state = StateField()

View file

@ -1,8 +1,14 @@
from fields import *
import model
from fields import *
__all__ = ['SubmitterRecord', 'EmployerRecord',
'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
'TotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'FinalRecord',]
class SubmitterRecord(model.Model):
record_identifier = StaticField(value='ra')
record_identifier = 'RA'
submitter_ein = NumericField(max_length=9)
user_id = TextField(max_length=8)
software_vendor = TextField(max_length=4)
@ -44,7 +50,8 @@ class SubmitterRecord(model.Model):
blank6 = BlankField(max_length=12)
class EmployerRecord(model.Model):
record_identifier = StaticField(value='re')
record_identifier = 'RE'
tax_year = NumericField(max_length=4)
agent_indicator = NumericField(max_length=1)
employer_ein = TextField(max_length=9)
@ -69,7 +76,8 @@ class EmployerRecord(model.Model):
blank2 = BlankField(max_length=291)
class EmployeeWageRecord(model.Model):
record_identifier = StaticField(value='rw')
record_identifier = 'RW'
ssn = NumericField(max_length=9, required=False)
employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15)
@ -119,7 +127,8 @@ class EmployeeWageRecord(model.Model):
class OptionalEmployeeWageRecord(model.Model):
record_identifier = StaticField(value='ro')
record_identifier = 'RO'
blank1 = BlankField(max_length=9)
allocated_tips = NumericField(max_length=11)
uncollected_tax_on_tips = NumericField(max_length=11)
@ -145,10 +154,11 @@ class OptionalEmployeeWageRecord(model.Model):
class TotalRecord(model.Model):
record_identifier = StaticField(value='rt')
record_identifier = 'RT'
class OptionalTotalRecord(model.Model):
record_identifier = StaticField(value='ru')
record_identifier = 'RU'
number_of_ro_records = NumericField(max_length=7)
allocated_tips = NumericField(max_length=15)
uncollected_tax_on_tips = NumericField(max_length=15)
@ -173,11 +183,12 @@ class OptionalTotalRecord(model.Model):
class StateTotalRecord(model.Model):
record_identifier = StaticField(value='rv')
record_identifier = 'RV'
supplemental_data = TextField(max_length=510)
class FinalRecord(model.Model):
record_identifier = StaticField(value='rf')
record_identifier = 'RF'
blank1 = BlankField(max_length=5)
number_of_rw_records = NumericField(max_length=9)
blank2 = BlankField(max_length=496)