hopefully fix python 2 and 3 compatability

This commit is contained in:
Mark Riedesel 2019-01-27 09:30:22 -06:00
parent 6381f8b1ec
commit d08f1ca586
5 changed files with 169 additions and 89 deletions

View file

@ -115,22 +115,22 @@ class Field(object):
class TextField(Field):
def validate(self):
if self.value == None and self.required:
if self.value is None and self.required:
raise ValidationError("value required", field=self)
if len(self.get_data()) > self.max_length:
raise ValidationError("value is too long", field=self)
def get_data(self):
value = self.value or ""
value = str(self.value).encode('ascii') or b''
if self.uppercase:
value = value.upper()
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
return value.ljust(self.max_length)[:self.max_length]
def __setvalue(self, value):
# NO NEWLINES
try:
value = value.replace('\n', '').replace('\r', '')
except AttributeError as e:
except AttributeError:
pass
self._value = value
@ -146,12 +146,15 @@ class StateField(TextField):
self.use_numeric = use_numeric
def get_data(self):
value = self.value or ""
# value = str(self.value or 'XX').encode('ascii') or b''
value = str(self.value or 'XX')
if value.strip() and self.use_numeric:
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
postcode = enums.state_postal_numeric[value.upper()]
postcode = str(postcode).encode('ascii')
return postcode.zfill(self.max_length)
else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
formatted = value.encode('ascii').ljust(self.max_length)
return formatted[:self.max_length]
def validate(self):
super(StateField, self).validate()
@ -160,7 +163,7 @@ class StateField(TextField):
def parse(self, s):
if s.strip() and self.use_numeric:
states = dict( [(v,k) for (k,v) in list(enums.state_postal_numeric.items())] )
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
self.value = states[int(s)]
else:
self.value = s
@ -179,9 +182,8 @@ class IntegerField(TextField):
except ValueError:
raise ValidationError("field contains non-numeric characters", field=self)
def get_data(self):
value = bytes(str(self.value), 'ascii') if self.value else b''
value = str(self.value).encode('ascii') if self.value else b''
return value.zfill(self.max_length)[:self.max_length]
def parse(self, s):
@ -209,7 +211,7 @@ class BlankField(TextField):
class ZeroField(BlankField):
def get_data(self):
return '0' * self.max_length
return b'0' * self.max_length
class CRLFField(TextField):
def __init__(self, name=None, required=False):
@ -255,7 +257,9 @@ class MoneyField(Field):
raise ValidationError("value is too long", field=self)
def get_data(self):
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
cents = int((self.value or 0) * 100)
formatted = str(cents).encode('ascii').zfill(self.max_length)
return formatted[:self.max_length]
def parse(self, s):
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
@ -269,7 +273,7 @@ class DateField(TextField):
def get_data(self):
if self._value:
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
return self._value.strftime('%m%d%Y').encode('ascii')
return b'0' * self.max_length
def parse(self, s):
@ -301,7 +305,7 @@ class MonthYearField(TextField):
def get_data(self):
if self._value:
return bytes(self._value.strftime("%m%Y"), 'ascii')
return str(self._value.strftime('%m%Y').encode('ascii'))
return b'0' * self.max_length
def parse(self, s):

View file

