Merge branch 'conversion-support'

This commit is contained in:
Mark Riedesel 2020-06-12 13:13:28 -05:00
commit 1f1d3dd9bb
8 changed files with 373 additions and 146 deletions

View file

@ -1,6 +1,4 @@
from .record import *
from .reader import RecordReader
import collections
from collections import Callable
VERSION = (0, 2012, 0)
@ -14,77 +12,55 @@ RECORD_TYPES = [
'OptionalTotalRecord',
'StateTotalRecord',
'FinalRecord'
]
def test():
from . import record, model
from .fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError as e:
print(e.msg, type(inst), inst.record_identifier)
continue
print(type(inst), inst.record_identifier, output_length)
]
def test_dump():
import record, io
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = io.BytesIO()
dump(records, out, None)
return out
def test_record_order():
from . import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
def get_record_types():
from . import record
types = {}
for r in RECORD_TYPES:
klass = record.__dict__[r]
types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM
while fp.tell() < fp.len:
record_ident = fp.read(2)
if record_ident in types:
record = types[record_ident]()
while True:
record_ident = fp.read(ident_length)
if not record_ident:
break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp)
yield record
def loads(s):
def loads(s, record_types=get_record_types()):
import io
fp = io.BytesIO(s)
return load(fp)
return load(fp, record_types)
def dump(records, fp, delim=None):
def dump(fp, records, delim=None):
for r in records:
fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records, delim=None):
import io
fp = io.BytesIO()
@ -92,15 +68,15 @@ def dumps(records, delim=None):
fp.seek(0)
return fp.read()
def json_dumps(records):
import json
from . import model
import decimal
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), collections.Callable):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
return o.toJSON()
if type(o) is bytes:
@ -111,37 +87,76 @@ def json_dumps(records):
return super(JSONEncoder, self).default(o)
return json.dumps(records, cls=JSONEncoder, indent=2)
return json.dumps(list(records), cls=JSONEncoder, indent=2)
def json_loads(s, record_classes):
def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json
from . import fields
import decimal
import re
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
if not isinstance(record_types, dict):
record_types = dict([ (x.__name__, x) for x in record_types])
def object_hook(o):
if '__class__' in o:
klass = o['__class__']
if klass in record_classes:
return record_classes[klass]().fromJSON(o)
if klass in record_types:
record = record_types[klass]()
record.fromJSON(o)
return record
elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o)
return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE.
@ -153,7 +168,7 @@ def validate_required_records(records):
klass = record.__dict__[r]
if klass.required:
req_types.append(klass.__name__)
while req_types:
req = req_types[0]
if req not in types:
@ -162,10 +177,11 @@ def validate_required_records(records):
else:
req_types.remove(req)
def validate_record_order(records):
from . import record
from .fields import ValidationError
# 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord):
raise ValidationError("First record must be SubmitterRecord")
@ -211,15 +227,10 @@ def test_unique_fields():
r1 = EmployeeWageRecord()
r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord()
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -1,7 +1,10 @@
import decimal, datetime
import inspect
from six import string_types
from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
class ValidationError(Exception):
def __init__(self, msg, field=None):
@ -17,6 +20,7 @@ class ValidationError(Exception):
class Field(object):
creation_counter = 0
is_read_only = False
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
self.name = name
@ -29,10 +33,10 @@ class Field(object):
Field.creation_counter += 1
def validate(self):
raise NotImplemented
raise NotImplementedError
def get_data(self):
raise NotImplemented
raise NotImplementedError
def __setvalue(self, value):
self._value = value
@ -77,7 +81,7 @@ class Field(object):
required=o['required'],
)
if isinstance(o['value'], str) and re.match('^\d*\.\d*$', o['value']):
if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value'])
self.value = o['value']
@ -164,9 +168,10 @@ class StateField(TextField):
else:
self.value = s
class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None):
return super(EmailField, self).__init__(name=name, max_length=max_length,
super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False)
class IntegerField(TextField):
@ -183,7 +188,10 @@ class IntegerField(TextField):
return value.zfill(self.max_length)[:self.max_length]
def parse(self, s):
self.value = int(s)
if not is_blank_space(s):
self.value = int(s)
else:
self.value = 0
class StaticField(TextField):
@ -197,8 +205,10 @@ class StaticField(TextField):
class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self):
return b' ' * self.max_length
@ -208,13 +218,17 @@ class BlankField(TextField):
class ZeroField(BlankField):
is_read_only = True
def get_data(self):
return b'0' * self.max_length
class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False):
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value):
self._value = value
@ -262,12 +276,27 @@ class MoneyField(Field):
return formatted[:self.max_length]
def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
if not is_blank_space(s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField):
def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=8)
super(DateField, self).__init__(name=name, required=required, max_length=8)
if value:
self.value = value
@ -298,7 +327,7 @@ class DateField(TextField):
class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=6)
super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
if value:
self.value = value

