diff --git a/pyaccuwage/__init__.py b/pyaccuwage/__init__.py index 810de18..513b8c6 100644 --- a/pyaccuwage/__init__.py +++ b/pyaccuwage/__init__.py @@ -1,9 +1,6 @@ -try: - from collections import Callable -except: - from typing import Callable # Python 3.10+ +from collections import Callable -VERSION = (0, 2025, 0) +VERSION = (0, 2012, 0) RECORD_TYPES = [ 'SubmitterRecord', @@ -58,21 +55,16 @@ def loads(s, record_types=get_record_types()): def dump(fp, records, delim=None): - if type(delim) is str: - delim = delim.encode('ascii') for r in records: fp.write(r.output()) if delim: fp.write(delim) -def dumps(records, delim=None, skip_validation=False): +def dumps(records, delim=None): import io fp = io.BytesIO() - if not skip_validation: - for record in records: - record.validate() - dump(fp, records, delim=delim) + dump(records, fp, delim=delim) fp.seek(0) return fp.read() diff --git a/pyaccuwage/enums.py b/pyaccuwage/enums.py index 66a8722..8c96ebf 100644 --- a/pyaccuwage/enums.py +++ b/pyaccuwage/enums.py @@ -323,7 +323,6 @@ employment_codes = ( ) tax_jurisdiction_codes = ( - (' ', 'W-2'), ('V', 'Virgin Islands'), ('G', 'Guam'), ('S', 'American Samoa'), diff --git a/pyaccuwage/fields.py b/pyaccuwage/fields.py index 8291d63..0d6fcd3 100644 --- a/pyaccuwage/fields.py +++ b/pyaccuwage/fields.py @@ -21,15 +21,12 @@ class ValidationError(Exception): class Field(object): creation_counter = 0 is_read_only = False - _value = None - def __init__(self, name=None, min_length=0, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None): + def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): self.name = name self._value = None self._orig_value = None - self.min_length = min_length self.max_length = max_length - self.blank = blank self.required = required self.uppercase = uppercase self.creation_counter = creation_counter or Field.creation_counter @@ -99,9 +96,9 @@ class Field(object): wrapper.width = 100 value = wrapper.wrap(value) value = list([(" " * 9) + ('"' + x + '"') for x in value]) - value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10)) - value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10)) - value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))]))) + value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10)) + value.append(" " * 10 + ('0123456789') * (wrapper.width / 10)) + value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(wrapper.width / 10 )]))) start = counter['c'] counter['c'] += len(self._orig_value or self.value) @@ -121,17 +118,11 @@ class TextField(Field): def validate(self): if self.value is None and self.required: raise ValidationError("value required", field=self) - data = self.get_data() - if len(data) > self.max_length: + if len(self.get_data()) > self.max_length: raise ValidationError("value is too long", field=self) - stripped_data_length = len(data.strip()) - if stripped_data_length < self.min_length: - raise ValidationError("value is too short", field=self) - if stripped_data_length == 0 and (not self.blank and self.required): - raise ValidationError("field cannot be blank", field=self) def get_data(self): - value = str(self.value or '').encode('ascii') or b'' + value = str(self.value).encode('ascii') or b'' if self.uppercase: value = value.upper() return value.ljust(self.max_length)[:self.max_length] @@ -152,7 +143,7 @@ class TextField(Field): class StateField(TextField): def __init__(self, name=None, required=True, use_numeric=False, max_length=2): - super(StateField, self).__init__(name=name, max_length=max_length, required=required) + super(StateField, self).__init__(name=name, max_length=2, required=required) self.use_numeric = use_numeric def get_data(self): @@ -204,17 +195,15 @@ class IntegerField(TextField): class StaticField(TextField): - def __init__(self, name=None, required=True, value=None, uppercase=False): - super(StaticField, self).__init__(name=name, - required=required, - max_length=len(value), - uppercase=uppercase) - self._static_value = value + def __init__(self, name=None, required=True, value=None): + super(StaticField, self).__init__(name=name, required=required, + max_length=len(value)) self._value = value def parse(self, s): pass + class BlankField(TextField): is_read_only = True @@ -227,10 +216,6 @@ class BlankField(TextField): def parse(self, s): pass - def validate(self): - if len(self.get_data()) != self.max_length: - raise ValidationError("blank field did not match expected length", field=self) - class ZeroField(BlankField): is_read_only = True diff --git a/pyaccuwage/model.py b/pyaccuwage/model.py index c950055..becd3ce 100644 --- a/pyaccuwage/model.py +++ b/pyaccuwage/model.py @@ -25,7 +25,7 @@ class Model(object): setattr(src_field, 'parent_name', self.__class__.__name__) new_field_instance = copy.copy(src_field) new_field_instance._orig_value = None - new_field_instance._value = new_field_instance.value + new_field_instance._value = None self.__dict__[key] = new_field_instance def __setattr__(self, key, value): @@ -36,14 +36,11 @@ class Model(object): self.__dict__[key] = value def set_field_value(self, field_name, value): + print('setfieldval: ' + field_name + ' ' + value) getattr(self, field_name).value = value def get_fields(self): - identifier = TextField( - "record_identifier", - max_length = len(self.record_identifier), - blank = len(self.record_identifier) == 0, - creation_counter=-1) + identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) identifier.value = self.record_identifier fields = [identifier] @@ -90,6 +87,7 @@ class Model(object): # Skip the first record, since that's an identifier for field in self.get_sorted_fields()[1:]: field.read(fp) + print(field.name, '"' + (str(field.value) or '') + '"', field.max_length, field._orig_value) def toJSON(self): return { diff --git a/pyaccuwage/pdfextract.py b/pyaccuwage/pdfextract.py index 352c400..2903b5d 100644 --- a/pyaccuwage/pdfextract.py +++ b/pyaccuwage/pdfextract.py @@ -3,102 +3,313 @@ import subprocess import re -import itertools -import fitz +import pdb """ pdftotext -layout -nopgbrk p1220.pdf - """ -def strip_values(items): - expr_non_alphanum = re.compile(r'[^\w\s]*', re.MULTILINE) - return [expr_non_alphanum.sub(x, '').strip().replace('\n', ' ') for x in items if x] class PDFRecordFinder(object): - field_range_expr = re.compile(r'^(\d+)[-]?(\d*)$') + def __init__(self, src, heading_exp=None): + if not heading_exp: + heading_exp = re.compile('(\s+Record Name: (.*))|Record\ Layout') - def __init__(self, src): - self.document = fitz.open(src) + field_heading_exp = re.compile('^Field.*Field.*Length.*Description') - def find_record_table_ranges(self): - matches = [] - for (page_number, page) in enumerate(self.document): - header_rects = page.search_for("Record Name:") - for header_match_rect in header_rects: - header_match_rect.x0 = header_match_rect.x1 # Start after match of "Record Name: " - header_match_rect.x1 = page.bound().x1 # Extend to right side of page - header_text = page.get_textbox(header_match_rect) - record_name = re.sub(r'[^\w\s\n]*', '', header_text).strip() - matches.append((record_name, { - 'page': page_number, - 'y': header_match_rect.y1 - 5, # Back up a hair to include header more reliably - })) - return matches - - def find_records(self): - record_ranges = self.find_record_table_ranges() - for record_index, (record_name, record_details) in enumerate(record_ranges): - current_rows = [] - next_index = record_index+1 - (_, next_record_details) = record_ranges[next_index] if next_index < len(record_ranges) else (None, {'page': self.document.page_count-1}) - for page_number in range(record_details['page'], next_record_details['page']): - page = self.document[page_number] - table_search_rect = page.bound() - if page_number == record_details['page']: - table_search_rect.y0 = record_details['y'] - tables = page.find_tables( - clip = table_search_rect, - min_words_horizontal = 1, - min_words_vertical = 1, - horizontal_strategy = "lines_strict", - intersection_tolerance = 1, - ) - for table in tables: - if table.col_count == 4: - table = table.extract() - # Parse field position (sometimes a cell has multiple - # values because IRS employees apparently smoke crack - for row in table: - first_column_lines = row[0].strip().split('\n') - if len(first_column_lines) > 1: - for sub_row in self.split_row(row): - current_rows.append(strip_values(sub_row)) - else: - current_rows.append(strip_values(row)) - consecutive_rows = self.filter_nonconsecutive_rows(current_rows) - yield(record_name, consecutive_rows) - - def split_row(self, row): - if not row[1]: - return [] - split_rows = list(itertools.zip_longest(*[x.strip().split('\n') for x in row[:3]], fillvalue=None)) - description = strip_values([row[3]])[0] - rows = [] - for row in split_rows: - if len(row) < 3 or not row[2]: - row = self.infer_field_length(row) - rows.append([*row, description]) - return rows - - def infer_field_length(self, row): - matches = PDFRecordFinder.field_range_expr.match(row[0]) - if not matches: - return row - (start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2] - length = str(end-start+1) if end and start else '1' - return (*row[:2], length) - - def filter_nonconsecutive_rows(self, rows): - consecutive_rows = [] - last_position = 0 - for row in rows: - matches = PDFRecordFinder.field_range_expr.match(row[0]) - if not matches: - continue - (start, end) = ([int(x) for x in list(matches.groups()) if x] + [None])[:2] - if start != last_position + 1: - continue - last_position = end if end else start - consecutive_rows.append(row) - return consecutive_rows + opts = ["pdftotext", "-layout", "-nopgbrk", "-eol", "unix", src, '-'] + pdftext = subprocess.check_output(opts) + self.textrows = pdftext.split('\n') + self.heading_exp = heading_exp + self.field_heading_exp = field_heading_exp def records(self): - return self.find_records() + headings = self.locate_heading_rows_by_field() + + #for x in headings: + # print x + + for (start, end, name) in headings: + name = name.decode('ascii', 'ignore') + yield (name, list(self.find_fields(iter(self.textrows[start+1:end]))), (start+1, end)) + + + def locate_heading_rows_by_field(self): + results = [] + record_break = [] + line_is_whitespace_exp = re.compile('^(\s*)$') + record_begin_exp = self.heading_exp #re.compile('Record\ Name') + + for (i, row) in enumerate(self.textrows): + match = self.field_heading_exp.match(row) + if match: + # work backwards until we think the header is fully copied + space_count_exp = re.compile('^(\s*)') + position = i - 1 + spaces = 0 + #last_spaces = 10000 + complete = False + header = None + while not complete: + line_is_whitespace = True if line_is_whitespace_exp.match(self.textrows[position]) else False + is_record_begin = record_begin_exp.search(self.textrows[position]) + if is_record_begin or line_is_whitespace: + header = self.textrows[position-1:i] + complete = True + position -= 1 + + name = ''.join(header).strip().decode('ascii','ignore') + print((name, position)) + results.append((i, name, position)) + else: + # See if this row forces us to break from field reading. + if re.search('Record\ Layout', row): + record_break.append(i) + + merged = [] + for (a, b) in zip(results, results[1:] + [(len(self.textrows), None)]): + end_pos = None + + #print a[0], record_break[0], b[0]-1 + + while record_break and record_break[0] < a[0]: + record_break = record_break[1:] + + if record_break[0] < b[0]-1: + end_pos = record_break[0] + record_break = record_break[1:] + else: + end_pos = b[0]-1 + + merged.append( (a[0], end_pos-1, a[1]) ) + return merged + + """ + def locate_heading_rows(self): + results = [] + for (i, row) in enumerate(self.textrows): + match = self.heading_exp.match(row) + if match: + results.append((i, ''.join(match.groups()))) + + merged = [] + for (a, b) in zip(results, results[1:] + [(len(self.textrows),None)]): + merged.append( (a[0], b[0]-1, a[1]) ) + + return merged + + def locate_layout_block_rows(self): + # Search for rows that contain "Record Layout", as these are not fields + # we are interested in because they contain the crazy blocks of field definitions + # and not the nice 4-column ones that we're looking for. + + results = [] + for (i, row) in enumerate(self.textrows): + match = re.match("Record Layout", row) + + """ + + def find_fields(self, row_iter): + cc = ColumnCollector() + blank_row_counter = 0 + + for r in row_iter: + row = r.decode('UTF-8') + #print row + row_columns = self.extract_columns_from_row(row) + + if not row_columns: + if cc.data and len(list(cc.data.keys())) > 1 and len(row.strip()) > list(cc.data.keys())[-1]: + yield cc + cc = ColumnCollector() + else: + cc.empty_row() + continue + + try: + cc.add(row_columns) + + except IsNextField as e: + yield cc + cc = ColumnCollector() + cc.add(row_columns) + except UnknownColumn as e: + raise StopIteration + + yield cc + + + def extract_columns_from_row(self, row): + re_multiwhite = re.compile(r'\s{2,}') + + # IF LINE DOESN'T CONTAIN MULTIPLE WHITESPACES, IT'S LIKELY NOT A TABLE + if not re_multiwhite.search(row): + return None + + white_ranges = [0,] + pos = 0 + while pos < len(row): + match = re_multiwhite.search(row[pos:]) + if match: + white_ranges.append(pos + match.start()) + white_ranges.append(pos + match.end()) + pos += match.end() + else: + white_ranges.append(len(row)) + pos = len(row) + + row_result = [] + white_iter = iter(white_ranges) + while white_iter: + try: + start = next(white_iter) + end = next(white_iter) + if start != end: + row_result.append( + (start, row[start:end].encode('ascii','ignore')) + ) + + except StopIteration: + white_iter = None + + #print row_result + return row_result + + +class UnknownColumn(Exception): + pass + +class IsNextField(Exception): + pass + +class ColumnCollector(object): + def __init__(self, initial=None): + self.data = None + self.column_widths = None + self.max_data_length = 0 + self.adjust_pad = 3 + self.empty_rows = 0 + pass + + def __repr__(self): + return "<%s: %s>" % ( + self.__class__.__name__, + [x if len(x) < 25 else x[:25] + '..' for x in list(self.data.values()) if self.data else '']) + + def add(self, data): + #if self.empty_rows > 2: + # raise IsNextField() + + if not self.data: + self.data = dict(data) + else: + data = self.adjust_columns(data) + if self.is_next_field(data): + raise IsNextField() + for col_id, value in data: + self.merge_column(col_id, value) + + self.update_column_widths(data) + + def empty_row(self): + self.empty_rows += 1 + + def update_column_widths(self, data): + self.last_data_length = len(data) + self.max_data_length = max(self.max_data_length, len(data)) + + if not self.column_widths: + self.column_widths = dict([[column_value[0], column_value[0] + len(column_value[1])] for column_value in data]) + else: + for col_id, value in data: + try: + self.column_widths[col_id] = max(self.column_widths[col_id], col_id + len(value.strip())) + except KeyError: + pass + + def add_old(self, data): + if not self.data: + self.data = dict(data) + else: + if self.is_next_field(data): + raise IsNextField() + for col_id, value in data: + self.merge_column(col_id, value) + + + def adjust_columns(self, data): + adjusted_data = {} + for col_id, value in data: + if col_id in list(self.data.keys()): + adjusted_data[col_id] = value.strip() + else: + for col_start, col_end in list(self.column_widths.items()): + if (col_start - self.adjust_pad) <= col_id and (col_end + self.adjust_pad) >= col_id: + if col_start in adjusted_data: + adjusted_data[col_start] += ' ' + value.strip() + else: + adjusted_data[col_start] = value.strip() + + return list(adjusted_data.items()) + + + def merge_column(self, col_id, value): + if col_id in list(self.data.keys()): + self.data[col_id] += ' ' + value.strip() + else: + # try adding a wiggle room value? + # FIXME: + # Sometimes description columns contain column-like + # layouts, and this causes the ColumnCollector to become + # confused. Perhaps we could check to see if a column occurs + # after the maximum column, and assume it's part of the + # max column? + + """ + for col_start, col_end in self.column_widths.items(): + if col_start <= col_id and (col_end) >= col_id: + self.data[col_start] += ' ' + value.strip() + return + """ + raise UnknownColumn + + def is_next_field(self, data): + """ + If the first key value contains a string + and we already have some data in the record, + then this row is probably the beginning of + the next field. Raise an exception and continue + on with a fresh ColumnCollector. + """ + + """ If the length of the value in column_id is less than the position of the next column_id, + then this is probably a continuation. + """ + + if self.data and data: + keys = list(dict(self.column_widths).keys()) + keys.sort() + keys += [None] + + if self.last_data_length < len(data): + return True + + first_key, first_value = list(dict(data).items())[0] + if list(self.data.keys())[0] == first_key: + + position = keys.index(first_key) + max_length = keys[position + 1] + if max_length: + return len(first_value) > max_length or len(data) == self.max_data_length + + return False + + + @property + def tuple(self): + #try: + if self.data: + return tuple([self.data[k] for k in sorted(self.data.keys())]) + return () + #except: + # import pdb + # pdb.set_trace() + diff --git a/pyaccuwage/record.py b/pyaccuwage/record.py index 6b91123..9e46217 100644 --- a/pyaccuwage/record.py +++ b/pyaccuwage/record.py @@ -105,8 +105,8 @@ class EmployerRecord(EFW2Record): zipcode_ext = TextField(max_length=4, required=False) kind_of_employer = TextField(max_length=1) blank1 = BlankField(max_length=4) - foreign_state_province = TextField(max_length=23, required=False) - foreign_postal_code = TextField(max_length=15, required=False) + foreign_state_province = TextField(max_length=23) + foreign_postal_code = TextField(max_length=15) country_code = TextField(max_length=2, required=False) employment_code = TextField(max_length=1) tax_jurisdiction_code = TextField(max_length=1, required=False) @@ -150,7 +150,7 @@ class EmployeeWageRecord(EFW2Record): ssn = IntegerField(max_length=9, required=False) employee_first_name = TextField(max_length=15) - employee_middle_name = TextField(max_length=15, required=False) + employee_middle_name = TextField(max_length=15) employee_last_name = TextField(max_length=20) employee_suffix = TextField(max_length=4, required=False) location_address = TextField(max_length=22) @@ -163,7 +163,7 @@ class EmployeeWageRecord(EFW2Record): blank1 = BlankField(max_length=5) foreign_state = TextField(max_length=23, required=False) foreign_postal_code = TextField(max_length=15, required=False) - country = TextField(max_length=2, required=True, blank=True) + country = TextField(max_length=2) wages_tips = MoneyField(max_length=11) federal_income_tax_withheld = MoneyField(max_length=11) social_security_wages = MoneyField(max_length=11) @@ -199,10 +199,8 @@ class EmployeeWageRecord(EFW2Record): blank6 = BlankField(max_length=23) def validate_ssn(self, f): - if str(f.value).startswith('666'): - raise ValidationError("ssn cannot start with 666", field=f) - if str(f.value).startswith('9'): - raise ValidationError("ssn cannot start with 9", field=f) + if str(f.value).startswith('666','9'): + raise ValidationError("ssn cannot start with 666 or 9", field=f) @@ -245,7 +243,7 @@ class StateWageRecord(EFW2Record): taxing_entity_code = TextField(max_length=5, required=False) ssn = IntegerField(max_length=9, required=False) employee_first_name = TextField(max_length=15) - employee_middle_name = TextField(max_length=15, required=False) + employee_middle_name = TextField(max_length=15) employee_last_name = TextField(max_length=20) employee_suffix = TextField(max_length=4, required=False) location_address = TextField(max_length=22) @@ -259,20 +257,20 @@ class StateWageRecord(EFW2Record): foreign_postal_code = TextField(max_length=15, required=False) country_code = TextField(max_length=2, required=False) optional_code = TextField(max_length=2, required=False) - reporting_period = MonthYearField(required=False) + reporting_period = MonthYearField() quarterly_unemp_ins_wages = MoneyField(max_length=11) quarterly_unemp_ins_taxable_wages = MoneyField(max_length=11) - number_of_weeks_worked = IntegerField(max_length=2, required=False) + number_of_weeks_worked = IntegerField(max_length=2) date_first_employed = DateField(required=False) date_of_separation = DateField(required=False) blank2 = BlankField(max_length=5) - state_employer_account_num = IntegerField(max_length=20, required=False) + state_employer_account_num = TextField(max_length=20) blank3 = BlankField(max_length=6) state_code_2 = StateField(use_numeric=True) state_taxable_wages = MoneyField(max_length=11) state_income_tax_wh = MoneyField(max_length=11) other_state_data = TextField(max_length=10, required=False) - tax_type_code = TextField(max_length=1, required=False) # VALIDATE C, D, E, or F + tax_type_code = TextField(max_length=1) # VALIDATE C, D, E, or F local_taxable_wages = MoneyField(max_length=11) local_income_tax_wh = MoneyField(max_length=11) state_control_number = IntegerField(max_length=7, required=False) @@ -282,8 +280,7 @@ class StateWageRecord(EFW2Record): def validate_tax_type_code(self, field): choices = [x for x,y in enums.tax_type_codes] - value = field.value - if value and value.upper() not in choices: + if field.value.upper() not in choices: raise ValidationError("%s not one of %s" % (field.value,choices), field=f) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 82813ad..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -PyMuPDF==1.24.0 diff --git a/scripts/pyaccuwage-convert b/scripts/pyaccuwage-convert index 9239cac..9591760 100755 --- a/scripts/pyaccuwage-convert +++ b/scripts/pyaccuwage-convert @@ -73,4 +73,3 @@ if __name__ == '__main__': records = list(read_file(in_file, in_file.name, get_record_types())) write_file(out_file, out_file.name, records) - print("wrote {} records to {}".format(len(records), out_file.name)) diff --git a/scripts/pyaccuwage-pdfparse b/scripts/pyaccuwage-pdfparse index d80abaa..6a35387 100755 --- a/scripts/pyaccuwage-pdfparse +++ b/scripts/pyaccuwage-pdfparse @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/python from pyaccuwage.parser import RecordBuilder from pyaccuwage.pdfextract import PDFRecordFinder import argparse @@ -29,9 +29,48 @@ doc = PDFRecordFinder(source_file) records = doc.records() builder = RecordBuilder() -for (name, fields) in records: - name = re.sub(r'^[^a-zA-Z]*','', name.split(':')[-1]) - name = re.sub(r'[^\w]*', '', name) - sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name) - for field in builder.load(map(lambda x: x, fields[0:])): +def record_begins_at(field): + return int(fields[0].data.values()[0].split('-')[0], 10) + +def record_ends_at(fields): + return int(fields[-1].data.values()[0].split('-')[-1], 10) + +last_record_begins_at = -1 +last_record_ends_at = -1 + +for rec in records: + #if not rec[1]: + # continue # no actual fields detected + fields = rec[1] + + # strip out fields that are not 4 items long + fields = filter(lambda x:len(x.tuple) == 4, fields) + + # strip fields that don't begin at position 0 + fields = filter(lambda x: 0 in x.data, fields) + + # strip fields that don't have a length-range type item in position 0 + fields = filter(lambda x: re.match('^\d+[-]?\d*$', x.data[0]), fields) + + if not fields: + continue + + begins_at = record_begins_at(fields) + ends_at = record_ends_at(fields) + + # FIXME record_ends_at is randomly exploding due to record data being + # a lump of text and not necessarily a field entry. I assume + # this is cleaned out by the record builder class. + + #print last_record_ends_at + 1, begins_at + if last_record_ends_at + 1 != begins_at: + name = re.sub('^[^a-zA-Z]*','',rec[0].split(':')[-1]) + name = re.sub('[^\w]*', '', name) + sys.stdout.write("\nclass %s(pyaccuwagemodel.Model):\n" % name) + + for field in builder.load(map(lambda x:x.tuple, rec[1][0:])): sys.stdout.write('\t' + field + '\n') + #print field + + last_record_ends_at = ends_at + diff --git a/setup.py b/setup.py index cdef46b..b543ddb 100644 --- a/setup.py +++ b/setup.py @@ -7,14 +7,13 @@ def pyaccuwage_tests(): return test_suite setup(name='pyaccuwage', - version='0.2025.0', + version='0.2018.1', packages=['pyaccuwage'], scripts=[ - 'scripts/pyaccuwage-checkseq', - 'scripts/pyaccuwage-convert', - 'scripts/pyaccuwage-genfieldfill', 'scripts/pyaccuwage-parse', 'scripts/pyaccuwage-pdfparse', + 'scripts/pyaccuwage-checkseq', + 'scripts/pyaccuwage-genfieldfill' ], zip_safe=True, test_suite='setup.pyaccuwage_tests', diff --git a/tests/test_fields.py b/tests/test_fields.py index 2707e6f..3d8fd3e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,5 @@ import unittest from pyaccuwage.fields import TextField -from pyaccuwage.fields import StaticField # from pyaccuwage.fields import IntegerField # from pyaccuwage.fields import StateField # from pyaccuwage.fields import BlankField @@ -32,36 +31,3 @@ class TestTextField(unittest.TestCase): data = field.get_data() self.assertEqual(len(data), field.max_length) self.assertEqual(data, b'HELLO,') - - def testStringUnsetOptional(self): - field = TextField(max_length=6, required=False) - field.validate() - self.assertEqual(field.get_data(), b' ' * 6) - - def testStringRequiredUnassigned(self): - field = TextField(max_length=6) - self.assertRaises(ValidationError, lambda: field.validate()) - - def testStringRequiredNonBlank(self): - field = TextField(max_length=6) - field.value = '' - self.assertRaises(ValidationError, lambda: field.validate()) - - def testStringRequiredBlank(self): - field = TextField(max_length=6, blank=True) - field.value = '' - field.validate() - self.assertEqual(len(field.get_data()), 6) - - def testStringMinimumLength(self): - field = TextField(max_length=6, min_length=6, blank=True) # blank has no effect - field.value = '' # one character too short - self.assertRaises(ValidationError, lambda: field.validate()) - field.value = '12345' # one character too short - self.assertRaises(ValidationError, lambda: field.validate()) - field.value = '123456' # one character too short - -class TestStaticField(unittest.TestCase): - def test_static_field(self): - field = StaticField(value='TEST') - self.assertEqual(field.get_data(), b'TEST') diff --git a/tests/test_records.py b/tests/test_records.py index 67ce3ce..a6485ac 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -7,8 +7,6 @@ from pyaccuwage.fields import MoneyField from pyaccuwage.fields import StateField from pyaccuwage.fields import TextField from pyaccuwage.fields import ZeroField -from pyaccuwage.fields import StaticField -from pyaccuwage.fields import ValidationError from pyaccuwage.model import Model class TestModelOutput(unittest.TestCase): @@ -22,8 +20,7 @@ class TestModelOutput(unittest.TestCase): money = MoneyField(max_length=32) state_txt = StateField() state_num = StateField(use_numeric=True) - blank2 = BlankField(max_length=12) - static1 = StaticField(value='hey mister!!') + blank2 = BlankField(max_length=24) def setUp(self): self.model = TestModelOutput.TestModel() @@ -45,8 +42,7 @@ class TestModelOutput(unittest.TestCase): b'313377'.zfill(32), b'IA', b'19', - b' ' * 12, - b'hey mister!!', + b' ' * 24, ]) output = model.output() @@ -68,7 +64,6 @@ field2: 12345 money: 3133.77 state_txt: IA state_num: IA -static1: hey mister!! ''') @@ -91,8 +86,7 @@ class TestFileFormats(unittest.TestCase): record_identifier = 'B' # 1 byte zero1 = ZeroField(max_length=32) text1 = TextField(max_length=71) - text2 = TextField(max_length=20, required=False) - blank2 = BlankField(max_length=4) + blank2 = BlankField(max_length=24) record_types = [TestModelA, TestModelB] @@ -131,49 +125,3 @@ class TestFileFormats(unittest.TestCase): original_bytes = pyaccuwage.dumps(records) reloaded_bytes = pyaccuwage.dumps(records_loaded) self.assertEqual(original_bytes, reloaded_bytes) - - -class TestRequiredFields(unittest.TestCase): - def createTestRecord(self, required=False, blank=False): - class Record(pyaccuwage.model.Model): - record_length = 16 - record_identifier = '' - test_field = TextField(max_length=16, required=required, blank=blank) - record = Record() - def dump(): - return pyaccuwage.dumps([record]) - return (record, dump) - - def testRequiredBlankField(self): - (record, dump) = self.createTestRecord(required=True, blank=True) - record.test_field.value # if nothing is ever assigned, raise error - self.assertRaises(ValidationError, dump) - record.test_field.value = '' # value may be empty string - dump() - - def testRequiredNonblankField(self): - (record, dump) = self.createTestRecord(required=True, blank=False) - record.test_field.value # if nothing is ever assigned, raise error - self.assertRaises(ValidationError, dump) - record.test_field.value = '' # value must not be empty string - self.assertRaises(ValidationError, dump) - record.test_field.value = 'hello' - dump() - - def testOptionalBlankField(self): - (record, dump) = self.createTestRecord(required=False, blank=True) - record.test_field.value # OK if nothing is ever assigned - dump() - record.test_field.value = '' # OK if empty string is assigned - dump() - record.test_field.value = 'hello' - dump() - - def testOptionalNonBlankField(self): - (record, dump) = self.createTestRecord(required=False, blank=False) - record.test_field.value # OK if nothing is ever assigned - dump() - record.test_field.value = '' # OK if empty string is assigned - dump() - record.test_field.value = 'hello' - dump()