@ -1,6 +1,5 @@
from .fields import Field, TextField, ValidationError
import copy
import pdb
import collections
@ -42,7 +41,7 @@ class Model(object):
def get_sorted_fields(self):
fields = self.get_fields()
fields.sort(key=lambda x:x.creation_counter)
fields.sort(key=lambda x: x.creation_counter)
return fields
def validate(self):
@ -51,19 +50,17 @@ class Model(object):
try:
custom_validator = getattr(self, 'validate_' + f.name)
except AttributeError as e:
except AttributeError:
continue
if isinstance(custom_validator, collections.Callable):
custom_validator(f)
custom_validator(f)
def output(self):
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
#if len(result) != self.target_size:
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
return result
def read(self, fp):
@ -71,7 +68,6 @@ class Model(object):
for field in self.get_sorted_fields()[1:]:
field.read(fp)
def toJSON(self):
return {
'__class__': self.__class__.__name__,
@ -84,8 +80,8 @@ class Model(object):
for f in fields:
target = self.__dict__[f.name]
if (target.required != f.required or
target.max_length != f.max_length):
if (target.required != f.required
or target.max_length != f.max_length):
print("Warning: value mismatch on import")
target._value = f._value

View file

@ -1,86 +1,86 @@
#!/usr/bin/env python
import re
class ClassEntryCommentSequence(object):
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
def __init__(self, classname, line):
self.classname = classname,
self.line = line
self.lines = []
def __init__(self, classname, line):
self.classname = classname,
self.line = line
self.lines = []
def add_line(self, line):
self.lines.append(line)
def add_line(self, line):
self.lines.append(line)
def process(self):
i = 0
for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line)
if match:
(a, b) = match.groups()
a = int(a)
def process(self):
i = 0
for (line_no, line) in enumerate(self.lines):
match = self.re_rangecomment.search(line)
if match:
(a, b) = match.groups()
a = int(a)
if (i + 1) != a:
line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a)))
if (i + 1) != a:
line_number = self.line + line_no
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
line_number, line.split(' ')[0].strip(), i+1, a)))
i = int(b) if b else a
i = int(b) if b else a
class ModelDefParser(object):
re_triplequote = re.compile('"""')
re_whitespace = re.compile("^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
re_triplequote = re.compile('"""')
re_whitespace = re.compile(r"^(\s*)[^\s]+")
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
def __init__(self, infile, entryclass):
self.infile = infile
self.line = 0
self.EntryClass = entryclass
def __init__(self, infile, entryclass):
self.infile = infile
self.line = 0
self.EntryClass = entryclass
def endclass(self):
if self.current_class:
self.current_class.process()
self.current_class = None
def endclass(self):
if self.current_class:
self.current_class.process()
self.current_class = None
def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line)
def beginclass(self, classname, line):
self.current_class = self.EntryClass(classname, line)
def parse(self):
infile = self.infile
whitespace = 0
in_block_comment = False
self.current_class = None
def parse(self):
infile = self.infile
whitespace = 0
in_block_comment = False
self.current_class = None
for line in infile:
self.line += 1
for line in infile:
self.line += 1
if line.startswith('#'):
continue
if line.startswith('#'):
continue
if self.re_triplequote.search(line):
in_block_comment = not in_block_comment
if self.re_triplequote.search(line):
in_block_comment = not in_block_comment
if in_block_comment:
continue
if in_block_comment:
continue
match_whitespace = self.re_whitespace.match(line)
if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0])
else:
match_whitespace = 0
match_whitespace = self.re_whitespace.match(line)
if match_whitespace:
match_whitespace = len(match_whitespace.groups()[0])
else:
match_whitespace = 0
classmatch = self.re_classdef.match(line)
if classmatch:
classname, subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)
classmatch = self.re_classdef.match(line)
if classmatch:
classname, subclass = classmatch.groups()
self.beginclass(classname, self.line)
continue
if match_whitespace < whitespace:
whitespace = match_whitespace
self.endclass()
continue
if self.current_class:
whitespace = match_whitespace
self.current_class.add_line(line)

View file

@ -1,4 +1,11 @@
from distutils.core import setup
from setuptools import setup
import unittest
def pyaccuwage_tests():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
return test_suite
setup(name='pyaccuwage',
version='0.2018.1',
packages=['pyaccuwage'],
@ -9,4 +16,5 @@ setup(name='pyaccuwage',
'scripts/pyaccuwage-genfieldfill'
],
zip_safe=True,
test_suite='setup.pyaccuwage_tests',
)

72
tests/test_fields.py Normal file
View file

@ -0,0 +1,72 @@
import unittest
import decimal
from pyaccuwage.fields import TextField
from pyaccuwage.fields import IntegerField
from pyaccuwage.fields import StateField
from pyaccuwage.fields import BlankField
from pyaccuwage.fields import ZeroField
from pyaccuwage.fields import MoneyField
from pyaccuwage.fields import ValidationError
from pyaccuwage.model import Model
class TestTextField(unittest.TestCase):
def testStringShortOptional(self):
field = TextField(max_length=6, required=False)
field.validate() # optional
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringShortRequired(self):
field = TextField(max_length=6, required=True)
with self.assertRaises(ValidationError):
field.validate()
field.value = 'Hello'
field.validate()
self.assertEqual(field.get_data(), b'HELLO ')
def testStringLongOptional(self):
field = TextField(max_length=6, required=False)
field.value = 'Hello, World!' # too long
self.assertEqual(len(field.get_data()), field.max_length)
class TestModelOutput(unittest.TestCase):
class TestModel(Model):
record_length = 128
record_identifier = 'TEST' # 4 bytes
field1 = TextField(max_length=16)
field2 = IntegerField(max_length=16)
blank1 = BlankField(max_length=16)
zero1 = ZeroField(max_length=16)
money = MoneyField(max_length=32)
state_txt = StateField()
state_num = StateField(use_numeric=True)
blank2 = BlankField(max_length=24)
def setUp(self):
self.model = TestModelOutput.TestModel()
def testModelOutput(self):
model = self.model
model.field1.value = 'Hello, sir!'
model.field2.value = 12345
model.money.value = decimal.Decimal('1234.56')
model.state_txt.value = 'IA'
model.state_num.value = 'IA'
expected = b''.join([
b'TEST',
b'HELLO, SIR!'.ljust(16),
b'12345'.zfill(16),
b' ' * 16,
b'0' * 16,
b'123456'.zfill(32),
b'IA',
b'19',
b' ' * 24,
])
self.assertEqual(model.output(), expected)