pyaccuwage/pyaccuwage/__init__.py

236 lines
7.4 KiB
Python

from collections import Callable
VERSION = (0, 2012, 0)
RECORD_TYPES = [
'SubmitterRecord',
'EmployerRecord',
'EmployeeWageRecord',
'OptionalEmployeeWageRecord',
'TotalRecord',
'StateWageRecord',
'OptionalTotalRecord',
'StateTotalRecord',
'FinalRecord'
]
def get_record_types():
from . import record
types = {}
for r in RECORD_TYPES:
klass = record.__dict__[r]
types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM
while True:
record_ident = fp.read(ident_length)
if not record_ident:
break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp)
yield record
def loads(s, record_types=get_record_types()):
import io
fp = io.BytesIO(s)
return load(fp, record_types)
def dump(fp, records, delim=None):
for r in records:
fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records, delim=None):
import io
fp = io.BytesIO()
dump(records, fp, delim=delim)
fp.seek(0)
return fp.read()
def json_dumps(records):
import json
import decimal
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
return o.toJSON()
if type(o) is bytes:
return o.decode('ascii')
elif isinstance(o, decimal.Decimal):
return str(o.quantize(decimal.Decimal('0.01')))
return super(JSONEncoder, self).default(o)
return json.dumps(list(records), cls=JSONEncoder, indent=2)
def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json
from . import fields
import decimal
if not isinstance(record_types, dict):
record_types = dict([ (x.__name__, x) for x in record_types])
def object_hook(o):
if '__class__' in o:
klass = o['__class__']
if klass in record_types:
record = record_types[klass]()
record.fromJSON(o)
return record
elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o)
return o
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE.
def validate_required_records(records):
types = [rec.__class__.__name__ for rec in records]
req_types = []
for r in record.RECORD_TYPES:
klass = record.__dict__[r]
if klass.required:
req_types.append(klass.__name__)
while req_types:
req = req_types[0]
if req not in types:
from .fields import ValidationError
raise ValidationError("Record set missing required record: %s" % req)
else:
req_types.remove(req)
def validate_record_order(records):
from . import record
from .fields import ValidationError
# 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord):
raise ValidationError("First record must be SubmitterRecord")
# 2nd record must be EmployeeRecord
if not isinstance(records[1], record.EmployerRecord):
raise ValidationError("The first record after SubmitterRecord must be an EmployeeRecord")
# FinalRecord - Must be the last record on the file
if not isinstance(records[-1], record.FinalRecord):
raise ValidationError("Last record must be a FinalRecord")
# an EmployerRecord *must* come after each EmployeeWageREcord
for i in range(len(records)):
if isinstance(records[i], record.EmployerRecord):
if not isinstance(records[i+1], record.EmployeeWageRecord):
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)])
num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)])
num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)])
num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)])
# a TotalRecord is required for each instance of an EmployeeRecord
if num_total_records != num_employer_records:
raise ValidationError("Number of TotalRecords (%d) does not match number of EmployeeRecords (%d)" % (
num_total_records, num_employer_records))
# an OptionalTotalRecord is required for each OptionalEmployeeWageRecord
if num_ro_records != num_ru_records:
raise ValidationError("Number of OptionalEmployeeWageRecords (%d) does not match number OptionalTotalRecords (%d)" % (
num_ro_records, num_ru_records))
# FinalRecord - Must appear only once on each file.
if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1:
raise ValidationError("Incorrect number of FinalRecords")
def validate_records(records):
validate_required_records(records)
validate_record_order(records)
def test_unique_fields():
r1 = EmployeeWageRecord()
r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord()
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records")