Compare commits

...
Sign in to create a new pull request.

29 commits

Author SHA1 Message Date
9029659f98 update internal VERSION property 2025-05-13 12:45:51 -05:00
1302de9df7 bump version to 0.2025.0 2025-05-13 12:14:02 -05:00
fb8091fb09 change Iowa RS record state_employer_account_num from TextField to IntegerField 2025-05-13 12:09:49 -05:00
4408da71a9 mark some fields as optional 2024-04-10 09:41:10 -04:00
e0e4c1291d add min_length option to TextField for SSNs and stuff like that 2024-03-31 11:52:22 -04:00
5f4dc8b80f add 'blank' field option to allow empty text in required fields (default: false) 2024-03-31 11:14:16 -04:00
74b7935ced bump version to 2024 2024-03-29 10:50:25 -04:00
66573e4d1d update for 2023 p1220 parsing, stupid irs 2024-03-29 10:48:04 -04:00
86f8861da1 encode record delimiter as ascii bytes when str is passed 2022-02-06 11:06:51 -06:00
042de7ecb0 import typing.Callable (python 3.10+) 2021-12-18 08:56:43 -05:00
f28cd6edf2 bump version 0.2020.0 2021-09-03 07:48:24 -05:00
0bd82e09c4 Fix StaticField + tests for StaticField and unset optional TextField 2021-09-03 05:45:01 -05:00
558e3fd232 hopefully fix STaticField 2021-09-02 17:40:35 -05:00
7867a52a0c fliped args around like a simpleton 2021-01-29 16:26:26 -05:00
bfd43b7448 release 0.2018.2 2020-06-12 14:45:08 -05:00
1f1d3dd9bb Merge branch 'conversion-support' 2020-06-12 13:13:28 -05:00
431b594c1e add pyaccuwage-convert 2020-06-12 13:10:13 -05:00
8f86f76167 add format interchange functions, add tests, fix stuff 2020-06-12 13:07:41 -05:00
6af5067fca add option for record delimiter 2019-01-30 14:25:24 -06:00
250ca8d31f fix flubbed blank field specifier on StateTotalRecordIA 2019-01-28 13:29:14 -06:00
7ddcfcc1c3 clean up some indent 2019-01-27 10:36:37 -06:00
d08f1ca586 hopefully fix python 2 and 3 compatability 2019-01-27 09:30:22 -06:00
6381f8b1ec bump version to 2018.01 2019-01-26 16:11:24 -06:00
7c32cb0dd3 add StateTotalRecord for Iowa 2019-01-26 11:49:31 -06:00
5afdcd6a50 add 'permitted_benefits_health' to RT and RO records for 2017 2018-01-27 11:38:23 -06:00
706c39f7bb CRLFField return binary data for get_data() 2017-10-29 10:41:52 -05:00
078273f49f fix json encoding by encoding bytes as ascii 2017-01-07 17:00:29 -06:00
9320c68961 use BytesIO to work with python3 2017-01-07 14:52:33 -06:00
16bf2c41d0 run through 2to3 2017-01-07 13:58:33 -06:00
14 changed files with 1153 additions and 960 deletions

View file

@ -1,7 +1,9 @@
from record import * try:
from reader import RecordReader from collections import Callable
except:
from typing import Callable # Python 3.10+
VERSION = (0, 2012, 0) VERSION = (0, 2025, 0)
RECORD_TYPES = [ RECORD_TYPES = [
'SubmitterRecord', 'SubmitterRecord',
@ -13,129 +15,156 @@ RECORD_TYPES = [
'OptionalTotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'StateTotalRecord',
'FinalRecord' 'FinalRecord'
] ]
def test():
import record, model
from fields import ValidationError
for rname in RECORD_TYPES:
inst = record.__dict__[rname]()
try:
output_length = len(inst.output())
except ValidationError, e:
print e.msg, type(inst), inst.record_identifier
continue
print type(inst), inst.record_identifier, output_length
def test_dump(): def get_record_types():
import record, StringIO from . import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = StringIO.StringIO()
dump(records, out)
return out
def test_record_order():
import record
records = [
record.SubmitterRecord(),
record.EmployerRecord(),
record.EmployeeWageRecord(),
record.TotalRecord(),
record.FinalRecord(),
]
validate_record_order(records)
def test_load(fp):
return load(fp)
def load(fp):
# BUILD LIST OF RECORD TYPES
import record
types = {} types = {}
for r in RECORD_TYPES: for r in RECORD_TYPES:
klass = record.__dict__[r] klass = record.__dict__[r]
types[klass.record_identifier] = klass types[klass.record_identifier] = klass
return types
def load(fp, record_types):
distinct_identifier_lengths = set([len(record_types[k].record_identifier) for k in record_types])
assert(len(distinct_identifier_lengths) == 1)
ident_length = list(distinct_identifier_lengths)[0]
# Add aliases for the record types based on their record_identifier since that's all
# we have to work with with the e1099 data.
record_types_by_ident = {}
for k in record_types:
record_type = record_types[k]
record_identifier = record_type.record_identifier
record_types_by_ident[record_identifier] = record_type
# PARSE DATA INTO RECORDS AND YIELD THEM # PARSE DATA INTO RECORDS AND YIELD THEM
while fp.tell() < fp.len: while True:
record_ident = fp.read(2) record_ident = fp.read(ident_length)
if record_ident in types: if not record_ident:
record = types[record_ident]() break
if record_ident in record_types_by_ident:
record = record_types_by_ident[record_ident]()
record.read(fp) record.read(fp)
yield record yield record
def loads(s):
import StringIO def loads(s, record_types=get_record_types()):
fp = StringIO.StringIO(s) import io
return load(fp) fp = io.BytesIO(s)
return load(fp, record_types)
def dump(records, fp): def dump(fp, records, delim=None):
if type(delim) is str:
delim = delim.encode('ascii')
for r in records: for r in records:
fp.write(r.output()) fp.write(r.output())
if delim:
fp.write(delim)
def dumps(records):
import StringIO def dumps(records, delim=None, skip_validation=False):
fp = StringIO.StringIO() import io
dump(records, fp) fp = io.BytesIO()
if not skip_validation:
for record in records:
record.validate()
dump(fp, records, delim=delim)
fp.seek(0) fp.seek(0)
return fp.read() return fp.read()
def json_dumps(records): def json_dumps(records):
import json import json
import model
import decimal import decimal
class JSONEncoder(json.JSONEncoder): class JSONEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):
if hasattr(o, 'toJSON') and callable(getattr(o, 'toJSON')): if hasattr(o, 'toJSON') and isinstance(getattr(o, 'toJSON'), Callable):
return o.toJSON() return o.toJSON()
if type(o) is bytes:
return o.decode('ascii')
elif isinstance(o, decimal.Decimal): elif isinstance(o, decimal.Decimal):
return str(o.quantize(decimal.Decimal('0.01'))) return str(o.quantize(decimal.Decimal('0.01')))
return super(JSONEncoder, self).default(o) return super(JSONEncoder, self).default(o)
return json.dumps(records, cls=JSONEncoder, indent=2) return json.dumps(list(records), cls=JSONEncoder, indent=2)
def json_loads(s, record_classes): def json_dump(fp, records):
fp.write(json_dumps(records))
def json_loads(s, record_types):
import json import json
import fields from . import fields
import decimal import decimal
import re
if not isinstance(record_classes, dict): if not isinstance(record_types, dict):
record_classes = dict([ (x.__class__.__name__, x) for x in record_classes]) record_types = dict([ (x.__name__, x) for x in record_types])
def object_hook(o): def object_hook(o):
if '__class__' in o: if '__class__' in o:
klass = o['__class__'] klass = o['__class__']
if klass in record_types:
if klass in record_classes: record = record_types[klass]()
return record_classes[klass]().fromJSON(o) record.fromJSON(o)
return record
elif hasattr(fields, klass): elif hasattr(fields, klass):
return getattr(fields, klass)().fromJSON(o) return getattr(fields, klass)().fromJSON(o)
return o return o
#print "OBJECTHOOK", str(o)
#return {'object_hook':str(o)}
#def default(self, o):
# return super(JSONDecoder, self).default(o)
return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook) return json.loads(s, parse_float=decimal.Decimal, object_hook=object_hook)
def json_load(fp, record_types):
return json_loads(fp.read(), record_types)
def text_dump(fp, records):
for r in records:
fp.write(r.output(format='text').encode('ascii'))
def text_dumps(records):
import io
fp = io.BytesIO()
text_dump(fp, records)
fp.seek(0)
return fp.read()
def text_load(fp, record_classes):
records = []
current_record = None
if not isinstance(record_classes, dict):
record_classes = dict([ (x.__name__, x) for x in record_classes])
while True: #fp.readable():
line = fp.readline().decode('ascii')
if not line:
break
if line.startswith('---'):
record_name = line.strip('---').strip()
current_record = record_classes[record_name]()
records.append(current_record)
elif ':' in line:
field, value = [x.strip() for x in line.split(':')]
current_record.set_field_value(field, value)
return records
def text_loads(s, record_classes):
import io
fp = io.BytesIO(s)
return text_load(fp, record_classes)
# THIS WAS IN CONTROLLER, BUT UNLESS WE # THIS WAS IN CONTROLLER, BUT UNLESS WE
# REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER # REALLY NEED A CONTROLLER CLASS, IT'S SIMPLER
# TO JUST KEEP IT IN HERE. # TO JUST KEEP IT IN HERE.
@ -151,14 +180,15 @@ def validate_required_records(records):
while req_types: while req_types:
req = req_types[0] req = req_types[0]
if req not in types: if req not in types:
from fields import ValidationError from .fields import ValidationError
raise ValidationError("Record set missing required record: %s" % req) raise ValidationError("Record set missing required record: %s" % req)
else: else:
req_types.remove(req) req_types.remove(req)
def validate_record_order(records): def validate_record_order(records):
import record from . import record
from fields import ValidationError from .fields import ValidationError
# 1st record must be SubmitterRecord # 1st record must be SubmitterRecord
if not isinstance(records[0], record.SubmitterRecord): if not isinstance(records[0], record.SubmitterRecord):
@ -178,10 +208,10 @@ def validate_record_order(records):
if not isinstance(records[i+1], record.EmployeeWageRecord): if not isinstance(records[i+1], record.EmployeeWageRecord):
raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord") raise ValidationError("All EmployerRecords must be followed by an EmployeeWageRecord")
num_ro_records = len(filter(lambda x:isinstance(x, record.OptionalEmployeeWageRecord), records)) num_ro_records = len([x for x in records if isinstance(x, record.OptionalEmployeeWageRecord)])
num_ru_records = len(filter(lambda x:isinstance(x, record.OptionalTotalRecord), records)) num_ru_records = len([x for x in records if isinstance(x, record.OptionalTotalRecord)])
num_employer_records = len(filter(lambda x:isinstance(x, record.EmployerRecord), records)) num_employer_records = len([x for x in records if isinstance(x, record.EmployerRecord)])
num_total_records = len(filter(lambda x: isinstance(x, record.TotalRecord), records)) num_total_records = len([x for x in records if isinstance(x, record.TotalRecord)])
# a TotalRecord is required for each instance of an EmployeeRecord # a TotalRecord is required for each instance of an EmployeeRecord
if num_total_records != num_employer_records: if num_total_records != num_employer_records:
@ -194,7 +224,7 @@ def validate_record_order(records):
num_ro_records, num_ru_records)) num_ro_records, num_ru_records))
# FinalRecord - Must appear only once on each file. # FinalRecord - Must appear only once on each file.
if len(filter(lambda x:isinstance(x, record.FinalRecord), records)) != 1: if len([x for x in records if isinstance(x, record.FinalRecord)]) != 1:
raise ValidationError("Incorrect number of FinalRecords") raise ValidationError("Incorrect number of FinalRecords")
def validate_records(records): def validate_records(records):
@ -207,13 +237,8 @@ def test_unique_fields():
r1.employee_first_name.value = "John Johnson" r1.employee_first_name.value = "John Johnson"
r2 = EmployeeWageRecord() r2 = EmployeeWageRecord()
print 'r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter print('r1:', r1.employee_first_name.value, r1.employee_first_name, r1.employee_first_name.creation_counter)
print 'r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter print('r2:', r2.employee_first_name.value, r2.employee_first_name, r2.employee_first_name.creation_counter)
if r1.employee_first_name.value == r2.employee_first_name.value: if r1.employee_first_name.value == r2.employee_first_name.value:
raise ValidationError("Horrible problem involving shared values across records") raise ValidationError("Horrible problem involving shared values across records")
#def state_postal_code(state_abbr):
# import enums
# return enums.state_postal_numeric[ state_abbr.upper() ]

