Compare commits

..

1 commit

14 changed files with 949 additions and 1140 deletions

View file

@ -1,9 +1,7 @@
try: from record import *
from collections import Callable from reader import RecordReader
except:
from typing import Callable # Python 3.10+
VERSION = (0, 2025, 0) VERSION = (0, 2012, 0)
RECORD_TYPES = [ RECORD_TYPES = [
'SubmitterRecord', 'SubmitterRecord',
@ -15,156 +13,129 @@ RECORD_TYPES = [
'OptionalTotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'StateTotalRecord',
'FinalRecord' 'FinalRecord'
] ]
def test():
import record, model
from fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError, e:
print e.msg, type(inst), inst.record_identifier
continue
print type(inst), inst.record_identifier, output_length
def get_record_types(): def test_dump():
from . import record import record, StringIO
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = StringIO.StringIO()
dump(records, out)
return out
def test_record_order():
import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
import record
types = {} types = {}
for r in RECORD_TYPES: for r in RECORD_TYPES:
klass = record.__dict__[r] klass = record.__dict__[r]
types[klass.record_identifier] = klass types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM # PARSE DATA INTO RECORDS AND YIELD THEM
while True: while fp.tell() < fp.len:
record_ident = fp.read(ident_length) record_ident = fp.read(2)
if not record_ident: if record_ident in types:
break record = types[record_ident]()
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp) record.read(fp)
yield record yield record
def loads(s):
def loads(s, record_types=get_record_types()): import StringIO
import io fp = StringIO.StringIO(s)
fp = io.BytesIO(s) return load(fp)
return load(fp, record_types)
def dump(fp, records, delim=None): def dump(records, fp):
if type(delim) is str:
delim = delim.encode('ascii')
for r in records: for r in records:
fp.write(r.output()) fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records):
def dumps(records, delim=None, skip_validation=False): import StringIO
import io fp = StringIO.StringIO()
fp = io.BytesIO() dump(records, fp)
if not skip_validation:
for record in records:
record.validate()
dump(fp, records, delim=delim)
fp.seek(0) fp.seek(0)
return fp.read() return fp.read()
def json_dumps(records): def json_dumps(records):
import json import json
import model
import decimal import decimal
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable): if hasattr(o, 'toJSON') and callable(getattr(o, 'toJSON')):
return o.toJSON() return o.toJSON()
if type(o) is bytes:
return o.decode('ascii')
elif isinstance(o, decimal.Decimal): elif isinstance(o, decimal.Decimal):
return str(o.quantize(decimal.Decimal('0.01'))) return str(o.quantize(decimal.Decimal('0.01')))
return super(JSONEncoder, self).default(o) return super(JSONEncoder, self).default(o)
return json.dumps(list(records), cls=JSONEncoder, indent=2) return json.dumps(records, cls=JSONEncoder, indent=2)
def json_dump(fp, records): def json_loads(s, record_classes):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json import json
from . import fields import fields
import decimal import decimal
import re
if not isinstance(record_types, dict): if not isinstance(record_classes, dict):
record_types = dict([ (x.__name__, x) for x in record_types]) record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
def object_hook(o): def object_hook(o):
if '__class__' in o: if '__class__' in o:
klass = o['__class__'] klass = o['__class__']
if klass in record_types:
record = record_types[klass]() if klass in record_classes:
record.fromJSON(o) return record_classes[klass]().fromJSON(o)
return record
elif hasattr(fields, klass): elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o) return getattr(fields, klass)().fromJSON(o)
return o return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook) return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE # THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER # REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE. # TO JUST KEEP IT IN HERE.
@ -180,15 +151,14 @@ def validate_required_records(records):
while req_types: while req_types:
req = req_types[0] req = req_types[0]
if req not in types: if req not in types:
from .fields import ValidationError from fields import ValidationError
raise ValidationError("Record set missing required record: %s" % req) raise ValidationError("Record set missing required record: %s" % req)
else: else:
req_types.remove(req) req_types.remove(req)
def validate_record_order(records): def validate_record_order(records):
from . import record import record
from .fields import ValidationError from fields import ValidationError
# 1st record must be SubmitterRecord # 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord): if not isinstance(records[0], record.SubmitterRecord):
@ -208,10 +178,10 @@ def validate_record_order(records):
if not isinstance(records[i+1], record.EmployeeWageRecord): if not isinstance(records[i+1], record.EmployeeWageRecord):
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord") raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)]) num_ro_records = len(filter(lambda x:isinstance(x, record.OptionalEmployeeWageRecord), records))
num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)]) num_ru_records = len(filter(lambda x:isinstance(x, record.OptionalTotalRecord), records))
num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)]) num_employer_records = len(filter(lambda x:isinstance(x, record.EmployerRecord), records))
num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)]) num_total_records = len(filter(lambda x: isinstance(x, record.TotalRecord), records))
# a TotalRecord is required for each instance of an EmployeeRecord # a TotalRecord is required for each instance of an EmployeeRecord
if num_total_records != num_employer_records: if num_total_records != num_employer_records:
@ -224,7 +194,7 @@ def validate_record_order(records):
num_ro_records, num_ru_records)) num_ro_records, num_ru_records))
# FinalRecord - Must appear only once on each file. # FinalRecord - Must appear only once on each file.
if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1: if len(filter(lambda x:isinstance(x, record.FinalRecord), records)) != 1:
raise ValidationError("Incorrect number of FinalRecords") raise ValidationError("Incorrect number of FinalRecords")
def validate_records(records): def validate_records(records):
@ -237,8 +207,13 @@ def test_unique_fields():
r1.employee_first_name.value = "John Johnson" r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord() r2 = EmployeeWageRecord()
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter) print 'r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter) print 'r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter
if r1.employee_first_name.value == r2.employee_first_name.value: if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records") raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -1,340 +1,338 @@
state_postal_numeric = { state_postal_numeric = {
'AL': 1, 'AL': 1,
'AK': 2, 'AK': 2,
'AZ': 4, 'AZ': 4,
'AR': 5, 'AR': 5,
'CA': 6, 'CA': 6,
'CO': 8, 'CO': 8,
'CT': 9, 'CT': 9,
'DE': 10, 'DE': 10,
'DC': 11, 'DC': 11,
'FL': 12, 'FL': 12,
'GA': 13, 'GA': 13,
'HI': 15, 'HI': 15,
'ID': 16, 'ID': 16,
'IL': 17, 'IL': 17,
'IN': 18, 'IN': 18,
'IA': 19, 'IA': 19,
'KS': 20, 'KS': 20,
'KY': 21, 'KY': 21,
'LA': 22, 'LA': 22,
'ME': 23, 'ME': 23,
'MD': 24, 'MD': 24,
'MA': 25, 'MA': 25,
'MI': 26, 'MI': 26,
'MN': 27, 'MN': 27,
'MS': 28, 'MS': 28,
'MO': 29, 'MO': 29,
'MT': 30, 'MT': 30,
'NE': 31, 'NE': 31,
'NV': 32, 'NV': 32,
'NH': 33, 'NH': 33,
'NJ': 34, 'NJ': 34,
'NM': 35, 'NM': 35,
'NY': 36, 'NY': 36,
'NC': 37, 'NC': 37,
'ND': 38, 'ND': 38,
'OH': 39, 'OH': 39,
'OK': 40, 'OK': 40,
'OR': 41, 'OR': 41,
'PA': 42, 'PA': 42,
'RI': 44, 'RI': 44,
'SC': 45, 'SC': 45,
'SD': 46, 'SD': 46,
'TN': 47, 'TN': 47,
'TX': 48, 'TX': 48,
'UT': 49, 'UT': 49,
'VT': 50, 'VT': 50,
'VA': 51, 'VA': 51,
'WA': 53, 'WA': 53,
'WV': 54, 'WV': 54,
'WI': 55, 'WI': 55,
'WY': 56, 'WY': 56,
} }
countries = ( countries = (
('AF', 'Afghanistan'), ('AF', 'Afghanistan'),
('AX', 'Aland Islands'), ('AX', 'Aland Islands'),
('AL', 'Albania'), ('AL', 'Albania'),
('DZ', 'Algeria'), ('DZ', 'Algeria'),
('AS', 'American Samoa'), ('AS', 'American Samoa'),
('AD', 'Andorra'), ('AD', 'Andorra'),
('AO', 'Angola'), ('AO', 'Angola'),
('AI', 'Anguilla'), ('AI', 'Anguilla'),
('AQ', 'Antarctica'), ('AQ', 'Antarctica'),
('AG', 'Antigua and Barbuda'), ('AG', 'Antigua and Barbuda'),
('AR', 'Argentina'), ('AR', 'Argentina'),
('AM', 'Armenia'), ('AM', 'Armenia'),
('AW', 'Aruba'), ('AW', 'Aruba'),
('AU', 'Australia'), ('AU', 'Australia'),
('AT', 'Austria'), ('AT', 'Austria'),
('AZ', 'Azerbaijan'), ('AZ', 'Azerbaijan'),
('BS', 'Bahamas'), ('BS', 'Bahamas'),
('BH', 'Bahrain'), ('BH', 'Bahrain'),
('BD', 'Bangladesh'), ('BD', 'Bangladesh'),
('BB', 'Barbados'), ('BB', 'Barbados'),
('BY', 'Belarus'), ('BY', 'Belarus'),
('BE', 'Belgium'), ('BE', 'Belgium'),
('BZ', 'Belize'), ('BZ', 'Belize'),
('BJ', 'Benin'), ('BJ', 'Benin'),
('BM', 'Bermuda'), ('BM', 'Bermuda'),
('BT', 'Bhutan'), ('BT', 'Bhutan'),
('BO', 'Bolivia, Plurinational State of'), ('BO', 'Bolivia, Plurinational State of'),
('BQ', 'Bonaire, Saint Eustatius and Saba'), ('BQ', 'Bonaire, Saint Eustatius and Saba'),
('BA', 'Bosnia and Herzegovina'), ('BA', 'Bosnia and Herzegovina'),
('BW', 'Botswana'), ('BW', 'Botswana'),
('BV', 'Bouvet Island'), ('BV', 'Bouvet Island'),
('BR', 'Brazil'), ('BR', 'Brazil'),
('IO', 'British Indian Ocean Territory'), ('IO', 'British Indian Ocean Territory'),
('BN', 'Brunei Darussalam'), ('BN', 'Brunei Darussalam'),
('BG', 'Bulgaria'), ('BG', 'Bulgaria'),
('BF', 'Burkina Faso'), ('BF', 'Burkina Faso'),
('BI', 'Burundi'), ('BI', 'Burundi'),
('KH', 'Cambodia'), ('KH', 'Cambodia'),
('CM', 'Cameroon'), ('CM', 'Cameroon'),
('CA', 'Canada'), ('CA', 'Canada'),
('CV', 'Cape Verde'), ('CV', 'Cape Verde'),
('KY', 'Cayman Islands'), ('KY', 'Cayman Islands'),
('CF', 'Central African Republic'), ('CF', 'Central African Republic'),
('TD', 'Chad'), ('TD', 'Chad'),
('CL', 'Chile'), ('CL', 'Chile'),
('CN', 'China'), ('CN', 'China'),
('CX', 'Christmas Island'), ('CX', 'Christmas Island'),
('CC', 'Cocos (Keeling) Islands'), ('CC', 'Cocos (Keeling) Islands'),
('CO', 'Colombia'), ('CO', 'Colombia'),
('KM', 'Comoros'), ('KM', 'Comoros'),
('CG', 'Congo'), ('CG', 'Congo'),
('CD', 'Congo, The Democratic Republic of the'), ('CD', 'Congo, The Democratic Republic of the'),
('CK', 'Cook Islands'), ('CK', 'Cook Islands'),
('CR', 'Costa Rica'), ('CR', 'Costa Rica'),
('CI', "Cote D'ivoire"), ('CI', "Cote D'ivoire"),
('HR', 'Croatia'), ('HR', 'Croatia'),
('CU', 'Cuba'), ('CU', 'Cuba'),
('CW', 'Curacao'), ('CW', 'Curacao'),
('CY', 'Cyprus'), ('CY', 'Cyprus'),
('CZ', 'Czech Republic'), ('CZ', 'Czech Republic'),
('DK', 'Denmark'), ('DK', 'Denmark'),
('DJ', 'Djibouti'), ('DJ', 'Djibouti'),
('DM', 'Dominica'), ('DM', 'Dominica'),
('DO', 'Dominican Republic'), ('DO', 'Dominican Republic'),
('EC', 'Ecuador'), ('EC', 'Ecuador'),
('EG', 'Egypt'), ('EG', 'Egypt'),
('SV', 'El Salvador'), ('SV', 'El Salvador'),
('GQ', 'Equatorial Guinea'), ('GQ', 'Equatorial Guinea'),
('ER', 'Eritrea'), ('ER', 'Eritrea'),
('EE', 'Estonia'), ('EE', 'Estonia'),
('ET', 'Ethiopia'), ('ET', 'Ethiopia'),
('FK', 'Falkland Islands (Malvinas)'), ('FK', 'Falkland Islands (Malvinas)'),
('FO', 'Faroe Islands'), ('FO', 'Faroe Islands'),
('FJ', 'Fiji'), ('FJ', 'Fiji'),
('FI', 'Finland'), ('FI', 'Finland'),
('FR', 'France'), ('FR', 'France'),
('GF', 'French Guiana'), ('GF', 'French Guiana'),
('PF', 'French Polynesia'), ('PF', 'French Polynesia'),
('TF', 'French Southern Territories'), ('TF', 'French Southern Territories'),
('GA', 'Gabon'), ('GA', 'Gabon'),
('GM', 'Gambia'), ('GM', 'Gambia'),
('GE', 'Georgia'), ('GE', 'Georgia'),
('DE', 'Germany'), ('DE', 'Germany'),
('GH', 'Ghana'), ('GH', 'Ghana'),
('GI', 'Gibraltar'), ('GI', 'Gibraltar'),
('GR', 'Greece'), ('GR', 'Greece'),
('GL', 'Greenland'), ('GL', 'Greenland'),
('GD', 'Grenada'), ('GD', 'Grenada'),
('GP', 'Guadeloupe'), ('GP', 'Guadeloupe'),
('GU', 'Guam'), ('GU', 'Guam'),
('GT', 'Guatemala'), ('GT', 'Guatemala'),
('GG', 'Guernsey'), ('GG', 'Guernsey'),
('GN', 'Guinea'), ('GN', 'Guinea'),
('GW', 'Guinea-Bissau'), ('GW', 'Guinea-Bissau'),
('GY', 'Guyana'), ('GY', 'Guyana'),
('HT', 'Haiti'), ('HT', 'Haiti'),
('HM', 'Heard Island and McDonald Islands'), ('HM', 'Heard Island and McDonald Islands'),
('VA', 'Holy See (Vatican City State)'), ('VA', 'Holy See (Vatican City State)'),
('HN', 'Honduras'), ('HN', 'Honduras'),
('HK', 'Hong Kong'), ('HK', 'Hong Kong'),
('HU', 'Hungary'), ('HU', 'Hungary'),
('IS', 'Iceland'), ('IS', 'Iceland'),
('IN', 'India'), ('IN', 'India'),
('ID', 'Indonesia'), ('ID', 'Indonesia'),
('IR', 'Iran, Islamic Republic of'), ('IR', 'Iran, Islamic Republic of'),
('IQ', 'Iraq'), ('IQ', 'Iraq'),
('IE', 'Ireland'), ('IE', 'Ireland'),
('IM', 'Isle of Man'), ('IM', 'Isle of Man'),
('IL', 'Israel'), ('IL', 'Israel'),
('IT', 'Italy'), ('IT', 'Italy'),
('JM', 'Jamaica'), ('JM', 'Jamaica'),
('JP', 'Japan'), ('JP', 'Japan'),
('JE', 'Jersey'), ('JE', 'Jersey'),
('JO', 'Jordan'), ('JO', 'Jordan'),
('KZ', 'Kazakhstan'), ('KZ', 'Kazakhstan'),
('KE', 'Kenya'), ('KE', 'Kenya'),
('KI', 'Kiribati'), ('KI', 'Kiribati'),
('KP', "Korea, Democratic People's Republic of"), ('KP', "Korea, Democratic People's Republic of"),
('KR', 'Korea, Republic of'), ('KR', 'Korea, Republic of'),
('KW', 'Kuwait'), ('KW', 'Kuwait'),
('KG', 'Kyrgyzstan'), ('KG', 'Kyrgyzstan'),
('LA', "Lao People's Democratic Republic"), ('LA', "Lao People's Democratic Republic"),
('LV', 'Latvia'), ('LV', 'Latvia'),
('LB', 'Lebanon'), ('LB', 'Lebanon'),
('LS', 'Lesotho'), ('LS', 'Lesotho'),
('LR', 'Liberia'), ('LR', 'Liberia'),
('LY', 'Libyan Arab Jamahiriya'), ('LY', 'Libyan Arab Jamahiriya'),
('LI', 'Liechtenstein'), ('LI', 'Liechtenstein'),
('LT', 'Lithuania'), ('LT', 'Lithuania'),
('LU', 'Luxembourg'), ('LU', 'Luxembourg'),
('MO', 'Macao'), ('MO', 'Macao'),
('MK', 'Macedonia, The Former Yugoslav Republic of'), ('MK', 'Macedonia, The Former Yugoslav Republic of'),
('MG', 'Madagascar'), ('MG', 'Madagascar'),
('MW', 'Malawi'), ('MW', 'Malawi'),
('MY', 'Malaysia'), ('MY', 'Malaysia'),
('MV', 'Maldives'), ('MV', 'Maldives'),
('ML', 'Mali'), ('ML', 'Mali'),
('MT', 'Malta'), ('MT', 'Malta'),
('MH', 'Marshall Islands'), ('MH', 'Marshall Islands'),
('MQ', 'Martinique'), ('MQ', 'Martinique'),
('MR', 'Mauritania'), ('MR', 'Mauritania'),
('MU', 'Mauritius'), ('MU', 'Mauritius'),
('YT', 'Mayotte'), ('YT', 'Mayotte'),
('MX', 'Mexico'), ('MX', 'Mexico'),
('FM', 'Micronesia, Federated States of'), ('FM', 'Micronesia, Federated States of'),
('MD', 'Moldova, Republic of'), ('MD', 'Moldova, Republic of'),
('MC', 'Monaco'), ('MC', 'Monaco'),
('MN', 'Mongolia'), ('MN', 'Mongolia'),
('ME', 'Montenegro'), ('ME', 'Montenegro'),
('MS', 'Montserrat'), ('MS', 'Montserrat'),
('MA', 'Morocco'), ('MA', 'Morocco'),
('MZ', 'Mozambique'), ('MZ', 'Mozambique'),
('MM', 'Myanmar'), ('MM', 'Myanmar'),
('NA', 'Namibia'), ('NA', 'Namibia'),
('NR', 'Nauru'), ('NR', 'Nauru'),
('NP', 'Nepal'), ('NP', 'Nepal'),
('NL', 'Netherlands'), ('NL', 'Netherlands'),
('NC', 'New Caledonia'), ('NC', 'New Caledonia'),
('NZ', 'New Zealand'), ('NZ', 'New Zealand'),
('NI', 'Nicaragua'), ('NI', 'Nicaragua'),
('NE', 'Niger'), ('NE', 'Niger'),
('NG', 'Nigeria'), ('NG', 'Nigeria'),
('NU', 'Niue'), ('NU', 'Niue'),
('NF', 'Norfolk Island'), ('NF', 'Norfolk Island'),
('MP', 'Northern Mariana Islands'), ('MP', 'Northern Mariana Islands'),
('NO', 'Norway'), ('NO', 'Norway'),
('OM', 'Oman'), ('OM', 'Oman'),
('PK', 'Pakistan'), ('PK', 'Pakistan'),
('PW', 'Palau'), ('PW', 'Palau'),
('PS', 'Palestinian Territory, Occupied'), ('PS', 'Palestinian Territory, Occupied'),
('PA', 'Panama'), ('PA', 'Panama'),
('PG', 'Papua New Guinea'), ('PG', 'Papua New Guinea'),
('PY', 'Paraguay'), ('PY', 'Paraguay'),
('PE', 'Peru'), ('PE', 'Peru'),
('PH', 'Philippines'), ('PH', 'Philippines'),
('PN', 'Pitcairn'), ('PN', 'Pitcairn'),
('PL', 'Poland'), ('PL', 'Poland'),
('PT', 'Portugal'), ('PT', 'Portugal'),
('PR', 'Puerto Rico'), ('PR', 'Puerto Rico'),
('QA', 'Qatar'), ('QA', 'Qatar'),
('RE', 'Reunion'), ('RE', 'Reunion'),
('RO', 'Romania'), ('RO', 'Romania'),
('RU', 'Russian Federation'), ('RU', 'Russian Federation'),
('RW', 'Rwanda'), ('RW', 'Rwanda'),
('BL', 'Saint Barthelemy'), ('BL', 'Saint Barthelemy'),
('SH', 'Saint Helena, Ascension and Tristan Da Cunha'), ('SH', 'Saint Helena, Ascension and Tristan Da Cunha'),
('KN', 'Saint Kitts and Nevis'), ('KN', 'Saint Kitts and Nevis'),
('LC', 'Saint Lucia'), ('LC', 'Saint Lucia'),
('MF', 'Saint Martin (French Part)'), ('MF', 'Saint Martin (French Part)'),
('PM', 'Saint Pierre and Miquelon'), ('PM', 'Saint Pierre and Miquelon'),
('VC', 'Saint Vincent and the Grenadines'), ('VC', 'Saint Vincent and the Grenadines'),
('WS', 'Samoa'), ('WS', 'Samoa'),
('SM', 'San Marino'), ('SM', 'San Marino'),
('ST', 'Sao Tome and Principe'), ('ST', 'Sao Tome and Principe'),
('SA', 'Saudi Arabia'), ('SA', 'Saudi Arabia'),
('SN', 'Senegal'), ('SN', 'Senegal'),
('RS', 'Serbia'), ('RS', 'Serbia'),
('SC', 'Seychelles'), ('SC', 'Seychelles'),
('SL', 'Sierra Leone'), ('SL', 'Sierra Leone'),
('SG', 'Singapore'), ('SG', 'Singapore'),
('SX', 'Sint Maarten (Dutch Part)'), ('SX', 'Sint Maarten (Dutch Part)'),
('SK', 'Slovakia'), ('SK', 'Slovakia'),
('SI', 'Slovenia'), ('SI', 'Slovenia'),
('SB', 'Solomon Islands'), ('SB', 'Solomon Islands'),
('SO', 'Somalia'), ('SO', 'Somalia'),
('ZA', 'South Africa'), ('ZA', 'South Africa'),
('GS', 'South Georgia and the South Sandwich Islands'), ('GS', 'South Georgia and the South Sandwich Islands'),
('ES', 'Spain'), ('ES', 'Spain'),
('LK', 'Sri Lanka'), ('LK', 'Sri Lanka'),
('SD', 'Sudan'), ('SD', 'Sudan'),
('SR', 'Suriname'), ('SR', 'Suriname'),
('SJ', 'Svalbard and Jan Mayen'), ('SJ', 'Svalbard and Jan Mayen'),
('SZ', 'Swaziland'), ('SZ', 'Swaziland'),
('SE', 'Sweden'), ('SE', 'Sweden'),
('CH', 'Switzerland'), ('CH', 'Switzerland'),
('SY', 'Syrian Arab Republic'), ('SY', 'Syrian Arab Republic'),
('TW', 'Taiwan, Province of China'), ('TW', 'Taiwan, Province of China'),
('TJ', 'Tajikistan'), ('TJ', 'Tajikistan'),
('TZ', 'Tanzania, United Republic of'), ('TZ', 'Tanzania, United Republic of'),
('TH', 'Thailand'), ('TH', 'Thailand'),
('TL', 'Timor-Leste'), ('TL', 'Timor-Leste'),
('TG', 'Togo'), ('TG', 'Togo'),
('TK', 'Tokelau'), ('TK', 'Tokelau'),
('TO', 'Tonga'), ('TO', 'Tonga'),
('TT', 'Trinidad and Tobago'), ('TT', 'Trinidad and Tobago'),
('TN', 'Tunisia'), ('TN', 'Tunisia'),
('TR', 'Turkey'), ('TR', 'Turkey'),
('TM', 'Turkmenistan'), ('TM', 'Turkmenistan'),
('TC', 'Turks and Caicos Islands'), ('TC', 'Turks and Caicos Islands'),
('TV', 'Tuvalu'), ('TV', 'Tuvalu'),
('UG', 'Uganda'), ('UG', 'Uganda'),
('UA', 'Ukraine'), ('UA', 'Ukraine'),
('AE', 'United Arab Emirates'), ('AE', 'United Arab Emirates'),
('GB', 'United Kingdom'), ('GB', 'United Kingdom'),
('US', 'United States'), ('US', 'United States'),
('UM', 'United States Minor Outlying Islands'), ('UM', 'United States Minor Outlying Islands'),
('UY', 'Uruguay'), ('UY', 'Uruguay'),
('UZ', 'Uzbekistan'), ('UZ', 'Uzbekistan'),
('VU', 'Vanuatu'), ('VU', 'Vanuatu'),
('VE', 'Venezuela, Bolivarian Republic of'), ('VE', 'Venezuela, Bolivarian Republic of'),
('VN', 'Viet Nam'), ('VN', 'Viet Nam'),
('VG', 'Virgin Islands, British'), ('VG', 'Virgin Islands, British'),
('VI', 'Virgin Islands, U.S.'), ('VI', 'Virgin Islands, U.S.'),
('WF', 'Wallis and Futuna'), ('WF', 'Wallis and Futuna'),
('EH', 'Western Sahara'), ('EH', 'Western Sahara'),
('YE', 'Yemen'), ('YE', 'Yemen'),
('ZM', 'Zambia'), ('ZM', 'Zambia'),
('ZW', 'Zimbabwe'), ('ZW', 'Zimbabwe'))
)
employer_types = ( employer_types = (
('F','Federal Government'), ('F','Federal Government'),
('S','State and Local Governmental Employer'), ('S','State and Local Governmental Employer'),
('T','Tax Exempt Employer'), ('T','Tax Exempt Employer'),
('Y','State and Local Tax Exempt Employer'), ('Y','State and Local Tax Exempt Employer'),
('N','None Apply'), ('N','None Apply'),
) )
employment_codes = ( employment_codes = (
('A', 'Agriculture'), ('A', 'Agriculture'),
('H', 'Household'), ('H', 'Household'),
('M', 'Military'), ('M', 'Military'),
('Q', 'Medicare Qualified Government Employee'), ('Q', 'Medicare Qualified Government Employee'),
('X', 'Railroad'), ('X', 'Railroad'),
('F', 'Regular'), ('F', 'Regular'),
('R', 'Regular (all others)'), ('R', 'Regular (all others)'),
) )
tax_jurisdiction_codes = ( tax_jurisdiction_codes = (
(' ', 'W-2'), ('V', 'Virgin Islands'),
('V', 'Virgin Islands'), ('G', 'Guam'),
('G', 'Guam'), ('S', 'American Samoa'),
('S', 'American Samoa'), ('N', 'Northern Mariana Islands'),
('N', 'Northern Mariana Islands'), ('P', 'Puerto Rico'),
('P', 'Puerto Rico'), )
)
tax_type_codes = ( tax_type_codes = (
('C', 'City Income Tax'), ('C', 'City Income Tax'),
('D', 'Country Income Tax'), ('D', 'Country Income Tax'),
('E', 'School District Income Tax'), ('E', 'School District Income Tax'),
('F', 'Other Income Tax'), ('F', 'Other Income Tax'),
) )

