Compare commits

..

16 commits

Author SHA1 Message Date
9029659f98 update internal VERSION property 2025-05-13 12:45:51 -05:00
1302de9df7 bump version to 0.2025.0 2025-05-13 12:14:02 -05:00
fb8091fb09 change Iowa RS record state_employer_account_num from TextField to IntegerField 2025-05-13 12:09:49 -05:00
4408da71a9 mark some fields as optional 2024-04-10 09:41:10 -04:00
e0e4c1291d add min_length option to TextField for SSNs and stuff like that 2024-03-31 11:52:22 -04:00
5f4dc8b80f add 'blank' field option to allow empty text in required fields (default: false) 2024-03-31 11:14:16 -04:00
74b7935ced bump version to 2024 2024-03-29 10:50:25 -04:00
66573e4d1d update for 2023 p1220 parsing, stupid irs 2024-03-29 10:48:04 -04:00
86f8861da1 encode record delimiter as ascii bytes when str is passed 2022-02-06 11:06:51 -06:00
042de7ecb0 import typing.Callable (python 3.10+) 2021-12-18 08:56:43 -05:00
f28cd6edf2 bump version 0.2020.0 2021-09-03 07:48:24 -05:00
0bd82e09c4 Fix StaticField + tests for StaticField and unset optional TextField 2021-09-03 05:45:01 -05:00
558e3fd232 hopefully fix STaticField 2021-09-02 17:40:35 -05:00
7867a52a0c fliped args around like a simpleton 2021-01-29 16:26:26 -05:00
bfd43b7448 release 0.2018.2 2020-06-12 14:45:08 -05:00
1f1d3dd9bb Merge branch 'conversion-support' 2020-06-12 13:13:28 -05:00
12 changed files with 251 additions and 383 deletions

View file

