Compare commits

..

29 commits

Author SHA1 Message Date
9029659f98 update internal VERSION property 2025-05-13 12:45:51 -05:00
1302de9df7 bump version to 0.2025.0 2025-05-13 12:14:02 -05:00
fb8091fb09 change Iowa RS record state_employer_account_num from TextField to IntegerField 2025-05-13 12:09:49 -05:00
4408da71a9 mark some fields as optional 2024-04-10 09:41:10 -04:00
e0e4c1291d add min_length option to TextField for SSNs and stuff like that 2024-03-31 11:52:22 -04:00
5f4dc8b80f add 'blank' field option to allow empty text in required fields (default: false) 2024-03-31 11:14:16 -04:00
74b7935ced bump version to 2024 2024-03-29 10:50:25 -04:00
66573e4d1d update for 2023 p1220 parsing, stupid irs 2024-03-29 10:48:04 -04:00
86f8861da1 encode record delimiter as ascii bytes when str is passed 2022-02-06 11:06:51 -06:00
042de7ecb0 import typing.Callable (python 3.10+) 2021-12-18 08:56:43 -05:00
f28cd6edf2 bump version 0.2020.0 2021-09-03 07:48:24 -05:00
0bd82e09c4 Fix StaticField + tests for StaticField and unset optional TextField 2021-09-03 05:45:01 -05:00
558e3fd232 hopefully fix STaticField 2021-09-02 17:40:35 -05:00
7867a52a0c fliped args around like a simpleton 2021-01-29 16:26:26 -05:00
bfd43b7448 release 0.2018.2 2020-06-12 14:45:08 -05:00
1f1d3dd9bb Merge branch 'conversion-support' 2020-06-12 13:13:28 -05:00
431b594c1e add pyaccuwage-convert 2020-06-12 13:10:13 -05:00
8f86f76167 add format interchange functions, add tests, fix stuff 2020-06-12 13:07:41 -05:00
6af5067fca add option for record delimiter 2019-01-30 14:25:24 -06:00
250ca8d31f fix flubbed blank field specifier on StateTotalRecordIA 2019-01-28 13:29:14 -06:00
7ddcfcc1c3 clean up some indent 2019-01-27 10:36:37 -06:00
d08f1ca586 hopefully fix python 2 and 3 compatability 2019-01-27 09:30:22 -06:00
6381f8b1ec bump version to 2018.01 2019-01-26 16:11:24 -06:00
7c32cb0dd3 add StateTotalRecord for Iowa 2019-01-26 11:49:31 -06:00
5afdcd6a50 add 'permitted_benefits_health' to RT and RO records for 2017 2018-01-27 11:38:23 -06:00
706c39f7bb CRLFField return binary data for get_data() 2017-10-29 10:41:52 -05:00
078273f49f fix json encoding by encoding bytes as ascii 2017-01-07 17:00:29 -06:00
9320c68961 use BytesIO to work with python3 2017-01-07 14:52:33 -06:00
16bf2c41d0 run through 2to3 2017-01-07 13:58:33 -06:00
14 changed files with 1140 additions and 949 deletions

View file