View file

@ -1,10 +1,6 @@
import decimal, datetime import decimal, datetime
import inspect import inspect
from six import string_types import enums
from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
class ValidationError(Exception): class ValidationError(Exception):
def __init__(self, msg, field=None): def __init__(self, msg, field=None):
@ -20,26 +16,22 @@ class ValidationError(Exception):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
is_read_only = False
_value = None
def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None): def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
self.name = name self.name = name
self._value = None self._value = None
self._orig_value = None self._orig_value = None
self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.blank = blank
self.required = required self.required = required
self.uppercase = uppercase self.uppercase = uppercase
self.creation_counter = creation_counter or Field.creation_counter self.creation_counter = creation_counter or Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
def validate(self): def validate(self):
raise NotImplementedError raise NotImplemented
def get_data(self): def get_data(self):
raise NotImplementedError raise NotImplemented
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -84,7 +76,7 @@ class Field(object):
required=o['required'], required=o['required'],
) )
if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']): if isinstance(o['value'], basestring) and re.match('^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value']) o['value'] = decimal.Decimal(o['value'])
self.value = o['value'] self.value = o['value']
@ -98,10 +90,14 @@ class Field(object):
wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False) wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False)
wrapper.width = 100 wrapper.width = 100
value = wrapper.wrap(value) value = wrapper.wrap(value)
value = list([(" " * 9) + ('"' + x + '"') for x in value]) #value = textwrap.wrap(value, 100)
value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10)) #print value
value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10)) value = list(map(lambda x:(" " * 9) + ('"' + x + '"'), value))
value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))]))) #value[0] = '"' + value[0] + '"'
value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * (wrapper.width / 10))
value.append(" " * 10 + ''.join((map(lambda x:str(x) + (' ' * 9), range(wrapper.width / 10 )))))
#value.append((" " * 59) + map(lambda x:("%x" % x), range(16))
start = counter['c'] start = counter['c']
counter['c'] += len(self._orig_value or self.value) counter['c'] += len(self._orig_value or self.value)
@ -119,28 +115,22 @@ class Field(object):
class TextField(Field): class TextField(Field):
def validate(self): def validate(self):
if self.value is None and self.required: if self.value == None and self.required:
raise ValidationError("value required", field=self) raise ValidationError("value required", field=self)
data = self.get_data() if len(self.get_data()) > self.max_length:
if len(data) > self.max_length:
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
stripped_data_length = len(data.strip())
if stripped_data_length < self.min_length:
raise ValidationError("value is too short", field=self)
if stripped_data_length == 0 and (not self.blank and self.required):
raise ValidationError("field cannot be blank", field=self)
def get_data(self): def get_data(self):
value = str(self.value or '').encode('ascii') or b'' value = self.value or ""
if self.uppercase: if self.uppercase:
value = value.upper() value = value.upper()
return value.ljust(self.max_length)[:self.max_length] return value.ljust(self.max_length).encode('ascii')[:self.max_length]
def __setvalue(self, value): def __setvalue(self, value):
# NO NEWLINES # NO NEWLINES
try: try:
value = value.replace('\n', '').replace('\r', '') value = value.replace('\n', '').replace('\r', '')
except AttributeError: except AttributeError, e:
pass pass
self._value = value self._value = value
@ -152,35 +142,31 @@ class TextField(Field):
class StateField(TextField): class StateField(TextField):
def __init__(self, name=None, required=True, use_numeric=False, max_length=2): def __init__(self, name=None, required=True, use_numeric=False, max_length=2):
super(StateField, self).__init__(name=name, max_length=max_length, required=required) super(StateField, self).__init__(name=name, max_length=2, required=required)
self.use_numeric = use_numeric self.use_numeric = use_numeric
def get_data(self): def get_data(self):
value = str(self.value or 'XX') value = self.value or ""
if value.strip() and self.use_numeric: if value.strip() and self.use_numeric:
postcode = enums.state_postal_numeric[value.upper()] return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length)
else: else:
formatted = value.encode('ascii').ljust(self.max_length) return value.ljust(self.max_length).encode('ascii')[:self.max_length]
return formatted[:self.max_length]
def validate(self): def validate(self):
super(StateField, self).validate() super(StateField, self).validate()
if self.value and self.value.upper() not in list(enums.state_postal_numeric.keys()): if self.value and self.value.upper() not in enums.state_postal_numeric.keys():
raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self) raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self)
def parse(self, s): def parse(self, s):
if s.strip() and self.use_numeric: if s.strip() and self.use_numeric:
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())]) states = dict( [(v,k) for (k,v) in enums.state_postal_numeric.items()] )
self.value = states[int(s)] self.value = states[int(s)]
else: else:
self.value = s self.value = s
class EmailField(TextField): class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None): def __init__(self, name=None, required=True, max_length=None):
super(EmailField, self).__init__(name=name, max_length=max_length, return super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False) required=required, uppercase=False)
class IntegerField(TextField): class IntegerField(TextField):
@ -192,58 +178,37 @@ class IntegerField(TextField):
except ValueError: except ValueError:
raise ValidationError("field contains non-numeric characters", field=self) raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self): def get_data(self):
value = str(self.value).encode('ascii') if self.value else b'' value = self.value or ""
return value.zfill(self.max_length)[:self.max_length] return str(value).zfill(self.max_length)[:self.max_length]
def parse(self, s): def parse(self, s):
if not is_blank_space(s): self.value = int(s)
self.value = int(s)
else:
self.value = 0
class StaticField(TextField): class StaticField(TextField):
def __init__(self, name=None, required=True, value=None, uppercase=False): def __init__(self, name=None, required=True, value=None):
super(StaticField, self).__init__(name=name, super(StaticField, self).__init__(name=name, required=required,
required=required, max_length=len(value))
max_length=len(value),
uppercase=uppercase)
self._static_value = value
self._value = value self._value = value
def parse(self, s): def parse(self, s):
pass pass
class BlankField(TextField): class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False): def __init__(self, name=None, max_length=0, required=False):
super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self): def get_data(self):
return b' ' * self.max_length return " " * self.max_length
def parse(self, s): def parse(self, s):
pass pass
def validate(self):
if len(self.get_data()) != self.max_length:
raise ValidationError("blank field did not match expected length", field=self)
class ZeroField(BlankField):
is_read_only = True
def get_data(self):
return b'0' * self.max_length
class CRLFField(TextField): class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False): def __init__(self, name=None, required=False):
super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False) super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -254,12 +219,11 @@ class CRLFField(TextField):
value = property(__getvalue, __setvalue) value = property(__getvalue, __setvalue)
def get_data(self): def get_data(self):
return b'\r\n' return '\r\n'
def parse(self, s): def parse(self, s):
self.value = s self.value = s
class BooleanField(Field): class BooleanField(Field):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(BooleanField, self).__init__(name=name, required=required, max_length=1) super(BooleanField, self).__init__(name=name, required=required, max_length=1)
@ -269,7 +233,7 @@ class BooleanField(Field):
pass pass
def get_data(self): def get_data(self):
return b'1' if self._value else b'0' return '1' if self._value else '0'
def parse(self, s): def parse(self, s):
self.value = (s == '1') self.value = (s == '1')
@ -286,43 +250,26 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
def get_data(self): def get_data(self):
cents = int((self.value or 0) * 100) return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
def parse(self, s): def parse(self, s):
if not is_blank_space(s): self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField): class DateField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(DateField, self).__init__(name=name, required=required, max_length=8) super(TextField, self).__init__(name=name, required=required, max_length=8)
if value: if value:
self.value = value self.value = value
def get_data(self): def get_data(self):
if self._value: if self._value:
return self._value.strftime('%m%d%Y').encode('ascii') return self._value.strftime('%m%d%Y')
return b'0' * self.max_length return '0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:
self.value = datetime.date(*[int(x) for x in (s[4:8], s[0:2], s[2:4])]) self.value = datetime.date(*[int(x) for x in s[4:8], s[0:2], s[2:4]])
else: else:
self.value = None self.value = None
@ -330,7 +277,7 @@ class DateField(TextField):
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
self._value = value self._value = value
elif value: elif value:
self._value = datetime.date(*[int(x) for x in (value[4:8], value[0:2], value[2:4])]) self._value = datetime.date(*[int(x) for x in value[4:8], value[0:2], value[2:4]])
else: else:
self._value = None self._value = None
@ -342,18 +289,19 @@ class DateField(TextField):
class MonthYearField(TextField): class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(MonthYearField, self).__init__(name=name, required=required, max_length=6) super(TextField, self).__init__(name=name, required=required, max_length=6)
if value: if value:
self.value = value self.value = value
def get_data(self): def get_data(self):
if self._value: if self._value:
return str(self._value.strftime('%m%Y').encode('ascii')) return self._value.strftime("%m%Y")
return b'0' * self.max_length return '0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:
self.value = datetime.date(*[int(x) for x in (s[2:6], s[0:2], 1)]) self.value = datetime.date(*[int(x) for x in s[2:6], s[0:2], 1])
else: else:
self.value = None self.value = None
@ -361,7 +309,7 @@ class MonthYearField(TextField):
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
self._value = value self._value = value
elif value: elif value:
self._value = datetime.date(*[int(x) for x in (value[2:6], value[0:2], 1)]) self._value = datetime.date(*[int(x) for x in value[2:6], value[0:2], 1])
else: else:
self._value = None self._value = None
@ -369,3 +317,4 @@ class MonthYearField(TextField):
return self._value return self._value
value = property(__getvalue, __setvalue) value = property(__getvalue, __setvalue)