@ -1,6 +1,9 @@
from collections import Callable try:
from collections import Callable
except:
from typing import Callable # Python 3.10+
VERSION = (0, 2012, 0) VERSION = (0, 2025, 0)
RECORD_TYPES = [ RECORD_TYPES = [
'SubmitterRecord', 'SubmitterRecord',
@ -55,16 +58,21 @@ def loads(s, record_types=get_record_types()):
def dump(fp, records, delim=None): def dump(fp, records, delim=None):
if type(delim) is str:
delim = delim.encode('ascii')
for r in records: for r in records:
fp.write(r.output()) fp.write(r.output())
if delim: if delim:
fp.write(delim) fp.write(delim)
def dumps(records, delim=None): def dumps(records, delim=None, skip_validation=False):
import io import io
fp = io.BytesIO() fp = io.BytesIO()
dump(records, fp, delim=delim) if not skip_validation:
for record in records:
record.validate()
dump(fp, records, delim=delim)
fp.seek(0) fp.seek(0)
return fp.read() return fp.read()

View file

@ -323,6 +323,7 @@ employment_codes = (
) )
tax_jurisdiction_codes = ( tax_jurisdiction_codes = (
(' ', 'W-2'),
('V', 'Virgin Islands'), ('V', 'Virgin Islands'),
('G', 'Guam'), ('G', 'Guam'),
('S', 'American Samoa'), ('S', 'American Samoa'),

View file

@ -21,12 +21,15 @@ class ValidationError(Exception):
class Field(object): class Field(object):
creation_counter = 0 creation_counter = 0
is_read_only = False is_read_only = False
_value = None
def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None):
self.name = name self.name = name
self._value = None self._value = None
self._orig_value = None self._orig_value = None
self.min_length = min_length
self.max_length = max_length self.max_length = max_length
self.blank = blank
self.required = required self.required = required
self.uppercase = uppercase self.uppercase = uppercase
self.creation_counter = creation_counter or Field.creation_counter self.creation_counter = creation_counter or Field.creation_counter
@ -96,9 +99,9 @@ class Field(object):
wrapper.width = 100 wrapper.width = 100
value = wrapper.wrap(value) value = wrapper.wrap(value)
value = list([(" " * 9) + ('"' + x + '"') for x in value]) value = list([(" " * 9) + ('"' + x + '"') for x in value])
value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10)) value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10))
value.append(" " * 10 + ('0123456789') * (wrapper.width / 10)) value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10))
value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(wrapper.width / 10 )]))) value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))])))
start = counter['c'] start = counter['c']
counter['c'] += len(self._orig_value or self.value) counter['c'] += len(self._orig_value or self.value)
@ -118,11 +121,17 @@ class TextField(Field):
def validate(self): def validate(self):
if self.value is None and self.required: if self.value is None and self.required:
raise ValidationError("value required", field=self) raise ValidationError("value required", field=self)
if len(self.get_data()) > self.max_length: data = self.get_data()
if len(data) > self.max_length:
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
stripped_data_length = len(data.strip())
if stripped_data_length < self.min_length:
raise ValidationError("value is too short", field=self)
if stripped_data_length == 0 and (not self.blank and self.required):
raise ValidationError("field cannot be blank", field=self)
def get_data(self): def get_data(self):
value = str(self.value).encode('ascii') or b'' value = str(self.value or '').encode('ascii') or b''
if self.uppercase: if self.uppercase:
value = value.upper() value = value.upper()
return value.ljust(self.max_length)[:self.max_length] return value.ljust(self.max_length)[:self.max_length]
@ -143,7 +152,7 @@ class TextField(Field):
class StateField(TextField): class StateField(TextField):
def __init__(self, name=None, required=True, use_numeric=False, max_length=2): def __init__(self, name=None, required=True, use_numeric=False, max_length=2):
super(StateField, self).__init__(name=name, max_length=2, required=required) super(StateField, self).__init__(name=name, max_length=max_length, required=required)
self.use_numeric = use_numeric self.use_numeric = use_numeric
def get_data(self): def get_data(self):
@ -195,15 +204,17 @@ class IntegerField(TextField):
class StaticField(TextField): class StaticField(TextField):
def __init__(self, name=None, required=True, value=None): def __init__(self, name=None, required=True, value=None, uppercase=False):
super(StaticField, self).__init__(name=name, required=required, super(StaticField, self).__init__(name=name,
max_length=len(value)) required=required,
max_length=len(value),
uppercase=uppercase)
self._static_value = value
self._value = value self._value = value
def parse(self, s): def parse(self, s):
pass pass
class BlankField(TextField): class BlankField(TextField):
is_read_only = True is_read_only = True
@ -216,6 +227,10 @@ class BlankField(TextField):
def parse(self, s): def parse(self, s):
pass pass
def validate(self):
if len(self.get_data()) != self.max_length:
raise ValidationError("blank field did not match expected length", field=self)
class ZeroField(BlankField): class ZeroField(BlankField):
is_read_only = True is_read_only = True

View file

@ -25,7 +25,7 @@ class Model(object):
setattr(src_field, 'parent_name', self.__class__.__name__) setattr(src_field, 'parent_name', self.__class__.__name__)
new_field_instance = copy.copy(src_field) new_field_instance = copy.copy(src_field)
new_field_instance._orig_value = None new_field_instance._orig_value = None
new_field_instance._value = None new_field_instance._value = new_field_instance.value
self.__dict__[key] = new_field_instance self.__dict__[key] = new_field_instance
def __setattr__(self, key, value): def __setattr__(self, key, value):
@ -36,11 +36,14 @@ class Model(object):
self.__dict__[key] = value self.__dict__[key] = value
def set_field_value(self, field_name, value): def set_field_value(self, field_name, value):
print('setfieldval: ' + field_name + ' ' + value)
getattr(self, field_name).value = value getattr(self, field_name).value = value
def get_fields(self): def get_fields(self):
identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) identifier = TextField(
"record_identifier",
max_length = len(self.record_identifier),
blank = len(self.record_identifier) == 0,
creation_counter=-1)
identifier.value = self.record_identifier identifier.value = self.record_identifier
fields = [identifier] fields = [identifier]
@ -87,7 +90,6 @@ class Model(object):
# Skip the first record, since that's an identifier # Skip the first record, since that's an identifier
for field in self.get_sorted_fields()[1:]: for field in self.get_sorted_fields()[1:]:
field.read(fp) field.read(fp)
print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value)
def toJSON(self): def toJSON(self):
return { return {

View file

@ -3,313 +3,102 @@
import subprocess import subprocess
import re import re
import pdb import itertools
import fitz
""" pdftotext -layout -nopgbrk p1220.pdf - """ """ pdftotext -layout -nopgbrk p1220.pdf - """
def strip_values(items):
expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE)
return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x]
class PDFRecordFinder(object): class PDFRecordFinder(object):
def __init__(self, src, heading_exp=None): field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$')
if not heading_exp:
heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout')
field_heading_exp = re.compile('^Field.*Field.*Length.*Description') def __init__(self, src):
self.document = fitz.open(src)
opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-'] def find_record_table_ranges(self):
pdftext = subprocess.check_output(opts) matches = []
self.textrows = pdftext.split('\n') for (page_number, page) in enumerate(self.document):
self.heading_exp = heading_exp header_rects = page.search_for("Record Name:")
self.field_heading_exp = field_heading_exp for header_match_rect in header_rects:
header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: "
header_match_rect.x1 = page.bound().x1 # Extend to right side of page
header_text = page.get_textbox(header_match_rect)
record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip()
matches.append((record_name, {
'page': page_number,
'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably
}))
return matches
def find_records(self):
record_ranges = self.find_record_table_ranges()
for record_index, (record_name, record_details) in enumerate(record_ranges):
current_rows = []
next_index = record_index+1
(_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1})
for page_number in range(record_details['page'], next_record_details['page']):
page = self.document[page_number]
table_search_rect = page.bound()
if page_number == record_details['page']:
table_search_rect.y0 = record_details['y']
tables = page.find_tables(
clip = table_search_rect,
min_words_horizontal = 1,
min_words_vertical = 1,
horizontal_strategy = "lines_strict",
intersection_tolerance = 1,
)
for table in tables:
if table.col_count == 4:
table = table.extract()
# Parse field position (sometimes a cell has multiple
# values because IRS employees apparently smoke crack
for row in table:
first_column_lines = row[0].strip().split('\n')
if len(first_column_lines) > 1:
for sub_row in self.split_row(row):
current_rows.append(strip_values(sub_row))
else:
current_rows.append(strip_values(row))
consecutive_rows = self.filter_nonconsecutive_rows(current_rows)
yield(record_name, consecutive_rows)
def split_row(self, row):
if not row[1]:
return []
split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None))
description = strip_values([row[3]])[0]
rows = []
for row in split_rows:
if len(row) < 3 or not row[2]:
row = self.infer_field_length(row)
rows.append([*row, description])
return rows
def infer_field_length(self, row):
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
return row
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
length = str(end-start+1) if end and start else '1'
return (*row[:2], length)
def filter_nonconsecutive_rows(self, rows):
consecutive_rows = []
last_position = 0
for row in rows:
matches = PDFRecordFinder.field_range_expr.match(row[0])
if not matches:
continue
(start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2]
if start != last_position + 1:
continue
last_position = end if end else start
consecutive_rows.append(row)
return consecutive_rows
def records(self): def records(self):
headings = self.locate_heading_rows_by_field() return self.find_records()
#for x in headings:
# print x
for (start, end, name) in headings:
name = name.decode('ascii', 'ignore')
yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end))
def locate_heading_rows_by_field(self):
results = []
record_break = []
line_is_whitespace_exp = re.compile('^(\s*)$')
record_begin_exp = self.heading_exp #re.compile('Record\ Name')
for (i, row) in enumerate(self.textrows):
match = self.field_heading_exp.match(row)
if match:
# work backwards until we think the header is fully copied
space_count_exp = re.compile('^(\s*)')
position = i - 1
spaces = 0
#last_spaces = 10000
complete = False
header = None
while not complete:
line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False
is_record_begin = record_begin_exp.search(self.textrows[position])
if is_record_begin or line_is_whitespace:
header = self.textrows[position-1:i]
complete = True
position -= 1
name = ''.join(header).strip().decode('ascii','ignore')
print((name, position))
results.append((i, name, position))
else:
# See if this row forces us to break from field reading.
if re.search('Record\ Layout', row):
record_break.append(i)
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]):
end_pos = None
#print a[0], record_break[0], b[0]-1
while record_break and record_break[0] < a[0]:
record_break = record_break[1:]
if record_break[0] < b[0]-1:
end_pos = record_break[0]
record_break = record_break[1:]
else:
end_pos = b[0]-1
merged.append( (a[0], end_pos-1, a[1]) )
return merged
"""
def locate_heading_rows(self):
results = []
for (i, row) in enumerate(self.textrows):
match = self.heading_exp.match(row)
if match:
results.append((i, ''.join(match.groups())))
merged = []
for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]):
merged.append( (a[0], b[0]-1, a[1]) )
return merged
def locate_layout_block_rows(self):
# Search for rows that contain "Record Layout", as these are not fields
# we are interested in because they contain the crazy blocks of field definitions
# and not the nice 4-column ones that we're looking for.
results = []
for (i, row) in enumerate(self.textrows):
match = re.match("Record Layout", row)
"""
def find_fields(self, row_iter):
cc = ColumnCollector()
blank_row_counter = 0
for r in row_iter:
row = r.decode('UTF-8')
#print row
row_columns = self.extract_columns_from_row(row)
if not row_columns:
if cc.data and len(list(cc.data.keys())) > 1 and len(row.strip()) > list(cc.data.keys())[-1]:
yield cc
cc = ColumnCollector()
else:
cc.empty_row()
continue
try:
cc.add(row_columns)
except IsNextField as e:
yield cc
cc = ColumnCollector()
cc.add(row_columns)
except UnknownColumn as e:
raise StopIteration
yield cc
def extract_columns_from_row(self, row):
re_multiwhite = re.compile(r'\s{2,}')
# IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE
if not re_multiwhite.search(row):
return None
white_ranges = [0,]
pos = 0
while pos < len(row):
match = re_multiwhite.search(row[pos:])
if match:
white_ranges.append(pos + match.start())
white_ranges.append(pos + match.end())
pos += match.end()
else:
white_ranges.append(len(row))
pos = len(row)
row_result = []
white_iter = iter(white_ranges)
while white_iter:
try:
start = next(white_iter)
end = next(white_iter)
if start != end:
row_result.append(
(start, row[start:end].encode('ascii','ignore'))
)
except StopIteration:
white_iter = None
#print row_result
return row_result
class UnknownColumn(Exception):
pass
class IsNextField(Exception):
pass
class ColumnCollector(object):
def __init__(self, initial=None):
self.data = None
self.column_widths = None
self.max_data_length = 0
self.adjust_pad = 3
self.empty_rows = 0
pass
def __repr__(self):
return "<%s: %s>" % (
self.__class__.__name__,
[x if len(x) < 25 else x[:25] + '..' for x in list(self.data.values()) if self.data else ''])
def add(self, data):
#if self.empty_rows > 2:
# raise IsNextField()
if not self.data:
self.data = dict(data)
else:
data = self.adjust_columns(data)
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
self.update_column_widths(data)
def empty_row(self):
self.empty_rows += 1
def update_column_widths(self, data):
self.last_data_length = len(data)
self.max_data_length = max(self.max_data_length, len(data))
if not self.column_widths:
self.column_widths = dict([[column_value[0], column_value[0] + len(column_value[1])] for column_value in data])
else:
for col_id, value in data:
try:
self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip()))
except KeyError:
pass
def add_old(self, data):
if not self.data:
self.data = dict(data)
else:
if self.is_next_field(data):
raise IsNextField()
for col_id, value in data:
self.merge_column(col_id, value)
def adjust_columns(self, data):
adjusted_data = {}
for col_id, value in data:
if col_id in list(self.data.keys()):
adjusted_data[col_id] = value.strip()
else:
for col_start, col_end in list(self.column_widths.items()):
if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id:
if col_start in adjusted_data:
adjusted_data[col_start] += ' ' + value.strip()
else:
adjusted_data[col_start] = value.strip()
return list(adjusted_data.items())
def merge_column(self, col_id, value):
if col_id in list(self.data.keys()):
self.data[col_id] += ' ' + value.strip()
else:
# try adding a wiggle room value?
# FIXME:
# Sometimes description columns contain column-like
# layouts, and this causes the ColumnCollector to become
# confused. Perhaps we could check to see if a column occurs
# after the maximum column, and assume it's part of the
# max column?
"""
for col_start, col_end in self.column_widths.items():
if col_start <= col_id and (col_end) >= col_id:
self.data[col_start] += ' ' + value.strip()
return
"""
raise UnknownColumn
def is_next_field(self, data):
"""
If the first key value contains a string
and we already have some data in the record,
then this row is probably the beginning of
the next field. Raise an exception and continue
on with a fresh ColumnCollector.
"""
""" If the length of the value in column_id is less than the position of the next column_id,
then this is probably a continuation.
"""
if self.data and data:
keys = list(dict(self.column_widths).keys())
keys.sort()
keys += [None]
if self.last_data_length < len(data):
return True
first_key, first_value = list(dict(data).items())[0]
if list(self.data.keys())[0] == first_key:
position = keys.index(first_key)
max_length = keys[position + 1]
if max_length:
return len(first_value) > max_length or len(data) == self.max_data_length
return False
@property
def tuple(self):
#try:
if self.data:
return tuple([self.data[k] for k in sorted(self.data.keys())])
return ()
#except:
# import pdb
# pdb.set_trace()

View file

@ -105,8 +105,8 @@ class EmployerRecord(EFW2Record):
zipcode_ext = TextField(max_length=4, required=False) zipcode_ext = TextField(max_length=4, required=False)
kind_of_employer = TextField(max_length=1) kind_of_employer = TextField(max_length=1)
blank1 = BlankField(max_length=4) blank1 = BlankField(max_length=4)
foreign_state_province = TextField(max_length=23) foreign_state_province = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15) foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
employment_code = TextField(max_length=1) employment_code = TextField(max_length=1)
tax_jurisdiction_code = TextField(max_length=1, required=False) tax_jurisdiction_code = TextField(max_length=1, required=False)
@ -150,7 +150,7 @@ class EmployeeWageRecord(EFW2Record):
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -163,7 +163,7 @@ class EmployeeWageRecord(EFW2Record):
blank1 = BlankField(max_length=5) blank1 = BlankField(max_length=5)
foreign_state = TextField(max_length=23, required=False) foreign_state = TextField(max_length=23, required=False)
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country = TextField(max_length=2) country = TextField(max_length=2, required=True, blank=True)
wages_tips = MoneyField(max_length=11) wages_tips = MoneyField(max_length=11)
federal_income_tax_withheld = MoneyField(max_length=11) federal_income_tax_withheld = MoneyField(max_length=11)
social_security_wages = MoneyField(max_length=11) social_security_wages = MoneyField(max_length=11)
@ -199,8 +199,10 @@ class EmployeeWageRecord(EFW2Record):
blank6 = BlankField(max_length=23) blank6 = BlankField(max_length=23)
def validate_ssn(self, f): def validate_ssn(self, f):
if str(f.value).startswith('666','9'): if str(f.value).startswith('666'):
raise ValidationError("ssn cannot start with 666 or 9", field=f) raise ValidationError("ssn cannot start with 666", field=f)
if str(f.value).startswith('9'):
raise ValidationError("ssn cannot start with 9", field=f)
@ -243,7 +245,7 @@ class StateWageRecord(EFW2Record):
taxing_entity_code = TextField(max_length=5, required=False) taxing_entity_code = TextField(max_length=5, required=False)
ssn = IntegerField(max_length=9, required=False) ssn = IntegerField(max_length=9, required=False)
employee_first_name = TextField(max_length=15) employee_first_name = TextField(max_length=15)
employee_middle_name = TextField(max_length=15) employee_middle_name = TextField(max_length=15, required=False)
employee_last_name = TextField(max_length=20) employee_last_name = TextField(max_length=20)
employee_suffix = TextField(max_length=4, required=False) employee_suffix = TextField(max_length=4, required=False)
location_address = TextField(max_length=22) location_address = TextField(max_length=22)
@ -257,20 +259,20 @@ class StateWageRecord(EFW2Record):
foreign_postal_code = TextField(max_length=15, required=False) foreign_postal_code = TextField(max_length=15, required=False)
country_code = TextField(max_length=2, required=False) country_code = TextField(max_length=2, required=False)
optional_code = TextField(max_length=2, required=False) optional_code = TextField(max_length=2, required=False)
reporting_period = MonthYearField() reporting_period = MonthYearField(required=False)
quarterly_unemp_ins_wages = MoneyField(max_length=11) quarterly_unemp_ins_wages = MoneyField(max_length=11)
quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11) quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11)
number_of_weeks_worked = IntegerField(max_length=2) number_of_weeks_worked = IntegerField(max_length=2, required=False)
date_first_employed = DateField(required=False) date_first_employed = DateField(required=False)
date_of_separation = DateField(required=False) date_of_separation = DateField(required=False)
blank2 = BlankField(max_length=5) blank2 = BlankField(max_length=5)
state_employer_account_num = TextField(max_length=20) state_employer_account_num = IntegerField(max_length=20, required=False)
blank3 = BlankField(max_length=6) blank3 = BlankField(max_length=6)
state_code_2 = StateField(use_numeric=True) state_code_2 = StateField(use_numeric=True)
state_taxable_wages = MoneyField(max_length=11) state_taxable_wages = MoneyField(max_length=11)
state_income_tax_wh = MoneyField(max_length=11) state_income_tax_wh = MoneyField(max_length=11)
other_state_data = TextField(max_length=10, required=False) other_state_data = TextField(max_length=10, required=False)
tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F
local_taxable_wages = MoneyField(max_length=11) local_taxable_wages = MoneyField(max_length=11)
local_income_tax_wh = MoneyField(max_length=11) local_income_tax_wh = MoneyField(max_length=11)
state_control_number = IntegerField(max_length=7, required=False) state_control_number = IntegerField(max_length=7, required=False)
@ -280,7 +282,8 @@ class StateWageRecord(EFW2Record):
def validate_tax_type_code(self, field): def validate_tax_type_code(self, field):
choices = [x for x,y in enums.tax_type_codes] choices = [x for x,y in enums.tax_type_codes]
if field.value.upper() not in choices: value = field.value
if value and value.upper() not in choices:
raise ValidationError("%s not one of %s" % (field.value,choices), field=f) raise ValidationError("%s not one of %s" % (field.value,choices), field=f)

1
requirements.txt Normal file
View file

@ -0,0 +1 @@
PyMuPDF==1.24.0

View file

@ -73,3 +73,4 @@ if __name__ == '__main__':
records = list(read_file(in_file, in_file.name, get_record_types())) records = list(read_file(in_file, in_file.name, get_record_types()))
write_file(out_file, out_file.name, records) write_file(out_file, out_file.name, records)
print("wrote {} records to {}".format(len(records), out_file.name))

View file

@ -1,4 +1,4 @@
#!/usr/bin/python #!/usr/bin/env python
from pyaccuwage.parser import RecordBuilder from pyaccuwage.parser import RecordBuilder
from pyaccuwage.pdfextract import PDFRecordFinder from pyaccuwage.pdfextract import PDFRecordFinder
import argparse import argparse
@ -29,48 +29,9 @@ doc = PDFRecordFinder(source_file)
records = doc.records() records = doc.records()
builder = RecordBuilder() builder = RecordBuilder()
def record_begins_at(field): for (name, fields) in records:
return int(fields[0].data.values()[0].split('-')[0], 10) name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1])
name = re.sub(r'[^\w]*', '', name)
def record_ends_at(fields): sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
return int(fields[-1].data.values()[0].split('-')[-1], 10) for field in builder.load(map(lambda x: x, fields[0:])):
last_record_begins_at = -1
last_record_ends_at = -1
for rec in records:
#if not rec[1]:
# continue # no actual fields detected
fields = rec[1]
# strip out fields that are not 4 items long
fields = filter(lambda x:len(x.tuple) == 4, fields)
# strip fields that don't begin at position 0
fields = filter(lambda x: 0 in x.data, fields)
# strip fields that don't have a length-range type item in position 0
fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields)
if not fields:
continue
begins_at = record_begins_at(fields)
ends_at = record_ends_at(fields)
# FIXME record_ends_at is randomly exploding due to record data being
# a lump of text and not necessarily a field entry. I assume
# this is cleaned out by the record builder class.
#print last_record_ends_at + 1, begins_at
if last_record_ends_at + 1 != begins_at:
name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1])
name = re.sub('[^\w]*', '', name)
sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name)
for field in builder.load(map(lambda x:x.tuple, rec[1][0:])):
sys.stdout.write('\t' + field + '\n') sys.stdout.write('\t' + field + '\n')
#print field
last_record_ends_at = ends_at

View file

@ -7,13 +7,14 @@ def pyaccuwage_tests():
return test_suite return test_suite
setup(name='pyaccuwage', setup(name='pyaccuwage',
version='0.2018.1', version='0.2025.0',
packages=['pyaccuwage'], packages=['pyaccuwage'],
scripts=[ scripts=[
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-convert',
'scripts/pyaccuwage-genfieldfill',
'scripts/pyaccuwage-parse', 'scripts/pyaccuwage-parse',
'scripts/pyaccuwage-pdfparse', 'scripts/pyaccuwage-pdfparse',
'scripts/pyaccuwage-checkseq',
'scripts/pyaccuwage-genfieldfill'
], ],
zip_safe=True, zip_safe=True,
test_suite='setup.pyaccuwage_tests', test_suite='setup.pyaccuwage_tests',

View file

@ -1,5 +1,6 @@
import unittest import unittest
from pyaccuwage.fields import TextField from pyaccuwage.fields import TextField
from pyaccuwage.fields import StaticField
# from pyaccuwage.fields import IntegerField # from pyaccuwage.fields import IntegerField
# from pyaccuwage.fields import StateField # from pyaccuwage.fields import StateField
# from pyaccuwage.fields import BlankField # from pyaccuwage.fields import BlankField
@ -31,3 +32,36 @@ class TestTextField(unittest.TestCase):
data = field.get_data() data = field.get_data()
self.assertEqual(len(data), field.max_length) self.assertEqual(len(data), field.max_length)
self.assertEqual(data, b'HELLO,') self.assertEqual(data, b'HELLO,')
def testStringUnsetOptional(self):
field = TextField(max_length=6, required=False)
field.validate()
self.assertEqual(field.get_data(), b' ' * 6)
def testStringRequiredUnassigned(self):
field = TextField(max_length=6)
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredNonBlank(self):
field = TextField(max_length=6)
field.value = ''
self.assertRaises(ValidationError, lambda: field.validate())
def testStringRequiredBlank(self):
field = TextField(max_length=6, blank=True)
field.value = ''
field.validate()
self.assertEqual(len(field.get_data()), 6)
def testStringMinimumLength(self):
field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect
field.value = '' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '12345' # one character too short
self.assertRaises(ValidationError, lambda: field.validate())
field.value = '123456' # one character too short
class TestStaticField(unittest.TestCase):
def test_static_field(self):
field = StaticField(value='TEST')
self.assertEqual(field.get_data(), b'TEST')

View file

@ -7,6 +7,8 @@ from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import StateField from pyaccuwage.fields import StateField
from pyaccuwage.fields import TextField from pyaccuwage.fields import TextField
from pyaccuwage.fields import ZeroField from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import StaticField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model from pyaccuwage.model import Model
class TestModelOutput(unittest.TestCase): class TestModelOutput(unittest.TestCase):
@ -20,7 +22,8 @@ class TestModelOutput(unittest.TestCase):
money = MoneyField(max_length=32) money = MoneyField(max_length=32)
state_txt = StateField() state_txt = StateField()
state_num = StateField(use_numeric=True) state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24) blank2 = BlankField(max_length=12)
static1 = StaticField(value='hey mister!!')
def setUp(self): def setUp(self):
self.model = TestModelOutput.TestModel() self.model = TestModelOutput.TestModel()
@ -42,7 +45,8 @@ class TestModelOutput(unittest.TestCase):
b'313377'.zfill(32), b'313377'.zfill(32),
b'IA', b'IA',
b'19', b'19',
b' ' * 24, b' ' * 12,
b'hey mister!!',
]) ])
output = model.output() output = model.output()
@ -64,6 +68,7 @@ field2: 12345
money: 3133.77 money: 3133.77
state_txt: IA state_txt: IA
state_num: IA state_num: IA
static1: hey mister!!
''') ''')
@ -86,7 +91,8 @@ class TestFileFormats(unittest.TestCase):
record_identifier = 'B' # 1 byte record_identifier = 'B' # 1 byte
zero1 = ZeroField(max_length=32) zero1 = ZeroField(max_length=32)
text1 = TextField(max_length=71) text1 = TextField(max_length=71)
blank2 = BlankField(max_length=24) text2 = TextField(max_length=20, required=False)
blank2 = BlankField(max_length=4)
record_types = [TestModelA, TestModelB] record_types = [TestModelA, TestModelB]
@ -125,3 +131,49 @@ class TestFileFormats(unittest.TestCase):
original_bytes = pyaccuwage.dumps(records) original_bytes = pyaccuwage.dumps(records)
reloaded_bytes = pyaccuwage.dumps(records_loaded) reloaded_bytes = pyaccuwage.dumps(records_loaded)
self.assertEqual(original_bytes, reloaded_bytes) self.assertEqual(original_bytes, reloaded_bytes)
class TestRequiredFields(unittest.TestCase):
def createTestRecord(self, required=False, blank=False):
class Record(pyaccuwage.model.Model):
record_length = 16
record_identifier = ''
test_field = TextField(max_length=16, required=required, blank=blank)
record = Record()
def dump():
return pyaccuwage.dumps([record])
return (record, dump)
def testRequiredBlankField(self):
(record, dump) = self.createTestRecord(required=True, blank=True)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value may be empty string
dump()
def testRequiredNonblankField(self):
(record, dump) = self.createTestRecord(required=True, blank=False)
record.test_field.value # if nothing is ever assigned, raise error
self.assertRaises(ValidationError, dump)
record.test_field.value = '' # value must not be empty string
self.assertRaises(ValidationError, dump)
record.test_field.value = 'hello'
dump()
def testOptionalBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=True)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()
def testOptionalNonBlankField(self):
(record, dump) = self.createTestRecord(required=False, blank=False)
record.test_field.value # OK if nothing is ever assigned
dump()
record.test_field.value = '' # OK if empty string is assigned
dump()
record.test_field.value = 'hello'
dump()