hopefully fix python 2 and 3 compatability
This commit is contained in:
parent
6381f8b1ec
commit
d08f1ca586
5 changed files with 169 additions and 89 deletions
|
@ -115,22 +115,22 @@ class Field(object):
|
||||||
|
|
||||||
class TextField(Field):
|
class TextField(Field):
|
||||||
def validate(self):
|
def validate(self):
|
||||||
if self.value == None and self.required:
|
if self.value is None and self.required:
|
||||||
raise ValidationError("value required", field=self)
|
raise ValidationError("value required", field=self)
|
||||||
if len(self.get_data()) > self.max_length:
|
if len(self.get_data()) > self.max_length:
|
||||||
raise ValidationError("value is too long", field=self)
|
raise ValidationError("value is too long", field=self)
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
value = self.value or ""
|
value = str(self.value).encode('ascii') or b''
|
||||||
if self.uppercase:
|
if self.uppercase:
|
||||||
value = value.upper()
|
value = value.upper()
|
||||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
return value.ljust(self.max_length)[:self.max_length]
|
||||||
|
|
||||||
def __setvalue(self, value):
|
def __setvalue(self, value):
|
||||||
# NO NEWLINES
|
# NO NEWLINES
|
||||||
try:
|
try:
|
||||||
value = value.replace('\n', '').replace('\r', '')
|
value = value.replace('\n', '').replace('\r', '')
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
self._value = value
|
self._value = value
|
||||||
|
|
||||||
|
@ -146,12 +146,15 @@ class StateField(TextField):
|
||||||
self.use_numeric = use_numeric
|
self.use_numeric = use_numeric
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
value = self.value or ""
|
# value = str(self.value or 'XX').encode('ascii') or b''
|
||||||
|
value = str(self.value or 'XX')
|
||||||
if value.strip() and self.use_numeric:
|
if value.strip() and self.use_numeric:
|
||||||
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
|
postcode = enums.state_postal_numeric[value.upper()]
|
||||||
|
postcode = str(postcode).encode('ascii')
|
||||||
return postcode.zfill(self.max_length)
|
return postcode.zfill(self.max_length)
|
||||||
else:
|
else:
|
||||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
formatted = value.encode('ascii').ljust(self.max_length)
|
||||||
|
return formatted[:self.max_length]
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
super(StateField, self).validate()
|
super(StateField, self).validate()
|
||||||
|
@ -160,7 +163,7 @@ class StateField(TextField):
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
if s.strip() and self.use_numeric:
|
if s.strip() and self.use_numeric:
|
||||||
states = dict( [(v,k) for (k,v) in list(enums.state_postal_numeric.items())] )
|
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
|
||||||
self.value = states[int(s)]
|
self.value = states[int(s)]
|
||||||
else:
|
else:
|
||||||
self.value = s
|
self.value = s
|
||||||
|
@ -179,9 +182,8 @@ class IntegerField(TextField):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValidationError("field contains non-numeric characters", field=self)
|
raise ValidationError("field contains non-numeric characters", field=self)
|
||||||
|
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
value = bytes(str(self.value), 'ascii') if self.value else b''
|
value = str(self.value).encode('ascii') if self.value else b''
|
||||||
return value.zfill(self.max_length)[:self.max_length]
|
return value.zfill(self.max_length)[:self.max_length]
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
|
@ -209,7 +211,7 @@ class BlankField(TextField):
|
||||||
|
|
||||||
class ZeroField(BlankField):
|
class ZeroField(BlankField):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return '0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
class CRLFField(TextField):
|
class CRLFField(TextField):
|
||||||
def __init__(self, name=None, required=False):
|
def __init__(self, name=None, required=False):
|
||||||
|
@ -255,7 +257,9 @@ class MoneyField(Field):
|
||||||
raise ValidationError("value is too long", field=self)
|
raise ValidationError("value is too long", field=self)
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
|
cents = int((self.value or 0) * 100)
|
||||||
|
formatted = str(cents).encode('ascii').zfill(self.max_length)
|
||||||
|
return formatted[:self.max_length]
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
||||||
|
@ -269,7 +273,7 @@ class DateField(TextField):
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
if self._value:
|
if self._value:
|
||||||
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
|
return self._value.strftime('%m%d%Y').encode('ascii')
|
||||||
return b'0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
|
@ -301,7 +305,7 @@ class MonthYearField(TextField):
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
if self._value:
|
if self._value:
|
||||||
return bytes(self._value.strftime("%m%Y"), 'ascii')
|
return str(self._value.strftime('%m%Y').encode('ascii'))
|
||||||
return b'0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from .fields import Field, TextField, ValidationError
|
from .fields import Field, TextField, ValidationError
|
||||||
import copy
|
import copy
|
||||||
import pdb
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +41,7 @@ class Model(object):
|
||||||
|
|
||||||
def get_sorted_fields(self):
|
def get_sorted_fields(self):
|
||||||
fields = self.get_fields()
|
fields = self.get_fields()
|
||||||
fields.sort(key=lambda x:x.creation_counter)
|
fields.sort(key=lambda x: x.creation_counter)
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
@ -51,7 +50,7 @@ class Model(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
custom_validator = getattr(self, 'validate_' + f.name)
|
custom_validator = getattr(self, 'validate_' + f.name)
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
continue
|
continue
|
||||||
if isinstance(custom_validator, collections.Callable):
|
if isinstance(custom_validator, collections.Callable):
|
||||||
custom_validator(f)
|
custom_validator(f)
|
||||||
|
@ -61,9 +60,7 @@ class Model(object):
|
||||||
|
|
||||||
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
||||||
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
||||||
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
|
|
||||||
#if len(result) != self.target_size:
|
|
||||||
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def read(self, fp):
|
def read(self, fp):
|
||||||
|
@ -71,7 +68,6 @@ class Model(object):
|
||||||
for field in self.get_sorted_fields()[1:]:
|
for field in self.get_sorted_fields()[1:]:
|
||||||
field.read(fp)
|
field.read(fp)
|
||||||
|
|
||||||
|
|
||||||
def toJSON(self):
|
def toJSON(self):
|
||||||
return {
|
return {
|
||||||
'__class__': self.__class__.__name__,
|
'__class__': self.__class__.__name__,
|
||||||
|
@ -84,8 +80,8 @@ class Model(object):
|
||||||
for f in fields:
|
for f in fields:
|
||||||
target = self.__dict__[f.name]
|
target = self.__dict__[f.name]
|
||||||
|
|
||||||
if (target.required != f.required or
|
if (target.required != f.required
|
||||||
target.max_length != f.max_length):
|
or target.max_length != f.max_length):
|
||||||
print("Warning: value mismatch on import")
|
print("Warning: value mismatch on import")
|
||||||
|
|
||||||
target._value = f._value
|
target._value = f._value
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
class ClassEntryCommentSequence(object):
|
class ClassEntryCommentSequence(object):
|
||||||
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
|
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
|
||||||
|
|
||||||
|
@ -22,13 +22,15 @@ class ClassEntryCommentSequence(object):
|
||||||
|
|
||||||
if (i + 1) != a:
|
if (i + 1) != a:
|
||||||
line_number = self.line + line_no
|
line_number = self.line + line_no
|
||||||
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a)))
|
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
|
||||||
|
line_number, line.split(' ')[0].strip(), i+1, a)))
|
||||||
|
|
||||||
i = int(b) if b else a
|
i = int(b) if b else a
|
||||||
|
|
||||||
|
|
||||||
class ModelDefParser(object):
|
class ModelDefParser(object):
|
||||||
re_triplequote = re.compile('"""')
|
re_triplequote = re.compile('"""')
|
||||||
re_whitespace = re.compile("^(\s*)[^\s]+")
|
re_whitespace = re.compile(r"^(\s*)[^\s]+")
|
||||||
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
|
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
|
||||||
|
|
||||||
def __init__(self, infile, entryclass):
|
def __init__(self, infile, entryclass):
|
||||||
|
@ -82,5 +84,3 @@ class ModelDefParser(object):
|
||||||
if self.current_class:
|
if self.current_class:
|
||||||
whitespace = match_whitespace
|
whitespace = match_whitespace
|
||||||
self.current_class.add_line(line)
|
self.current_class.add_line(line)
|
||||||
|
|
||||||
|
|
||||||
|
|
10
setup.py
10
setup.py
|
@ -1,4 +1,11 @@
|
||||||
from distutils.core import setup
|
from setuptools import setup
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
def pyaccuwage_tests():
|
||||||
|
test_loader = unittest.TestLoader()
|
||||||
|
test_suite = test_loader.discover('tests', pattern='test_*.py')
|
||||||
|
return test_suite
|
||||||
|
|
||||||
setup(name='pyaccuwage',
|
setup(name='pyaccuwage',
|
||||||
version='0.2018.1',
|
version='0.2018.1',
|
||||||
packages=['pyaccuwage'],
|
packages=['pyaccuwage'],
|
||||||
|
@ -9,4 +16,5 @@ setup(name='pyaccuwage',
|
||||||
'scripts/pyaccuwage-genfieldfill'
|
'scripts/pyaccuwage-genfieldfill'
|
||||||
],
|
],
|
||||||
zip_safe=True,
|
zip_safe=True,
|
||||||
|
test_suite='setup.pyaccuwage_tests',
|
||||||
)
|
)
|
||||||
|
|
72
tests/test_fields.py
Normal file
72
tests/test_fields.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import unittest
|
||||||
|
import decimal
|
||||||
|
from pyaccuwage.fields import TextField
|
||||||
|
from pyaccuwage.fields import IntegerField
|
||||||
|
from pyaccuwage.fields import StateField
|
||||||
|
from pyaccuwage.fields import BlankField
|
||||||
|
from pyaccuwage.fields import ZeroField
|
||||||
|
from pyaccuwage.fields import MoneyField
|
||||||
|
from pyaccuwage.fields import ValidationError
|
||||||
|
from pyaccuwage.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextField(unittest.TestCase):
|
||||||
|
|
||||||
|
def testStringShortOptional(self):
|
||||||
|
field = TextField(max_length=6, required=False)
|
||||||
|
field.validate() # optional
|
||||||
|
field.value = 'Hello'
|
||||||
|
field.validate()
|
||||||
|
self.assertEqual(field.get_data(), b'HELLO ')
|
||||||
|
|
||||||
|
def testStringShortRequired(self):
|
||||||
|
field = TextField(max_length=6, required=True)
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
field.validate()
|
||||||
|
field.value = 'Hello'
|
||||||
|
field.validate()
|
||||||
|
self.assertEqual(field.get_data(), b'HELLO ')
|
||||||
|
|
||||||
|
def testStringLongOptional(self):
|
||||||
|
field = TextField(max_length=6, required=False)
|
||||||
|
field.value = 'Hello, World!' # too long
|
||||||
|
self.assertEqual(len(field.get_data()), field.max_length)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelOutput(unittest.TestCase):
|
||||||
|
class TestModel(Model):
|
||||||
|
record_length = 128
|
||||||
|
record_identifier = 'TEST' # 4 bytes
|
||||||
|
field1 = TextField(max_length=16)
|
||||||
|
field2 = IntegerField(max_length=16)
|
||||||
|
blank1 = BlankField(max_length=16)
|
||||||
|
zero1 = ZeroField(max_length=16)
|
||||||
|
money = MoneyField(max_length=32)
|
||||||
|
state_txt = StateField()
|
||||||
|
state_num = StateField(use_numeric=True)
|
||||||
|
blank2 = BlankField(max_length=24)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model = TestModelOutput.TestModel()
|
||||||
|
|
||||||
|
def testModelOutput(self):
|
||||||
|
model = self.model
|
||||||
|
model.field1.value = 'Hello, sir!'
|
||||||
|
model.field2.value = 12345
|
||||||
|
model.money.value = decimal.Decimal('1234.56')
|
||||||
|
model.state_txt.value = 'IA'
|
||||||
|
model.state_num.value = 'IA'
|
||||||
|
|
||||||
|
expected = b''.join([
|
||||||
|
b'TEST',
|
||||||
|
b'HELLO, SIR!'.ljust(16),
|
||||||
|
b'12345'.zfill(16),
|
||||||
|
b' ' * 16,
|
||||||
|
b'0' * 16,
|
||||||
|
b'123456'.zfill(32),
|
||||||
|
b'IA',
|
||||||
|
b'19',
|
||||||
|
b' ' * 24,
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertEqual(model.output(), expected)
|
Loading…
Add table
Add a link
Reference in a new issue