View file

@ -1,338 +1,340 @@
state_postal_numeric = { state_postal_numeric = {
'AL': 1, 'AL': 1,
'AK': 2, 'AK': 2,
'AZ': 4, 'AZ': 4,
'AR': 5, 'AR': 5,
'CA': 6, 'CA': 6,
'CO': 8, 'CO': 8,
'CT': 9, 'CT': 9,
'DE': 10, 'DE': 10,
'DC': 11, 'DC': 11,
'FL': 12, 'FL': 12,
'GA': 13, 'GA': 13,
'HI': 15, 'HI': 15,
'ID': 16, 'ID': 16,
'IL': 17, 'IL': 17,
'IN': 18, 'IN': 18,
'IA': 19, 'IA': 19,
'KS': 20, 'KS': 20,
'KY': 21, 'KY': 21,
'LA': 22, 'LA': 22,
'ME': 23, 'ME': 23,
'MD': 24, 'MD': 24,
'MA': 25, 'MA': 25,
'MI': 26, 'MI': 26,
'MN': 27, 'MN': 27,
'MS': 28, 'MS': 28,
'MO': 29, 'MO': 29,
'MT': 30, 'MT': 30,
'NE': 31, 'NE': 31,
'NV': 32, 'NV': 32,
'NH': 33, 'NH': 33,
'NJ': 34, 'NJ': 34,
'NM': 35, 'NM': 35,
'NY': 36, 'NY': 36,
'NC': 37, 'NC': 37,
'ND': 38, 'ND': 38,
'OH': 39, 'OH': 39,
'OK': 40, 'OK': 40,
'OR': 41, 'OR': 41,
'PA': 42, 'PA': 42,
'RI': 44, 'RI': 44,
'SC': 45, 'SC': 45,
'SD': 46, 'SD': 46,
'TN': 47, 'TN': 47,
'TX': 48, 'TX': 48,
'UT': 49, 'UT': 49,
'VT': 50, 'VT': 50,
'VA': 51, 'VA': 51,
'WA': 53, 'WA': 53,
'WV': 54, 'WV': 54,
'WI': 55, 'WI': 55,
'WY': 56, 'WY': 56,
} }
countries = ( countries = (
('AF', 'Afghanistan'), ('AF', 'Afghanistan'),
('AX', 'Aland Islands'), ('AX', 'Aland Islands'),
('AL', 'Albania'), ('AL', 'Albania'),
('DZ', 'Algeria'), ('DZ', 'Algeria'),
('AS', 'American Samoa'), ('AS', 'American Samoa'),
('AD', 'Andorra'), ('AD', 'Andorra'),
('AO', 'Angola'), ('AO', 'Angola'),
('AI', 'Anguilla'), ('AI', 'Anguilla'),
('AQ', 'Antarctica'), ('AQ', 'Antarctica'),
('AG', 'Antigua and Barbuda'), ('AG', 'Antigua and Barbuda'),
('AR', 'Argentina'), ('AR', 'Argentina'),
('AM', 'Armenia'), ('AM', 'Armenia'),
('AW', 'Aruba'), ('AW', 'Aruba'),
('AU', 'Australia'), ('AU', 'Australia'),
('AT', 'Austria'), ('AT', 'Austria'),
('AZ', 'Azerbaijan'), ('AZ', 'Azerbaijan'),
('BS', 'Bahamas'), ('BS', 'Bahamas'),
('BH', 'Bahrain'), ('BH', 'Bahrain'),
('BD', 'Bangladesh'), ('BD', 'Bangladesh'),
('BB', 'Barbados'), ('BB', 'Barbados'),
('BY', 'Belarus'), ('BY', 'Belarus'),
('BE', 'Belgium'), ('BE', 'Belgium'),
('BZ', 'Belize'), ('BZ', 'Belize'),
('BJ', 'Benin'), ('BJ', 'Benin'),
('BM', 'Bermuda'), ('BM', 'Bermuda'),
('BT', 'Bhutan'), ('BT', 'Bhutan'),
('BO', 'Bolivia, Plurinational State of'), ('BO', 'Bolivia, Plurinational State of'),
('BQ', 'Bonaire, Saint Eustatius and Saba'), ('BQ', 'Bonaire, Saint Eustatius and Saba'),
('BA', 'Bosnia and Herzegovina'), ('BA', 'Bosnia and Herzegovina'),
('BW', 'Botswana'), ('BW', 'Botswana'),
('BV', 'Bouvet Island'), ('BV', 'Bouvet Island'),
('BR', 'Brazil'), ('BR', 'Brazil'),
('IO', 'British Indian Ocean Territory'), ('IO', 'British Indian Ocean Territory'),
('BN', 'Brunei Darussalam'), ('BN', 'Brunei Darussalam'),
('BG', 'Bulgaria'), ('BG', 'Bulgaria'),
('BF', 'Burkina Faso'), ('BF', 'Burkina Faso'),
('BI', 'Burundi'), ('BI', 'Burundi'),
('KH', 'Cambodia'), ('KH', 'Cambodia'),
('CM', 'Cameroon'), ('CM', 'Cameroon'),
('CA', 'Canada'), ('CA', 'Canada'),
('CV', 'Cape Verde'), ('CV', 'Cape Verde'),
('KY', 'Cayman Islands'), ('KY', 'Cayman Islands'),
('CF', 'Central African Republic'), ('CF', 'Central African Republic'),
('TD', 'Chad'), ('TD', 'Chad'),
('CL', 'Chile'), ('CL', 'Chile'),
('CN', 'China'), ('CN', 'China'),
('CX', 'Christmas Island'), ('CX', 'Christmas Island'),
('CC', 'Cocos (Keeling) Islands'), ('CC', 'Cocos (Keeling) Islands'),
('CO', 'Colombia'), ('CO', 'Colombia'),
('KM', 'Comoros'), ('KM', 'Comoros'),
('CG', 'Congo'), ('CG', 'Congo'),
('CD', 'Congo, The Democratic Republic of the'), ('CD', 'Congo, The Democratic Republic of the'),
('CK', 'Cook Islands'), ('CK', 'Cook Islands'),
('CR', 'Costa Rica'), ('CR', 'Costa Rica'),
('CI', "Cote D'ivoire"), ('CI', "Cote D'ivoire"),
('HR', 'Croatia'), ('HR', 'Croatia'),
('CU', 'Cuba'), ('CU', 'Cuba'),
('CW', 'Curacao'), ('CW', 'Curacao'),
('CY', 'Cyprus'), ('CY', 'Cyprus'),
('CZ', 'Czech Republic'), ('CZ', 'Czech Republic'),
('DK', 'Denmark'), ('DK', 'Denmark'),
('DJ', 'Djibouti'), ('DJ', 'Djibouti'),
('DM', 'Dominica'), ('DM', 'Dominica'),
('DO', 'Dominican Republic'), ('DO', 'Dominican Republic'),
('EC', 'Ecuador'), ('EC', 'Ecuador'),
('EG', 'Egypt'), ('EG', 'Egypt'),
('SV', 'El Salvador'), ('SV', 'El Salvador'),
('GQ', 'Equatorial Guinea'), ('GQ', 'Equatorial Guinea'),
('ER', 'Eritrea'), ('ER', 'Eritrea'),
('EE', 'Estonia'), ('EE', 'Estonia'),
('ET', 'Ethiopia'), ('ET', 'Ethiopia'),
('FK', 'Falkland Islands (Malvinas)'), ('FK', 'Falkland Islands (Malvinas)'),
('FO', 'Faroe Islands'), ('FO', 'Faroe Islands'),
('FJ', 'Fiji'), ('FJ', 'Fiji'),
('FI', 'Finland'), ('FI', 'Finland'),
('FR', 'France'), ('FR', 'France'),
('GF', 'French Guiana'), ('GF', 'French Guiana'),
('PF', 'French Polynesia'), ('PF', 'French Polynesia'),
('TF', 'French Southern Territories'), ('TF', 'French Southern Territories'),
('GA', 'Gabon'), ('GA', 'Gabon'),
('GM', 'Gambia'), ('GM', 'Gambia'),
('GE', 'Georgia'), ('GE', 'Georgia'),
('DE', 'Germany'), ('DE', 'Germany'),
('GH', 'Ghana'), ('GH', 'Ghana'),
('GI', 'Gibraltar'), ('GI', 'Gibraltar'),
('GR', 'Greece'), ('GR', 'Greece'),
('GL', 'Greenland'), ('GL', 'Greenland'),
('GD', 'Grenada'), ('GD', 'Grenada'),
('GP', 'Guadeloupe'), ('GP', 'Guadeloupe'),
('GU', 'Guam'), ('GU', 'Guam'),
('GT', 'Guatemala'), ('GT', 'Guatemala'),
('GG', 'Guernsey'), ('GG', 'Guernsey'),
('GN', 'Guinea'), ('GN', 'Guinea'),
('GW', 'Guinea-Bissau'), ('GW', 'Guinea-Bissau'),
('GY', 'Guyana'), ('GY', 'Guyana'),
('HT', 'Haiti'), ('HT', 'Haiti'),
('HM', 'Heard Island and McDonald Islands'), ('HM', 'Heard Island and McDonald Islands'),
('VA', 'Holy See (Vatican City State)'), ('VA', 'Holy See (Vatican City State)'),
('HN', 'Honduras'), ('HN', 'Honduras'),
('HK', 'Hong Kong'), ('HK', 'Hong Kong'),
('HU', 'Hungary'), ('HU', 'Hungary'),
('IS', 'Iceland'), ('IS', 'Iceland'),
('IN', 'India'), ('IN', 'India'),
('ID', 'Indonesia'), ('ID', 'Indonesia'),
('IR', 'Iran, Islamic Republic of'), ('IR', 'Iran, Islamic Republic of'),
('IQ', 'Iraq'), ('IQ', 'Iraq'),
('IE', 'Ireland'), ('IE', 'Ireland'),
('IM', 'Isle of Man'), ('IM', 'Isle of Man'),
('IL', 'Israel'), ('IL', 'Israel'),
('IT', 'Italy'), ('IT', 'Italy'),
('JM', 'Jamaica'), ('JM', 'Jamaica'),
('JP', 'Japan'), ('JP', 'Japan'),
('JE', 'Jersey'), ('JE', 'Jersey'),
('JO', 'Jordan'), ('JO', 'Jordan'),
('KZ', 'Kazakhstan'), ('KZ', 'Kazakhstan'),
('KE', 'Kenya'), ('KE', 'Kenya'),
('KI', 'Kiribati'), ('KI', 'Kiribati'),
('KP', "Korea, Democratic People's Republic of"), ('KP', "Korea, Democratic People's Republic of"),
('KR', 'Korea, Republic of'), ('KR', 'Korea, Republic of'),
('KW', 'Kuwait'), ('KW', 'Kuwait'),
('KG', 'Kyrgyzstan'), ('KG', 'Kyrgyzstan'),
('LA', "Lao People's Democratic Republic"), ('LA', "Lao People's Democratic Republic"),
('LV', 'Latvia'), ('LV', 'Latvia'),
('LB', 'Lebanon'), ('LB', 'Lebanon'),
('LS', 'Lesotho'), ('LS', 'Lesotho'),
('LR', 'Liberia'), ('LR', 'Liberia'),
('LY', 'Libyan Arab Jamahiriya'), ('LY', 'Libyan Arab Jamahiriya'),
('LI', 'Liechtenstein'), ('LI', 'Liechtenstein'),
('LT', 'Lithuania'), ('LT', 'Lithuania'),
('LU', 'Luxembourg'), ('LU', 'Luxembourg'),
('MO', 'Macao'), ('MO', 'Macao'),
('MK', 'Macedonia, The Former Yugoslav Republic of'), ('MK', 'Macedonia, The Former Yugoslav Republic of'),
('MG', 'Madagascar'), ('MG', 'Madagascar'),
('MW', 'Malawi'), ('MW', 'Malawi'),
('MY', 'Malaysia'), ('MY', 'Malaysia'),
('MV', 'Maldives'), ('MV', 'Maldives'),
('ML', 'Mali'), ('ML', 'Mali'),
('MT', 'Malta'), ('MT', 'Malta'),
('MH', 'Marshall Islands'), ('MH', 'Marshall Islands'),
('MQ', 'Martinique'), ('MQ', 'Martinique'),
('MR', 'Mauritania'), ('MR', 'Mauritania'),
('MU', 'Mauritius'), ('MU', 'Mauritius'),
('YT', 'Mayotte'), ('YT', 'Mayotte'),
('MX', 'Mexico'), ('MX', 'Mexico'),
('FM', 'Micronesia, Federated States of'), ('FM', 'Micronesia, Federated States of'),
('MD', 'Moldova, Republic of'), ('MD', 'Moldova, Republic of'),
('MC', 'Monaco'), ('MC', 'Monaco'),
('MN', 'Mongolia'), ('MN', 'Mongolia'),
('ME', 'Montenegro'), ('ME', 'Montenegro'),
('MS', 'Montserrat'), ('MS', 'Montserrat'),
('MA', 'Morocco'), ('MA', 'Morocco'),
('MZ', 'Mozambique'), ('MZ', 'Mozambique'),
('MM', 'Myanmar'), ('MM', 'Myanmar'),
('NA', 'Namibia'), ('NA', 'Namibia'),
('NR', 'Nauru'), ('NR', 'Nauru'),
('NP', 'Nepal'), ('NP', 'Nepal'),
('NL', 'Netherlands'), ('NL', 'Netherlands'),
('NC', 'New Caledonia'), ('NC', 'New Caledonia'),
('NZ', 'New Zealand'), ('NZ', 'New Zealand'),
('NI', 'Nicaragua'), ('NI', 'Nicaragua'),
('NE', 'Niger'), ('NE', 'Niger'),
('NG', 'Nigeria'), ('NG', 'Nigeria'),
('NU', 'Niue'), ('NU', 'Niue'),
('NF', 'Norfolk Island'), ('NF', 'Norfolk Island'),
('MP', 'Northern Mariana Islands'), ('MP', 'Northern Mariana Islands'),
('NO', 'Norway'), ('NO', 'Norway'),
('OM', 'Oman'), ('OM', 'Oman'),
('PK', 'Pakistan'), ('PK', 'Pakistan'),
('PW', 'Palau'), ('PW', 'Palau'),
('PS', 'Palestinian Territory, Occupied'), ('PS', 'Palestinian Territory, Occupied'),
('PA', 'Panama'), ('PA', 'Panama'),
('PG', 'Papua New Guinea'), ('PG', 'Papua New Guinea'),
('PY', 'Paraguay'), ('PY', 'Paraguay'),
('PE', 'Peru'), ('PE', 'Peru'),
('PH', 'Philippines'), ('PH', 'Philippines'),
('PN', 'Pitcairn'), ('PN', 'Pitcairn'),
('PL', 'Poland'), ('PL', 'Poland'),
('PT', 'Portugal'), ('PT', 'Portugal'),
('PR', 'Puerto Rico'), ('PR', 'Puerto Rico'),
('QA', 'Qatar'), ('QA', 'Qatar'),
('RE', 'Reunion'), ('RE', 'Reunion'),
('RO', 'Romania'), ('RO', 'Romania'),
('RU', 'Russian Federation'), ('RU', 'Russian Federation'),
('RW', 'Rwanda'), ('RW', 'Rwanda'),
('BL', 'Saint Barthelemy'), ('BL', 'Saint Barthelemy'),
('SH', 'Saint Helena, Ascension and Tristan Da Cunha'), ('SH', 'Saint Helena, Ascension and Tristan Da Cunha'),
('KN', 'Saint Kitts and Nevis'), ('KN', 'Saint Kitts and Nevis'),
('LC', 'Saint Lucia'), ('LC', 'Saint Lucia'),
('MF', 'Saint Martin (French Part)'), ('MF', 'Saint Martin (French Part)'),
('PM', 'Saint Pierre and Miquelon'), ('PM', 'Saint Pierre and Miquelon'),
('VC', 'Saint Vincent and the Grenadines'), ('VC', 'Saint Vincent and the Grenadines'),
('WS', 'Samoa'), ('WS', 'Samoa'),
('SM', 'San Marino'), ('SM', 'San Marino'),
('ST', 'Sao Tome and Principe'), ('ST', 'Sao Tome and Principe'),
('SA', 'Saudi Arabia'), ('SA', 'Saudi Arabia'),
('SN', 'Senegal'), ('SN', 'Senegal'),
('RS', 'Serbia'), ('RS', 'Serbia'),
('SC', 'Seychelles'), ('SC', 'Seychelles'),
('SL', 'Sierra Leone'), ('SL', 'Sierra Leone'),
('SG', 'Singapore'), ('SG', 'Singapore'),
('SX', 'Sint Maarten (Dutch Part)'), ('SX', 'Sint Maarten (Dutch Part)'),
('SK', 'Slovakia'), ('SK', 'Slovakia'),
('SI', 'Slovenia'), ('SI', 'Slovenia'),
('SB', 'Solomon Islands'), ('SB', 'Solomon Islands'),
('SO', 'Somalia'), ('SO', 'Somalia'),
('ZA', 'South Africa'), ('ZA', 'South Africa'),
('GS', 'South Georgia and the South Sandwich Islands'), ('GS', 'South Georgia and the South Sandwich Islands'),
('ES', 'Spain'), ('ES', 'Spain'),
('LK', 'Sri Lanka'), ('LK', 'Sri Lanka'),
('SD', 'Sudan'), ('SD', 'Sudan'),
('SR', 'Suriname'), ('SR', 'Suriname'),
('SJ', 'Svalbard and Jan Mayen'), ('SJ', 'Svalbard and Jan Mayen'),
('SZ', 'Swaziland'), ('SZ', 'Swaziland'),
('SE', 'Sweden'), ('SE', 'Sweden'),
('CH', 'Switzerland'), ('CH', 'Switzerland'),
('SY', 'Syrian Arab Republic'), ('SY', 'Syrian Arab Republic'),
('TW', 'Taiwan, Province of China'), ('TW', 'Taiwan, Province of China'),
('TJ', 'Tajikistan'), ('TJ', 'Tajikistan'),
('TZ', 'Tanzania, United Republic of'), ('TZ', 'Tanzania, United Republic of'),
('TH', 'Thailand'), ('TH', 'Thailand'),
('TL', 'Timor-Leste'), ('TL', 'Timor-Leste'),
('TG', 'Togo'), ('TG', 'Togo'),
('TK', 'Tokelau'), ('TK', 'Tokelau'),
('TO', 'Tonga'), ('TO', 'Tonga'),
('TT', 'Trinidad and Tobago'), ('TT', 'Trinidad and Tobago'),
('TN', 'Tunisia'), ('TN', 'Tunisia'),
('TR', 'Turkey'), ('TR', 'Turkey'),
('TM', 'Turkmenistan'), ('TM', 'Turkmenistan'),
('TC', 'Turks and Caicos Islands'), ('TC', 'Turks and Caicos Islands'),
('TV', 'Tuvalu'), ('TV', 'Tuvalu'),
('UG', 'Uganda'), ('UG', 'Uganda'),
('UA', 'Ukraine'), ('UA', 'Ukraine'),
('AE', 'United Arab Emirates'), ('AE', 'United Arab Emirates'),
('GB', 'United Kingdom'), ('GB', 'United Kingdom'),
('US', 'United States'), ('US', 'United States'),
('UM', 'United States Minor Outlying Islands'), ('UM', 'United States Minor Outlying Islands'),
('UY', 'Uruguay'), ('UY', 'Uruguay'),
('UZ', 'Uzbekistan'), ('UZ', 'Uzbekistan'),
('VU', 'Vanuatu'), ('VU', 'Vanuatu'),
('VE', 'Venezuela, Bolivarian Republic of'), ('VE', 'Venezuela, Bolivarian Republic of'),
('VN', 'Viet Nam'), ('VN', 'Viet Nam'),
('VG', 'Virgin Islands, British'), ('VG', 'Virgin Islands, British'),
('VI', 'Virgin Islands, U.S.'), ('VI', 'Virgin Islands, U.S.'),
('WF', 'Wallis and Futuna'), ('WF', 'Wallis and Futuna'),
('EH', 'Western Sahara'), ('EH', 'Western Sahara'),
('YE', 'Yemen'), ('YE', 'Yemen'),
('ZM', 'Zambia'), ('ZM', 'Zambia'),
('ZW', 'Zimbabwe')) ('ZW', 'Zimbabwe'),
)
employer_types = ( employer_types = (
('F','Federal Government'), ('F','Federal Government'),
('S','State and Local Governmental Employer'), ('S','State and Local Governmental Employer'),
('T','Tax Exempt Employer'), ('T','Tax Exempt Employer'),
('Y','State and Local Tax Exempt Employer'), ('Y','State and Local Tax Exempt Employer'),
('N','None Apply'), ('N','None Apply'),
) )
employment_codes = ( employment_codes = (
('A', 'Agriculture'), ('A', 'Agriculture'),
('H', 'Household'), ('H', 'Household'),
('M', 'Military'), ('M', 'Military'),
('Q', 'Medicare Qualified Government Employee'), ('Q', 'Medicare Qualified Government Employee'),
('X', 'Railroad'), ('X', 'Railroad'),
('F', 'Regular'), ('F', 'Regular'),
('R', 'Regular (all others)'), ('R', 'Regular (all others)'),
) )
tax_jurisdiction_codes = ( tax_jurisdiction_codes = (
('V', 'Virgin Islands'), (' ', 'W-2'),
('G', 'Guam'), ('V', 'Virgin Islands'),
('S', 'American Samoa'), ('G', 'Guam'),
('N', 'Northern Mariana Islands'), ('S', 'American Samoa'),
('P', 'Puerto Rico'), ('N', 'Northern Mariana Islands'),
) ('P', 'Puerto Rico'),
)
tax_type_codes = ( tax_type_codes = (
('C', 'City Income Tax'), ('C', 'City Income Tax'),
('D', 'Country Income Tax'), ('D', 'Country Income Tax'),
('E', 'School District Income Tax'), ('E', 'School District Income Tax'),
('F', 'Other Income Tax'), ('F', 'Other Income Tax'),
) )

View file

@ -1,6 +1,10 @@
import decimal, datetime import decimal, datetime
import inspect import inspect
import enums from six import string_types
from . import enums
def is_blank_space(val):
return len(val.strip()) == 0
class ValidationError(Exception): class ValidationError(Exception):
def __init__(self, msg, field=None): def __init__(self, msg, field=None):
@ -16,22 +20,26 @@ class ValidationError(Exception):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
is_read_only = False
_value = None
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None):
self.name = name self.name = name
self._value = None self._value = None
self._orig_value = None self._orig_value = None
self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.blank = blank
self.required = required self.required = required
self.uppercase = uppercase self.uppercase = uppercase
self.creation_counter = creation_counter or Field.creation_counter self.creation_counter = creation_counter or Field.creation_counter
Field.creation_counter += 1 Field.creation_counter += 1
def validate(self): def validate(self):
raise NotImplemented raise NotImplementedError
def get_data(self): def get_data(self):
raise NotImplemented raise NotImplementedError
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -76,7 +84,7 @@ class Field(object):
required=o['required'], required=o['required'],
) )
if isinstance(o['value'], basestring) and re.match('^\d*\.\d*$', o['value']): if isinstance(o['value'], str) and re.match(r'^\d*\.\d*$', o['value']):
o['value'] = decimal.Decimal(o['value']) o['value'] = decimal.Decimal(o['value'])
self.value = o['value'] self.value = o['value']
@ -90,14 +98,10 @@ class Field(object):
wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False) wrapper = textwrap.TextWrapper(replace_whitespace=False, drop_whitespace=False)
wrapper.width = 100 wrapper.width = 100
value = wrapper.wrap(value) value = wrapper.wrap(value)
#value = textwrap.wrap(value, 100) value = list([(" " * 9) + ('"' + x + '"') for x in value])
#print value value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10))
value = list(map(lambda x:(" " * 9) + ('"' + x + '"'), value)) value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10))
#value[0] = '"' + value[0] + '"' value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))])))
value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * (wrapper.width / 10))
value.append(" " * 10 + ''.join((map(lambda x:str(x) + (' ' * 9), range(wrapper.width / 10 )))))
#value.append((" " * 59) + map(lambda x:("%x" % x), range(16))
start = counter['c'] start = counter['c']
counter['c'] += len(self._orig_value or self.value) counter['c'] += len(self._orig_value or self.value)
@ -115,22 +119,28 @@ class Field(object):
class TextField(Field): class TextField(Field):
def validate(self): def validate(self):
if self.value == None and self.required: if self.value is None and self.required:
raise ValidationError("value required", field=self) raise ValidationError("value required", field=self)
if len(self.get_data()) > self.max_length: data = self.get_data()
if len(data) > self.max_length:
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
stripped_data_length = len(data.strip())
if stripped_data_length < self.min_length:
raise ValidationError("value is too short", field=self)
if stripped_data_length == 0 and (not self.blank and self.required):
raise ValidationError("field cannot be blank", field=self)
def get_data(self): def get_data(self):
value = self.value or "" value = str(self.value or '').encode('ascii') or b''
if self.uppercase: if self.uppercase:
value = value.upper() value = value.upper()
return value.ljust(self.max_length).encode('ascii')[:self.max_length] return value.ljust(self.max_length)[:self.max_length]
def __setvalue(self, value): def __setvalue(self, value):
# NO NEWLINES # NO NEWLINES
try: try:
value = value.replace('\n', '').replace('\r', '') value = value.replace('\n', '').replace('\r', '')
except AttributeError, e: except AttributeError:
pass pass
self._value = value self._value = value
@ -142,31 +152,35 @@ class TextField(Field):
class StateField(TextField): class StateField(TextField):
def __init__(self, name=None, required=True, use_numeric=False, max_length=2): def __init__(self, name=None, required=True, use_numeric=False, max_length=2):
super(StateField, self).__init__(name=name, max_length=2, required=required) super(StateField, self).__init__(name=name, max_length=max_length, required=required)
self.use_numeric = use_numeric self.use_numeric = use_numeric
def get_data(self): def get_data(self):
value = self.value or "" value = str(self.value or 'XX')
if value.strip() and self.use_numeric: if value.strip() and self.use_numeric:
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length) postcode = enums.state_postal_numeric[value.upper()]
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length)
else: else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length] formatted = value.encode('ascii').ljust(self.max_length)
return formatted[:self.max_length]
def validate(self): def validate(self):
super(StateField, self).validate() super(StateField, self).validate()
if self.value and self.value.upper() not in enums.state_postal_numeric.keys(): if self.value and self.value.upper() not in list(enums.state_postal_numeric.keys()):
raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self) raise ValidationError("%s is not a valid state abbreviation" % self.value, field=self)
def parse(self, s): def parse(self, s):
if s.strip() and self.use_numeric: if s.strip() and self.use_numeric:
states = dict( [(v,k) for (k,v) in enums.state_postal_numeric.items()] ) states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
self.value = states[int(s)] self.value = states[int(s)]
else: else:
self.value = s self.value = s
class EmailField(TextField): class EmailField(TextField):
def __init__(self, name=None, required=True, max_length=None): def __init__(self, name=None, required=True, max_length=None):
return super(EmailField, self).__init__(name=name, max_length=max_length, super(EmailField, self).__init__(name=name, max_length=max_length,
required=required, uppercase=False) required=required, uppercase=False)
class IntegerField(TextField): class IntegerField(TextField):
@ -178,37 +192,58 @@ class IntegerField(TextField):
except ValueError: except ValueError:
raise ValidationError("field contains non-numeric characters", field=self) raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self): def get_data(self):
value = self.value or "" value = str(self.value).encode('ascii') if self.value else b''
return str(value).zfill(self.max_length)[:self.max_length] return value.zfill(self.max_length)[:self.max_length]
def parse(self, s): def parse(self, s):
self.value = int(s) if not is_blank_space(s):
self.value = int(s)
else:
self.value = 0
class StaticField(TextField): class StaticField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None, uppercase=False):
super(StaticField, self).__init__(name=name, required=required, super(StaticField, self).__init__(name=name,
max_length=len(value)) required=required,
max_length=len(value),
uppercase=uppercase)
self._static_value = value
self._value = value self._value = value
def parse(self, s): def parse(self, s):
pass pass
class BlankField(TextField): class BlankField(TextField):
is_read_only = True
def __init__(self, name=None, max_length=0, required=False): def __init__(self, name=None, max_length=0, required=False):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) super(BlankField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self): def get_data(self):
return " " * self.max_length return b' ' * self.max_length
def parse(self, s): def parse(self, s):
pass pass
def validate(self):
if len(self.get_data()) != self.max_length:
raise ValidationError("blank field did not match expected length", field=self)
class ZeroField(BlankField):
is_read_only = True
def get_data(self):
return b'0' * self.max_length
class CRLFField(TextField): class CRLFField(TextField):
is_read_only = True
def __init__(self, name=None, required=False): def __init__(self, name=None, required=False):
super(TextField, self).__init__(name=name, max_length=2, required=required, uppercase=False) super(CRLFField, self).__init__(name=name, max_length=2, required=required, uppercase=False)
def __setvalue(self, value): def __setvalue(self, value):
self._value = value self._value = value
@ -219,11 +254,12 @@ class CRLFField(TextField):
value = property(__getvalue, __setvalue) value = property(__getvalue, __setvalue)
def get_data(self): def get_data(self):
return '\r\n' return b'\r\n'
def parse(self, s): def parse(self, s):
self.value = s self.value = s
class BooleanField(Field): class BooleanField(Field):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(BooleanField, self).__init__(name=name, required=required, max_length=1) super(BooleanField, self).__init__(name=name, required=required, max_length=1)
@ -233,7 +269,7 @@ class BooleanField(Field):
pass pass
def get_data(self): def get_data(self):
return '1' if self._value else '0' return b'1' if self._value else b'0'
def parse(self, s): def parse(self, s):
self.value = (s == '1') self.value = (s == '1')
@ -250,26 +286,43 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
def get_data(self): def get_data(self):
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length] cents = int((self.value or 0) * 100)
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
def parse(self, s): def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01') if not is_blank_space(s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
else:
self.value = decimal.Decimal(0.0)
def __setvalue(self, value):
new_value = value
if isinstance(new_value, string_types):
new_value = decimal.Decimal(new_value or '0')
if '.' not in value: # must be cents?
new_value *= decimal.Decimal('100.')
self._value = new_value
def __getvalue(self):
return self._value
value = property(__getvalue, __setvalue)
class DateField(TextField): class DateField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=8) super(DateField, self).__init__(name=name, required=required, max_length=8)
if value: if value:
self.value = value self.value = value
def get_data(self): def get_data(self):
if self._value: if self._value:
return self._value.strftime('%m%d%Y') return self._value.strftime('%m%d%Y').encode('ascii')
return '0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:
self.value = datetime.date(*[int(x) for x in s[4:8], s[0:2], s[2:4]]) self.value = datetime.date(*[int(x) for x in (s[4:8], s[0:2], s[2:4])])
else: else:
self.value = None self.value = None
@ -277,7 +330,7 @@ class DateField(TextField):
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
self._value = value self._value = value
elif value: elif value:
self._value = datetime.date(*[int(x) for x in value[4:8], value[0:2], value[2:4]]) self._value = datetime.date(*[int(x) for x in (value[4:8], value[0:2], value[2:4])])
else: else:
self._value = None self._value = None
@ -289,19 +342,18 @@ class DateField(TextField):
class MonthYearField(TextField): class MonthYearField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None):
super(TextField, self).__init__(name=name, required=required, max_length=6) super(MonthYearField, self).__init__(name=name, required=required, max_length=6)
if value: if value:
self.value = value self.value = value
def get_data(self): def get_data(self):
if self._value: if self._value:
return self._value.strftime("%m%Y") return str(self._value.strftime('%m%Y').encode('ascii'))
return '0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:
self.value = datetime.date(*[int(x) for x in s[2:6], s[0:2], 1]) self.value = datetime.date(*[int(x) for x in (s[2:6], s[0:2], 1)])
else: else:
self.value = None self.value = None
@ -309,7 +361,7 @@ class MonthYearField(TextField):
if isinstance(value, datetime.date): if isinstance(value, datetime.date):
self._value = value self._value = value
elif value: elif value:
self._value = datetime.date(*[int(x) for x in value[2:6], value[0:2], 1]) self._value = datetime.date(*[int(x) for x in (value[2:6], value[0:2], 1)])
else: else:
self._value = None self._value = None
@ -317,4 +369,3 @@ class MonthYearField(TextField):
return self._value return self._value
value = property(__getvalue, __setvalue) value = property(__getvalue, __setvalue)

View file

@ -1,15 +1,19 @@
from fields import Field, TextField, ValidationError from .fields import Field, TextField, ValidationError
import copy import copy
import pdb import collections
class Model(object): class Model(object):
record_length = -1
record_identifier = ' ' record_identifier = ' '
required = False required = False
target_size = 512 target_size = 512
def __init__(self): def __init__(self):
for (key, value) in self.__class__.__dict__.items(): if self.record_length == -1:
raise ValueError(self.record_length)
for (key, value) in list(self.__class__.__dict__.items()):
if isinstance(value, Field): if isinstance(value, Field):
# GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION # GRAB THE FIELD INSTANCE FROM THE CLASS DEFINITION
# AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE, # AND MAKE A LOCAL COPY FOR THIS RECORD'S INSTANCE,
@ -19,21 +23,31 @@ class Model(object):
if not src_field.name: if not src_field.name:
setattr(src_field, 'name', key) setattr(src_field, 'name', key)
setattr(src_field, 'parent_name', self.__class__.__name__) setattr(src_field, 'parent_name', self.__class__.__name__)
self.__dict__[key] = copy.copy(src_field) new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None
new_field_instance._value = new_field_instance.value
self.__dict__[key] = new_field_instance
def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, key) and isinstance(getattr(self, key), Field): if hasattr(self, key) and isinstance(getattr(self, key), Field):
getattr(self, key).value = value self.set_field_value(key, value)
else: else:
# MAYBE THIS SHOULD RAISE A PROPERTY ERROR? # MAYBE THIS SHOULD RAISE A PROPERTY ERROR?
self.__dict__[key] = value self.__dict__[key] = value
def set_field_value(self, field_name, value):
getattr(self, field_name).value = value
def get_fields(self): def get_fields(self):
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) identifier = TextField(
"record_identifier",
max_length = len(self.record_identifier),
blank = len(self.record_identifier) == 0,
creation_counter=-1)
identifier.value = self.record_identifier identifier.value = self.record_identifier
fields = [identifier] fields = [identifier]
for key in self.__class__.__dict__.keys(): for key in list(self.__class__.__dict__.keys()):
attr = getattr(self, key) attr = getattr(self, key)
if isinstance(attr, Field): if isinstance(attr, Field):
fields.append(attr) fields.append(attr)
@ -41,7 +55,7 @@ class Model(object):
def get_sorted_fields(self): def get_sorted_fields(self):
fields = self.get_fields() fields = self.get_fields()
fields.sort(key=lambda x:x.creation_counter) fields.sort(key=lambda x: x.creation_counter)
return fields return fields
def validate(self): def validate(self):
@ -50,27 +64,33 @@ class Model(object):
try: try:
custom_validator = getattr(self, 'validate_' + f.name) custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError, e: except AttributeError:
continue continue
if callable(custom_validator): if isinstance(custom_validator, collections.Callable):
custom_validator(f) custom_validator(f)
def output(self): def output(self, format='binary'):
result = ''.join([field.get_data() for field in self.get_sorted_fields()]) if format == 'text':
return self.output_text()
return self.output_efile()
if hasattr(self, 'record_length') and len(result) != self.record_length: def output_efile(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if self.record_length < 0 or len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result return result
def output_text(self):
fields = self.get_sorted_fields()[1:] # skip record identifier
fields = [field for field in fields if not field.is_read_only]
header = ''.join(['---', self.__class__.__name__, '\n'])
return header + '\n'.join([f.name + ': ' + (str(f.value) if f.value else '') for f in fields]) + '\n\n'
def read(self, fp): def read(self, fp):
# Skip the first record, since that's an identifier # Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]: for field in self.get_sorted_fields()[1:]:
field.read(fp) field.read(fp)
def toJSON(self): def toJSON(self):
return { return {
'__class__': self.__class__.__name__, '__class__': self.__class__.__name__,
@ -80,19 +100,17 @@ class Model(object):
def fromJSON(self, o): def fromJSON(self, o):
fields = o['fields'] fields = o['fields']
identifier, fields = fields[0], fields[1:]
assert(identifier.value == self.record_identifier)
for f in fields: for f in fields:
target = self.__dict__[f.name] target = self.__dict__[f.name]
if (target.required != f.required or if (target.required != f.required
target.max_length != f.max_length): or target.max_length != f.max_length):
print "Warning: value mismatch on import" print("Warning: value mismatch on import")
target._value = f._value target.value = f.value
#print (self.__dict__[f.name].name == f.name)
#self.__dict__[f.name].name == f.name
#self.__dict__[f.name].max_length == f.max_length
return self return self

View file

@ -1,86 +1,86 @@
#!/usr/bin/env python
import re import re
class ClassEntryCommentSequence(object): class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') re_rangecomment = re.compile(r'#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line): def __init__(self, classname, line):
self.classname = classname, self.classname = classname,
self.line = line self.line = line
self.lines = [] self.lines = []
def add_line(self, line): def add_line(self, line):
self.lines.append(line) self.lines.append(line)
def process(self): def process(self):
i = 0 i = 0
for (line_no, line) in enumerate(self.lines): for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line) match = self.re_rangecomment.search(line)
if match: if match:
(a, b) = match.groups() (a, b) = match.groups()
a = int(a) a = int(a)
if (i + 1) != a: if (i + 1) != a:
line_number = self.line + line_no line_number = self.line + line_no
print("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a)) print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
line_number, line.split(' ')[0].strip(), i+1, a)))
i = int(b) if b else a
i = int(b) if b else a
class ModelDefParser(object): class ModelDefParser(object):
re_triplequote = re.compile('"""') re_triplequote = re.compile('"""')
re_whitespace = re.compile("^(\s*)[^\s]+") re_whitespace = re.compile(r"^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$") re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass): def __init__(self, infile, entryclass):
self.infile = infile self.infile = infile
self.line = 0 self.line = 0
self.EntryClass = entryclass self.EntryClass = entryclass
def endclass(self): def endclass(self):
if self.current_class: if self.current_class:
self.current_class.process() self.current_class.process()
self.current_class = None self.current_class = None
def beginclass(self, classname, line): def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line) self.current_class = self.EntryClass(classname, line)
def parse(self): def parse(self):
infile = self.infile infile = self.infile
whitespace = 0 whitespace = 0
in_block_comment = False in_block_comment = False
self.current_class = None self.current_class = None
for line in infile: for line in infile:
self.line += 1 self.line += 1
if line.startswith('#'): if line.startswith('#'):
continue continue
if self.re_triplequote.search(line): if self.re_triplequote.search(line):
in_block_comment = not in_block_comment in_block_comment = not in_block_comment
if in_block_comment: if in_block_comment:
continue continue
match_whitespace = self.re_whitespace.match(line) match_whitespace = self.re_whitespace.match(line)
if match_whitespace: if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0]) match_whitespace = len(match_whitespace.groups()[0])
else: else:
match_whitespace = 0 match_whitespace = 0
classmatch = self.re_classdef.match(line) classmatch = self.re_classdef.match(line)
if classmatch: if classmatch:
classname, subclass = classmatch.groups() classname, _subclass = classmatch.groups()
self.beginclass(classname, self.line) self.beginclass(classname, self.line)
continue continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,5 +1,3 @@
#!/usr/bin/python
# coding=UTF-8
""" """
Parser utility to read data from Publication 1220 and Parser utility to read data from Publication 1220 and
convert it into python classes. convert it into python classes.
@ -7,6 +5,7 @@ convert it into python classes.
""" """
import re import re
import hashlib import hashlib
from functools import reduce
class SimpleDefParser(object): class SimpleDefParser(object):
def __init__(self): def __init__(self):
@ -34,7 +33,7 @@ class SimpleDefParser(object):
item = item.upper() item = item.upper()
if '-' in item: if '-' in item:
parts = map(lambda x:self._intify(x), item.split('-')) parts = [self._intify(x) for x in item.split('-')]
item = reduce(lambda x,y: y-x, parts) item = reduce(lambda x,y: y-x, parts)
else: else:
item = self._intify(item) item = self._intify(item)
@ -56,7 +55,7 @@ class LengthExpression(object):
self.exp_cache = {} self.exp_cache = {}
def __call__(self, value, exps): def __call__(self, value, exps):
return len(exps) == sum(map(lambda x: self.check(value, x), exps)) return len(exps) == sum([self.check(value, x) for x in exps])
def compile_exp(self, exp): def compile_exp(self, exp):
op, val = self.REG.match(exp).groups() op, val = self.REG.match(exp).groups()
@ -98,7 +97,7 @@ class RangeToken(BaseToken):
def value(self): def value(self):
if '-' not in self._value: if '-' not in self._value:
return 1 return 1
return reduce(lambda x,y: y-x, map(int, self._value.split('-')))+1 return reduce(lambda x,y: y-x, list(map(int, self._value.split('-'))))+1
@property @property
def end_position(self): def end_position(self):
@ -110,7 +109,7 @@ class RangeToken(BaseToken):
class NumericToken(BaseToken): class NumericToken(BaseToken):
regexp = re.compile('^(\d+)$') regexp = re.compile(r'^(\d+)$')
@property @property
def value(self): def value(self):
@ -118,7 +117,7 @@ class NumericToken(BaseToken):
class RecordBuilder(object): class RecordBuilder(object):
import fields from . import fields
entry_max_length = 4 entry_max_length = 4
@ -145,8 +144,7 @@ class RecordBuilder(object):
(re.compile(r'zero\-filled', re.IGNORECASE), +1), (re.compile(r'zero\-filled', re.IGNORECASE), +1),
(re.compile(r'leading zeroes', re.IGNORECASE), +1), (re.compile(r'leading zeroes', re.IGNORECASE), +1),
(re.compile(r'left-\justif', re.IGNORECASE), -1), (re.compile(r'left\-justif', re.IGNORECASE), -1),
], ],
}, },
}), }),
@ -201,15 +199,15 @@ class RecordBuilder(object):
try: try:
f_length = int(f_length) f_length = int(f_length)
except ValueError, e: except ValueError as e:
# bad result, skip # bad result, skip
continue continue
try: try:
assert f_length == RangeToken(f_range).value assert f_length == RangeToken(f_range).value
except AssertionError, e: except AssertionError as e:
continue continue
except ValueError, e: except ValueError as e:
# bad result, skip # bad result, skip
continue continue
@ -223,7 +221,7 @@ class RecordBuilder(object):
else: else:
required = None required = None
f_name = u'_'.join(map(lambda x:x.lower(), name_parts)) f_name = '_'.join([x.lower() for x in name_parts])
f_name = f_name.replace('&', 'and') f_name = f_name.replace('&', 'and')
f_name = re.sub(r'[^\w]','', f_name) f_name = re.sub(r'[^\w]','', f_name)
@ -240,7 +238,7 @@ class RecordBuilder(object):
lengthexp = LengthExpression() lengthexp = LengthExpression()
for entry in entries: for entry in entries:
matches = dict(map(lambda x:(x[0],0), self.FIELD_TYPES)) matches = dict([(x[0],0) for x in self.FIELD_TYPES])
for (classtype, criteria) in self.FIELD_TYPES: for (classtype, criteria) in self.FIELD_TYPES:
if 'length' in criteria: if 'length' in criteria:
@ -248,7 +246,7 @@ class RecordBuilder(object):
continue continue
if 'regexp' in criteria: if 'regexp' in criteria:
for crit_key, crit_values in criteria['regexp'].items(): for crit_key, crit_values in list(criteria['regexp'].items()):
for (crit_re, score) in crit_values: for (crit_re, score) in crit_values:
matches[classtype] += score if crit_re.search(entry[crit_key]) else 0 matches[classtype] += score if crit_re.search(entry[crit_key]) else 0
@ -256,7 +254,7 @@ class RecordBuilder(object):
matches = list(matches.items()) matches = list(matches.items())
matches.sort(key=lambda x:x[1]) matches.sort(key=lambda x:x[1])
matches_found = True if sum(map(lambda x:x[1], matches)) > 0 else False matches_found = True if sum([x[1] for x in matches]) > 0 else False
entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField entry['guessed_type'] = matches[-1][0] if matches_found else self.fields.TextField
yield entry yield entry
@ -271,7 +269,7 @@ class RecordBuilder(object):
if entry['name'] == 'blank': if entry['name'] == 'blank':
blank_id = hashlib.new('md5') blank_id = hashlib.new('md5')
blank_id.update(entry['range'].encode()) blank_id.update(entry['range'].encode())
add( (u'blank_%s' % blank_id.hexdigest()[:8]).ljust(40) ) add( ('blank_%s' % blank_id.hexdigest()[:8]).ljust(40) )
else: else:
add(entry['name'].ljust(40)) add(entry['name'].ljust(40))
@ -386,7 +384,7 @@ class PastedDefParser(RecordBuilder):
for g in groups: for g in groups:
assert g['byterange'].value == g['length'].value assert g['byterange'].value == g['length'].value
desc = u' '.join(map(lambda x:unicode(x.value), g['desc'])) desc = ' '.join([str(x.value) for x in g['desc']])
if g['name'][-1].value.lower() == '(optional)': if g['name'][-1].value.lower() == '(optional)':
g['name'] = g['name'][0:-1] g['name'] = g['name'][0:-1]
@ -396,7 +394,7 @@ class PastedDefParser(RecordBuilder):
else: else:
required = None required = None
name = u'_'.join(map(lambda x:x.value.lower(), g['name'])) name = '_'.join([x.value.lower() for x in g['name']])
name = re.sub(r'[^\w]','', name) name = re.sub(r'[^\w]','', name)
yield({ yield({

View file

@ -3,314 +3,102 @@
import subprocess import subprocess
import re import re
import pdb import itertools
import fitz
""" pdftotext -layout -nopgbrk p1220.pdf - """ """ pdftotext -layout -nopgbrk p1220.pdf - """
def strip_values(items):
expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE)
return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x]
class PDFRecordFinder(object): class PDFRecordFinder(object):
def __init__(self, src, heading_exp=None): field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$')
if not heading_exp:
heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout')
field_heading_exp = re.compile('^Field.*Field.*Length.*Description') def __init__(self, src):
self.document = fitz.open(src)
opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-'] def find_record_table_ranges(self):
pdftext = subprocess.check_output(opts) matches = []
self.textrows = pdftext.split('\n') for (page_number, page) in enumerate(self.document):
self.heading_exp = heading_exp header_rects = page.search_for("Record Name:")
self.field_heading_exp = field_heading_exp for header_match_rect in header_rects:
header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: "
header_match_rect.x1 = page.bound().x1 # Extend to right side of page
header_text = page.get_textbox(header_match_rect)
record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip()
matches.append((record_name, {
'page': page_number,
'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably
}))
return matches
def find_records(self):
record_ranges = self.find_record_table_ranges()
for record_index, (record_name, record_details) in enumerate(record_ranges):
current_rows = []
next_index = record_index+1
(_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1})
for page_number in range(record_details['page'], next_record_details['page']):
page = self.document[page_number]
table_search_rect = page.bound()
if page_number == record_details['page']:
table_search_rect.y0 = record_details['y']
tables = page.find_tables(
clip = table_search_rect,
min_words_horizontal = 1,
min_words_vertical = 1,
horizontal_strategy = "lines_strict",
intersection_tolerance = 1,
)
for table in tables:
if table.col_count == 4:
table = table.extract()
# Parse field position (sometimes a cell has multiple
# values because IRS employees apparently smoke crack
for row in table:
first_column_lines = row[0].strip().split('\n')
if len(first_column_lines) > 1:
for sub_row in self.split_row(row):
current_rows.append(strip_values(sub_row))
else:
current_rows.append(strip_values(row))
consecutive_rows = self.filter_nonconsecutive_rows(current_rows)
yield(record_name, consecutive_rows)
def split_row(self, row):
if not row[1]:
return []
split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None))
description = strip_values([row[3]])[0]
rows = []
for row in split_rows:
if len(row) < 3 or not row[2]:
row = self.infer_field_length(row)
rows.append([*row, description])
return rows
def infer_field_length(self, row):
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
return row
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
length = str(end-start+1) if end and start else '1'
return (*row[:2], length)
def filter_nonconsecutive_rows(self, rows):
consecutive_rows = []
last_position = 0
for row in rows:
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
continue
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
if start != last_position + 1:
continue
last_position = end if end else start
consecutive_rows.append(row)
return consecutive_rows
def records(self): def records(self):
headings = self.locate_heading_rows_by_field() return self.find_records()
#for x in headings:
# print x
for (start, end, name) in headings:
name = name.decode('ascii', 'ignore')
yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end))
def locate_heading_rows_by_field(self):
results = []
record_break = []
line_is_whitespace_exp = re.compile('^(\s*)$')
record_begin_exp = self.heading_exp #re.compile('Record\ Name')
for (i, row) in enumerate(self.textrows):
match = self.field_heading_exp.match(row)
if match:
# work backwards until we think the header is fully copied
space_count_exp = re.compile('^(\s*)')
position = i - 1
spaces = 0
#last_spaces = 10000
complete = False
header = None
while not complete:
line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False
is_record_begin = record_begin_exp.search(self.textrows[position])
if is_record_begin or line_is_whitespace:
header = self.textrows[position-1:i]
complete = True
position -= 1
name = ''.join(header).strip().decode('ascii','ignore')
print (name, position)
results.append((i, name, position))
else:
# See if this row forces us to break from field reading.
if re.search('Record\ Layout', row):
record_break.append(i)
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]):
end_pos = None
#print a[0], record_break[0], b[0]-1
while record_break and record_break[0] < a[0]:
record_break = record_break[1:]
if record_break[0] < b[0]-1:
end_pos = record_break[0]
record_break = record_break[1:]
else:
end_pos = b[0]-1
merged.append( (a[0], end_pos-1, a[1]) )
return merged
"""
def locate_heading_rows(self):
results = []
for (i, row) in enumerate(self.textrows):
match = self.heading_exp.match(row)
if match:
results.append((i, ''.join(match.groups())))
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]):
merged.append( (a[0], b[0]-1, a[1]) )
return merged
def locate_layout_block_rows(self):
# Search for rows that contain "Record Layout", as these are not fields
# we are interested in because they contain the crazy blocks of field definitions
# and not the nice 4-column ones that we're looking for.
results = []
for (i, row) in enumerate(self.textrows):
match = re.match("Record Layout", row)
"""
def find_fields(self, row_iter):
cc = ColumnCollector()
blank_row_counter = 0
for r in row_iter:
row = r.decode('UTF-8')
#print row
row_columns = self.extract_columns_from_row(row)
if not row_columns:
if cc.data and len(cc.data.keys()) > 1 and len(row.strip()) > cc.data.keys()[-1]:
yield cc
cc = ColumnCollector()
else:
cc.empty_row()
continue
try:
cc.add(row_columns)
except IsNextField, e:
yield cc
cc = ColumnCollector()
cc.add(row_columns)
except UnknownColumn, e:
raise StopIteration
yield cc
def extract_columns_from_row(self, row):
re_multiwhite = re.compile(r'\s{2,}')
# IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE
if not re_multiwhite.search(row):
return None
white_ranges = [0,]
pos = 0
while pos < len(row):
match = re_multiwhite.search(row[pos:])
if match:
white_ranges.append(pos + match.start())
white_ranges.append(pos + match.end())
pos += match.end()
else:
white_ranges.append(len(row))
pos = len(row)
row_result = []
white_iter = iter(white_ranges)
while white_iter:
try:
start = white_iter.next()
end = white_iter.next()
if start != end:
row_result.append(
(start, row[start:end].encode('ascii','ignore'))
)
except StopIteration:
white_iter = None
#print row_result
return row_result
class UnknownColumn(Exception):
pass
class IsNextField(Exception):
pass
class ColumnCollector(object):
def __init__(self, initial=None):
self.data = None
self.column_widths = None
self.max_data_length = 0
self.adjust_pad = 3
self.empty_rows = 0
pass
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
map(lambda x:x if len(x) < 25 else x[:25] + '..',
self.data.values() if self.data else ''))
def add(self, data):
#if self.empty_rows > 2:
# raise IsNextField()
if not self.data:
self.data = dict(data)
else:
data = self.adjust_columns(data)
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
self.update_column_widths(data)
def empty_row(self):
self.empty_rows += 1
def update_column_widths(self, data):
self.last_data_length = len(data)
self.max_data_length = max(self.max_data_length, len(data))
if not self.column_widths:
self.column_widths = dict(map(lambda (column, value): [column, column + len(value)], data))
else:
for col_id, value in data:
try:
self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip()))
except KeyError:
pass
def add_old(self, data):
if not self.data:
self.data = dict(data)
else:
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
def adjust_columns(self, data):
adjusted_data = {}
for col_id, value in data:
if col_id in self.data.keys():
adjusted_data[col_id] = value.strip()
else:
for col_start, col_end in self.column_widths.items():
if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id:
if col_start in adjusted_data:
adjusted_data[col_start] += ' ' + value.strip()
else:
adjusted_data[col_start] = value.strip()
return adjusted_data.items()
def merge_column(self, col_id, value):
if col_id in self.data.keys():
self.data[col_id] += ' ' + value.strip()
else:
# try adding a wiggle room value?
# FIXME:
# Sometimes description columns contain column-like
# layouts, and this causes the ColumnCollector to become
# confused. Perhaps we could check to see if a column occurs
# after the maximum column, and assume it's part of the
# max column?
"""
for col_start, col_end in self.column_widths.items():
if col_start <= col_id and (col_end) >= col_id:
self.data[col_start] += ' ' + value.strip()
return
"""
raise UnknownColumn
def is_next_field(self, data):
"""
If the first key value contains a string
and we already have some data in the record,
then this row is probably the beginning of
the next field. Raise an exception and continue
on with a fresh ColumnCollector.
"""
""" If the length of the value in column_id is less than the position of the next column_id,
then this is probably a continuation.
"""
if self.data and data:
keys = dict(self.column_widths).keys()
keys.sort()
keys += [None]
if self.last_data_length < len(data):
return True
first_key, first_value = dict(data).items()[0]
if self.data.keys()[0] == first_key:
position = keys.index(first_key)
max_length = keys[position + 1]
if max_length:
return len(first_value) > max_length or len(data) == self.max_data_length
return False
@property
def tuple(self):
#try:
if self.data:
return tuple(map(lambda k:self.data[k], sorted(self.data.keys())))
return ()
#except:
# import pdb
# pdb.set_trace()

View file

@ -1,11 +1,13 @@
import model from . import model
from fields import * from .fields import *
import enums from . import enums
__all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord', __all__ = RECORD_TYPES = ['SubmitterRecord', 'EmployerRecord',
'EmployeeWageRecord', 'OptionalEmployeeWageRecord', 'EmployeeWageRecord', 'OptionalEmployeeWageRecord',
'TotalRecord', 'OptionalTotalRecord', 'TotalRecord', 'OptionalTotalRecord',
'StateTotalRecord', 'FinalRecord', 'StateWageRecord'] 'StateTotalRecord', 'FinalRecord', 'StateWageRecord',
'StateTotalRecordIA',
]
class EFW2Record(model.Model): class EFW2Record(model.Model):
record_length = 512 record_length = 512
@ -103,8 +105,8 @@ class EmployerRecord(EFW2Record):
zipcode_ext = TextField(max_length=4, required=False) zipcode_ext = TextField(max_length=4, required=False)
kind_of_employer = TextField(max_length=1) kind_of_employer = TextField(max_length=1)
blank1 = BlankField(max_length=4) blank1 = BlankField(max_length=4)
foreign_state_province = TextField(max_length=23) foreign_state_province = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15) foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
employment_code = TextField(max_length=1) employment_code = TextField(max_length=1)
tax_jurisdiction_code = TextField(max_length=1, required=False) tax_jurisdiction_code = TextField(max_length=1, required=False)
@ -148,7 +150,7 @@ class EmployeeWageRecord(EFW2Record):
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -161,7 +163,7 @@ class EmployeeWageRecord(EFW2Record):
blank1 = BlankField(max_length=5) blank1 = BlankField(max_length=5)
foreign_state = TextField(max_length=23, required=False) foreign_state = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country = TextField(max_length=2) country = TextField(max_length=2, required=True, blank=True)
wages_tips = MoneyField(max_length=11) wages_tips = MoneyField(max_length=11)
federal_income_tax_withheld = MoneyField(max_length=11) federal_income_tax_withheld = MoneyField(max_length=11)
social_security_wages = MoneyField(max_length=11) social_security_wages = MoneyField(max_length=11)
@ -188,7 +190,8 @@ class EmployeeWageRecord(EFW2Record):
designated_roth_contrib_401k = MoneyField(max_length=11, required=False) designated_roth_contrib_401k = MoneyField(max_length=11, required=False)
designated_roth_contrib_403b = MoneyField(max_length=11, required=False) designated_roth_contrib_403b = MoneyField(max_length=11, required=False)
employer_sponsored_health = MoneyField(max_length=11, required=False) employer_sponsored_health = MoneyField(max_length=11, required=False)
blank4 = BlankField(max_length=12) permitted_benefits_health = MoneyField(max_length=11, required=False)
blank4 = BlankField(max_length=1)
statutory_employee_indicator = BooleanField() statutory_employee_indicator = BooleanField()
blank5 = BlankField(max_length=1) blank5 = BlankField(max_length=1)
retirement_plan_indicator = BooleanField() retirement_plan_indicator = BooleanField()
@ -196,8 +199,10 @@ class EmployeeWageRecord(EFW2Record):
blank6 = BlankField(max_length=23) blank6 = BlankField(max_length=23)
def validate_ssn(self, f): def validate_ssn(self, f):
if str(f.value).startswith('666','9'): if str(f.value).startswith('666'):
raise ValidationError("ssn cannot start with 666 or 9", field=f) raise ValidationError("ssn cannot start with 666", field=f)
if str(f.value).startswith('9'):
raise ValidationError("ssn cannot start with 9", field=f)
@ -240,7 +245,7 @@ class StateWageRecord(EFW2Record):
taxing_entity_code = TextField(max_length=5, required=False) taxing_entity_code = TextField(max_length=5, required=False)
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -254,20 +259,20 @@ class StateWageRecord(EFW2Record):
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
optional_code = TextField(max_length=2, required=False) optional_code = TextField(max_length=2, required=False)
reporting_period = MonthYearField() reporting_period = MonthYearField(required=False)
quarterly_unemp_ins_wages = MoneyField(max_length=11) quarterly_unemp_ins_wages = MoneyField(max_length=11)
quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11) quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11)
number_of_weeks_worked = IntegerField(max_length=2) number_of_weeks_worked = IntegerField(max_length=2, required=False)
date_first_employed = DateField(required=False) date_first_employed = DateField(required=False)
date_of_separation = DateField(required=False) date_of_separation = DateField(required=False)
blank2 = BlankField(max_length=5) blank2 = BlankField(max_length=5)
state_employer_account_num = TextField(max_length=20) state_employer_account_num = IntegerField(max_length=20, required=False)
blank3 = BlankField(max_length=6) blank3 = BlankField(max_length=6)
state_code_2 = StateField(use_numeric=True) state_code_2 = StateField(use_numeric=True)
state_taxable_wages = MoneyField(max_length=11) state_taxable_wages = MoneyField(max_length=11)
state_income_tax_wh = MoneyField(max_length=11) state_income_tax_wh = MoneyField(max_length=11)
other_state_data = TextField(max_length=10, required=False) other_state_data = TextField(max_length=10, required=False)
tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F
local_taxable_wages = MoneyField(max_length=11) local_taxable_wages = MoneyField(max_length=11)
local_income_tax_wh = MoneyField(max_length=11) local_income_tax_wh = MoneyField(max_length=11)
state_control_number = IntegerField(max_length=7, required=False) state_control_number = IntegerField(max_length=7, required=False)
@ -277,7 +282,8 @@ class StateWageRecord(EFW2Record):
def validate_tax_type_code(self, field): def validate_tax_type_code(self, field):
choices = [x for x,y in enums.tax_type_codes] choices = [x for x,y in enums.tax_type_codes]
if field.value.upper() not in choices: value = field.value
if value and value.upper() not in choices:
raise ValidationError("%s not one of %s" % (field.value,choices), field=f) raise ValidationError("%s not one of %s" % (field.value,choices), field=f)
@ -313,7 +319,8 @@ class TotalRecord(EFW2Record):
deferred_409a_compensation_plan = MoneyField(max_length=15, required=False) deferred_409a_compensation_plan = MoneyField(max_length=15, required=False)
designated_roth_contribs_401k = MoneyField(max_length=15, required=False) designated_roth_contribs_401k = MoneyField(max_length=15, required=False)
designated_roth_contribs_403b = MoneyField(max_length=15, required=False) designated_roth_contribs_403b = MoneyField(max_length=15, required=False)
blank2 = BlankField(max_length=113) permitted_benefits_health = MoneyField(max_length=15, required=False)
blank2 = BlankField(max_length=98)
class OptionalTotalRecord(EFW2Record): class OptionalTotalRecord(EFW2Record):
@ -352,6 +359,17 @@ class StateTotalRecord(EFW2Record):
supplemental_data = TextField(max_length=510) supplemental_data = TextField(max_length=510)
class StateTotalRecordIA(EFW2Record):
#year=2018
record_identifier = 'RV'
number_of_rs_records = IntegerField(max_length=7) # num records since last 'RE' record
wages_tips = MoneyField(max_length=15)
state_income_tax_wh = MoneyField(max_length=15)
employer_ben = TextField(max_length=8)
iowa_confirmation_number = ZeroField(max_length=10)
blank1 = BlankField(max_length=455)
class FinalRecord(EFW2Record): class FinalRecord(EFW2Record):
#year=2012 #year=2012
record_identifier = 'RF' record_identifier = 'RF'

