diff --git a/pyaccuwage/fields.py b/pyaccuwage/fields.py index ca50125..94147a8 100644 --- a/pyaccuwage/fields.py +++ b/pyaccuwage/fields.py @@ -115,22 +115,22 @@ class Field(object): class TextField(Field): def validate(self): - if self.value == None and self.required: + if self.value is None and self.required: raise ValidationError("value required", field=self) if len(self.get_data()) > self.max_length: raise ValidationError("value is too long", field=self) def get_data(self): - value = self.value or "" + value = str(self.value).encode('ascii') or b'' if self.uppercase: value = value.upper() - return value.ljust(self.max_length).encode('ascii')[:self.max_length] + return value.ljust(self.max_length)[:self.max_length] def __setvalue(self, value): # NO NEWLINES try: value = value.replace('\n', '').replace('\r', '') - except AttributeError as e: + except AttributeError: pass self._value = value @@ -146,12 +146,15 @@ class StateField(TextField): self.use_numeric = use_numeric def get_data(self): - value = self.value or "" + # value = str(self.value or 'XX').encode('ascii') or b'' + value = str(self.value or 'XX') if value.strip() and self.use_numeric: - postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii') + postcode = enums.state_postal_numeric[value.upper()] + postcode = str(postcode).encode('ascii') return postcode.zfill(self.max_length) else: - return value.ljust(self.max_length).encode('ascii')[:self.max_length] + formatted = value.encode('ascii').ljust(self.max_length) + return formatted[:self.max_length] def validate(self): super(StateField, self).validate() @@ -160,7 +163,7 @@ class StateField(TextField): def parse(self, s): if s.strip() and self.use_numeric: - states = dict( [(v,k) for (k,v) in list(enums.state_postal_numeric.items())] ) + states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())]) self.value = states[int(s)] else: self.value = s @@ -179,9 +182,8 @@ class IntegerField(TextField): except ValueError: raise ValidationError("field contains non-numeric characters", field=self) - def get_data(self): - value = bytes(str(self.value), 'ascii') if self.value else b'' + value = str(self.value).encode('ascii') if self.value else b'' return value.zfill(self.max_length)[:self.max_length] def parse(self, s): @@ -209,7 +211,7 @@ class BlankField(TextField): class ZeroField(BlankField): def get_data(self): - return '0' * self.max_length + return b'0' * self.max_length class CRLFField(TextField): def __init__(self, name=None, required=False): @@ -255,7 +257,9 @@ class MoneyField(Field): raise ValidationError("value is too long", field=self) def get_data(self): - return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length] + cents = int((self.value or 0) * 100) + formatted = str(cents).encode('ascii').zfill(self.max_length) + return formatted[:self.max_length] def parse(self, s): self.value = decimal.Decimal(s) * decimal.Decimal('0.01') @@ -269,7 +273,7 @@ class DateField(TextField): def get_data(self): if self._value: - return bytes(self._value.strftime('%m%d%Y'), 'ascii') + return self._value.strftime('%m%d%Y').encode('ascii') return b'0' * self.max_length def parse(self, s): @@ -301,7 +305,7 @@ class MonthYearField(TextField): def get_data(self): if self._value: - return bytes(self._value.strftime("%m%Y"), 'ascii') + return str(self._value.strftime('%m%Y').encode('ascii')) return b'0' * self.max_length def parse(self, s): diff --git a/pyaccuwage/model.py b/pyaccuwage/model.py index 93d9251..ae5f9e1 100644 --- a/pyaccuwage/model.py +++ b/pyaccuwage/model.py @@ -1,6 +1,5 @@ from .fields import Field, TextField, ValidationError import copy -import pdb import collections @@ -42,7 +41,7 @@ class Model(object): def get_sorted_fields(self): fields = self.get_fields() - fields.sort(key=lambda x:x.creation_counter) + fields.sort(key=lambda x: x.creation_counter) return fields def validate(self): @@ -51,19 +50,17 @@ class Model(object): try: custom_validator = getattr(self, 'validate_' + f.name) - except AttributeError as e: + except AttributeError: continue if isinstance(custom_validator, collections.Callable): - custom_validator(f) + custom_validator(f) def output(self): result = b''.join([field.get_data() for field in self.get_sorted_fields()]) if hasattr(self, 'record_length') and len(result) != self.record_length: raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) - #result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()]) - #if len(result) != self.target_size: - # raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result))) + return result def read(self, fp): @@ -71,7 +68,6 @@ class Model(object): for field in self.get_sorted_fields()[1:]: field.read(fp) - def toJSON(self): return { '__class__': self.__class__.__name__, @@ -84,8 +80,8 @@ class Model(object): for f in fields: target = self.__dict__[f.name] - if (target.required != f.required or - target.max_length != f.max_length): + if (target.required != f.required + or target.max_length != f.max_length): print("Warning: value mismatch on import") target._value = f._value diff --git a/pyaccuwage/modeldef.py b/pyaccuwage/modeldef.py index a9364ca..c6c9110 100644 --- a/pyaccuwage/modeldef.py +++ b/pyaccuwage/modeldef.py @@ -1,86 +1,86 @@ -#!/usr/bin/env python import re + class ClassEntryCommentSequence(object): - re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') + re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') - def __init__(self, classname, line): - self.classname = classname, - self.line = line - self.lines = [] + def __init__(self, classname, line): + self.classname = classname, + self.line = line + self.lines = [] - def add_line(self, line): - self.lines.append(line) + def add_line(self, line): + self.lines.append(line) - def process(self): - i = 0 - for (line_no, line) in enumerate(self.lines): - match = self.re_rangecomment.search(line) - if match: - (a, b) = match.groups() - a = int(a) + def process(self): + i = 0 + for (line_no, line) in enumerate(self.lines): + match = self.re_rangecomment.search(line) + if match: + (a, b) = match.groups() + a = int(a) - if (i + 1) != a: - line_number = self.line + line_no - print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a))) + if (i + 1) != a: + line_number = self.line + line_no + print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % ( + line_number, line.split(' ')[0].strip(), i+1, a))) + + i = int(b) if b else a - i = int(b) if b else a class ModelDefParser(object): - re_triplequote = re.compile('"""') - re_whitespace = re.compile("^(\s*)[^\s]+") - re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$") + re_triplequote = re.compile('"""') + re_whitespace = re.compile(r"^(\s*)[^\s]+") + re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$") - def __init__(self, infile, entryclass): - self.infile = infile - self.line = 0 - self.EntryClass = entryclass + def __init__(self, infile, entryclass): + self.infile = infile + self.line = 0 + self.EntryClass = entryclass - def endclass(self): - if self.current_class: - self.current_class.process() - self.current_class = None + def endclass(self): + if self.current_class: + self.current_class.process() + self.current_class = None - def beginclass(self, classname, line): - self.current_class = self.EntryClass(classname, line) + def beginclass(self, classname, line): + self.current_class = self.EntryClass(classname, line) - def parse(self): - infile = self.infile - whitespace = 0 - in_block_comment = False - self.current_class = None + def parse(self): + infile = self.infile + whitespace = 0 + in_block_comment = False + self.current_class = None - for line in infile: - self.line += 1 + for line in infile: + self.line += 1 - if line.startswith('#'): - continue + if line.startswith('#'): + continue - if self.re_triplequote.search(line): - in_block_comment = not in_block_comment + if self.re_triplequote.search(line): + in_block_comment = not in_block_comment - if in_block_comment: - continue + if in_block_comment: + continue - match_whitespace = self.re_whitespace.match(line) - if match_whitespace: - match_whitespace = len(match_whitespace.groups()[0]) - else: - match_whitespace = 0 + match_whitespace = self.re_whitespace.match(line) + if match_whitespace: + match_whitespace = len(match_whitespace.groups()[0]) + else: + match_whitespace = 0 - classmatch = self.re_classdef.match(line) - if classmatch: - classname, subclass = classmatch.groups() - self.beginclass(classname, self.line) - continue - - if match_whitespace < whitespace: - whitespace = match_whitespace - self.endclass() - continue - - if self.current_class: - whitespace = match_whitespace - self.current_class.add_line(line) + classmatch = self.re_classdef.match(line) + if classmatch: + classname, subclass = classmatch.groups() + self.beginclass(classname, self.line) + continue + if match_whitespace < whitespace: + whitespace = match_whitespace + self.endclass() + continue + if self.current_class: + whitespace = match_whitespace + self.current_class.add_line(line) diff --git a/setup.py b/setup.py index d434a6f..b543ddb 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,11 @@ -from distutils.core import setup +from setuptools import setup +import unittest + +def pyaccuwage_tests(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover('tests', pattern='test_*.py') + return test_suite + setup(name='pyaccuwage', version='0.2018.1', packages=['pyaccuwage'], @@ -9,4 +16,5 @@ setup(name='pyaccuwage', 'scripts/pyaccuwage-genfieldfill' ], zip_safe=True, + test_suite='setup.pyaccuwage_tests', ) diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100644 index 0000000..9293acd --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,72 @@ +import unittest +import decimal +from pyaccuwage.fields import TextField +from pyaccuwage.fields import IntegerField +from pyaccuwage.fields import StateField +from pyaccuwage.fields import BlankField +from pyaccuwage.fields import ZeroField +from pyaccuwage.fields import MoneyField +from pyaccuwage.fields import ValidationError +from pyaccuwage.model import Model + + +class TestTextField(unittest.TestCase): + + def testStringShortOptional(self): + field = TextField(max_length=6, required=False) + field.validate() # optional + field.value = 'Hello' + field.validate() + self.assertEqual(field.get_data(), b'HELLO ') + + def testStringShortRequired(self): + field = TextField(max_length=6, required=True) + with self.assertRaises(ValidationError): + field.validate() + field.value = 'Hello' + field.validate() + self.assertEqual(field.get_data(), b'HELLO ') + + def testStringLongOptional(self): + field = TextField(max_length=6, required=False) + field.value = 'Hello, World!' # too long + self.assertEqual(len(field.get_data()), field.max_length) + + +class TestModelOutput(unittest.TestCase): + class TestModel(Model): + record_length = 128 + record_identifier = 'TEST' # 4 bytes + field1 = TextField(max_length=16) + field2 = IntegerField(max_length=16) + blank1 = BlankField(max_length=16) + zero1 = ZeroField(max_length=16) + money = MoneyField(max_length=32) + state_txt = StateField() + state_num = StateField(use_numeric=True) + blank2 = BlankField(max_length=24) + + def setUp(self): + self.model = TestModelOutput.TestModel() + + def testModelOutput(self): + model = self.model + model.field1.value = 'Hello, sir!' + model.field2.value = 12345 + model.money.value = decimal.Decimal('1234.56') + model.state_txt.value = 'IA' + model.state_num.value = 'IA' + + expected = b''.join([ + b'TEST', + b'HELLO, SIR!'.ljust(16), + b'12345'.zfill(16), + b' ' * 16, + b'0' * 16, + b'123456'.zfill(32), + b'IA', + b'19', + b' ' * 24, + ]) + + self.assertEqual(model.output(), expected)