@ -1,7 +1,9 @@
from record import *
from reader import RecordReader
try:
from collections import Callable
except:
from typing import Callable # Python 3.10+
VERSION = (0, 2012, 0)
VERSION = (0, 2025, 0)
RECORD_TYPES = [
'SubmitterRecord',
@ -13,129 +15,156 @@ RECORD_TYPES = [
'OptionalTotalRecord',
'StateTotalRecord',
'FinalRecord'
]
def test():
import record, model
from fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError, e:
print e.msg, type(inst), inst.record_identifier
continue
print type(inst), inst.record_identifier, output_length
]
def test_dump():
import record, StringIO
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = StringIO.StringIO()
dump(records, out)
return out
def test_record_order():
import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
import record
def get_record_types():
from . import record
types = {}
for r in RECORD_TYPES:
klass = record.__dict__[r]
types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM
while fp.tell() < fp.len:
record_ident = fp.read(2)
if record_ident in types:
record = types[record_ident]()
while True:
record_ident = fp.read(ident_length)
if not record_ident:
break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp)
yield record
def loads(s):
import StringIO
fp = StringIO.StringIO(s)
return load(fp)
def loads(s, record_types=get_record_types()):
import io
fp = io.BytesIO(s)
return load(fp, record_types)
def dump(records, fp):
def dump(fp, records, delim=None):
if type(delim) is str:
delim = delim.encode('ascii')
for r in records:
fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records):
import StringIO
fp = StringIO.StringIO()
dump(records, fp)
def dumps(records, delim=None, skip_validation=False):
import io
fp = io.BytesIO()
if not skip_validation:
for record in records:
record.validate()
dump(fp, records, delim=delim)
fp.seek(0)
return fp.read()
def json_dumps(records):
import json
import model
import decimal
class JSONEncoder(json.JSONEncoder):
def default(self, o):
if hasattr(o, 'toJSON') and callable(getattr(o, 'toJSON')):
if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
return o.toJSON()
if type(o) is bytes:
return o.decode('ascii')
elif isinstance(o, decimal.Decimal):
return str(o.quantize(decimal.Decimal('0.01')))
return super(JSONEncoder, self).default(o)
return json.dumps(records, cls=JSONEncoder, indent=2)
return json.dumps(list(records), cls=JSONEncoder, indent=2)
def json_loads(s, record_classes):
def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json
import fields
from . import fields
import decimal
import re
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes])
if not isinstance(record_types, dict):
record_types = dict([ (x.__name__, x) for x in record_types])
def object_hook(o):
if '__class__' in o:
klass = o['__class__']
if klass in record_classes:
return record_classes[klass]().fromJSON(o)
if klass in record_types:
record = record_types[klass]()
record.fromJSON(o)
return record
elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o)
return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE.
@ -151,14 +180,15 @@ def validate_required_records(records):
while req_types:
req = req_types[0]
if req not in types:
from fields import ValidationError
from .fields import ValidationError
raise ValidationError("Record set missing required record: %s" % req)
else:
req_types.remove(req)
def validate_record_order(records):
import record
from fields import ValidationError
from . import record
from .fields import ValidationError
# 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord):
@ -178,10 +208,10 @@ def validate_record_order(records):
if not isinstance(records[i+1], record.EmployeeWageRecord):
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
num_ro_records = len(filter(lambda x:isinstance(x, record.OptionalEmployeeWageRecord), records))
num_ru_records = len(filter(lambda x:isinstance(x, record.OptionalTotalRecord), records))
num_employer_records = len(filter(lambda x:isinstance(x, record.EmployerRecord), records))
num_total_records = len(filter(lambda x: isinstance(x, record.TotalRecord), records))
num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)])
num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)])
num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)])
num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)])
# a TotalRecord is required for each instance of an EmployeeRecord
if num_total_records != num_employer_records:
@ -194,7 +224,7 @@ def validate_record_order(records):
num_ro_records, num_ru_records))
# FinalRecord - Must appear only once on each file.
if len(filter(lambda x:isinstance(x, record.FinalRecord), records)) != 1:
if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1:
raise ValidationError("Incorrect number of FinalRecords")
def validate_records(records):
@ -207,13 +237,8 @@ def test_unique_fields():
r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord()
print 'r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter
print 'r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter
print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -1,338 +1,340 @@
state_postal_numeric = {
'AL': 1,
'AK': 2,
'AZ': 4,
'AR': 5,
'CA': 6,
'CO': 8,
'CT': 9,
'DE': 10,
'DC': 11,
'FL': 12,
'GA': 13,
'HI': 15,
'ID': 16,
'IL': 17,
'IN': 18,
'IA': 19,
'KS': 20,
'KY': 21,
'LA': 22,
'ME': 23,
'MD': 24,
'MA': 25,
'MI': 26,
'MN': 27,
'MS': 28,
'MO': 29,
'MT': 30,
'NE': 31,
'NV': 32,
'NH': 33,
'NJ': 34,
'NM': 35,
'NY': 36,
'NC': 37,
'ND': 38,
'OH': 39,
'OK': 40,
'OR': 41,
'PA': 42,
'RI': 44,
'SC': 45,
'SD': 46,
'TN': 47,
'TX': 48,
'UT': 49,
'VT': 50,
'VA': 51,
'WA': 53,
'WV': 54,
'WI': 55,
'WY': 56,
'AL': 1,
'AK': 2,
'AZ': 4,
'AR': 5,
'CA': 6,
'CO': 8,
'CT': 9,
'DE': 10,
'DC': 11,
'FL': 12,
'GA': 13,
'HI': 15,
'ID': 16,
'IL': 17,
'IN': 18,
'IA': 19,
'KS': 20,
'KY': 21,
'LA': 22,
'ME': 23,
'MD': 24,
'MA': 25,
'MI': 26,
'MN': 27,
'MS': 28,
'MO': 29,
'MT': 30,
'NE': 31,
'NV': 32,
'NH': 33,
'NJ': 34,
'NM': 35,
'NY': 36,
'NC': 37,
'ND': 38,
'OH': 39,
'OK': 40,
'OR': 41,
'PA': 42,
'RI': 44,
'SC': 45,
'SD': 46,
'TN': 47,
'TX': 48,
'UT': 49,
'VT': 50,
'VA': 51,
'WA': 53,
'WV': 54,
'WI': 55,
'WY': 56,
}
countries = (
('AF', 'Afghanistan'),
('AX', 'Aland Islands'),
('AL', 'Albania'),
('DZ', 'Algeria'),
('AS', 'American Samoa'),
('AD', 'Andorra'),
('AO', 'Angola'),
('AI', 'Anguilla'),
('AQ', 'Antarctica'),
('AG', 'Antigua and Barbuda'),
('AR', 'Argentina'),
('AM', 'Armenia'),
('AW', 'Aruba'),
('AU', 'Australia'),
('AT', 'Austria'),
('AZ', 'Azerbaijan'),
('BS', 'Bahamas'),
('BH', 'Bahrain'),
('BD', 'Bangladesh'),
('BB', 'Barbados'),
('BY', 'Belarus'),
('BE', 'Belgium'),
('BZ', 'Belize'),
('BJ', 'Benin'),
('BM', 'Bermuda'),
('BT', 'Bhutan'),
('BO', 'Bolivia, Plurinational State of'),
('BQ', 'Bonaire, Saint Eustatius and Saba'),
('BA', 'Bosnia and Herzegovina'),
('BW', 'Botswana'),
('BV', 'Bouvet Island'),
('BR', 'Brazil'),
('IO', 'British Indian Ocean Territory'),
('BN', 'Brunei Darussalam'),
('BG', 'Bulgaria'),
('BF', 'Burkina Faso'),
('BI', 'Burundi'),
('KH', 'Cambodia'),
('CM', 'Cameroon'),
('CA', 'Canada'),
('CV', 'Cape Verde'),
('KY', 'Cayman Islands'),
('CF', 'Central African Republic'),
('TD', 'Chad'),
('CL', 'Chile'),
('CN', 'China'),
('CX', 'Christmas Island'),
('CC', 'Cocos (Keeling) Islands'),
('CO', 'Colombia'),
('KM', 'Comoros'),
('CG', 'Congo'),
('CD', 'Congo, The Democratic Republic of the'),
('CK', 'Cook Islands'),
('CR', 'Costa Rica'),
('CI', "Cote D'ivoire"),
('HR', 'Croatia'),
('CU', 'Cuba'),
('CW', 'Curacao'),
('CY', 'Cyprus'),
('CZ', 'Czech Republic'),
('DK', 'Denmark'),
('DJ', 'Djibouti'),
('DM', 'Dominica'),
('DO', 'Dominican Republic'),
('EC', 'Ecuador'),
('EG', 'Egypt'),
('SV', 'El Salvador'),
('GQ', 'Equatorial Guinea'),
('ER', 'Eritrea'),
('EE', 'Estonia'),
('ET', 'Ethiopia'),
('FK', 'Falkland Islands (Malvinas)'),
('FO', 'Faroe Islands'),
('FJ', 'Fiji'),
('FI', 'Finland'),
('FR', 'France'),
('GF', 'French Guiana'),
('PF', 'French Polynesia'),
('TF', 'French Southern Territories'),
('GA', 'Gabon'),
('GM', 'Gambia'),
('GE', 'Georgia'),
('DE', 'Germany'),
('GH', 'Ghana'),
('GI', 'Gibraltar'),
('GR', 'Greece'),
('GL', 'Greenland'),
('GD', 'Grenada'),
('GP', 'Guadeloupe'),
('GU', 'Guam'),
('GT', 'Guatemala'),
('GG', 'Guernsey'),
('GN', 'Guinea'),
('GW', 'Guinea-Bissau'),
('GY', 'Guyana'),
('HT', 'Haiti'),
('HM', 'Heard Island and McDonald Islands'),
('VA', 'Holy See (Vatican City State)'),
('HN', 'Honduras'),
('HK', 'Hong Kong'),
('HU', 'Hungary'),
('IS', 'Iceland'),
('IN', 'India'),
('ID', 'Indonesia'),
('IR', 'Iran, Islamic Republic of'),
('IQ', 'Iraq'),
('IE', 'Ireland'),
('IM', 'Isle of Man'),
('IL', 'Israel'),
('IT', 'Italy'),
('JM', 'Jamaica'),
('JP', 'Japan'),
('JE', 'Jersey'),
('JO', 'Jordan'),
('KZ', 'Kazakhstan'),
('KE', 'Kenya'),
('KI', 'Kiribati'),
('KP', "Korea, Democratic People's Republic of"),
('KR', 'Korea, Republic of'),
('KW', 'Kuwait'),
('KG', 'Kyrgyzstan'),
('LA', "Lao People's Democratic Republic"),
('LV', 'Latvia'),
('LB', 'Lebanon'),
('LS', 'Lesotho'),
('LR', 'Liberia'),
('LY', 'Libyan Arab Jamahiriya'),
('LI', 'Liechtenstein'),
('LT', 'Lithuania'),
('LU', 'Luxembourg'),
('MO', 'Macao'),
('MK', 'Macedonia, The Former Yugoslav Republic of'),
('MG', 'Madagascar'),
('MW', 'Malawi'),
('MY', 'Malaysia'),
('MV', 'Maldives'),
('ML', 'Mali'),
('MT', 'Malta'),
('MH', 'Marshall Islands'),
('MQ', 'Martinique'),
('MR', 'Mauritania'),
('MU', 'Mauritius'),
('YT', 'Mayotte'),
('MX', 'Mexico'),
('FM', 'Micronesia, Federated States of'),
('MD', 'Moldova, Republic of'),
('MC', 'Monaco'),
('MN', 'Mongolia'),
('ME', 'Montenegro'),
('MS', 'Montserrat'),
('MA', 'Morocco'),
('MZ', 'Mozambique'),
('MM', 'Myanmar'),
('NA', 'Namibia'),
('NR', 'Nauru'),
('NP', 'Nepal'),
('NL', 'Netherlands'),
('NC', 'New Caledonia'),
('NZ', 'New Zealand'),
('NI', 'Nicaragua'),
('NE', 'Niger'),
('NG', 'Nigeria'),
('NU', 'Niue'),
('NF', 'Norfolk Island'),
('MP', 'Northern Mariana Islands'),
('NO', 'Norway'),
('OM', 'Oman'),
('PK', 'Pakistan'),
('PW', 'Palau'),
('PS', 'Palestinian Territory, Occupied'),
('PA', 'Panama'),
('PG', 'Papua New Guinea'),
('PY', 'Paraguay'),
('PE', 'Peru'),
('PH', 'Philippines'),
('PN', 'Pitcairn'),
('PL', 'Poland'),
('PT', 'Portugal'),
('PR', 'Puerto Rico'),
('QA', 'Qatar'),
('RE', 'Reunion'),
('RO', 'Romania'),
('RU', 'Russian Federation'),
('RW', 'Rwanda'),
('BL', 'Saint Barthelemy'),
('SH', 'Saint Helena, Ascension and Tristan Da Cunha'),
('KN', 'Saint Kitts and Nevis'),
('LC', 'Saint Lucia'),
('MF', 'Saint Martin (French Part)'),
('PM', 'Saint Pierre and Miquelon'),
('VC', 'Saint Vincent and the Grenadines'),
('WS', 'Samoa'),
('SM', 'San Marino'),
('ST', 'Sao Tome and Principe'),
('SA', 'Saudi Arabia'),
('SN', 'Senegal'),
('RS', 'Serbia'),
('SC', 'Seychelles'),
('SL', 'Sierra Leone'),
('SG', 'Singapore'),
('SX', 'Sint Maarten (Dutch Part)'),
('SK', 'Slovakia'),
('SI', 'Slovenia'),
('SB', 'Solomon Islands'),
('SO', 'Somalia'),
('ZA', 'South Africa'),
('GS', 'South Georgia and the South Sandwich Islands'),
('ES', 'Spain'),
('LK', 'Sri Lanka'),
('SD', 'Sudan'),
('SR', 'Suriname'),
('SJ', 'Svalbard and Jan Mayen'),
('SZ', 'Swaziland'),
('SE', 'Sweden'),
('CH', 'Switzerland'),
('SY', 'Syrian Arab Republic'),
('TW', 'Taiwan, Province of China'),
('TJ', 'Tajikistan'),
('TZ', 'Tanzania, United Republic of'),
('TH', 'Thailand'),
('TL', 'Timor-Leste'),
('TG', 'Togo'),
('TK', 'Tokelau'),
('TO', 'Tonga'),
('TT', 'Trinidad and Tobago'),
('TN', 'Tunisia'),
('TR', 'Turkey'),
('TM', 'Turkmenistan'),
('TC', 'Turks and Caicos Islands'),
('TV', 'Tuvalu'),
('UG', 'Uganda'),
('UA', 'Ukraine'),
('AE', 'United Arab Emirates'),
('GB', 'United Kingdom'),
('US', 'United States'),
('UM', 'United States Minor Outlying Islands'),
('UY', 'Uruguay'),
('UZ', 'Uzbekistan'),
('VU', 'Vanuatu'),
('VE', 'Venezuela, Bolivarian Republic of'),
('VN', 'Viet Nam'),
('VG', 'Virgin Islands, British'),
('VI', 'Virgin Islands, U.S.'),
('WF', 'Wallis and Futuna'),
('EH', 'Western Sahara'),
('YE', 'Yemen'),
('ZM', 'Zambia'),
('ZW', 'Zimbabwe'))
('AF', 'Afghanistan'),
('AX', 'Aland Islands'),
('AL', 'Albania'),
('DZ', 'Algeria'),
('AS', 'American Samoa'),
('AD', 'Andorra'),
('AO', 'Angola'),
('AI', 'Anguilla'),
('AQ', 'Antarctica'),
('AG', 'Antigua and Barbuda'),
('AR', 'Argentina'),
('AM', 'Armenia'),
('AW', 'Aruba'),
('AU', 'Australia'),
('AT', 'Austria'),
('AZ', 'Azerbaijan'),
('BS', 'Bahamas'),
('BH', 'Bahrain'),
('BD', 'Bangladesh'),
('BB', 'Barbados'),
('BY', 'Belarus'),
('BE', 'Belgium'),
('BZ', 'Belize'),
('BJ', 'Benin'),
('BM', 'Bermuda'),
('BT', 'Bhutan'),
('BO', 'Bolivia, Plurinational State of'),
('BQ', 'Bonaire, Saint Eustatius and Saba'),
('BA', 'Bosnia and Herzegovina'),
('BW', 'Botswana'),
('BV', 'Bouvet Island'),
('BR', 'Brazil'),
('IO', 'British Indian Ocean Territory'),
('BN', 'Brunei Darussalam'),
('BG', 'Bulgaria'),
('BF', 'Burkina Faso'),
('BI', 'Burundi'),
('KH', 'Cambodia'),
('CM', 'Cameroon'),
('CA', 'Canada'),
('CV', 'Cape Verde'),
('KY', 'Cayman Islands'),
('CF', 'Central African Republic'),
('TD', 'Chad'),
('CL', 'Chile'),
('CN', 'China'),
('CX', 'Christmas Island'),
('CC', 'Cocos (Keeling) Islands'),
('CO', 'Colombia'),
('KM', 'Comoros'),
('CG', 'Congo'),
('CD', 'Congo, The Democratic Republic of the'),
('CK', 'Cook Islands'),
('CR', 'Costa Rica'),
('CI', "Cote D'ivoire"),
('HR', 'Croatia'),
('CU', 'Cuba'),
('CW', 'Curacao'),
('CY', 'Cyprus'),
('CZ', 'Czech Republic'),
('DK', 'Denmark'),
('DJ', 'Djibouti'),
('DM', 'Dominica'),
('DO', 'Dominican Republic'),
('EC', 'Ecuador'),
('EG', 'Egypt'),
('SV', 'El Salvador'),
('GQ', 'Equatorial Guinea'),
('ER', 'Eritrea'),
('EE', 'Estonia'),
('ET', 'Ethiopia'),
('FK', 'Falkland Islands (Malvinas)'),
('FO', 'Faroe Islands'),
('FJ', 'Fiji'),
('FI', 'Finland'),
('FR', 'France'),
('GF', 'French Guiana'),
('PF', 'French Polynesia'),
('TF', 'French Southern Territories'),
('GA', 'Gabon'),
('GM', 'Gambia'),
('GE', 'Georgia'),
('DE', 'Germany'),
('GH', 'Ghana'),
('GI', 'Gibraltar'),
('GR', 'Greece'),
('GL', 'Greenland'),
('GD', 'Grenada'),
('GP', 'Guadeloupe'),
('GU', 'Guam'),
('GT', 'Guatemala'),
('GG', 'Guernsey'),
('GN', 'Guinea'),
('GW', 'Guinea-Bissau'),
('GY', 'Guyana'),
('HT', 'Haiti'),
('HM', 'Heard Island and McDonald Islands'),
('VA', 'Holy See (Vatican City State)'),
('HN', 'Honduras'),
('HK', 'Hong Kong'),
('HU', 'Hungary'),
('IS', 'Iceland'),
('IN', 'India'),
('ID', 'Indonesia'),
('IR', 'Iran, Islamic Republic of'),
('IQ', 'Iraq'),
('IE', 'Ireland'),
('IM', 'Isle of Man'),
('IL', 'Israel'),
('IT', 'Italy'),
('JM', 'Jamaica'),
('JP', 'Japan'),
('JE', 'Jersey'),
('JO', 'Jordan'),
('KZ', 'Kazakhstan'),
('KE', 'Kenya'),
('KI', 'Kiribati'),
('KP', "Korea, Democratic People's Republic of"),
('KR', 'Korea, Republic of'),
('KW', 'Kuwait'),
('KG', 'Kyrgyzstan'),
('LA', "Lao People's Democratic Republic"),
('LV', 'Latvia'),
('LB', 'Lebanon'),
('LS', 'Lesotho'),
('LR', 'Liberia'),
('LY', 'Libyan Arab Jamahiriya'),
('LI', 'Liechtenstein'),
('LT', 'Lithuania'),
('LU', 'Luxembourg'),
('MO', 'Macao'),
('MK', 'Macedonia, The Former Yugoslav Republic of'),
('MG', 'Madagascar'),
('MW', 'Malawi'),
('MY', 'Malaysia'),
('MV', 'Maldives'),
('ML', 'Mali'),
('MT', 'Malta'),
('MH', 'Marshall Islands'),
('MQ', 'Martinique'),
('MR', 'Mauritania'),
('MU', 'Mauritius'),
('YT', 'Mayotte'),
('MX', 'Mexico'),
('FM', 'Micronesia, Federated States of'),
('MD', 'Moldova, Republic of'),
('MC', 'Monaco'),
('MN', 'Mongolia'),
('ME', 'Montenegro'),
('MS', 'Montserrat'),
('MA', 'Morocco'),
('MZ', 'Mozambique'),
('MM', 'Myanmar'),
('NA', 'Namibia'),
('NR', 'Nauru'),
('NP', 'Nepal'),
('NL', 'Netherlands'),
('NC', 'New Caledonia'),
('NZ', 'New Zealand'),
('NI', 'Nicaragua'),
('NE', 'Niger'),
('NG', 'Nigeria'),
('NU', 'Niue'),
('NF', 'Norfolk Island'),
('MP', 'Northern Mariana Islands'),
('NO', 'Norway'),
('OM', 'Oman'),
('PK', 'Pakistan'),
('PW', 'Palau'),
('PS', 'Palestinian Territory, Occupied'),
('PA', 'Panama'),
('PG', 'Papua New Guinea'),
('PY', 'Paraguay'),
('PE', 'Peru'),
('PH', 'Philippines'),
('PN', 'Pitcairn'),
('PL', 'Poland'),
('PT', 'Portugal'),
('PR', 'Puerto Rico'),
('QA', 'Qatar'),
('RE', 'Reunion'),
('RO', 'Romania'),
('RU', 'Russian Federation'),
('RW', 'Rwanda'),
('BL', 'Saint Barthelemy'),
('SH', 'Saint Helena, Ascension and Tristan Da Cunha'),
('KN', 'Saint Kitts and Nevis'),
('LC', 'Saint Lucia'),
('MF', 'Saint Martin (French Part)'),
('PM', 'Saint Pierre and Miquelon'),
('VC', 'Saint Vincent and the Grenadines'),
('WS', 'Samoa'),
('SM', 'San Marino'),
('ST', 'Sao Tome and Principe'),
('SA', 'Saudi Arabia'),
('SN', 'Senegal'),
('RS', 'Serbia'),
('SC', 'Seychelles'),
('SL', 'Sierra Leone'),
('SG', 'Singapore'),
('SX', 'Sint Maarten (Dutch Part)'),
('SK', 'Slovakia'),
('SI', 'Slovenia'),
('SB', 'Solomon Islands'),
('SO', 'Somalia'),
('ZA', 'South Africa'),
('GS', 'South Georgia and the South Sandwich Islands'),
('ES', 'Spain'),
('LK', 'Sri Lanka'),
('SD', 'Sudan'),
('SR', 'Suriname'),
('SJ', 'Svalbard and Jan Mayen'),
('SZ', 'Swaziland'),
('SE', 'Sweden'),
('CH', 'Switzerland'),
('SY', 'Syrian Arab Republic'),
('TW', 'Taiwan, Province of China'),
('TJ', 'Tajikistan'),
('TZ', 'Tanzania, United Republic of'),
('TH', 'Thailand'),
('TL', 'Timor-Leste'),
('TG', 'Togo'),
('TK', 'Tokelau'),
('TO', 'Tonga'),
('TT', 'Trinidad and Tobago'),
('TN', 'Tunisia'),
('TR', 'Turkey'),
('TM', 'Turkmenistan'),
('TC', 'Turks and Caicos Islands'),
('TV', 'Tuvalu'),
('UG', 'Uganda'),
('UA', 'Ukraine'),
('AE', 'United Arab Emirates'),
('GB', 'United Kingdom'),
('US', 'United States'),
('UM', 'United States Minor Outlying Islands'),
('UY', 'Uruguay'),
('UZ', 'Uzbekistan'),
('VU', 'Vanuatu'),
('VE', 'Venezuela, Bolivarian Republic of'),
('VN', 'Viet Nam'),
('VG', 'Virgin Islands, British'),
('VI', 'Virgin Islands, U.S.'),
('WF', 'Wallis and Futuna'),
('EH', 'Western Sahara'),
('YE', 'Yemen'),
('ZM', 'Zambia'),
('ZW', 'Zimbabwe'),
)
employer_types = (
('F','Federal Government'),
('S','State and Local Governmental Employer'),
('T','Tax Exempt Employer'),
('Y','State and Local Tax Exempt Employer'),
('N','None Apply'),
)
('F','Federal Government'),
('S','State and Local Governmental Employer'),
('T','Tax Exempt Employer'),
('Y','State and Local Tax Exempt Employer'),
('N','None Apply'),
)
employment_codes = (
('A', 'Agriculture'),
('H', 'Household'),
('M', 'Military'),
('Q', 'Medicare Qualified Government Employee'),
('X', 'Railroad'),
('F', 'Regular'),
('R', 'Regular (all others)'),
)
('A', 'Agriculture'),
('H', 'Household'),
('M', 'Military'),
('Q', 'Medicare Qualified Government Employee'),
('X', 'Railroad'),
('F', 'Regular'),
('R', 'Regular (all others)'),
)
tax_jurisdiction_codes = (
('V', 'Virgin Islands'),
('G', 'Guam'),
('S', 'American Samoa'),
('N', 'Northern Mariana Islands'),
('P', 'Puerto Rico'),
)
(' ', 'W-2'),
('V', 'Virgin Islands'),
('G', 'Guam'),
('S', 'American Samoa'),
('N', 'Northern Mariana Islands'),
('P', 'Puerto Rico'),
)
tax_type_codes = (
('C', 'City Income Tax'),
('D', 'Country Income Tax'),
('E', 'School District Income Tax'),
('F', 'Other Income Tax'),
)
('C', 'City Income Tax'),
('D', 'Country Income Tax'),
('E', 'School District Income Tax'),
('F', 'Other Income Tax'),
)

