add format interchange functions, add tests, fix stuff
This commit is contained in:
parent
6af5067fca
commit
8f86f76167
7 changed files with 298 additions and 146 deletions
|
@ -1,6 +1,4 @@
|
|||
from .record import *
|
||||
from .reader import RecordReader
|
||||
import collections
|
||||
from collections import Callable
|
||||
|
||||
VERSION = (0, 2012, 0)
|
||||
|
||||
|
@ -14,77 +12,55 @@ RECORD_TYPES = [
|
|||
'OptionalTotalRecord',
|
||||
'StateTotalRecord',
|
||||
'FinalRecord'
|
||||
]
|
||||
|
||||
def test():
|
||||
from . import record, model
|
||||
from .fields import ValidationError
|
||||
for rname in RECORD_TYPES:
|
||||
inst = record.__dict__[rname]()
|
||||
try:
|
||||
output_length = len(inst.output())
|
||||
except ValidationError as e:
|
||||
print(e.msg, type(inst), inst.record_identifier)
|
||||
continue
|
||||
|
||||
print(type(inst), inst.record_identifier, output_length)
|
||||
]
|
||||
|
||||
|
||||
def test_dump():
|
||||
import record, io
|
||||
records = [
|
||||
record.SubmitterRecord(),
|
||||
record.EmployerRecord(),
|
||||
record.EmployeeWageRecord(),
|
||||
]
|
||||
out = io.BytesIO()
|
||||
dump(records, out, None)
|
||||
return out
|
||||
|
||||
|
||||
def test_record_order():
|
||||
from . import record
|
||||
records = [
|
||||
record.SubmitterRecord(),
|
||||
record.EmployerRecord(),
|
||||
record.EmployeeWageRecord(),
|
||||
record.TotalRecord(),
|
||||
record.FinalRecord(),
|
||||
]
|
||||
validate_record_order(records)
|
||||
|
||||
|
||||
def test_load(fp):
|
||||
return load(fp)
|
||||
|
||||
def load(fp):
|
||||
# BUILD LIST OF RECORD TYPES
|
||||
def get_record_types():
|
||||
from . import record
|
||||
types = {}
|
||||
for r in RECORD_TYPES:
|
||||
klass = record.__dict__[r]
|
||||
types[klass.record_identifier] = klass
|
||||
return types
|
||||
|
||||
|
||||
def load(fp, record_types):
|
||||
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
|
||||
assert(len(distinct_identifier_lengths) == 1)
|
||||
ident_length = list(distinct_identifier_lengths)[0]
|
||||
|
||||
# Add aliases for the record types based on their record_identifier since that's all
|
||||
# we have to work with with the e1099 data.
|
||||
record_types_by_ident = {}
|
||||
for k in record_types:
|
||||
record_type = record_types[k]
|
||||
record_identifier = record_type.record_identifier
|
||||
record_types_by_ident[record_identifier] = record_type
|
||||
|
||||
# PARSE DATA INTO RECORDS AND YIELD THEM
|
||||
while fp.tell() < fp.len:
|
||||
record_ident = fp.read(2)
|
||||
if record_ident in types:
|
||||
record = types[record_ident]()
|
||||
while True:
|
||||
record_ident = fp.read(ident_length)
|
||||
if not record_ident:
|
||||
break
|
||||
if record_ident in record_types_by_ident:
|
||||
record = record_types_by_ident[record_ident]()
|
||||
record.read(fp)
|
||||
yield record
|
||||
|
||||
def loads(s):
|
||||
|
||||
def loads(s, record_types=get_record_types()):
|
||||
import io
|
||||
fp = io.BytesIO(s)
|
||||
return load(fp)
|
||||
return load(fp, record_types)
|
||||
|
||||
|
||||
def dump(records, fp, delim=None):
|
||||
def dump(fp, records, delim=None):
|
||||
for r in records:
|
||||
fp.write(r.output())
|
||||
if delim:
|
||||
fp.write(delim)
|
||||
|
||||
|
||||
def dumps(records, delim=None):
|
||||
import io
|
||||
fp = io.BytesIO()
|
||||
|
@ -92,15 +68,15 @@ def dumps(records, delim=None):
|
|||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
||||
|
||||
def json_dumps(records):
|
||||
import json
|
||||
from . import model
|
||||
import decimal
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
|
||||
def default(self, o):
|
||||
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), collections.Callable):
|
||||
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
|
||||
return o.toJSON()
|
||||
|
||||
if type(o) is bytes:
|
||||
|
@ -111,37 +87,76 @@ def json_dumps(records):
|
|||
|
||||
return super(JSONEncoder, self).default(o)
|
||||
|
||||
return json.dumps(records, cls=JSONEncoder, indent=2)
|
||||
return json.dumps(list(records), cls=JSONEncoder, indent=2)
|
||||
|
||||
|
||||
def json_loads(s, record_classes):
|
||||
def json_dump(fp, records):
|
||||
fp.write(json_dumps(records))
|
||||
|
||||
|
||||
def json_loads(s, record_types):
|
||||
import json
|
||||
from . import fields
|
||||
import decimal
|
||||
import re
|
||||
|
||||
if not isinstance(record_classes, dict):
|
||||
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
|
||||
if not isinstance(record_types, dict):
|
||||
record_types = dict([ (x.__name__, x) for x in record_types])
|
||||
|
||||
def object_hook(o):
|
||||
if '__class__' in o:
|
||||
klass = o['__class__']
|
||||
|
||||
if klass in record_classes:
|
||||
return record_classes[klass]().fromJSON(o)
|
||||
|
||||
if klass in record_types:
|
||||
record = record_types[klass]()
|
||||
record.fromJSON(o)
|
||||
return record
|
||||
elif hasattr(fields, klass):
|
||||
return getattr(fields, klass)().fromJSON(o)
|
||||
|
||||
return o
|
||||
|
||||
#print "OBJECTHOOK", str(o)
|
||||
#return {'object_hook':str(o)}
|
||||
#def default(self, o):
|
||||
# return super(JSONDecoder, self).default(o)
|
||||
|
||||
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
|
||||
|
||||
def json_load(fp, record_types):
|
||||
return json_loads(fp.read(), record_types)
|
||||
|
||||
def text_dump(fp, records):
|
||||
for r in records:
|
||||
fp.write(r.output(format='text').encode('ascii'))
|
||||
|
||||
|
||||
def text_dumps(records):
|
||||
import io
|
||||
fp = io.BytesIO()
|
||||
text_dump(fp, records)
|
||||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
||||
|
||||
def text_load(fp, record_classes):
|
||||
records = []
|
||||
current_record = None
|
||||
|
||||
if not isinstance(record_classes, dict):
|
||||
record_classes = dict([ (x.__name__, x) for x in record_classes])
|
||||
|
||||
while True: #fp.readable():
|
||||
line = fp.readline().decode('ascii')
|
||||
if not line:
|
||||
break
|
||||
if line.startswith('---'):
|
||||
record_name = line.strip('---').strip()
|
||||
current_record = record_classes[record_name]()
|
||||
records.append(current_record)
|
||||
elif ':' in line:
|
||||
field, value = [x.strip() for x in line.split(':')]
|
||||
current_record.set_field_value(field, value)
|
||||
return records
|
||||
|
||||
def text_loads(s, record_classes):
|
||||
import io
|
||||
fp = io.BytesIO(s)
|
||||
return text_load(fp, record_classes)
|
||||
|
||||
|
||||
# THIS WAS IN CONTROLLER, BUT UNLESS WE
|
||||
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
|
||||
# TO JUST KEEP IT IN HERE.
|
||||
|
@ -153,7 +168,7 @@ def validate_required_records(records):
|
|||
klass = record.__dict__[r]
|
||||
if klass.required:
|
||||
req_types.append(klass.__name__)
|
||||
|
||||
|
||||
while req_types:
|
||||
req = req_types[0]
|
||||
if req not in types:
|
||||
|
@ -162,10 +177,11 @@ def validate_required_records(records):
|
|||
else:
|
||||
req_types.remove(req)
|
||||
|
||||
|
||||
def validate_record_order(records):
|
||||
from . import record
|
||||
from .fields import ValidationError
|
||||
|
||||
|
||||
# 1st record must be SubmitterRecord
|
||||
if not isinstance(records[0], record.SubmitterRecord):
|
||||
raise ValidationError("First record must be SubmitterRecord")
|
||||
|
@ -211,15 +227,10 @@ def test_unique_fields():
|
|||
r1 = EmployeeWageRecord()
|
||||
|
||||
r1.employee_first_name.value = "John Johnson"
|
||||
|
||||
|
||||
r2 = EmployeeWageRecord()
|
||||
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
|
||||
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
|
||||
|
||||
|
||||
if r1.employee_first_name.value == r2.employee_first_name.value:
|
||||
raise ValidationError("Horrible problem involving shared values across records")
|
||||
|
||||
#def state_postal_code(state_abbr):
|
||||
# import enums
|
||||
# return enums.state_postal_numeric[ state_abbr.upper() ]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue