hopefully fix python 2 and 3 compatability

This commit is contained in:
Mark Riedesel 2019-01-27 09:30:22 -06:00
parent 6381f8b1ec
commit d08f1ca586
5 changed files with 169 additions and 89 deletions

View file

@ -115,22 +115,22 @@ class Field(object):
class TextField(Field): class TextField(Field):
def validate(self): def validate(self):
if self.value == None and self.required: if self.value is None and self.required:
raise ValidationError("value required", field=self) raise ValidationError("value required", field=self)
if len(self.get_data()) > self.max_length: if len(self.get_data()) > self.max_length:
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
def get_data(self): def get_data(self):
value = self.value or "" value = str(self.value).encode('ascii') or b''
if self.uppercase: if self.uppercase:
value = value.upper() value = value.upper()
return value.ljust(self.max_length).encode('ascii')[:self.max_length] return value.ljust(self.max_length)[:self.max_length]
def __setvalue(self, value): def __setvalue(self, value):
# NO NEWLINES # NO NEWLINES
try: try:
value = value.replace('\n', '').replace('\r', '') value = value.replace('\n', '').replace('\r', '')
except AttributeError as e: except AttributeError:
pass pass
self._value = value self._value = value
@ -146,12 +146,15 @@ class StateField(TextField):
self.use_numeric = use_numeric self.use_numeric = use_numeric
def get_data(self): def get_data(self):
value = self.value or "" # value = str(self.value or 'XX').encode('ascii') or b''
value = str(self.value or 'XX')
if value.strip() and self.use_numeric: if value.strip() and self.use_numeric:
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii') postcode = enums.state_postal_numeric[value.upper()]
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length) return postcode.zfill(self.max_length)
else: else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length] formatted = value.encode('ascii').ljust(self.max_length)
return formatted[:self.max_length]
def validate(self): def validate(self):
super(StateField, self).validate() super(StateField, self).validate()
@ -160,7 +163,7 @@ class StateField(TextField):
def parse(self, s): def parse(self, s):
if s.strip() and self.use_numeric: if s.strip() and self.use_numeric:
states = dict( [(v,k) for (k,v) in list(enums.state_postal_numeric.items())] ) states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
self.value = states[int(s)] self.value = states[int(s)]
else: else:
self.value = s self.value = s
@ -179,9 +182,8 @@ class IntegerField(TextField):
except ValueError: except ValueError:
raise ValidationError("field contains non-numeric characters", field=self) raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self): def get_data(self):
value = bytes(str(self.value), 'ascii') if self.value else b'' value = str(self.value).encode('ascii') if self.value else b''
return value.zfill(self.max_length)[:self.max_length] return value.zfill(self.max_length)[:self.max_length]
def parse(self, s): def parse(self, s):
@ -209,7 +211,7 @@ class BlankField(TextField):
class ZeroField(BlankField): class ZeroField(BlankField):
def get_data(self): def get_data(self):
return '0' * self.max_length return b'0' * self.max_length
class CRLFField(TextField): class CRLFField(TextField):
def __init__(self, name=None, required=False): def __init__(self, name=None, required=False):
@ -255,7 +257,9 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self) raise ValidationError("value is too long", field=self)
def get_data(self): def get_data(self):
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length] cents = int((self.value or 0) * 100)
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
def parse(self, s): def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01') self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
@ -269,7 +273,7 @@ class DateField(TextField):
def get_data(self): def get_data(self):
if self._value: if self._value:
return bytes(self._value.strftime('%m%d%Y'), 'ascii') return self._value.strftime('%m%d%Y').encode('ascii')
return b'0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):
@ -301,7 +305,7 @@ class MonthYearField(TextField):
def get_data(self): def get_data(self):
if self._value: if self._value:
return bytes(self._value.strftime("%m%Y"), 'ascii') return str(self._value.strftime('%m%Y').encode('ascii'))
return b'0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):

View file

@ -1,6 +1,5 @@
from .fields import Field, TextField, ValidationError from .fields import Field, TextField, ValidationError
import copy import copy
import pdb
import collections import collections
@ -42,7 +41,7 @@ class Model(object):
def get_sorted_fields(self): def get_sorted_fields(self):
fields = self.get_fields() fields = self.get_fields()
fields.sort(key=lambda x:x.creation_counter) fields.sort(key=lambda x: x.creation_counter)
return fields return fields
def validate(self): def validate(self):
@ -51,19 +50,17 @@ class Model(object):
try: try:
custom_validator = getattr(self, 'validate_' + f.name) custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError as e: except AttributeError:
continue continue
if isinstance(custom_validator, collections.Callable): if isinstance(custom_validator, collections.Callable):
custom_validator(f) custom_validator(f)
def output(self): def output(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()]) result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if hasattr(self, 'record_length') and len(result) != self.record_length: if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result return result
def read(self, fp): def read(self, fp):
@ -71,7 +68,6 @@ class Model(object):
for field in self.get_sorted_fields()[1:]: for field in self.get_sorted_fields()[1:]:
field.read(fp) field.read(fp)
def toJSON(self): def toJSON(self):
return { return {
'__class__': self.__class__.__name__, '__class__': self.__class__.__name__,
@ -84,8 +80,8 @@ class Model(object):
for f in fields: for f in fields:
target = self.__dict__[f.name] target = self.__dict__[f.name]
if (target.required != f.required or if (target.required != f.required
target.max_length != f.max_length): or target.max_length != f.max_length):
print("Warning: value mismatch on import") print("Warning: value mismatch on import")
target._value = f._value target._value = f._value

View file

@ -1,86 +1,86 @@
#!/usr/bin/env python
import re import re
class ClassEntryCommentSequence(object): class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$') re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line): def __init__(self, classname, line):
self.classname = classname, self.classname = classname,
self.line = line self.line = line
self.lines = [] self.lines = []
def add_line(self, line): def add_line(self, line):
self.lines.append(line) self.lines.append(line)
def process(self): def process(self):
i = 0 i = 0
for (line_no, line) in enumerate(self.lines): for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line) match = self.re_rangecomment.search(line)
if match: if match:
(a, b) = match.groups() (a, b) = match.groups()
a = int(a) a = int(a)
if (i + 1) != a: if (i + 1) != a:
line_number = self.line + line_no line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a))) print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
line_number, line.split(' ')[0].strip(), i+1, a)))
i = int(b) if b else a
i = int(b) if b else a
class ModelDefParser(object): class ModelDefParser(object):
re_triplequote = re.compile('"""') re_triplequote = re.compile('"""')
re_whitespace = re.compile("^(\s*)[^\s]+") re_whitespace = re.compile(r"^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$") re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass): def __init__(self, infile, entryclass):
self.infile = infile self.infile = infile
self.line = 0 self.line = 0
self.EntryClass = entryclass self.EntryClass = entryclass
def endclass(self): def endclass(self):
if self.current_class: if self.current_class:
self.current_class.process() self.current_class.process()
self.current_class = None self.current_class = None
def beginclass(self, classname, line): def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line) self.current_class = self.EntryClass(classname, line)
def parse(self): def parse(self):
infile = self.infile infile = self.infile
whitespace = 0 whitespace = 0
in_block_comment = False in_block_comment = False
self.current_class = None self.current_class = None
for line in infile: for line in infile:
self.line += 1 self.line += 1
if line.startswith('#'): if line.startswith('#'):
continue continue
if self.re_triplequote.search(line): if self.re_triplequote.search(line):
in_block_comment = not in_block_comment in_block_comment = not in_block_comment
if in_block_comment: if in_block_comment:
continue continue
match_whitespace = self.re_whitespace.match(line) match_whitespace = self.re_whitespace.match(line)
if match_whitespace: if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0]) match_whitespace = len(match_whitespace.groups()[0])
else: else:
match_whitespace = 0 match_whitespace = 0
classmatch = self.re_classdef.match(line) classmatch = self.re_classdef.match(line)
if classmatch: if classmatch:
classname, subclass = classmatch.groups() classname, subclass = classmatch.groups()
self.beginclass(classname, self.line) self.beginclass(classname, self.line)
continue continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,4 +1,11 @@
from distutils.core import setup from setuptools import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
setup(name='pyaccuwage', setup(name='pyaccuwage',
version='0.2018.1', version='0.2018.1',
packages=['pyaccuwage'], packages=['pyaccuwage'],
@ -9,4 +16,5 @@ setup(name='pyaccuwage',
'scripts/pyaccuwage-genfieldfill' 'scripts/pyaccuwage-genfieldfill'
], ],
zip_safe=True, zip_safe=True,
test_suite='setup.pyaccuwage_tests',
) )

72
tests/test_fields.py Normal file
View file

@ -0,0 +1,72 @@
import unittest
import decimal
from pyaccuwage.fields import TextField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
self.assertEqual(len(field.get_data()), field.max_length)
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('1234.56')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'123456'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
self.assertEqual(model.output(), expected)