View file

@ -4,11 +4,15 @@ import collections
class Model(object):
record_length = -1
record_identifier = ' '
required = False
target_size = 512
def __init__(self):
if self.record_length == -1:
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()):
if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
@ -19,15 +23,22 @@ class Model(object):
if not src_field.name:
setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__)
self.__dict__[key] = copy.copy(src_field)
new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = None
self.__dict__[key] = new_field_instance
def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field):
getattr(self, key).value = value
self.set_field_value(key, value)
else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value
def set_field_value(self, field_name, value):
print('setfieldval: ' + field_name + ' ' + value)
getattr(self, field_name).value = value
def get_fields(self):
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
identifier.value = self.record_identifier
@ -55,18 +66,28 @@ class Model(object):
if isinstance(custom_validator, collections.Callable):
custom_validator(f)
def output(self):
def output(self, format='binary'):
if format == 'text':
return self.output_text()
return self.output_efile()
def output_efile(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if hasattr(self, 'record_length') and len(result) != self.record_length:
if self.record_length < 0 or len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp):
# Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]:
field.read(fp)
print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value)
def toJSON(self):
return {
@ -77,6 +98,9 @@ class Model(object):
def fromJSON(self, o):
fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields:
target = self.__dict__[f.name]
@ -84,7 +108,7 @@ class Model(object):
or target.max_length != f.max_length):
print("Warning: value mismatch on import")
target._value = f._value
target.value = f.value
return self

View file

@ -2,7 +2,7 @@ import re
class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line):
self.classname = classname,
@ -72,7 +72,7 @@ class ModelDefParser(object):
classmatch = self.re_classdef.match(line)
if classmatch:
classname, subclass = classmatch.groups()
classname, _subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue

View file

@ -109,7 +109,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken):
regexp = re.compile('^(\d+)$')
regexp = re.compile(r'^(\d+)$')
@property
def value(self):

75
scripts/pyaccuwage-convert Executable file
View file

@ -0,0 +1,75 @@
#!/usr/bin/env python
import pyaccuwage
import argparse
import os, os.path
import sys
"""
Command line tool for converting IRS e-file fixed field records
to/from JSON or a simple text format.
Attempts to load record types from a python module in the current working
directory named record_types.py
The module must export a RECORD_TYPES list with the names of the classes to
import as valid record types.
"""
def get_record_types():
try:
sys.path.append(os.getcwd())
import record_types
r = {}
for record_type in record_types.RECORD_TYPES:
r[record_type] = getattr(record_types, record_type)
return r
except ImportError:
print('warning: using default record types (failed to import record_types.py)')
return pyaccuwage.get_record_types()
def read_file(fd, filename, record_types):
filename, extension = os.path.splitext(filename)
if extension == '.json':
return pyaccuwage.json_load(fd, record_types)
elif extension == '.txt':
return pyaccuwage.text_load(fd, record_types)
else:
return pyaccuwage.load(fd, record_types)
def write_file(outfile, filename, records):
filename, extension = os.path.splitext(filename)
if extension == '.json':
pyaccuwage.json_dump(outfile, records)
elif extension == '.txt':
pyaccuwage.text_dump(outfile, records)
else:
pyaccuwage.dump(outfile, records)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Convert accuwage efile data between different formats."
)
parser.add_argument("-i", '--input',
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('r'),
help="Source file to convert")
parser.add_argument("-o", "--output",
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('w'),
help="Destination file to output")
args = parser.parse_args()
in_file = args.input[0]
out_file = args.output[0]
records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records)

View file

@ -1,17 +1,15 @@
import unittest
import decimal
from pyaccuwage.fields import TextField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import MoneyField
# from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField
# from pyaccuwage.fields import ZeroField
# from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
@ -30,43 +28,6 @@ class TestTextField(unittest.TestCase):
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
self.assertEqual(len(field.get_data()), field.max_length)
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('1234.56')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'123456'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
self.assertEqual(model.output(), expected)
data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')

127
tests/test_records.py Normal file
View file

@ -0,0 +1,127 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
blank2 = BlankField(max_length=24)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)