View file

@ -1,6 +1,10 @@
import decimal, datetime
import inspect
import enums
from six import string_types
from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
class ValidationError(Exception):
def __init__(self, msg, field=None):
@ -16,22 +20,26 @@ class ValidationError(Exception):
class Field(object):
creation_counter = 0
is_read_only = False
_value = None
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None):
def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None):
self.name = name
self._value = None
self._orig_value = None
self.min_length = min_length
self.max_length = max_length
self.blank = blank
self.required = required
self.uppercase = uppercase
self.creation_counter = creation_counter or Field.creation_counter
Field.creation_counter += 1
def validate(self):
raise NotImplemented
raise NotImplementedError
def get_data(self):
raise NotImplemented
raise NotImplementedError
def __setvalue(self, value):
self._value = value
@ -76,7 +84,7 @@ class Field(object):
required=o['required'],
)
if isinstance(o['value'], basestring) and re.match('^\d*\.\d*$', o['value']):
if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value'])
self.value = o['value']
@ -90,14 +98,10 @@ class Field(object):
wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False)
wrapper.width = 100
value = wrapper.wrap(value)
#value = textwrap.wrap(value, 100)
#print value
value = list(map(lambda x:(" " * 9) + ('"' + x + '"'), value))
#value[0] = '"' + value[0] + '"'
value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * (wrapper.width / 10))
value.append(" " * 10 + ''.join((map(lambda x:str(x) + (' ' * 9), range(wrapper.width / 10 )))))
#value.append((" " * 59) + map(lambda x:("%x" % x), range(16))
value = list([(" " * 9) + ('"' + x + '"') for x in value])
value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10))
value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))])))
start = counter['c']
counter['c'] += len(self._orig_value or self.value)
@ -115,22 +119,28 @@ class Field(object):
class TextField(Field):
def validate(self):
if self.value == None and self.required:
if self.value is None and self.required:
raise ValidationError("value required", field=self)
if len(self.get_data()) > self.max_length:
data = self.get_data()
if len(data) > self.max_length:
raise ValidationError("value is too long", field=self)
stripped_data_length = len(data.strip())
if stripped_data_length < self.min_length:
raise ValidationError("value is too short", field=self)
if stripped_data_length == 0 and (not self.blank and self.required):
raise ValidationError("field cannot be blank", field=self)
def get_data(self):
value = self.value or ""
value = str(self.value or '').encode('ascii') or b''
if self.uppercase:
value = value.upper()
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
return value.ljust(self.max_length)[:self.max_length]
def __setvalue(self, value):
# NO NEWLINES
try:
value = value.replace('\n', '').replace('\r', '')
except AttributeError, e:
except AttributeError:
pass
self._value = value
@ -142,31 +152,35 @@ class TextField(Field):
class StateField(TextField):
def __init__(self, name=None, required=True, use_numeric=False, max_length=2):
super(StateField, self).__init__(name=name, max_length=2, required=required)
super(StateField, self).__init__(name=name, max_length=max_length, required=required)
self.use_numeric = use_numeric
def get_data(self):
value = self.value or ""
value = str(self.value or 'XX')
if value.strip() and self.use_numeric:
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
postcode = enums.state_postal_numeric[value.upper()]
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length)
else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
formatted = value.encode('ascii').ljust(self.max_length)
return formatted[:self.max_length]
def validate(self):
super(StateField, self).validate()
if self.value and self.value.upper() not in enums.state_postal_numeric.keys():
if self.value and self.value.upper() not in list(enums.state_postal_numeric.keys()):
raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self)
def parse(self, s):
if s.strip() and self.use_numeric:
states = dict( [(v,k) for (k,v) in enums.state_postal_numeric.items()] )
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
self.value = states[int(s)]
else:
self.value = s
class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None):
return super(EmailField, self).__init__(name=name, max_length=max_length,
super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False)
class IntegerField(TextField):
@ -178,37 +192,58 @@ class IntegerField(TextField):
except ValueError:
raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self):
value = self.value or ""
return str(value).zfill(self.max_length)[:self.max_length]
value = str(self.value).encode('ascii') if self.value else b''
return value.zfill(self.max_length)[:self.max_length]
def parse(self, s):
self.value = int(s)
if not is_blank_space(s):
self.value = int(s)
else:
self.value = 0
class StaticField(TextField):
def __init__(self, name=None, required=True, value=None):
super(StaticField, self).__init__(name=name, required=required,
max_length=len(value))
def __init__(self, name=None, required=True, value=None, uppercase=False):
super(StaticField, self).__init__(name=name,
required=required,
max_length=len(value),
uppercase=uppercase)
self._static_value = value
self._value = value
def parse(self, s):
pass
class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self):
return " " * self.max_length
return b' ' * self.max_length
def parse(self, s):
pass
def validate(self):
if len(self.get_data()) != self.max_length:
raise ValidationError("blank field did not match expected length", field=self)
class ZeroField(BlankField):
is_read_only = True
def get_data(self):
return b'0' * self.max_length
class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False):
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value):
self._value = value
@ -219,11 +254,12 @@ class CRLFField(TextField):
value = property(__getvalue, __setvalue)
def get_data(self):
return '\r\n'
return b'\r\n'
def parse(self, s):
self.value = s
class BooleanField(Field):
def __init__(self, name=None, required=True, value=None):
super(BooleanField, self).__init__(name=name, required=required, max_length=1)
@ -233,7 +269,7 @@ class BooleanField(Field):
pass
def get_data(self):
return '1' if self._value else '0'
return b'1' if self._value else b'0'
def parse(self, s):
self.value = (s == '1')
@ -250,26 +286,43 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self)
def get_data(self):
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
cents = int((self.value or 0) * 100)
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
if not is_blank_space(s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField):
def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=8)
super(DateField, self).__init__(name=name, required=required, max_length=8)
if value:
self.value = value
def get_data(self):
if self._value:
return self._value.strftime('%m%d%Y')
return '0' * self.max_length
return self._value.strftime('%m%d%Y').encode('ascii')
return b'0' * self.max_length
def parse(self, s):
if int(s) > 0:
self.value = datetime.date(*[int(x) for x in s[4:8], s[0:2], s[2:4]])
self.value = datetime.date(*[int(x) for x in (s[4:8], s[0:2], s[2:4])])
else:
self.value = None
@ -277,7 +330,7 @@ class DateField(TextField):
if isinstance(value, datetime.date):
self._value = value
elif value:
self._value = datetime.date(*[int(x) for x in value[4:8], value[0:2], value[2:4]])
self._value = datetime.date(*[int(x) for x in (value[4:8], value[0:2], value[2:4])])
else:
self._value = None
@ -289,19 +342,18 @@ class DateField(TextField):
class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=6)
super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
if value:
self.value = value
def get_data(self):
if self._value:
return self._value.strftime("%m%Y")
return '0' * self.max_length
return str(self._value.strftime('%m%Y').encode('ascii'))
return b'0' * self.max_length
def parse(self, s):
if int(s) > 0:
self.value = datetime.date(*[int(x) for x in s[2:6], s[0:2], 1])
self.value = datetime.date(*[int(x) for x in (s[2:6], s[0:2], 1)])
else:
self.value = None
@ -309,7 +361,7 @@ class MonthYearField(TextField):
if isinstance(value, datetime.date):
self._value = value
elif value:
self._value = datetime.date(*[int(x) for x in value[2:6], value[0:2], 1])
self._value = datetime.date(*[int(x) for x in (value[2:6], value[0:2], 1)])
else:
self._value = None
@ -317,4 +369,3 @@ class MonthYearField(TextField):
return self._value
value = property(__getvalue, __setvalue)

View file

@ -1,15 +1,19 @@
from fields import Field, TextField, ValidationError
from .fields import Field, TextField, ValidationError
import copy
import pdb
import collections
class Model(object):
record_length = -1
record_identifier = ' '
required = False
target_size = 512
def __init__(self):
for (key, value) in self.__class__.__dict__.items():
if self.record_length == -1:
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()):
if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
# AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE,
@ -19,21 +23,31 @@ class Model(object):
if not src_field.name:
setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__)
self.__dict__[key] = copy.copy(src_field)
new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = new_field_instance.value
self.__dict__[key] = new_field_instance
def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field):
getattr(self, key).value = value
self.set_field_value(key, value)
else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value
def set_field_value(self, field_name, value):
getattr(self, field_name).value = value
def get_fields(self):
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1)
identifier = TextField(
"record_identifier",
max_length = len(self.record_identifier),
blank = len(self.record_identifier) == 0,
creation_counter=-1)
identifier.value = self.record_identifier
fields = [identifier]
for key in self.__class__.__dict__.keys():
for key in list(self.__class__.__dict__.keys()):
attr = getattr(self, key)
if isinstance(attr, Field):
fields.append(attr)
@ -41,7 +55,7 @@ class Model(object):
def get_sorted_fields(self):
fields = self.get_fields()
fields.sort(key=lambda x:x.creation_counter)
fields.sort(key=lambda x: x.creation_counter)
return fields
def validate(self):
@ -50,27 +64,33 @@ class Model(object):
try:
custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError, e:
except AttributeError:
continue
if callable(custom_validator):
custom_validator(f)
if isinstance(custom_validator, collections.Callable):
custom_validator(f)
def output(self):
result = ''.join([field.get_data() for field in self.get_sorted_fields()])
def output(self, format='binary'):
if format == 'text':
return self.output_text()
return self.output_efile()
if hasattr(self, 'record_length') and len(result) != self.record_length:
def output_efile(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if self.record_length < 0 or len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp):
# Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]:
field.read(fp)
def toJSON(self):
return {
'__class__': self.__class__.__name__,
@ -80,19 +100,17 @@ class Model(object):
def fromJSON(self, o):
fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields:
target = self.__dict__[f.name]
if (target.required != f.required or
target.max_length != f.max_length):
print "Warning: value mismatch on import"
if (target.required != f.required
or target.max_length != f.max_length):
print("Warning: value mismatch on import")
target._value = f._value
#print (self.__dict__[f.name].name == f.name)
#self.__dict__[f.name].name == f.name
#self.__dict__[f.name].max_length == f.max_length
target.value = f.value
return self

View file

@ -1,86 +1,86 @@
#!/usr/bin/env python
import re
class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line):
self.classname = classname,
self.line = line
self.lines = []
def __init__(self, classname, line):
self.classname = classname,
self.line = line
self.lines = []
def add_line(self, line):
self.lines.append(line)
def add_line(self, line):
self.lines.append(line)
def process(self):
i = 0
for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line)
if match:
(a, b) = match.groups()
a = int(a)
def process(self):
i = 0
for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line)
if match:
(a, b) = match.groups()
a = int(a)
if (i + 1) != a:
line_number = self.line + line_no
print("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a))
if (i + 1) != a:
line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
line_number, line.split(' ')[0].strip(), i+1, a)))
i = int(b) if b else a
i = int(b) if b else a
class ModelDefParser(object):
re_triplequote = re.compile('"""')
re_whitespace = re.compile("^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
re_triplequote = re.compile('"""')
re_whitespace = re.compile(r"^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass):
self.infile = infile
self.line = 0
self.EntryClass = entryclass
def __init__(self, infile, entryclass):
self.infile = infile
self.line = 0
self.EntryClass = entryclass
def endclass(self):
if self.current_class:
self.current_class.process()
self.current_class = None
def endclass(self):
if self.current_class:
self.current_class.process()
self.current_class = None
def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line)
def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line)
def parse(self):
infile = self.infile
whitespace = 0
in_block_comment = False
self.current_class = None
def parse(self):
infile = self.infile
whitespace = 0
in_block_comment = False
self.current_class = None
for line in infile:
self.line += 1
for line in infile:
self.line += 1
if line.startswith('#'):
continue
if line.startswith('#'):
continue
if self.re_triplequote.search(line):
in_block_comment = not in_block_comment
if self.re_triplequote.search(line):
in_block_comment = not in_block_comment
if in_block_comment:
continue
if in_block_comment:
continue
match_whitespace = self.re_whitespace.match(line)
if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0])
else:
match_whitespace = 0
match_whitespace = self.re_whitespace.match(line)
if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0])
else:
match_whitespace = 0
classmatch = self.re_classdef.match(line)
if classmatch:
classname, subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)
classmatch = self.re_classdef.match(line)
if classmatch:
classname, _subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,5 +1,3 @@
#!/usr/bin/python
# coding=UTF-8
"""
Parser utility to read data from Publication 1220 and
convert it into python classes.
@ -7,6 +5,7 @@ convert it into python classes.
"""
import re
import hashlib
from functools import reduce
class SimpleDefParser(object):
def __init__(self):
@ -34,7 +33,7 @@ class SimpleDefParser(object):
item = item.upper()
if '-' in item:
parts = map(lambda x:self._intify(x), item.split('-'))
parts = [self._intify(x) for x in item.split('-')]
item = reduce(lambda x,y: y-x, parts)
else:
item = self._intify(item)
@ -56,7 +55,7 @@ class LengthExpression(object):
self.exp_cache = {}
def __call__(self, value, exps):
return len(exps) == sum(map(lambda x: self.check(value, x), exps))
return len(exps) == sum([self.check(value, x) for x in exps])
def compile_exp(self, exp):
op, val = self.REG.match(exp).groups()
@ -98,7 +97,7 @@ class RangeToken(BaseToken):
def value(self):
if '-' not in self._value:
return 1
return reduce(lambda x,y: y-x, map(int, self._value.split('-')))+1
return reduce(lambda x,y: y-x, list(map(int, self._value.split('-'))))+1
@property
def end_position(self):
@ -110,7 +109,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken):
regexp = re.compile('^(\d+)$')
regexp = re.compile(r'^(\d+)$')
@property
def value(self):
@ -118,7 +117,7 @@ class NumericToken(BaseToken):
class RecordBuilder(object):
import fields
from . import fields
entry_max_length = 4
@ -145,8 +144,7 @@ class RecordBuilder(object):
(re.compile(r'zero\-filled', re.IGNORECASE), +1),
(re.compile(r'leading zeroes', re.IGNORECASE), +1),
(re.compile(r'left-\justif', re.IGNORECASE), -1),
(re.compile(r'left\-justif', re.IGNORECASE), -1),
],
},
}),
@ -201,15 +199,15 @@ class RecordBuilder(object):
try:
f_length = int(f_length)
except ValueError, e:
except ValueError as e:
# bad result, skip
continue
try:
assert f_length == RangeToken(f_range).value
except AssertionError, e:
except AssertionError as e:
continue
except ValueError, e:
except ValueError as e:
# bad result, skip
continue
@ -223,7 +221,7 @@ class RecordBuilder(object):
else:
required = None
f_name = u'_'.join(map(lambda x:x.lower(), name_parts))
f_name = '_'.join([x.lower() for x in name_parts])
f_name = f_name.replace('&', 'and')
f_name = re.sub(r'[^\w]','', f_name)
@ -240,7 +238,7 @@ class RecordBuilder(object):
lengthexp = LengthExpression()
for entry in entries:
matches = dict(map(lambda x:(x[0],0), self.FIELD_TYPES))
matches = dict([(x[0],0) for x in self.FIELD_TYPES])
for (classtype, criteria) in self.FIELD_TYPES:
if 'length' in criteria:
@ -248,7 +246,7 @@ class RecordBuilder(object):
continue
if 'regexp' in criteria:
for crit_key, crit_values in criteria['regexp'].items():
for crit_key, crit_values in list(criteria['regexp'].items()):
for (crit_re, score) in crit_values:
matches[classtype] += score if crit_re.search(entry[crit_key]) else 0
@ -256,7 +254,7 @@ class RecordBuilder(object):
matches = list(matches.items())
matches.sort(key=lambda x:x[1])
matches_found = True if sum(map(lambda x:x[1], matches)) > 0 else False
matches_found = True if sum([x[1] for x in matches]) > 0 else False
entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField
yield entry
@ -271,7 +269,7 @@ class RecordBuilder(object):
if entry['name'] == 'blank':
blank_id = hashlib.new('md5')
blank_id.update(entry['range'].encode())
add( (u'blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
add( ('blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
else:
add(entry['name'].ljust(40))
@ -386,7 +384,7 @@ class PastedDefParser(RecordBuilder):
for g in groups:
assert g['byterange'].value == g['length'].value
desc = u' '.join(map(lambda x:unicode(x.value), g['desc']))
desc = ' '.join([str(x.value) for x in g['desc']])
if g['name'][-1].value.lower() == '(optional)':
g['name'] = g['name'][0:-1]
@ -396,7 +394,7 @@ class PastedDefParser(RecordBuilder):
else:
required = None
name = u'_'.join(map(lambda x:x.value.lower(), g['name']))
name = '_'.join([x.value.lower() for x in g['name']])
name = re.sub(r'[^\w]','', name)
yield({

View file

@ -3,314 +3,102 @@
import subprocess
import re
import pdb
import itertools
import fitz
""" pdftotext -layout -nopgbrk p1220.pdf - """
def strip_values(items):
expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE)
return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x]
class PDFRecordFinder(object):
def __init__(self, src, heading_exp=None):
if not heading_exp:
heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout')
field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$')
field_heading_exp = re.compile('^Field.*Field.*Length.*Description')
def __init__(self, src):
self.document = fitz.open(src)
opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-']
pdftext = subprocess.check_output(opts)
self.textrows = pdftext.split('\n')
self.heading_exp = heading_exp
self.field_heading_exp = field_heading_exp
def find_record_table_ranges(self):
matches = []
for (page_number, page) in enumerate(self.document):
header_rects = page.search_for("Record Name:")
for header_match_rect in header_rects:
header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: "
header_match_rect.x1 = page.bound().x1 # Extend to right side of page
header_text = page.get_textbox(header_match_rect)
record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip()
matches.append((record_name, {
'page': page_number,
'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably
}))
return matches
def find_records(self):
record_ranges = self.find_record_table_ranges()
for record_index, (record_name, record_details) in enumerate(record_ranges):
current_rows = []
next_index = record_index+1
(_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1})
for page_number in range(record_details['page'], next_record_details['page']):
page = self.document[page_number]
table_search_rect = page.bound()
if page_number == record_details['page']:
table_search_rect.y0 = record_details['y']
tables = page.find_tables(
clip = table_search_rect,
min_words_horizontal = 1,
min_words_vertical = 1,
horizontal_strategy = "lines_strict",
intersection_tolerance = 1,
)
for table in tables:
if table.col_count == 4:
table = table.extract()
# Parse field position (sometimes a cell has multiple
# values because IRS employees apparently smoke crack
for row in table:
first_column_lines = row[0].strip().split('\n')
if len(first_column_lines) > 1:
for sub_row in self.split_row(row):
current_rows.append(strip_values(sub_row))
else:
current_rows.append(strip_values(row))
consecutive_rows = self.filter_nonconsecutive_rows(current_rows)
yield(record_name, consecutive_rows)
def split_row(self, row):
if not row[1]:
return []
split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None))
description = strip_values([row[3]])[0]
rows = []
for row in split_rows:
if len(row) < 3 or not row[2]:
row = self.infer_field_length(row)
rows.append([*row, description])
return rows
def infer_field_length(self, row):
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
return row
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
length = str(end-start+1) if end and start else '1'
return (*row[:2], length)
def filter_nonconsecutive_rows(self, rows):
consecutive_rows = []
last_position = 0
for row in rows:
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
continue
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
if start != last_position + 1:
continue
last_position = end if end else start
consecutive_rows.append(row)
return consecutive_rows
def records(self):
headings = self.locate_heading_rows_by_field()
#for x in headings:
# print x
for (start, end, name) in headings:
name = name.decode('ascii', 'ignore')
yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end))
def locate_heading_rows_by_field(self):
results = []
record_break = []
line_is_whitespace_exp = re.compile('^(\s*)$')
record_begin_exp = self.heading_exp #re.compile('Record\ Name')
for (i, row) in enumerate(self.textrows):
match = self.field_heading_exp.match(row)
if match:
# work backwards until we think the header is fully copied
space_count_exp = re.compile('^(\s*)')
position = i - 1
spaces = 0
#last_spaces = 10000
complete = False
header = None
while not complete:
line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False
is_record_begin = record_begin_exp.search(self.textrows[position])
if is_record_begin or line_is_whitespace:
header = self.textrows[position-1:i]
complete = True
position -= 1
name = ''.join(header).strip().decode('ascii','ignore')
print (name, position)
results.append((i, name, position))
else:
# See if this row forces us to break from field reading.
if re.search('Record\ Layout', row):
record_break.append(i)
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]):
end_pos = None
#print a[0], record_break[0], b[0]-1
while record_break and record_break[0] < a[0]:
record_break = record_break[1:]
if record_break[0] < b[0]-1:
end_pos = record_break[0]
record_break = record_break[1:]
else:
end_pos = b[0]-1
merged.append( (a[0], end_pos-1, a[1]) )
return merged
"""
def locate_heading_rows(self):
results = []
for (i, row) in enumerate(self.textrows):
match = self.heading_exp.match(row)
if match:
results.append((i, ''.join(match.groups())))
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]):
merged.append( (a[0], b[0]-1, a[1]) )
return merged
def locate_layout_block_rows(self):
# Search for rows that contain "Record Layout", as these are not fields
# we are interested in because they contain the crazy blocks of field definitions
# and not the nice 4-column ones that we're looking for.
results = []
for (i, row) in enumerate(self.textrows):
match = re.match("Record Layout", row)
"""
def find_fields(self, row_iter):
cc = ColumnCollector()
blank_row_counter = 0
for r in row_iter:
row = r.decode('UTF-8')
#print row
row_columns = self.extract_columns_from_row(row)
if not row_columns:
if cc.data and len(cc.data.keys()) > 1 and len(row.strip()) > cc.data.keys()[-1]:
yield cc
cc = ColumnCollector()
else:
cc.empty_row()
continue
try:
cc.add(row_columns)
except IsNextField, e:
yield cc
cc = ColumnCollector()
cc.add(row_columns)
except UnknownColumn, e:
raise StopIteration
yield cc
def extract_columns_from_row(self, row):
re_multiwhite = re.compile(r'\s{2,}')
# IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE
if not re_multiwhite.search(row):
return None
white_ranges = [0,]
pos = 0
while pos < len(row):
match = re_multiwhite.search(row[pos:])
if match:
white_ranges.append(pos + match.start())
white_ranges.append(pos + match.end())
pos += match.end()
else:
white_ranges.append(len(row))
pos = len(row)
row_result = []
white_iter = iter(white_ranges)
while white_iter:
try:
start = white_iter.next()
end = white_iter.next()
if start != end:
row_result.append(
(start, row[start:end].encode('ascii','ignore'))
)
except StopIteration:
white_iter = None
#print row_result
return row_result
class UnknownColumn(Exception):
pass
class IsNextField(Exception):
pass
class ColumnCollector(object):
def __init__(self, initial=None):
self.data = None
self.column_widths = None
self.max_data_length = 0
self.adjust_pad = 3
self.empty_rows = 0
pass
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
map(lambda x:x if len(x) < 25 else x[:25] + '..',
self.data.values() if self.data else ''))
def add(self, data):
#if self.empty_rows > 2:
# raise IsNextField()
if not self.data:
self.data = dict(data)
else:
data = self.adjust_columns(data)
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
self.update_column_widths(data)
def empty_row(self):
self.empty_rows += 1
def update_column_widths(self, data):
self.last_data_length = len(data)
self.max_data_length = max(self.max_data_length, len(data))
if not self.column_widths:
self.column_widths = dict(map(lambda (column, value): [column, column + len(value)], data))
else:
for col_id, value in data:
try:
self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip()))
except KeyError:
pass
def add_old(self, data):
if not self.data:
self.data = dict(data)
else:
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
def adjust_columns(self, data):
adjusted_data = {}
for col_id, value in data:
if col_id in self.data.keys():
adjusted_data[col_id] = value.strip()
else:
for col_start, col_end in self.column_widths.items():
if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id:
if col_start in adjusted_data:
adjusted_data[col_start] += ' ' + value.strip()
else:
adjusted_data[col_start] = value.strip()
return adjusted_data.items()
def merge_column(self, col_id, value):
if col_id in self.data.keys():
self.data[col_id] += ' ' + value.strip()
else:
# try adding a wiggle room value?
# FIXME:
# Sometimes description columns contain column-like
# layouts, and this causes the ColumnCollector to become
# confused. Perhaps we could check to see if a column occurs
# after the maximum column, and assume it's part of the
# max column?
"""
for col_start, col_end in self.column_widths.items():
if col_start <= col_id and (col_end) >= col_id:
self.data[col_start] += ' ' + value.strip()
return
"""
raise UnknownColumn
def is_next_field(self, data):
"""
If the first key value contains a string
and we already have some data in the record,
then this row is probably the beginning of
the next field. Raise an exception and continue
on with a fresh ColumnCollector.
"""
""" If the length of the value in column_id is less than the position of the next column_id,
then this is probably a continuation.
"""
if self.data and data:
keys = dict(self.column_widths).keys()
keys.sort()
keys += [None]
if self.last_data_length < len(data):
return True
first_key, first_value = dict(data).items()[0]
if self.data.keys()[0] == first_key:
position = keys.index(first_key)
max_length = keys[position + 1]
if max_length:
return len(first_value) > max_length or len(data) == self.max_data_length
return False
@property
def tuple(self):
#try:
if self.data:
return tuple(map(lambda k:self.data[k], sorted(self.data.keys())))
return ()
#except:
# import pdb
# pdb.set_trace()
return self.find_records()