View file

@ -1,19 +1,15 @@
from .fields import Field, TextField, ValidationError from fields import Field, TextField, ValidationError
import copy import copy
import collections import pdb
class Model(object): class Model(object):
record_length = -1
record_identifier = ' ' record_identifier = ' '
required = False required = False
target_size = 512 target_size = 512
def __init__(self): def __init__(self):
if self.record_length == -1: for (key, value) in self.__class__.__dict__.items():
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()):
if isinstance(value, Field): if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION # GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
# AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE, # AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE,
@ -23,31 +19,21 @@ class Model(object):
if not src_field.name: if not src_field.name:
setattr(src_field, 'name', key) setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__) setattr(src_field, 'parent_name', self.__class__.__name__)
new_field_instance = copy.copy(src_field) self.__dict__[key] = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = new_field_instance.value
self.__dict__[key] = new_field_instance
def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field): if hasattr(self, key) and isinstance(getattr(self, key), Field):
self.set_field_value(key, value) getattr(self, key).value = value
else: else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR? # MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value self.__dict__[key] = value
def set_field_value(self, field_name, value):
getattr(self, field_name).value = value
def get_fields(self): def get_fields(self):
identifier = TextField( identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
"record_identifier",
max_length = len(self.record_identifier),
blank = len(self.record_identifier) == 0,
creation_counter=-1)
identifier.value = self.record_identifier identifier.value = self.record_identifier
fields = [identifier] fields = [identifier]
for key in list(self.__class__.__dict__.keys()): for key in self.__class__.__dict__.keys():
attr = getattr(self, key) attr = getattr(self, key)
if isinstance(attr, Field): if isinstance(attr, Field):
fields.append(attr) fields.append(attr)
@ -55,7 +41,7 @@ class Model(object):
def get_sorted_fields(self): def get_sorted_fields(self):
fields = self.get_fields() fields = self.get_fields()
fields.sort(key=lambda x: x.creation_counter) fields.sort(key=lambda x:x.creation_counter)
return fields return fields
def validate(self): def validate(self):
@ -64,33 +50,27 @@ class Model(object):
try: try:
custom_validator = getattr(self, 'validate_' + f.name) custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError: except AttributeError, e:
continue continue
if isinstance(custom_validator, collections.Callable): if callable(custom_validator):
custom_validator(f) custom_validator(f)
def output(self, format='binary'): def output(self):
if format == 'text': result = ''.join([field.get_data() for field in self.get_sorted_fields()])
return self.output_text()
return self.output_efile()
def output_efile(self): if hasattr(self, 'record_length') and len(result) != self.record_length:
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if self.record_length < 0 or len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp): def read(self, fp):
# Skip the first record, since that's an identifier # Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]: for field in self.get_sorted_fields()[1:]:
field.read(fp) field.read(fp)
def toJSON(self): def toJSON(self):
return { return {
'__class__': self.__class__.__name__, '__class__': self.__class__.__name__,
@ -100,17 +80,19 @@ class Model(object):
def fromJSON(self, o): def fromJSON(self, o):
fields = o['fields'] fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields: for f in fields:
target = self.__dict__[f.name] target = self.__dict__[f.name]
if (target.required != f.required if (target.required != f.required or
or target.max_length != f.max_length): target.max_length != f.max_length):
print("Warning: value mismatch on import") print "Warning: value mismatch on import"
target.value = f.value target._value = f._value
#print (self.__dict__[f.name].name == f.name)
#self.__dict__[f.name].name == f.name
#self.__dict__[f.name].max_length == f.max_length
return self return self

View file

@ -1,86 +1,86 @@
#!/usr/bin/env python
import re import re
class ClassEntryCommentSequence(object): class ClassEntryCommentSequence(object):
re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$') re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line): def __init__(self, classname, line):
self.classname = classname, self.classname = classname,
self.line = line self.line = line
self.lines = [] self.lines = []
def add_line(self, line): def add_line(self, line):
self.lines.append(line) self.lines.append(line)
def process(self): def process(self):
i = 0 i = 0
for (line_no, line) in enumerate(self.lines): for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line) match = self.re_rangecomment.search(line)
if match: if match:
(a, b) = match.groups() (a, b) = match.groups()
a = int(a) a = int(a)
if (i + 1) != a: if (i + 1) != a:
line_number = self.line + line_no line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % ( print("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a))
line_number, line.split(' ')[0].strip(), i+1, a)))
i = int(b) if b else a
i = int(b) if b else a
class ModelDefParser(object): class ModelDefParser(object):
re_triplequote = re.compile('"""') re_triplequote = re.compile('"""')
re_whitespace = re.compile(r"^(\s*)[^\s]+") re_whitespace = re.compile("^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$") re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass): def __init__(self, infile, entryclass):
self.infile = infile self.infile = infile
self.line = 0 self.line = 0
self.EntryClass = entryclass self.EntryClass = entryclass
def endclass(self): def endclass(self):
if self.current_class: if self.current_class:
self.current_class.process() self.current_class.process()
self.current_class = None self.current_class = None
def beginclass(self, classname, line): def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line) self.current_class = self.EntryClass(classname, line)
def parse(self): def parse(self):
infile = self.infile infile = self.infile
whitespace = 0 whitespace = 0
in_block_comment = False in_block_comment = False
self.current_class = None self.current_class = None
for line in infile: for line in infile:
self.line += 1 self.line += 1
if line.startswith('#'): if line.startswith('#'):
continue continue
if self.re_triplequote.search(line): if self.re_triplequote.search(line):
in_block_comment = not in_block_comment in_block_comment = not in_block_comment
if in_block_comment: if in_block_comment:
continue continue
match_whitespace = self.re_whitespace.match(line) match_whitespace = self.re_whitespace.match(line)
if match_whitespace: if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0]) match_whitespace = len(match_whitespace.groups()[0])
else: else:
match_whitespace = 0 match_whitespace = 0
classmatch = self.re_classdef.match(line) classmatch = self.re_classdef.match(line)
if classmatch: if classmatch:
classname, _subclass = classmatch.groups() classname, subclass = classmatch.groups()
self.beginclass(classname, self.line) self.beginclass(classname, self.line)
continue continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,3 +1,5 @@
#!/usr/bin/python
# coding=UTF-8
""" """
Parser utility to read data from Publication 1220 and Parser utility to read data from Publication 1220 and
convert it into python classes. convert it into python classes.
@ -5,7 +7,6 @@ convert it into python classes.
""" """
import re import re
import hashlib import hashlib
from functools import reduce
class SimpleDefParser(object): class SimpleDefParser(object):
def __init__(self): def __init__(self):
@ -33,7 +34,7 @@ class SimpleDefParser(object):
item = item.upper() item = item.upper()
if '-' in item: if '-' in item:
parts = [self._intify(x) for x in item.split('-')] parts = map(lambda x:self._intify(x), item.split('-'))
item = reduce(lambda x,y: y-x, parts) item = reduce(lambda x,y: y-x, parts)
else: else:
item = self._intify(item) item = self._intify(item)
@ -55,7 +56,7 @@ class LengthExpression(object):
self.exp_cache = {} self.exp_cache = {}
def __call__(self, value, exps): def __call__(self, value, exps):
return len(exps) == sum([self.check(value, x) for x in exps]) return len(exps) == sum(map(lambda x: self.check(value, x), exps))
def compile_exp(self, exp): def compile_exp(self, exp):
op, val = self.REG.match(exp).groups() op, val = self.REG.match(exp).groups()
@ -97,7 +98,7 @@ class RangeToken(BaseToken):
def value(self): def value(self):
if '-' not in self._value: if '-' not in self._value:
return 1 return 1
return reduce(lambda x,y: y-x, list(map(int, self._value.split('-'))))+1 return reduce(lambda x,y: y-x, map(int, self._value.split('-')))+1
@property @property
def end_position(self): def end_position(self):
@ -109,7 +110,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken): class NumericToken(BaseToken):
regexp = re.compile(r'^(\d+)$') regexp = re.compile('^(\d+)$')
@property @property
def value(self): def value(self):
@ -117,7 +118,7 @@ class NumericToken(BaseToken):
class RecordBuilder(object): class RecordBuilder(object):
from . import fields import fields
entry_max_length = 4 entry_max_length = 4
@ -144,7 +145,8 @@ class RecordBuilder(object):
(re.compile(r'zero\-filled', re.IGNORECASE), +1), (re.compile(r'zero\-filled', re.IGNORECASE), +1),
(re.compile(r'leading zeroes', re.IGNORECASE), +1), (re.compile(r'leading zeroes', re.IGNORECASE), +1),
(re.compile(r'left\-justif', re.IGNORECASE), -1), (re.compile(r'left-\justif', re.IGNORECASE), -1),
], ],
}, },
}), }),
@ -199,15 +201,15 @@ class RecordBuilder(object):
try: try:
f_length = int(f_length) f_length = int(f_length)
except ValueError as e: except ValueError, e:
# bad result, skip # bad result, skip
continue continue
try: try:
assert f_length == RangeToken(f_range).value assert f_length == RangeToken(f_range).value
except AssertionError as e: except AssertionError, e:
continue continue
except ValueError as e: except ValueError, e:
# bad result, skip # bad result, skip
continue continue
@ -221,7 +223,7 @@ class RecordBuilder(object):
else: else:
required = None required = None
f_name = '_'.join([x.lower() for x in name_parts]) f_name = u'_'.join(map(lambda x:x.lower(), name_parts))
f_name = f_name.replace('&', 'and') f_name = f_name.replace('&', 'and')
f_name = re.sub(r'[^\w]','', f_name) f_name = re.sub(r'[^\w]','', f_name)
@ -238,7 +240,7 @@ class RecordBuilder(object):
lengthexp = LengthExpression() lengthexp = LengthExpression()
for entry in entries: for entry in entries:
matches = dict([(x[0],0) for x in self.FIELD_TYPES]) matches = dict(map(lambda x:(x[0],0), self.FIELD_TYPES))
for (classtype, criteria) in self.FIELD_TYPES: for (classtype, criteria) in self.FIELD_TYPES:
if 'length' in criteria: if 'length' in criteria:
@ -246,7 +248,7 @@ class RecordBuilder(object):
continue continue
if 'regexp' in criteria: if 'regexp' in criteria:
for crit_key, crit_values in list(criteria['regexp'].items()): for crit_key, crit_values in criteria['regexp'].items():
for (crit_re, score) in crit_values: for (crit_re, score) in crit_values:
matches[classtype] += score if crit_re.search(entry[crit_key]) else 0 matches[classtype] += score if crit_re.search(entry[crit_key]) else 0
@ -254,7 +256,7 @@ class RecordBuilder(object):
matches = list(matches.items()) matches = list(matches.items())
matches.sort(key=lambda x:x[1]) matches.sort(key=lambda x:x[1])
matches_found = True if sum([x[1] for x in matches]) > 0 else False matches_found = True if sum(map(lambda x:x[1], matches)) > 0 else False
entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField
yield entry yield entry
@ -269,7 +271,7 @@ class RecordBuilder(object):
if entry['name'] == 'blank': if entry['name'] == 'blank':
blank_id = hashlib.new('md5') blank_id = hashlib.new('md5')
blank_id.update(entry['range'].encode()) blank_id.update(entry['range'].encode())
add( ('blank_%s' % blank_id.hexdigest()[:8]).ljust(40) ) add( (u'blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
else: else:
add(entry['name'].ljust(40)) add(entry['name'].ljust(40))
@ -384,7 +386,7 @@ class PastedDefParser(RecordBuilder):
for g in groups: for g in groups:
assert g['byterange'].value == g['length'].value assert g['byterange'].value == g['length'].value
desc = ' '.join([str(x.value) for x in g['desc']]) desc = u' '.join(map(lambda x:unicode(x.value), g['desc']))
if g['name'][-1].value.lower() == '(optional)': if g['name'][-1].value.lower() == '(optional)':
g['name'] = g['name'][0:-1] g['name'] = g['name'][0:-1]
@ -394,7 +396,7 @@ class PastedDefParser(RecordBuilder):
else: else:
required = None required = None
name = '_'.join([x.value.lower() for x in g['name']]) name = u'_'.join(map(lambda x:x.value.lower(), g['name']))
name = re.sub(r'[^\w]','', name) name = re.sub(r'[^\w]','', name)
yield({ yield({

View file

@ -3,102 +3,314 @@
import subprocess import subprocess
import re import re
import itertools import pdb
import fitz
""" pdftotext -layout -nopgbrk p1220.pdf - """ """ pdftotext -layout -nopgbrk p1220.pdf - """
def strip_values(items):
expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE)
return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x]
class PDFRecordFinder(object): class PDFRecordFinder(object):
field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$') def __init__(self, src, heading_exp=None):
if not heading_exp:
heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout')
def __init__(self, src): field_heading_exp = re.compile('^Field.*Field.*Length.*Description')
self.document = fitz.open(src)
def find_record_table_ranges(self): opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-']
matches = [] pdftext = subprocess.check_output(opts)
for (page_number, page) in enumerate(self.document): self.textrows = pdftext.split('\n')
header_rects = page.search_for("Record Name:") self.heading_exp = heading_exp
for header_match_rect in header_rects: self.field_heading_exp = field_heading_exp
header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: "
header_match_rect.x1 = page.bound().x1 # Extend to right side of page
header_text = page.get_textbox(header_match_rect)
record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip()
matches.append((record_name, {
'page': page_number,
'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably
}))
return matches
def find_records(self):
record_ranges = self.find_record_table_ranges()
for record_index, (record_name, record_details) in enumerate(record_ranges):
current_rows = []
next_index = record_index+1
(_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1})
for page_number in range(record_details['page'], next_record_details['page']):
page = self.document[page_number]
table_search_rect = page.bound()
if page_number == record_details['page']:
table_search_rect.y0 = record_details['y']
tables = page.find_tables(
clip = table_search_rect,
min_words_horizontal = 1,
min_words_vertical = 1,
horizontal_strategy = "lines_strict",
intersection_tolerance = 1,
)
for table in tables:
if table.col_count == 4:
table = table.extract()
# Parse field position (sometimes a cell has multiple
# values because IRS employees apparently smoke crack
for row in table:
first_column_lines = row[0].strip().split('\n')
if len(first_column_lines) > 1:
for sub_row in self.split_row(row):
current_rows.append(strip_values(sub_row))
else:
current_rows.append(strip_values(row))
consecutive_rows = self.filter_nonconsecutive_rows(current_rows)
yield(record_name, consecutive_rows)
def split_row(self, row):
if not row[1]:
return []
split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None))
description = strip_values([row[3]])[0]
rows = []
for row in split_rows:
if len(row) < 3 or not row[2]:
row = self.infer_field_length(row)
rows.append([*row, description])
return rows
def infer_field_length(self, row):
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
return row
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
length = str(end-start+1) if end and start else '1'
return (*row[:2], length)
def filter_nonconsecutive_rows(self, rows):
consecutive_rows = []
last_position = 0
for row in rows:
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
continue
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
if start != last_position + 1:
continue
last_position = end if end else start
consecutive_rows.append(row)
return consecutive_rows
def records(self): def records(self):
return self.find_records() headings = self.locate_heading_rows_by_field()
#for x in headings:
# print x
for (start, end, name) in headings:
name = name.decode('ascii', 'ignore')
yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end))
def locate_heading_rows_by_field(self):
results = []
record_break = []
line_is_whitespace_exp = re.compile('^(\s*)$')
record_begin_exp = self.heading_exp #re.compile('Record\ Name')
for (i, row) in enumerate(self.textrows):
match = self.field_heading_exp.match(row)
if match:
# work backwards until we think the header is fully copied
space_count_exp = re.compile('^(\s*)')
position = i - 1
spaces = 0
#last_spaces = 10000
complete = False
header = None
while not complete:
line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False
is_record_begin = record_begin_exp.search(self.textrows[position])
if is_record_begin or line_is_whitespace:
header = self.textrows[position-1:i]
complete = True
position -= 1
name = ''.join(header).strip().decode('ascii','ignore')
print (name, position)
results.append((i, name, position))
else:
# See if this row forces us to break from field reading.
if re.search('Record\ Layout', row):
record_break.append(i)
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]):
end_pos = None
#print a[0], record_break[0], b[0]-1
while record_break and record_break[0] < a[0]:
record_break = record_break[1:]
if record_break[0] < b[0]-1:
end_pos = record_break[0]
record_break = record_break[1:]
else:
end_pos = b[0]-1
merged.append( (a[0], end_pos-1, a[1]) )
return merged
"""
def locate_heading_rows(self):
results = []
for (i, row) in enumerate(self.textrows):
match = self.heading_exp.match(row)
if match:
results.append((i, ''.join(match.groups())))
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]):
merged.append( (a[0], b[0]-1, a[1]) )
return merged
def locate_layout_block_rows(self):
# Search for rows that contain "Record Layout", as these are not fields
# we are interested in because they contain the crazy blocks of field definitions
# and not the nice 4-column ones that we're looking for.
results = []
for (i, row) in enumerate(self.textrows):
match = re.match("Record Layout", row)
"""
def find_fields(self, row_iter):
cc = ColumnCollector()
blank_row_counter = 0
for r in row_iter:
row = r.decode('UTF-8')
#print row
row_columns = self.extract_columns_from_row(row)
if not row_columns:
if cc.data and len(cc.data.keys()) > 1 and len(row.strip()) > cc.data.keys()[-1]:
yield cc
cc = ColumnCollector()
else:
cc.empty_row()
continue
try:
cc.add(row_columns)
except IsNextField, e:
yield cc
cc = ColumnCollector()
cc.add(row_columns)
except UnknownColumn, e:
raise StopIteration
yield cc
def extract_columns_from_row(self, row):
re_multiwhite = re.compile(r'\s{2,}')
# IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE
if not re_multiwhite.search(row):
return None
white_ranges = [0,]
pos = 0
while pos < len(row):
match = re_multiwhite.search(row[pos:])
if match:
white_ranges.append(pos + match.start())
white_ranges.append(pos + match.end())
pos += match.end()
else:
white_ranges.append(len(row))
pos = len(row)
row_result = []
white_iter = iter(white_ranges)
while white_iter:
try:
start = white_iter.next()
end = white_iter.next()
if start != end:
row_result.append(
(start, row[start:end].encode('ascii','ignore'))
)
except StopIteration:
white_iter = None
#print row_result
return row_result
class UnknownColumn(Exception):
pass
class IsNextField(Exception):
pass
class ColumnCollector(object):
def __init__(self, initial=None):
self.data = None
self.column_widths = None
self.max_data_length = 0
self.adjust_pad = 3
self.empty_rows = 0
pass
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
map(lambda x:x if len(x) < 25 else x[:25] + '..',
self.data.values() if self.data else ''))
def add(self, data):
#if self.empty_rows > 2:
# raise IsNextField()
if not self.data:
self.data = dict(data)
else:
data = self.adjust_columns(data)
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
self.update_column_widths(data)
def empty_row(self):
self.empty_rows += 1
def update_column_widths(self, data):
self.last_data_length = len(data)
self.max_data_length = max(self.max_data_length, len(data))
if not self.column_widths:
self.column_widths = dict(map(lambda (column, value): [column, column + len(value)], data))
else:
for col_id, value in data:
try:
self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip()))
except KeyError:
pass
def add_old(self, data):
if not self.data:
self.data = dict(data)
else:
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
def adjust_columns(self, data):
adjusted_data = {}
for col_id, value in data:
if col_id in self.data.keys():
adjusted_data[col_id] = value.strip()
else:
for col_start, col_end in self.column_widths.items():
if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id:
if col_start in adjusted_data:
adjusted_data[col_start] += ' ' + value.strip()
else:
adjusted_data[col_start] = value.strip()
return adjusted_data.items()
def merge_column(self, col_id, value):
if col_id in self.data.keys():
self.data[col_id] += ' ' + value.strip()
else:
# try adding a wiggle room value?
# FIXME:
# Sometimes description columns contain column-like
# layouts, and this causes the ColumnCollector to become
# confused. Perhaps we could check to see if a column occurs
# after the maximum column, and assume it's part of the
# max column?
"""
for col_start, col_end in self.column_widths.items():
if col_start <= col_id and (col_end) >= col_id:
self.data[col_start] += ' ' + value.strip()
return
"""
raise UnknownColumn
def is_next_field(self, data):
"""
If the first key value contains a string
and we already have some data in the record,
then this row is probably the beginning of
the next field. Raise an exception and continue
on with a fresh ColumnCollector.
"""
""" If the length of the value in column_id is less than the position of the next column_id,
then this is probably a continuation.
"""
if self.data and data:
keys = dict(self.column_widths).keys()
keys.sort()
keys += [None]
if self.last_data_length < len(data):
return True
first_key, first_value = dict(data).items()[0]
if self.data.keys()[0] == first_key:
position = keys.index(first_key)
max_length = keys[position + 1]
if max_length:
return len(first_value) > max_length or len(data) == self.max_data_length
return False
@property
def tuple(self):
#try:
if self.data:
return tuple(map(lambda k:self.data[k], sorted(self.data.keys())))
return ()
#except:
# import pdb
# pdb.set_trace()

View file

@ -1,13 +1,11 @@
from . import model import model
from .fields import * from fields import *
from . import enums import enums
__all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord', __all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord',
'EmployeeWageRecord', 'OptionalEmployeeWageRecord', 'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
'TotalRecord', 'OptionalTotalRecord', 'TotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'FinalRecord', 'StateWageRecord', 'StateTotalRecord', 'FinalRecord', 'StateWageRecord']
'StateTotalRecordIA',
]
class EFW2Record(model.Model): class EFW2Record(model.Model):
record_length = 512 record_length = 512
@ -105,8 +103,8 @@ class EmployerRecord(EFW2Record):
zipcode_ext = TextField(max_length=4, required=False) zipcode_ext = TextField(max_length=4, required=False)
kind_of_employer = TextField(max_length=1) kind_of_employer = TextField(max_length=1)
blank1 = BlankField(max_length=4) blank1 = BlankField(max_length=4)
foreign_state_province = TextField(max_length=23, required=False) foreign_state_province = TextField(max_length=23)
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
employment_code = TextField(max_length=1) employment_code = TextField(max_length=1)
tax_jurisdiction_code = TextField(max_length=1, required=False) tax_jurisdiction_code = TextField(max_length=1, required=False)
@ -150,7 +148,7 @@ class EmployeeWageRecord(EFW2Record):
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False) employee_middle_name = TextField(max_length=15)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -163,7 +161,7 @@ class EmployeeWageRecord(EFW2Record):
blank1 = BlankField(max_length=5) blank1 = BlankField(max_length=5)
foreign_state = TextField(max_length=23, required=False) foreign_state = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country = TextField(max_length=2, required=True, blank=True) country = TextField(max_length=2)
wages_tips = MoneyField(max_length=11) wages_tips = MoneyField(max_length=11)
federal_income_tax_withheld = MoneyField(max_length=11) federal_income_tax_withheld = MoneyField(max_length=11)
social_security_wages = MoneyField(max_length=11) social_security_wages = MoneyField(max_length=11)
@ -199,10 +197,8 @@ class EmployeeWageRecord(EFW2Record):
blank6 = BlankField(max_length=23) blank6 = BlankField(max_length=23)
def validate_ssn(self, f): def validate_ssn(self, f):
if str(f.value).startswith('666'): if str(f.value).startswith('666','9'):
raise ValidationError("ssn cannot start with 666", field=f) raise ValidationError("ssn cannot start with 666 or 9", field=f)
if str(f.value).startswith('9'):
raise ValidationError("ssn cannot start with 9", field=f)
@ -245,7 +241,7 @@ class StateWageRecord(EFW2Record):
taxing_entity_code = TextField(max_length=5, required=False) taxing_entity_code = TextField(max_length=5, required=False)
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False) employee_middle_name = TextField(max_length=15)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -259,20 +255,20 @@ class StateWageRecord(EFW2Record):
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
optional_code = TextField(max_length=2, required=False) optional_code = TextField(max_length=2, required=False)
reporting_period = MonthYearField(required=False) reporting_period = MonthYearField()
quarterly_unemp_ins_wages = MoneyField(max_length=11) quarterly_unemp_ins_wages = MoneyField(max_length=11)
quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11) quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11)
number_of_weeks_worked = IntegerField(max_length=2, required=False) number_of_weeks_worked = IntegerField(max_length=2)
date_first_employed = DateField(required=False) date_first_employed = DateField(required=False)
date_of_separation = DateField(required=False) date_of_separation = DateField(required=False)
blank2 = BlankField(max_length=5) blank2 = BlankField(max_length=5)
state_employer_account_num = IntegerField(max_length=20, required=False) state_employer_account_num = TextField(max_length=20)
blank3 = BlankField(max_length=6) blank3 = BlankField(max_length=6)
state_code_2 = StateField(use_numeric=True) state_code_2 = StateField(use_numeric=True)
state_taxable_wages = MoneyField(max_length=11) state_taxable_wages = MoneyField(max_length=11)
state_income_tax_wh = MoneyField(max_length=11) state_income_tax_wh = MoneyField(max_length=11)
other_state_data = TextField(max_length=10, required=False) other_state_data = TextField(max_length=10, required=False)
tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F
local_taxable_wages = MoneyField(max_length=11) local_taxable_wages = MoneyField(max_length=11)
local_income_tax_wh = MoneyField(max_length=11) local_income_tax_wh = MoneyField(max_length=11)
state_control_number = IntegerField(max_length=7, required=False) state_control_number = IntegerField(max_length=7, required=False)
@ -282,8 +278,7 @@ class StateWageRecord(EFW2Record):
def validate_tax_type_code(self, field): def validate_tax_type_code(self, field):
choices = [x for x,y in enums.tax_type_codes] choices = [x for x,y in enums.tax_type_codes]
value = field.value if field.value.upper() not in choices:
if value and value.upper() not in choices:
raise ValidationError("%s not one of %s" % (field.value,choices), field=f) raise ValidationError("%s not one of %s" % (field.value,choices), field=f)
@ -359,17 +354,6 @@ class StateTotalRecord(EFW2Record):
supplemental_data = TextField(max_length=510) supplemental_data = TextField(max_length=510)
class StateTotalRecordIA(EFW2Record):
#year=2018
record_identifier = 'RV'
number_of_rs_records = IntegerField(max_length=7) # num records since last 'RE' record
wages_tips = MoneyField(max_length=15)
state_income_tax_wh = MoneyField(max_length=15)
employer_ben = TextField(max_length=8)
iowa_confirmation_number = ZeroField(max_length=10)
blank1 = BlankField(max_length=455)
class FinalRecord(EFW2Record): class FinalRecord(EFW2Record):
#year=2012 #year=2012
record_identifier = 'RF' record_identifier = 'RF'

