241 lines
7.5 KiB
Python
241 lines
7.5 KiB
Python
try:
|
|
from collections import Callable
|
|
except:
|
|
from typing import Callable # Python 3.10+
|
|
|
|
VERSION = (0, 2012, 0)
|
|
|
|
RECORD_TYPES = [
|
|
'SubmitterRecord',
|
|
'EmployerRecord',
|
|
'EmployeeWageRecord',
|
|
'OptionalEmployeeWageRecord',
|
|
'TotalRecord',
|
|
'StateWageRecord',
|
|
'OptionalTotalRecord',
|
|
'StateTotalRecord',
|
|
'FinalRecord'
|
|
]
|
|
|
|
|
|
def get_record_types():
|
|
from . import record
|
|
types = {}
|
|
for r in RECORD_TYPES:
|
|
klass = record.__dict__[r]
|
|
types[klass.record_identifier] = klass
|
|
return types
|
|
|
|
|
|
def load(fp, record_types):
|
|
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
|
|
assert(len(distinct_identifier_lengths) == 1)
|
|
ident_length = list(distinct_identifier_lengths)[0]
|
|
|
|
# Add aliases for the record types based on their record_identifier since that's all
|
|
# we have to work with with the e1099 data.
|
|
record_types_by_ident = {}
|
|
for k in record_types:
|
|
record_type = record_types[k]
|
|
record_identifier = record_type.record_identifier
|
|
record_types_by_ident[record_identifier] = record_type
|
|
|
|
# PARSE DATA INTO RECORDS AND YIELD THEM
|
|
while True:
|
|
record_ident = fp.read(ident_length)
|
|
if not record_ident:
|
|
break
|
|
if record_ident in record_types_by_ident:
|
|
record = record_types_by_ident[record_ident]()
|
|
record.read(fp)
|
|
yield record
|
|
|
|
|
|
def loads(s, record_types=get_record_types()):
|
|
import io
|
|
fp = io.BytesIO(s)
|
|
return load(fp, record_types)
|
|
|
|
|
|
def dump(fp, records, delim=None):
|
|
if type(delim) is str:
|
|
delim = delim.encode('ascii')
|
|
for r in records:
|
|
fp.write(r.output())
|
|
if delim:
|
|
fp.write(delim)
|
|
|
|
|
|
def dumps(records, delim=None):
|
|
import io
|
|
fp = io.BytesIO()
|
|
dump(fp, records, delim=delim)
|
|
fp.seek(0)
|
|
return fp.read()
|
|
|
|
|
|
def json_dumps(records):
|
|
import json
|
|
import decimal
|
|
|
|
class JSONEncoder(json.JSONEncoder):
|
|
|
|
def default(self, o):
|
|
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
|
|
return o.toJSON()
|
|
|
|
if type(o) is bytes:
|
|
return o.decode('ascii')
|
|
|
|
elif isinstance(o, decimal.Decimal):
|
|
return str(o.quantize(decimal.Decimal('0.01')))
|
|
|
|
return super(JSONEncoder, self).default(o)
|
|
|
|
return json.dumps(list(records), cls=JSONEncoder, indent=2)
|
|
|
|
|
|
def json_dump(fp, records):
|
|
fp.write(json_dumps(records))
|
|
|
|
|
|
def json_loads(s, record_types):
|
|
import json
|
|
from . import fields
|
|
import decimal
|
|
|
|
if not isinstance(record_types, dict):
|
|
record_types = dict([ (x.__name__, x) for x in record_types])
|
|
|
|
def object_hook(o):
|
|
if '__class__' in o:
|
|
klass = o['__class__']
|
|
if klass in record_types:
|
|
record = record_types[klass]()
|
|
record.fromJSON(o)
|
|
return record
|
|
elif hasattr(fields, klass):
|
|
return getattr(fields, klass)().fromJSON(o)
|
|
return o
|
|
|
|
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
|
|
|
|
def json_load(fp, record_types):
|
|
return json_loads(fp.read(), record_types)
|
|
|
|
def text_dump(fp, records):
|
|
for r in records:
|
|
fp.write(r.output(format='text').encode('ascii'))
|
|
|
|
|
|
def text_dumps(records):
|
|
import io
|
|
fp = io.BytesIO()
|
|
text_dump(fp, records)
|
|
fp.seek(0)
|
|
return fp.read()
|
|
|
|
|
|
def text_load(fp, record_classes):
|
|
records = []
|
|
current_record = None
|
|
|
|
if not isinstance(record_classes, dict):
|
|
record_classes = dict([ (x.__name__, x) for x in record_classes])
|
|
|
|
while True: #fp.readable():
|
|
line = fp.readline().decode('ascii')
|
|
if not line:
|
|
break
|
|
if line.startswith('---'):
|
|
record_name = line.strip('---').strip()
|
|
current_record = record_classes[record_name]()
|
|
records.append(current_record)
|
|
elif ':' in line:
|
|
field, value = [x.strip() for x in line.split(':')]
|
|
current_record.set_field_value(field, value)
|
|
return records
|
|
|
|
def text_loads(s, record_classes):
|
|
import io
|
|
fp = io.BytesIO(s)
|
|
return text_load(fp, record_classes)
|
|
|
|
|
|
# THIS WAS IN CONTROLLER, BUT UNLESS WE
|
|
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
|
|
# TO JUST KEEP IT IN HERE.
|
|
def validate_required_records(records):
|
|
types = [rec.__class__.__name__ for rec in records]
|
|
req_types = []
|
|
|
|
for r in record.RECORD_TYPES:
|
|
klass = record.__dict__[r]
|
|
if klass.required:
|
|
req_types.append(klass.__name__)
|
|
|
|
while req_types:
|
|
req = req_types[0]
|
|
if req not in types:
|
|
from .fields import ValidationError
|
|
raise ValidationError("Record set missing required record: %s" % req)
|
|
else:
|
|
req_types.remove(req)
|
|
|
|
|
|
def validate_record_order(records):
|
|
from . import record
|
|
from .fields import ValidationError
|
|
|
|
# 1st record must be SubmitterRecord
|
|
if not isinstance(records[0], record.SubmitterRecord):
|
|
raise ValidationError("First record must be SubmitterRecord")
|
|
|
|
# 2nd record must be EmployeeRecord
|
|
if not isinstance(records[1], record.EmployerRecord):
|
|
raise ValidationError("The first record after SubmitterRecord must be an EmployeeRecord")
|
|
|
|
# FinalRecord - Must be the last record on the file
|
|
if not isinstance(records[-1], record.FinalRecord):
|
|
raise ValidationError("Last record must be a FinalRecord")
|
|
|
|
# an EmployerRecord *must* come after each EmployeeWageREcord
|
|
for i in range(len(records)):
|
|
if isinstance(records[i], record.EmployerRecord):
|
|
if not isinstance(records[i+1], record.EmployeeWageRecord):
|
|
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
|
|
|
|
num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)])
|
|
num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)])
|
|
num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)])
|
|
num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)])
|
|
|
|
# a TotalRecord is required for each instance of an EmployeeRecord
|
|
if num_total_records != num_employer_records:
|
|
raise ValidationError("Number of TotalRecords (%d) does not match number of EmployeeRecords (%d)" % (
|
|
num_total_records, num_employer_records))
|
|
|
|
# an OptionalTotalRecord is required for each OptionalEmployeeWageRecord
|
|
if num_ro_records != num_ru_records:
|
|
raise ValidationError("Number of OptionalEmployeeWageRecords (%d) does not match number OptionalTotalRecords (%d)" % (
|
|
num_ro_records, num_ru_records))
|
|
|
|
# FinalRecord - Must appear only once on each file.
|
|
if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1:
|
|
raise ValidationError("Incorrect number of FinalRecords")
|
|
|
|
def validate_records(records):
|
|
validate_required_records(records)
|
|
validate_record_order(records)
|
|
|
|
def test_unique_fields():
|
|
r1 = EmployeeWageRecord()
|
|
|
|
r1.employee_first_name.value = "John Johnson"
|
|
|
|
r2 = EmployeeWageRecord()
|
|
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
|
|
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
|
|
|
|
if r1.employee_first_name.value == r2.employee_first_name.value:
|
|
raise ValidationError("Horrible problem involving shared values across records")
|