View file

@ -1,11 +1,13 @@
import model
from fields import *
import enums
from . import model
from .fields import *
from . import enums
__all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord',
'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
'TotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'FinalRecord', 'StateWageRecord']
'StateTotalRecord', 'FinalRecord', 'StateWageRecord',
'StateTotalRecordIA',
]
class EFW2Record(model.Model):
record_length = 512
@ -103,8 +105,8 @@ class EmployerRecord(EFW2Record):
zipcode_ext = TextField(max_length=4, required=False)
kind_of_employer = TextField(max_length=1)
blank1 = BlankField(max_length=4)
foreign_state_province = TextField(max_length=23)
foreign_postal_code = TextField(max_length=15)
foreign_state_province = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False)
employment_code = TextField(max_length=1)
tax_jurisdiction_code = TextField(max_length=1, required=False)
@ -148,7 +150,7 @@ class EmployeeWageRecord(EFW2Record):
ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22)
@ -161,7 +163,7 @@ class EmployeeWageRecord(EFW2Record):
blank1 = BlankField(max_length=5)
foreign_state = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False)
country = TextField(max_length=2)
country = TextField(max_length=2, required=True, blank=True)
wages_tips = MoneyField(max_length=11)
federal_income_tax_withheld = MoneyField(max_length=11)
social_security_wages = MoneyField(max_length=11)
@ -197,8 +199,10 @@ class EmployeeWageRecord(EFW2Record):
blank6 = BlankField(max_length=23)
def validate_ssn(self, f):
if str(f.value).startswith('666','9'):
raise ValidationError("ssn cannot start with 666 or 9", field=f)
if str(f.value).startswith('666'):
raise ValidationError("ssn cannot start with 666", field=f)
if str(f.value).startswith('9'):
raise ValidationError("ssn cannot start with 9", field=f)
@ -241,7 +245,7 @@ class StateWageRecord(EFW2Record):
taxing_entity_code = TextField(max_length=5, required=False)
ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22)
@ -255,20 +259,20 @@ class StateWageRecord(EFW2Record):
foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False)
optional_code = TextField(max_length=2, required=False)
reporting_period = MonthYearField()
reporting_period = MonthYearField(required=False)
quarterly_unemp_ins_wages = MoneyField(max_length=11)
quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11)
number_of_weeks_worked = IntegerField(max_length=2)
number_of_weeks_worked = IntegerField(max_length=2, required=False)
date_first_employed = DateField(required=False)
date_of_separation = DateField(required=False)
blank2 = BlankField(max_length=5)
state_employer_account_num = TextField(max_length=20)
state_employer_account_num = IntegerField(max_length=20, required=False)
blank3 = BlankField(max_length=6)
state_code_2 = StateField(use_numeric=True)
state_taxable_wages = MoneyField(max_length=11)
state_income_tax_wh = MoneyField(max_length=11)
other_state_data = TextField(max_length=10, required=False)
tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F
tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F
local_taxable_wages = MoneyField(max_length=11)
local_income_tax_wh = MoneyField(max_length=11)
state_control_number = IntegerField(max_length=7, required=False)
@ -278,7 +282,8 @@ class StateWageRecord(EFW2Record):
def validate_tax_type_code(self, field):
choices = [x for x,y in enums.tax_type_codes]
if field.value.upper() not in choices:
value = field.value
if value and value.upper() not in choices:
raise ValidationError("%s not one of %s" % (field.value,choices), field=f)
@ -354,6 +359,17 @@ class StateTotalRecord(EFW2Record):
supplemental_data = TextField(max_length=510)
class StateTotalRecordIA(EFW2Record):
#year=2018
record_identifier = 'RV'
number_of_rs_records = IntegerField(max_length=7) # num records since last 'RE' record
wages_tips = MoneyField(max_length=15)
state_income_tax_wh = MoneyField(max_length=15)
employer_ben = TextField(max_length=8)
iowa_confirmation_number = ZeroField(max_length=10)
blank1 = BlankField(max_length=455)
class FinalRecord(EFW2Record):
#year=2012
record_identifier = 'RF'