1
requirements.txt Normal file
View file

@ -0,0 +1 @@
PyMuPDF==1.24.0

76
scripts/pyaccuwage-convert Executable file
View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
import pyaccuwage
import argparse
import os, os.path
import sys
"""
Command line tool for converting IRS e-file fixed field records
to/from JSON or a simple text format.
Attempts to load record types from a python module in the current working
directory named record_types.py
The module must export a RECORD_TYPES list with the names of the classes to
import as valid record types.
"""
def get_record_types():
try:
sys.path.append(os.getcwd())
import record_types
r = {}
for record_type in record_types.RECORD_TYPES:
r[record_type] = getattr(record_types, record_type)
return r
except ImportError:
print('warning: using default record types (failed to import record_types.py)')
return pyaccuwage.get_record_types()
def read_file(fd, filename, record_types):
filename, extension = os.path.splitext(filename)
if extension == '.json':
return pyaccuwage.json_load(fd, record_types)
elif extension == '.txt':
return pyaccuwage.text_load(fd, record_types)
else:
return pyaccuwage.load(fd, record_types)
def write_file(outfile, filename, records):
filename, extension = os.path.splitext(filename)
if extension == '.json':
pyaccuwage.json_dump(outfile, records)
elif extension == '.txt':
pyaccuwage.text_dump(outfile, records)
else:
pyaccuwage.dump(outfile, records)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Convert accuwage efile data between different formats."
)
parser.add_argument("-i", '--input',
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('r'),
help="Source file to convert")
parser.add_argument("-o", "--output",
nargs=1,
required=True,
metavar="file",
type=argparse.FileType('w'),
help="Destination file to output")
args = parser.parse_args()
in_file = args.input[0]
out_file = args.output[0]
records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records)
print("wrote {} records to {}".format(len(records), out_file.name))