View file

@ -1 +0,0 @@
PyMuPDF==1.24.0

View file

@ -1,76 +0,0 @@
#!/usr/bin/env python
import pyaccuwage
import argparse
import os, os.path
import sys
"""
Command line tool for converting IRS e-file fixed field records
to/from JSON or a simple text format.
Attempts to load record types from a python module in the current working
directory named record_types.py
The module must export a RECORD_TYPES list with the names of the classes to
import as valid record types.
"""
def get_record_types():
try:
sys.path.append(os.getcwd())
import record_types
r = {}
for record_type in record_types.RECORD_TYPES:
r[record_type] = getattr(record_types, record_type)
return r
except ImportError:
print('warning: using default record types (failed to import record_types.py)')
return pyaccuwage.get_record_types()
def read_file(fd, filename, record_types):
filename, extension = os.path.splitext(filename)
if extension == '.json':
return pyaccuwage.json_load(fd, record_types)
elif extension == '.txt':
return pyaccuwage.text_load(fd, record_types)
else:
return pyaccuwage.load(fd, record_types)
def write_file(outfile, filename, records):
filename, extension = os.path.splitext(filename)
if extension == '.json':
pyaccuwage.json_dump(outfile, records)
elif extension == '.txt':
pyaccuwage.text_dump(outfile, records)
else:
pyaccuwage.dump(outfile, records)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Convert accuwage efile data between different formats."
)
parser.add_argument("-i", '--input',
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('r'),
help="Source file to convert")
parser.add_argument("-o", "--output",
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('w'),
help="Destination file to output")
args = parser.parse_args()
in_file = args.input[0]
out_file = args.output[0]
records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records)
print("wrote {} records to {}".format(len(records), out_file.name))