1
requirements.txt Normal file
View file

@ -0,0 +1 @@
PyMuPDF==1.24.0

76
scripts/pyaccuwage-convert Executable file
View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
import pyaccuwage
import argparse
import os, os.path
import sys
"""
Command line tool for converting IRS e-file fixed field records
to/from JSON or a simple text format.
Attempts to load record types from a python module in the current working
directory named record_types.py
The module must export a RECORD_TYPES list with the names of the classes to
import as valid record types.
"""
def get_record_types():
try:
sys.path.append(os.getcwd())
import record_types
r = {}
for record_type in record_types.RECORD_TYPES:
r[record_type] = getattr(record_types, record_type)
return r
except ImportError:
print('warning: using default record types (failed to import record_types.py)')
return pyaccuwage.get_record_types()
def read_file(fd, filename, record_types):
filename, extension = os.path.splitext(filename)
if extension == '.json':
return pyaccuwage.json_load(fd, record_types)
elif extension == '.txt':
return pyaccuwage.text_load(fd, record_types)
else:
return pyaccuwage.load(fd, record_types)
def write_file(outfile, filename, records):
filename, extension = os.path.splitext(filename)
if extension == '.json':
pyaccuwage.json_dump(outfile, records)
elif extension == '.txt':
pyaccuwage.text_dump(outfile, records)
else:
pyaccuwage.dump(outfile, records)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Convert accuwage efile data between different formats."
)
parser.add_argument("-i", '--input',
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('r'),
help="Source file to convert")
parser.add_argument("-o", "--output",
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('w'),
help="Destination file to output")
args = parser.parse_args()
in_file = args.input[0]
out_file = args.output[0]
records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records)
print("wrote {} records to {}".format(len(records), out_file.name))