View file

@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
from pyaccuwage.parser import RecordBuilder from pyaccuwage.parser import RecordBuilder
from pyaccuwage.pdfextract import PDFRecordFinder from pyaccuwage.pdfextract import PDFRecordFinder
import argparse import argparse
@ -29,48 +29,9 @@ doc = PDFRecordFinder(source_file)
records = doc.records() records = doc.records()
builder = RecordBuilder() builder = RecordBuilder()
def record_begins_at(field): for (name, fields) in records:
return int(fields[0].data.values()[0].split('-')[0], 10) name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1])
name = re.sub(r'[^\w]*', '', name)
def record_ends_at(fields): sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
return int(fields[-1].data.values()[0].split('-')[-1], 10) for field in builder.load(map(lambda x: x, fields[0:])):
last_record_begins_at = -1
last_record_ends_at = -1
for rec in records:
#if not rec[1]:
# continue # no actual fields detected
fields = rec[1]
# strip out fields that are not 4 items long
fields = filter(lambda x:len(x.tuple) == 4, fields)
# strip fields that don't begin at position 0
fields = filter(lambda x: 0 in x.data, fields)
# strip fields that don't have a length-range type item in position 0
fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields)
if not fields:
continue
begins_at = record_begins_at(fields)
ends_at = record_ends_at(fields)
# FIXME record_ends_at is randomly exploding due to record data being
# a lump of text and not necessarily a field entry. I assume
# this is cleaned out by the record builder class.
#print last_record_ends_at + 1, begins_at
if last_record_ends_at + 1 != begins_at:
name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1])
name = re.sub('[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x:x.tuple, rec[1][0:])):
sys.stdout.write('\t' + field + '\n') sys.stdout.write('\t' + field + '\n')
#print field
last_record_ends_at = ends_at

View file

@ -1,12 +1,21 @@
from distutils.core import setup from setuptools import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
setup(name='pyaccuwage', setup(name='pyaccuwage',
version='0.2012.1', version='0.2025.0',
packages=['pyaccuwage'], packages=['pyaccuwage'],
scripts=[ scripts=[
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-convert',
'scripts/pyaccuwage-genfieldfill',
'scripts/pyaccuwage-parse', 'scripts/pyaccuwage-parse',
'scripts/pyaccuwage-pdfparse', 'scripts/pyaccuwage-pdfparse',
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-genfieldfill'
], ],
zip_safe=True, zip_safe=True,
test_suite='setup.pyaccuwage_tests',
) )

67
tests/test_fields.py Normal file
View file

@ -0,0 +1,67 @@
import unittest
from pyaccuwage.fields import TextField
from pyaccuwage.fields import StaticField
# from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField
# from pyaccuwage.fields import ZeroField
# from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
data = field.get_data()
self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,')
def testStringUnsetOptional(self):
field = TextField(max_length=6, required=False)
field.validate()
self.assertEqual(field.get_data(), b' ' * 6)
def testStringRequiredUnassigned(self):
field = TextField(max_length=6)
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredNonBlank(self):
field = TextField(max_length=6)
field.value = ''
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredBlank(self):
field = TextField(max_length=6, blank=True)
field.value = ''
field.validate()
self.assertEqual(len(field.get_data()), 6)
def testStringMinimumLength(self):
field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect
field.value = '' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '12345' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '123456' # one character too short
class TestStaticField(unittest.TestCase):
def test_static_field(self):
field = StaticField(value='TEST')
self.assertEqual(field.get_data(), b'TEST')

179
tests/test_records.py Normal file
View file

@ -0,0 +1,179 @@
import unittest
import decimal
import pyaccuwage
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import StaticField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=12)
static1 = StaticField(value='hey mister!!')
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelBinaryOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'313377'.zfill(32),
b'IA',
b'19',
b' ' * 12,
b'hey mister!!',
])
output = model.output()
self.assertEqual(len(output), TestModelOutput.TestModel.record_length)
self.assertEqual(output, expected)
def testModelTextOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('3133.77')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
output = model.output(format='text')
self.assertEqual(output, '''---TestModel
field1: Hello, sir!
field2: 12345
money: 3133.77
state_txt: IA
state_num: IA
static1: hey mister!!
''')
class TestFileFormats(unittest.TestCase):
class TestModelA(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'A' # 1 byte
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=27)
class TestModelB(pyaccuwage.model.Model):
record_length = 128
record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71)
text2 = TextField(max_length=20, required=False)
blank2 = BlankField(max_length=4)
record_types = [TestModelA, TestModelB]
def createExampleRecords(self):
model_a = TestFileFormats.TestModelA()
model_a.field1.value = 'I am model a'
model_a.field2.value = 5522
model_a.money.value = decimal.Decimal('23.00')
model_a.state_txt.value = 'IA'
model_a.state_num.value = 'IA'
model_b = TestFileFormats.TestModelB()
model_b.text1.value = 'hey I am model b and I have a big text field'
return [
model_a,
model_b,
]
def testJSONSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
json_data = pyaccuwage.json_dumps(records)
records_loaded = pyaccuwage.json_loads(json_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
def testTxtSerialization(self):
records = self.createExampleRecords()
record_types = self.record_types
text_data = pyaccuwage.text_dumps(records)
records_loaded = pyaccuwage.text_loads(text_data, record_types)
original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes)
class TestRequiredFields(unittest.TestCase):
def createTestRecord(self, required=False, blank=False):
class Record(pyaccuwage.model.Model):
record_length = 16
record_identifier = ''
test_field = TextField(max_length=16, required=required, blank=blank)
record = Record()
def dump():
return pyaccuwage.dumps([record])
return (record, dump)
def testRequiredBlankField(self):
(record, dump) = self.createTestRecord(required=True, blank=True)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value may be empty string
dump()
def testRequiredNonblankField(self):
(record, dump) = self.createTestRecord(required=True, blank=False)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value must not be empty string
self.assertRaises(ValidationError, dump)
record.test_field.value = 'hello'
dump()
def testOptionalBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=True)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()
def testOptionalNonBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=False)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()