try: from collections import Callable except: from typing import Callable # Python 3.10+ VERSION = (0, 2012, 0) RECORD_TYPES = [ 'SubmitterRecord', 'EmployerRecord', 'EmployeeWageRecord', 'OptionalEmployeeWageRecord', 'TotalRecord', 'StateWageRecord', 'OptionalTotalRecord', 'StateTotalRecord', 'FinalRecord' ] def get_record_types(): from . import record types = {} for r in RECORD_TYPES: klass = record.__dict__[r] types[klass.record_identifier] = klass return types def load(fp, record_types): distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types]) assert(len(distinct_identifier_lengths) == 1) ident_length = list(distinct_identifier_lengths)[0] # Add aliases for the record types based on their record_identifier since that's all # we have to work with with the e1099 data. record_types_by_ident = {} for k in record_types: record_type = record_types[k] record_identifier = record_type.record_identifier record_types_by_ident[record_identifier] = record_type # PARSE DATA INTO RECORDS AND YIELD THEM while True: record_ident = fp.read(ident_length) if not record_ident: break if record_ident in record_types_by_ident: record = record_types_by_ident[record_ident]() record.read(fp) yield record def loads(s, record_types=get_record_types()): import io fp = io.BytesIO(s) return load(fp, record_types) def dump(fp, records, delim=None): if type(delim) is str: delim = delim.encode('ascii') for r in records: fp.write(r.output()) if delim: fp.write(delim) def dumps(records, delim=None): import io fp = io.BytesIO() dump(fp, records, delim=delim) fp.seek(0) return fp.read() def json_dumps(records): import json import decimal class JSONEncoder(json.JSONEncoder): def default(self, o): if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable): return o.toJSON() if type(o) is bytes: return o.decode('ascii') elif isinstance(o, decimal.Decimal): return str(o.quantize(decimal.Decimal('0.01'))) return super(JSONEncoder, self).default(o) return json.dumps(list(records), cls=JSONEncoder, indent=2) def json_dump(fp, records): fp.write(json_dumps(records)) def json_loads(s, record_types): import json from . import fields import decimal if not isinstance(record_types, dict): record_types = dict([ (x.__name__, x) for x in record_types]) def object_hook(o): if '__class__' in o: klass = o['__class__'] if klass in record_types: record = record_types[klass]() record.fromJSON(o) return record elif hasattr(fields, klass): return getattr(fields, klass)().fromJSON(o) return o return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook) def json_load(fp, record_types): return json_loads(fp.read(), record_types) def text_dump(fp, records): for r in records: fp.write(r.output(format='text').encode('ascii')) def text_dumps(records): import io fp = io.BytesIO() text_dump(fp, records) fp.seek(0) return fp.read() def text_load(fp, record_classes): records = [] current_record = None if not isinstance(record_classes, dict): record_classes = dict([ (x.__name__, x) for x in record_classes]) while True: #fp.readable(): line = fp.readline().decode('ascii') if not line: break if line.startswith('---'): record_name = line.strip('---').strip() current_record = record_classes[record_name]() records.append(current_record) elif ':' in line: field, value = [x.strip() for x in line.split(':')] current_record.set_field_value(field, value) return records def text_loads(s, record_classes): import io fp = io.BytesIO(s) return text_load(fp, record_classes) # THIS WAS IN CONTROLLER, BUT UNLESS WE # REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER # TO JUST KEEP IT IN HERE. def validate_required_records(records): types = [rec.__class__.__name__ for rec in records] req_types = [] for r in record.RECORD_TYPES: klass = record.__dict__[r] if klass.required: req_types.append(klass.__name__) while req_types: req = req_types[0] if req not in types: from .fields import ValidationError raise ValidationError("Record set missing required record: %s" % req) else: req_types.remove(req) def validate_record_order(records): from . import record from .fields import ValidationError # 1st record must be SubmitterRecord if not isinstance(records[0], record.SubmitterRecord): raise ValidationError("First record must be SubmitterRecord") # 2nd record must be EmployeeRecord if not isinstance(records[1], record.EmployerRecord): raise ValidationError("The first record after SubmitterRecord must be an EmployeeRecord") # FinalRecord - Must be the last record on the file if not isinstance(records[-1], record.FinalRecord): raise ValidationError("Last record must be a FinalRecord") # an EmployerRecord *must* come after each EmployeeWageREcord for i in range(len(records)): if isinstance(records[i], record.EmployerRecord): if not isinstance(records[i+1], record.EmployeeWageRecord): raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord") num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)]) num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)]) num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)]) num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)]) # a TotalRecord is required for each instance of an EmployeeRecord if num_total_records != num_employer_records: raise ValidationError("Number of TotalRecords (%d) does not match number of EmployeeRecords (%d)" % ( num_total_records, num_employer_records)) # an OptionalTotalRecord is required for each OptionalEmployeeWageRecord if num_ro_records != num_ru_records: raise ValidationError("Number of OptionalEmployeeWageRecords (%d) does not match number OptionalTotalRecords (%d)" % ( num_ro_records, num_ru_records)) # FinalRecord - Must appear only once on each file. if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1: raise ValidationError("Incorrect number of FinalRecords") def validate_records(records): validate_required_records(records) validate_record_order(records) def test_unique_fields(): r1 = EmployeeWageRecord() r1.employee_first_name.value = "John Johnson" r2 = EmployeeWageRecord() print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter) print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter) if r1.employee_first_name.value == r2.employee_first_name.value: raise ValidationError("Horrible problem involving shared values across records")