View file

@ -1,4 +1,4 @@
#!/usr/bin/python
#!/usr/bin/env python
from pyaccuwage.parser import RecordBuilder
from pyaccuwage.pdfextract import PDFRecordFinder
import argparse
@ -29,48 +29,9 @@ doc = PDFRecordFinder(source_file)
records = doc.records()
builder = RecordBuilder()
def record_begins_at(field):
return int(fields[0].data.values()[0].split('-')[0], 10)
def record_ends_at(fields):
return int(fields[-1].data.values()[0].split('-')[-1], 10)
last_record_begins_at = -1
last_record_ends_at = -1
for rec in records:
#if not rec[1]:
# continue # no actual fields detected
fields = rec[1]
# strip out fields that are not 4 items long
fields = filter(lambda x:len(x.tuple) == 4, fields)
# strip fields that don't begin at position 0
fields = filter(lambda x: 0 in x.data, fields)
# strip fields that don't have a length-range type item in position 0
fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields)
if not fields:
continue
begins_at = record_begins_at(fields)
ends_at = record_ends_at(fields)
# FIXME record_ends_at is randomly exploding due to record data being
# a lump of text and not necessarily a field entry. I assume
# this is cleaned out by the record builder class.
#print last_record_ends_at + 1, begins_at
if last_record_ends_at + 1 != begins_at:
name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1])
name = re.sub('[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x:x.tuple, rec[1][0:])):
for (name, fields) in records:
name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1])
name = re.sub(r'[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x: x, fields[0:])):
sys.stdout.write('\t' + field + '\n')
#print field
last_record_ends_at = ends_at

View file

@ -1,12 +1,21 @@
from distutils.core import setup
from setuptools import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
setup(name='pyaccuwage',
version='0.2012.1',
version='0.2025.0',
packages=['pyaccuwage'],
scripts=[
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-convert',
'scripts/pyaccuwage-genfieldfill',
'scripts/pyaccuwage-parse',
'scripts/pyaccuwage-pdfparse',
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-genfieldfill'
],
zip_safe=True,
test_suite='setup.pyaccuwage_tests',
)

67
tests/test_fields.py Normal file
View file

@ -0,0 +1,67 @@
import unittest
from pyaccuwage.fields import TextField
from pyaccuwage.fields import StaticField
# from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField
# from pyaccuwage.fields import ZeroField
# from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')
def testStringUnsetOptional(self):
field = TextField(max_length=6, required=False)
field.validate()
self.assertEqual(field.get_data(), b' ' * 6)
def testStringRequiredUnassigned(self):
field = TextField(max_length=6)
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredNonBlank(self):
field = TextField(max_length=6)
field.value = ''
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredBlank(self):
field = TextField(max_length=6, blank=True)
field.value = ''
field.validate()
self.assertEqual(len(field.get_data()), 6)
def testStringMinimumLength(self):
field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect
field.value = '' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '12345' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '123456' # one character too short
class TestStaticField(unittest.TestCase):
def test_static_field(self):
field = StaticField(value='TEST')
self.assertEqual(field.get_data(), b'TEST')

179
tests/test_records.py Normal file
View file

@ -0,0 +1,179 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import StaticField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=12)
static1 = StaticField(value='hey mister!!')
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 12,
b'hey mister!!',
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
static1: hey mister!!
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
text2 = TextField(max_length=20, required=False)
blank2 = BlankField(max_length=4)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
class TestRequiredFields(unittest.TestCase):
def createTestRecord(self, required=False, blank=False):
class Record(pyaccuwage.model.Model):
record_length = 16
record_identifier = ''
test_field = TextField(max_length=16, required=required, blank=blank)
record = Record()
def dump():
return pyaccuwage.dumps([record])
return (record, dump)
def testRequiredBlankField(self):
(record, dump) = self.createTestRecord(required=True, blank=True)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value may be empty string
dump()
def testRequiredNonblankField(self):
(record, dump) = self.createTestRecord(required=True, blank=False)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value must not be empty string
self.assertRaises(ValidationError, dump)
record.test_field.value = 'hello'
dump()
def testOptionalBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=True)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()
def testOptionalNonBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=False)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()