219 lines
6.6 KiB
Python
219 lines
6.6 KiB
Python
from record import *
|
|
from reader import RecordReader
|
|
|
|
VERSION = (0, 2012, 0)
|
|
|
|
RECORD_TYPES = [
|
|
'SubmitterRecord',
|
|
'EmployerRecord',
|
|
'EmployeeWageRecord',
|
|
'OptionalEmployeeWageRecord',
|
|
'TotalRecord',
|
|
'StateWageRecord',
|
|
'OptionalTotalRecord',
|
|
'StateTotalRecord',
|
|
'FinalRecord'
|
|
]
|
|
|
|
def test():
|
|
import record, model
|
|
from fields import ValidationError
|
|
for rname in RECORD_TYPES:
|
|
inst = record.__dict__[rname]()
|
|
try:
|
|
output_length = len(inst.output())
|
|
except ValidationError, e:
|
|
print e.msg, type(inst), inst.record_identifier
|
|
continue
|
|
|
|
print type(inst), inst.record_identifier, output_length
|
|
|
|
|
|
def test_dump():
|
|
import record, StringIO
|
|
records = [
|
|
record.SubmitterRecord(),
|
|
record.EmployerRecord(),
|
|
record.EmployeeWageRecord(),
|
|
]
|
|
out = StringIO.StringIO()
|
|
dump(records, out)
|
|
return out
|
|
|
|
|
|
def test_record_order():
|
|
import record
|
|
records = [
|
|
record.SubmitterRecord(),
|
|
record.EmployerRecord(),
|
|
record.EmployeeWageRecord(),
|
|
record.TotalRecord(),
|
|
record.FinalRecord(),
|
|
]
|
|
validate_record_order(records)
|
|
|
|
|
|
def test_load(fp):
|
|
return load(fp)
|
|
|
|
def load(fp):
|
|
# BUILD LIST OF RECORD TYPES
|
|
import record
|
|
types = {}
|
|
for r in RECORD_TYPES:
|
|
klass = record.__dict__[r]
|
|
types[klass.record_identifier] = klass
|
|
|
|
# PARSE DATA INTO RECORDS AND YIELD THEM
|
|
while fp.tell() < fp.len:
|
|
record_ident = fp.read(2)
|
|
if record_ident in types:
|
|
record = types[record_ident]()
|
|
record.read(fp)
|
|
yield record
|
|
|
|
def loads(s):
|
|
import StringIO
|
|
fp = StringIO.StringIO(s)
|
|
return load(fp)
|
|
|
|
|
|
def dump(records, fp):
|
|
for r in records:
|
|
fp.write(r.output())
|
|
|
|
def dumps(records):
|
|
import StringIO
|
|
fp = StringIO.StringIO()
|
|
dump(records, fp)
|
|
fp.seek(0)
|
|
return fp.read()
|
|
|
|
def json_dumps(records):
|
|
import json
|
|
import model
|
|
import decimal
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
|
|
def default(self, o):
|
|
if hasattr(o, 'toJSON') and callable(getattr(o, 'toJSON')):
|
|
return o.toJSON()
|
|
|
|
elif isinstance(o, decimal.Decimal):
|
|
return str(o.quantize(decimal.Decimal('0.01')))
|
|
|
|
return super(JSONEncoder, self).default(o)
|
|
|
|
return json.dumps(records, cls=JSONEncoder, indent=2)
|
|
|
|
|
|
def json_loads(s, record_classes):
|
|
import json
|
|
import fields
|
|
import decimal
|
|
import re
|
|
|
|
if not isinstance(record_classes, dict):
|
|
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
|
|
|
|
def object_hook(o):
|
|
if '__class__' in o:
|
|
klass = o['__class__']
|
|
|
|
if klass in record_classes:
|
|
return record_classes[klass]().fromJSON(o)
|
|
|
|
elif hasattr(fields, klass):
|
|
return getattr(fields, klass)().fromJSON(o)
|
|
|
|
return o
|
|
|
|
#print "OBJECTHOOK", str(o)
|
|
#return {'object_hook':str(o)}
|
|
#def default(self, o):
|
|
# return super(JSONDecoder, self).default(o)
|
|
|
|
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
|
|
|
|
# THIS WAS IN CONTROLLER, BUT UNLESS WE
|
|
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
|
|
# TO JUST KEEP IT IN HERE.
|
|
def validate_required_records(records):
|
|
types = [rec.__class__.__name__ for rec in records]
|
|
req_types = []
|
|
|
|
for r in record.RECORD_TYPES:
|
|
klass = record.__dict__[r]
|
|
if klass.required:
|
|
req_types.append(klass.__name__)
|
|
|
|
while req_types:
|
|
req = req_types[0]
|
|
if req not in types:
|
|
from fields import ValidationError
|
|
raise ValidationError("Record set missing required record: %s" % req)
|
|
else:
|
|
req_types.remove(req)
|
|
|
|
def validate_record_order(records):
|
|
import record
|
|
from fields import ValidationError
|
|
|
|
# 1st record must be SubmitterRecord
|
|
if not isinstance(records[0], record.SubmitterRecord):
|
|
raise ValidationError("First record must be SubmitterRecord")
|
|
|
|
# 2nd record must be EmployeeRecord
|
|
if not isinstance(records[1], record.EmployerRecord):
|
|
raise ValidationError("The first record after SubmitterRecord must be an EmployeeRecord")
|
|
|
|
# FinalRecord - Must be the last record on the file
|
|
if not isinstance(records[-1], record.FinalRecord):
|
|
raise ValidationError("Last record must be a FinalRecord")
|
|
|
|
# an EmployerRecord *must* come after each EmployeeWageREcord
|
|
for i in range(len(records)):
|
|
if isinstance(records[i], record.EmployerRecord):
|
|
if not isinstance(records[i+1], record.EmployeeWageRecord):
|
|
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
|
|
|
|
num_ro_records = len(filter(lambda x:isinstance(x, record.OptionalEmployeeWageRecord), records))
|
|
num_ru_records = len(filter(lambda x:isinstance(x, record.OptionalTotalRecord), records))
|
|
num_employer_records = len(filter(lambda x:isinstance(x, record.EmployerRecord), records))
|
|
num_total_records = len(filter(lambda x: isinstance(x, record.TotalRecord), records))
|
|
|
|
# a TotalRecord is required for each instance of an EmployeeRecord
|
|
if num_total_records != num_employer_records:
|
|
raise ValidationError("Number of TotalRecords (%d) does not match number of EmployeeRecords (%d)" % (
|
|
num_total_records, num_employer_records))
|
|
|
|
# an OptionalTotalRecord is required for each OptionalEmployeeWageRecord
|
|
if num_ro_records != num_ru_records:
|
|
raise ValidationError("Number of OptionalEmployeeWageRecords (%d) does not match number OptionalTotalRecords (%d)" % (
|
|
num_ro_records, num_ru_records))
|
|
|
|
# FinalRecord - Must appear only once on each file.
|
|
if len(filter(lambda x:isinstance(x, record.FinalRecord), records)) != 1:
|
|
raise ValidationError("Incorrect number of FinalRecords")
|
|
|
|
def validate_records(records):
|
|
validate_required_records(records)
|
|
validate_record_order(records)
|
|
|
|
def test_unique_fields():
|
|
r1 = EmployeeWageRecord()
|
|
|
|
r1.employee_first_name.value = "John Johnson"
|
|
|
|
r2 = EmployeeWageRecord()
|
|
print 'r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter
|
|
print 'r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter
|
|
|
|
if r1.employee_first_name.value == r2.employee_first_name.value:
|
|
raise ValidationError("Horrible problem involving shared values across records")
|
|
|
|
#def state_postal_code(state_abbr):
|
|
# import enums
|
|
# return enums.state_postal_numeric[ state_abbr.upper() ]
|
|
|