View file

@ -1,4 +1,4 @@
#!/usr/bin/env python #!/usr/bin/python
from pyaccuwage.parser import RecordBuilder from pyaccuwage.parser import RecordBuilder
from pyaccuwage.pdfextract import PDFRecordFinder from pyaccuwage.pdfextract import PDFRecordFinder
import argparse import argparse
@ -29,9 +29,48 @@ doc = PDFRecordFinder(source_file)
records = doc.records() records = doc.records()
builder = RecordBuilder() builder = RecordBuilder()
for (name, fields) in records: def record_begins_at(field):
name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1]) return int(fields[0].data.values()[0].split('-')[0], 10)
name = re.sub(r'[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name) def record_ends_at(fields):
for field in builder.load(map(lambda x: x, fields[0:])): return int(fields[-1].data.values()[0].split('-')[-1], 10)
last_record_begins_at = -1
last_record_ends_at = -1
for rec in records:
#if not rec[1]:
# continue # no actual fields detected
fields = rec[1]
# strip out fields that are not 4 items long
fields = filter(lambda x:len(x.tuple) == 4, fields)
# strip fields that don't begin at position 0
fields = filter(lambda x: 0 in x.data, fields)
# strip fields that don't have a length-range type item in position 0
fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields)
if not fields:
continue
begins_at = record_begins_at(fields)
ends_at = record_ends_at(fields)
# FIXME record_ends_at is randomly exploding due to record data being
# a lump of text and not necessarily a field entry. I assume
# this is cleaned out by the record builder class.
#print last_record_ends_at + 1, begins_at
if last_record_ends_at + 1 != begins_at:
name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1])
name = re.sub('[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x:x.tuple, rec[1][0:])):
sys.stdout.write('\t' + field + '\n') sys.stdout.write('\t' + field + '\n')
#print field
last_record_ends_at = ends_at

View file

@ -1,21 +1,12 @@
from setuptools import setup from distutils.core import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
setup(name='pyaccuwage', setup(name='pyaccuwage',
version='0.2025.0', version='0.2012.1',
packages=['pyaccuwage'], packages=['pyaccuwage'],
scripts=[ scripts=[
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-convert',
'scripts/pyaccuwage-genfieldfill',
'scripts/pyaccuwage-parse', 'scripts/pyaccuwage-parse',
'scripts/pyaccuwage-pdfparse', 'scripts/pyaccuwage-pdfparse',
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-genfieldfill'
], ],
zip_safe=True, zip_safe=True,
test_suite='setup.pyaccuwage_tests',
) )

View file

@ -1,67 +0,0 @@
import unittest
from pyaccuwage.fields import TextField
from pyaccuwage.fields import StaticField
# from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField
# from pyaccuwage.fields import ZeroField
# from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')
def testStringUnsetOptional(self):
field = TextField(max_length=6, required=False)
field.validate()
self.assertEqual(field.get_data(), b' ' * 6)
def testStringRequiredUnassigned(self):
field = TextField(max_length=6)
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredNonBlank(self):
field = TextField(max_length=6)
field.value = ''
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredBlank(self):
field = TextField(max_length=6, blank=True)
field.value = ''
field.validate()
self.assertEqual(len(field.get_data()), 6)
def testStringMinimumLength(self):
field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect
field.value = '' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '12345' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '123456' # one character too short
class TestStaticField(unittest.TestCase):
def test_static_field(self):
field = StaticField(value='TEST')
self.assertEqual(field.get_data(), b'TEST')

View file

@ -1,179 +0,0 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import StaticField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=12)
static1 = StaticField(value='hey mister!!')
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 12,
b'hey mister!!',
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
static1: hey mister!!
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
text2 = TextField(max_length=20, required=False)
blank2 = BlankField(max_length=4)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
class TestRequiredFields(unittest.TestCase):
def createTestRecord(self, required=False, blank=False):
class Record(pyaccuwage.model.Model):
record_length = 16
record_identifier = ''
test_field = TextField(max_length=16, required=required, blank=blank)
record = Record()
def dump():
return pyaccuwage.dumps([record])
return (record, dump)
def testRequiredBlankField(self):
(record, dump) = self.createTestRecord(required=True, blank=True)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value may be empty string
dump()
def testRequiredNonblankField(self):
(record, dump) = self.createTestRecord(required=True, blank=False)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value must not be empty string
self.assertRaises(ValidationError, dump)
record.test_field.value = 'hello'
dump()
def testOptionalBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=True)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()
def testOptionalNonBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=False)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()