Added load/dump methods which work similarly to those found in
simplejson. Tests seem to work so far. Still need to figure out how to get data into the records in some easy way.
This commit is contained in:
parent
a32feb79ed
commit
068f1bbae4
4 changed files with 118 additions and 34 deletions
60
__init__.py
60
__init__.py
|
@ -1,13 +1,4 @@
|
||||||
"""
|
from record import *
|
||||||
from record import SubmitterRecord
|
|
||||||
from record import EmployerRecord
|
|
||||||
from record import EmployeeWageRecord
|
|
||||||
from record import OptionalEmployeeWageRecord
|
|
||||||
from record import TotalRecord
|
|
||||||
from record import OptionalTotalRecord
|
|
||||||
from record import StateTotalRecord
|
|
||||||
from record import FinalRecord
|
|
||||||
"""
|
|
||||||
|
|
||||||
RECORD_TYPES = [
|
RECORD_TYPES = [
|
||||||
'SubmitterRecord',
|
'SubmitterRecord',
|
||||||
|
@ -20,8 +11,6 @@ RECORD_TYPES = [
|
||||||
'FinalRecord',
|
'FinalRecord',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
import record, model
|
import record, model
|
||||||
for rname in RECORD_TYPES:
|
for rname in RECORD_TYPES:
|
||||||
|
@ -29,6 +18,53 @@ def test():
|
||||||
print type(inst), len(inst.output())
|
print type(inst), len(inst.output())
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump():
|
||||||
|
import record, StringIO
|
||||||
|
records = [
|
||||||
|
record.SubmitterRecord(),
|
||||||
|
record.EmployerRecord(),
|
||||||
|
record.EmployeeWageRecord(),
|
||||||
|
]
|
||||||
|
out = StringIO.StringIO()
|
||||||
|
dump(records, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def test_load(fp):
|
||||||
|
return load(fp)
|
||||||
|
|
||||||
|
def load(fp):
|
||||||
|
# BUILD LIST OF RECORD TYPES
|
||||||
|
import record
|
||||||
|
types = {}
|
||||||
|
for r in RECORD_TYPES:
|
||||||
|
klass = record.__dict__[r]
|
||||||
|
types[klass.record_identifier] = klass
|
||||||
|
|
||||||
|
# PARSE DATA INTO RECORDS AND YIELD THEM
|
||||||
|
while fp.tell() < fp.len:
|
||||||
|
record_ident = fp.read(2)
|
||||||
|
if record_ident in types:
|
||||||
|
record = types[record_ident]()
|
||||||
|
record.read(fp)
|
||||||
|
yield record
|
||||||
|
|
||||||
|
|
||||||
|
def loads(s):
|
||||||
|
import StringIO
|
||||||
|
fp = StringIO.StringIO(s)
|
||||||
|
return load(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def dump(records, fp):
|
||||||
|
for r in records:
|
||||||
|
fp.write(r.output())
|
||||||
|
|
||||||
|
|
||||||
|
def dumps(records):
|
||||||
|
import StringIO
|
||||||
|
fp = StringIO.StringIO()
|
||||||
|
dump(records, fp)
|
||||||
|
fp.seek(0)
|
||||||
|
return fp.read()
|
||||||
|
|
||||||
|
|
47
fields.py
47
fields.py
|
@ -1,3 +1,4 @@
|
||||||
|
import decimal
|
||||||
|
|
||||||
class ValidationError(Exception):
|
class ValidationError(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -15,10 +16,10 @@ class Field(object):
|
||||||
Field.creation_counter += 1
|
Field.creation_counter += 1
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
raise NotImplemented()
|
raise NotImplemented
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
raise NotImplemented()
|
raise NotImplemented
|
||||||
|
|
||||||
def __setvalue(self, value):
|
def __setvalue(self, value):
|
||||||
self._value = value
|
self._value = value
|
||||||
|
@ -28,8 +29,16 @@ class Field(object):
|
||||||
|
|
||||||
value = property(__getvalue, __setvalue)
|
value = property(__getvalue, __setvalue)
|
||||||
|
|
||||||
def __repr__(self):
|
def read(self, fp):
|
||||||
return self.name
|
if fp.tell() + self.max_length <= fp.len:
|
||||||
|
data = fp.read(self.max_length)
|
||||||
|
print self, self.max_length, data
|
||||||
|
return self.parse(data)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
self.value = s.strip()
|
||||||
|
|
||||||
|
|
||||||
class TextField(Field):
|
class TextField(Field):
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
@ -44,14 +53,17 @@ class TextField(Field):
|
||||||
value = value.upper()
|
value = value.upper()
|
||||||
return value.ljust(self.max_length).encode('ascii')
|
return value.ljust(self.max_length).encode('ascii')
|
||||||
|
|
||||||
|
|
||||||
class StateField(TextField):
|
class StateField(TextField):
|
||||||
def __init__(self, name=None, required=True):
|
def __init__(self, name=None, required=True):
|
||||||
return super(StateField, self).__init__(name=name, max_length=2, required=required)
|
return super(StateField, self).__init__(name=name, max_length=2, required=required)
|
||||||
|
|
||||||
|
|
||||||
class EmailField(TextField):
|
class EmailField(TextField):
|
||||||
def __init__(self, name=None, required=True, max_length=None):
|
def __init__(self, name=None, required=True, max_length=None):
|
||||||
return super(EmailField, self).__init__(name=name, max_length=max_length,
|
return super(EmailField, self).__init__(name=name, max_length=max_length,
|
||||||
required=required, uppercase=False)
|
required=required, uppercase=False)
|
||||||
|
|
||||||
class NumericField(TextField):
|
class NumericField(TextField):
|
||||||
def validate(self):
|
def validate(self):
|
||||||
super(NumericField, self).validate()
|
super(NumericField, self).validate()
|
||||||
|
@ -59,24 +71,46 @@ class NumericField(TextField):
|
||||||
int(self.value)
|
int(self.value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValidationError("field contains non-numeric characters")
|
raise ValidationError("field contains non-numeric characters")
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
value = self.value or ""
|
||||||
|
return value.zfill(self.max_length)
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
self.value = int(s)
|
||||||
|
|
||||||
|
|
||||||
class StaticField(TextField):
|
class StaticField(TextField):
|
||||||
def __init__(self, name=None, required=True, value=None):
|
def __init__(self, name=None, required=True, value=None):
|
||||||
super(StaticField, self).__init__(name=name, required=required,
|
super(StaticField, self).__init__(name=name, required=required,
|
||||||
max_length=len(value))
|
max_length=len(value))
|
||||||
self._value = value
|
self._value = value
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
print 'STATIC', self.max_length, s, len(s)
|
||||||
|
pass
|
||||||
|
|
||||||
class BlankField(TextField):
|
class BlankField(TextField):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return " " * self.max_length
|
return " " * self.max_length
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
pass
|
||||||
|
|
||||||
class BooleanField(Field):
|
class BooleanField(Field):
|
||||||
|
def __init__(self, name=None, required=True, value=None):
|
||||||
|
super(BooleanField, self).__init__(name=name, required=required, max_length=1)
|
||||||
|
self._value = value
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return '1' if self._value else '0'
|
return '1' if self._value else '0'
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
self.value = (s == '1')
|
||||||
|
|
||||||
class MoneyField(Field):
|
class MoneyField(Field):
|
||||||
def validate(self):
|
def validate(self):
|
||||||
if self.value == None and self.required:
|
if self.value == None and self.required:
|
||||||
|
@ -87,3 +121,6 @@ class MoneyField(Field):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)
|
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)
|
||||||
|
|
||||||
|
def parse(self, s):
|
||||||
|
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
||||||
|
|
||||||
|
|
16
model.py
16
model.py
|
@ -1,6 +1,8 @@
|
||||||
from fields import Field, TextField, MoneyField, StateField
|
from fields import Field
|
||||||
|
|
||||||
|
class Model(object):
|
||||||
|
record_identifier = ' '
|
||||||
|
|
||||||
class Model(object):
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
for (key, value) in self.__class__.__dict__.items():
|
for (key, value) in self.__class__.__dict__.items():
|
||||||
if isinstance(value, Field):
|
if isinstance(value, Field):
|
||||||
|
@ -32,12 +34,10 @@ class Model(object):
|
||||||
f.validate()
|
f.validate()
|
||||||
|
|
||||||
def output(self):
|
def output(self):
|
||||||
return ''.join([field.get_data() for field in self.get_sorted_fields()])
|
return ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
|
||||||
|
|
||||||
|
def read(self, fp):
|
||||||
|
for field in self.get_sorted_fields():
|
||||||
|
field.read(fp)
|
||||||
|
|
||||||
|
|
||||||
class TestModel(Model):
|
|
||||||
field_a = TextField(max_length=20)
|
|
||||||
field_b = MoneyField(max_length=10)
|
|
||||||
state = StateField()
|
|
||||||
|
|
||||||
|
|
29
record.py
29
record.py
|
@ -1,8 +1,14 @@
|
||||||
from fields import *
|
|
||||||
import model
|
import model
|
||||||
|
from fields import *
|
||||||
|
|
||||||
|
__all__ = ['SubmitterRecord', 'EmployerRecord',
|
||||||
|
'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
|
||||||
|
'TotalRecord', 'OptionalTotalRecord',
|
||||||
|
'StateTotalRecord', 'FinalRecord',]
|
||||||
|
|
||||||
class SubmitterRecord(model.Model):
|
class SubmitterRecord(model.Model):
|
||||||
record_identifier = StaticField(value='ra')
|
record_identifier = 'RA'
|
||||||
|
|
||||||
submitter_ein = NumericField(max_length=9)
|
submitter_ein = NumericField(max_length=9)
|
||||||
user_id = TextField(max_length=8)
|
user_id = TextField(max_length=8)
|
||||||
software_vendor = TextField(max_length=4)
|
software_vendor = TextField(max_length=4)
|
||||||
|
@ -44,7 +50,8 @@ class SubmitterRecord(model.Model):
|
||||||
blank6 = BlankField(max_length=12)
|
blank6 = BlankField(max_length=12)
|
||||||
|
|
||||||
class EmployerRecord(model.Model):
|
class EmployerRecord(model.Model):
|
||||||
record_identifier = StaticField(value='re')
|
record_identifier = 'RE'
|
||||||
|
|
||||||
tax_year = NumericField(max_length=4)
|
tax_year = NumericField(max_length=4)
|
||||||
agent_indicator = NumericField(max_length=1)
|
agent_indicator = NumericField(max_length=1)
|
||||||
employer_ein = TextField(max_length=9)
|
employer_ein = TextField(max_length=9)
|
||||||
|
@ -69,7 +76,8 @@ class EmployerRecord(model.Model):
|
||||||
blank2 = BlankField(max_length=291)
|
blank2 = BlankField(max_length=291)
|
||||||
|
|
||||||
class EmployeeWageRecord(model.Model):
|
class EmployeeWageRecord(model.Model):
|
||||||
record_identifier = StaticField(value='rw')
|
record_identifier = 'RW'
|
||||||
|
|
||||||
ssn = NumericField(max_length=9, required=False)
|
ssn = NumericField(max_length=9, required=False)
|
||||||
employee_first_name = TextField(max_length=15)
|
employee_first_name = TextField(max_length=15)
|
||||||
employee_middle_name = TextField(max_length=15)
|
employee_middle_name = TextField(max_length=15)
|
||||||
|
@ -119,7 +127,8 @@ class EmployeeWageRecord(model.Model):
|
||||||
|
|
||||||
|
|
||||||
class OptionalEmployeeWageRecord(model.Model):
|
class OptionalEmployeeWageRecord(model.Model):
|
||||||
record_identifier = StaticField(value='ro')
|
record_identifier = 'RO'
|
||||||
|
|
||||||
blank1 = BlankField(max_length=9)
|
blank1 = BlankField(max_length=9)
|
||||||
allocated_tips = NumericField(max_length=11)
|
allocated_tips = NumericField(max_length=11)
|
||||||
uncollected_tax_on_tips = NumericField(max_length=11)
|
uncollected_tax_on_tips = NumericField(max_length=11)
|
||||||
|
@ -145,10 +154,11 @@ class OptionalEmployeeWageRecord(model.Model):
|
||||||
|
|
||||||
|
|
||||||
class TotalRecord(model.Model):
|
class TotalRecord(model.Model):
|
||||||
record_identifier = StaticField(value='rt')
|
record_identifier = 'RT'
|
||||||
|
|
||||||
class OptionalTotalRecord(model.Model):
|
class OptionalTotalRecord(model.Model):
|
||||||
record_identifier = StaticField(value='ru')
|
record_identifier = 'RU'
|
||||||
|
|
||||||
number_of_ro_records = NumericField(max_length=7)
|
number_of_ro_records = NumericField(max_length=7)
|
||||||
allocated_tips = NumericField(max_length=15)
|
allocated_tips = NumericField(max_length=15)
|
||||||
uncollected_tax_on_tips = NumericField(max_length=15)
|
uncollected_tax_on_tips = NumericField(max_length=15)
|
||||||
|
@ -173,11 +183,12 @@ class OptionalTotalRecord(model.Model):
|
||||||
|
|
||||||
|
|
||||||
class StateTotalRecord(model.Model):
|
class StateTotalRecord(model.Model):
|
||||||
record_identifier = StaticField(value='rv')
|
record_identifier = 'RV'
|
||||||
supplemental_data = TextField(max_length=510)
|
supplemental_data = TextField(max_length=510)
|
||||||
|
|
||||||
class FinalRecord(model.Model):
|
class FinalRecord(model.Model):
|
||||||
record_identifier = StaticField(value='rf')
|
record_identifier = 'RF'
|
||||||
|
|
||||||
blank1 = BlankField(max_length=5)
|
blank1 = BlankField(max_length=5)
|
||||||
number_of_rw_records = NumericField(max_length=9)
|
number_of_rw_records = NumericField(max_length=9)
|
||||||
blank2 = BlankField(max_length=496)
|
blank2 = BlankField(max_length=496)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue