hopefully fix python 2 and 3 compatability
This commit is contained in:
parent
6381f8b1ec
commit
d08f1ca586
5 changed files with 169 additions and 89 deletions
|
@ -115,22 +115,22 @@ class Field(object):
|
|||
|
||||
class TextField(Field):
|
||||
def validate(self):
|
||||
if self.value == None and self.required:
|
||||
if self.value is None and self.required:
|
||||
raise ValidationError("value required", field=self)
|
||||
if len(self.get_data()) > self.max_length:
|
||||
raise ValidationError("value is too long", field=self)
|
||||
|
||||
def get_data(self):
|
||||
value = self.value or ""
|
||||
value = str(self.value).encode('ascii') or b''
|
||||
if self.uppercase:
|
||||
value = value.upper()
|
||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
||||
return value.ljust(self.max_length)[:self.max_length]
|
||||
|
||||
def __setvalue(self, value):
|
||||
# NO NEWLINES
|
||||
try:
|
||||
value = value.replace('\n', '').replace('\r', '')
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
pass
|
||||
self._value = value
|
||||
|
||||
|
@ -146,12 +146,15 @@ class StateField(TextField):
|
|||
self.use_numeric = use_numeric
|
||||
|
||||
def get_data(self):
|
||||
value = self.value or ""
|
||||
# value = str(self.value or 'XX').encode('ascii') or b''
|
||||
value = str(self.value or 'XX')
|
||||
if value.strip() and self.use_numeric:
|
||||
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
|
||||
postcode = enums.state_postal_numeric[value.upper()]
|
||||
postcode = str(postcode).encode('ascii')
|
||||
return postcode.zfill(self.max_length)
|
||||
else:
|
||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
||||
formatted = value.encode('ascii').ljust(self.max_length)
|
||||
return formatted[:self.max_length]
|
||||
|
||||
def validate(self):
|
||||
super(StateField, self).validate()
|
||||
|
@ -160,7 +163,7 @@ class StateField(TextField):
|
|||
|
||||
def parse(self, s):
|
||||
if s.strip() and self.use_numeric:
|
||||
states = dict( [(v,k) for (k,v) in list(enums.state_postal_numeric.items())] )
|
||||
states = dict([(v, k) for (k, v) in list(enums.state_postal_numeric.items())])
|
||||
self.value = states[int(s)]
|
||||
else:
|
||||
self.value = s
|
||||
|
@ -179,9 +182,8 @@ class IntegerField(TextField):
|
|||
except ValueError:
|
||||
raise ValidationError("field contains non-numeric characters", field=self)
|
||||
|
||||
|
||||
def get_data(self):
|
||||
value = bytes(str(self.value), 'ascii') if self.value else b''
|
||||
value = str(self.value).encode('ascii') if self.value else b''
|
||||
return value.zfill(self.max_length)[:self.max_length]
|
||||
|
||||
def parse(self, s):
|
||||
|
@ -209,7 +211,7 @@ class BlankField(TextField):
|
|||
|
||||
class ZeroField(BlankField):
|
||||
def get_data(self):
|
||||
return '0' * self.max_length
|
||||
return b'0' * self.max_length
|
||||
|
||||
class CRLFField(TextField):
|
||||
def __init__(self, name=None, required=False):
|
||||
|
@ -255,7 +257,9 @@ class MoneyField(Field):
|
|||
raise ValidationError("value is too long", field=self)
|
||||
|
||||
def get_data(self):
|
||||
return str(int((self.value or 0)*100)).encode('ascii').zfill(self.max_length)[:self.max_length]
|
||||
cents = int((self.value or 0) * 100)
|
||||
formatted = str(cents).encode('ascii').zfill(self.max_length)
|
||||
return formatted[:self.max_length]
|
||||
|
||||
def parse(self, s):
|
||||
self.value = decimal.Decimal(s) * decimal.Decimal('0.01')
|
||||
|
@ -269,7 +273,7 @@ class DateField(TextField):
|
|||
|
||||
def get_data(self):
|
||||
if self._value:
|
||||
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
|
||||
return self._value.strftime('%m%d%Y').encode('ascii')
|
||||
return b'0' * self.max_length
|
||||
|
||||
def parse(self, s):
|
||||
|
@ -301,7 +305,7 @@ class MonthYearField(TextField):
|
|||
|
||||
def get_data(self):
|
||||
if self._value:
|
||||
return bytes(self._value.strftime("%m%Y"), 'ascii')
|
||||
return str(self._value.strftime('%m%Y').encode('ascii'))
|
||||
return b'0' * self.max_length
|
||||
|
||||
def parse(self, s):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .fields import Field, TextField, ValidationError
|
||||
import copy
|
||||
import pdb
|
||||
import collections
|
||||
|
||||
|
||||
|
@ -42,7 +41,7 @@ class Model(object):
|
|||
|
||||
def get_sorted_fields(self):
|
||||
fields = self.get_fields()
|
||||
fields.sort(key=lambda x:x.creation_counter)
|
||||
fields.sort(key=lambda x: x.creation_counter)
|
||||
return fields
|
||||
|
||||
def validate(self):
|
||||
|
@ -51,19 +50,17 @@ class Model(object):
|
|||
|
||||
try:
|
||||
custom_validator = getattr(self, 'validate_' + f.name)
|
||||
except AttributeError as e:
|
||||
except AttributeError:
|
||||
continue
|
||||
if isinstance(custom_validator, collections.Callable):
|
||||
custom_validator(f)
|
||||
custom_validator(f)
|
||||
|
||||
def output(self):
|
||||
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
|
||||
|
||||
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
||||
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
||||
#result = ''.join([self.record_identifier] + [field.get_data() for field in self.get_sorted_fields()])
|
||||
#if len(result) != self.target_size:
|
||||
# raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.target_size, len(result)))
|
||||
|
||||
return result
|
||||
|
||||
def read(self, fp):
|
||||
|
@ -71,7 +68,6 @@ class Model(object):
|
|||
for field in self.get_sorted_fields()[1:]:
|
||||
field.read(fp)
|
||||
|
||||
|
||||
def toJSON(self):
|
||||
return {
|
||||
'__class__': self.__class__.__name__,
|
||||
|
@ -84,8 +80,8 @@ class Model(object):
|
|||
for f in fields:
|
||||
target = self.__dict__[f.name]
|
||||
|
||||
if (target.required != f.required or
|
||||
target.max_length != f.max_length):
|
||||
if (target.required != f.required
|
||||
or target.max_length != f.max_length):
|
||||
print("Warning: value mismatch on import")
|
||||
|
||||
target._value = f._value
|
||||
|
|
|
@ -1,86 +1,86 @@
|
|||
#!/usr/bin/env python
|
||||
import re
|
||||
|
||||
|
||||
class ClassEntryCommentSequence(object):
|
||||
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
|
||||
re_rangecomment = re.compile('#\s+(\d+)\-?(\d*)$')
|
||||
|
||||
def __init__(self, classname, line):
|
||||
self.classname = classname,
|
||||
self.line = line
|
||||
self.lines = []
|
||||
def __init__(self, classname, line):
|
||||
self.classname = classname,
|
||||
self.line = line
|
||||
self.lines = []
|
||||
|
||||
def add_line(self, line):
|
||||
self.lines.append(line)
|
||||
def add_line(self, line):
|
||||
self.lines.append(line)
|
||||
|
||||
def process(self):
|
||||
i = 0
|
||||
for (line_no, line) in enumerate(self.lines):
|
||||
match = self.re_rangecomment.search(line)
|
||||
if match:
|
||||
(a, b) = match.groups()
|
||||
a = int(a)
|
||||
def process(self):
|
||||
i = 0
|
||||
for (line_no, line) in enumerate(self.lines):
|
||||
match = self.re_rangecomment.search(line)
|
||||
if match:
|
||||
(a, b) = match.groups()
|
||||
a = int(a)
|
||||
|
||||
if (i + 1) != a:
|
||||
line_number = self.line + line_no
|
||||
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (line_number, line.split(' ')[0].strip(), i+1, a)))
|
||||
if (i + 1) != a:
|
||||
line_number = self.line + line_no
|
||||
print(("ERROR\tline:%d\tnear:%s\texpected:%d\tsaw:%d" % (
|
||||
line_number, line.split(' ')[0].strip(), i+1, a)))
|
||||
|
||||
i = int(b) if b else a
|
||||
|
||||
i = int(b) if b else a
|
||||
|
||||
class ModelDefParser(object):
|
||||
re_triplequote = re.compile('"""')
|
||||
re_whitespace = re.compile("^(\s*)[^\s]+")
|
||||
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
|
||||
re_triplequote = re.compile('"""')
|
||||
re_whitespace = re.compile(r"^(\s*)[^\s]+")
|
||||
re_classdef = re.compile(r"^\s*class\s(.*)\((.*)\):\s*$")
|
||||
|
||||
def __init__(self, infile, entryclass):
|
||||
self.infile = infile
|
||||
self.line = 0
|
||||
self.EntryClass = entryclass
|
||||
def __init__(self, infile, entryclass):
|
||||
self.infile = infile
|
||||
self.line = 0
|
||||
self.EntryClass = entryclass
|
||||
|
||||
def endclass(self):
|
||||
if self.current_class:
|
||||
self.current_class.process()
|
||||
self.current_class = None
|
||||
def endclass(self):
|
||||
if self.current_class:
|
||||
self.current_class.process()
|
||||
self.current_class = None
|
||||
|
||||
def beginclass(self, classname, line):
|
||||
self.current_class = self.EntryClass(classname, line)
|
||||
def beginclass(self, classname, line):
|
||||
self.current_class = self.EntryClass(classname, line)
|
||||
|
||||
def parse(self):
|
||||
infile = self.infile
|
||||
whitespace = 0
|
||||
in_block_comment = False
|
||||
self.current_class = None
|
||||
def parse(self):
|
||||
infile = self.infile
|
||||
whitespace = 0
|
||||
in_block_comment = False
|
||||
self.current_class = None
|
||||
|
||||
for line in infile:
|
||||
self.line += 1
|
||||
for line in infile:
|
||||
self.line += 1
|
||||
|
||||
if line.startswith('#'):
|
||||
continue
|
||||
if line.startswith('#'):
|
||||
continue
|
||||
|
||||
if self.re_triplequote.search(line):
|
||||
in_block_comment = not in_block_comment
|
||||
if self.re_triplequote.search(line):
|
||||
in_block_comment = not in_block_comment
|
||||
|
||||
if in_block_comment:
|
||||
continue
|
||||
if in_block_comment:
|
||||
continue
|
||||
|
||||
match_whitespace = self.re_whitespace.match(line)
|
||||
if match_whitespace:
|
||||
match_whitespace = len(match_whitespace.groups()[0])
|
||||
else:
|
||||
match_whitespace = 0
|
||||
match_whitespace = self.re_whitespace.match(line)
|
||||
if match_whitespace:
|
||||
match_whitespace = len(match_whitespace.groups()[0])
|
||||
else:
|
||||
match_whitespace = 0
|
||||
|
||||
classmatch = self.re_classdef.match(line)
|
||||
if classmatch:
|
||||
classname, subclass = classmatch.groups()
|
||||
self.beginclass(classname, self.line)
|
||||
continue
|
||||
|
||||
if match_whitespace < whitespace:
|
||||
whitespace = match_whitespace
|
||||
self.endclass()
|
||||
continue
|
||||
|
||||
if self.current_class:
|
||||
whitespace = match_whitespace
|
||||
self.current_class.add_line(line)
|
||||
classmatch = self.re_classdef.match(line)
|
||||
if classmatch:
|
||||
classname, subclass = classmatch.groups()
|
||||
self.beginclass(classname, self.line)
|
||||
continue
|
||||
|
||||
if match_whitespace < whitespace:
|
||||
whitespace = match_whitespace
|
||||
self.endclass()
|
||||
continue
|
||||
|
||||
if self.current_class:
|
||||
whitespace = match_whitespace
|
||||
self.current_class.add_line(line)
|
||||
|
|
10
setup.py
10
setup.py
|
@ -1,4 +1,11 @@
|
|||
from distutils.core import setup
|
||||
from setuptools import setup
|
||||
import unittest
|
||||
|
||||
def pyaccuwage_tests():
|
||||
test_loader = unittest.TestLoader()
|
||||
test_suite = test_loader.discover('tests', pattern='test_*.py')
|
||||
return test_suite
|
||||
|
||||
setup(name='pyaccuwage',
|
||||
version='0.2018.1',
|
||||
packages=['pyaccuwage'],
|
||||
|
@ -9,4 +16,5 @@ setup(name='pyaccuwage',
|
|||
'scripts/pyaccuwage-genfieldfill'
|
||||
],
|
||||
zip_safe=True,
|
||||
test_suite='setup.pyaccuwage_tests',
|
||||
)
|
||||
|
|
72
tests/test_fields.py
Normal file
72
tests/test_fields.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import unittest
|
||||
import decimal
|
||||
from pyaccuwage.fields import TextField
|
||||
from pyaccuwage.fields import IntegerField
|
||||
from pyaccuwage.fields import StateField
|
||||
from pyaccuwage.fields import BlankField
|
||||
from pyaccuwage.fields import ZeroField
|
||||
from pyaccuwage.fields import MoneyField
|
||||
from pyaccuwage.fields import ValidationError
|
||||
from pyaccuwage.model import Model
|
||||
|
||||
|
||||
class TestTextField(unittest.TestCase):
|
||||
|
||||
def testStringShortOptional(self):
|
||||
field = TextField(max_length=6, required=False)
|
||||
field.validate() # optional
|
||||
field.value = 'Hello'
|
||||
field.validate()
|
||||
self.assertEqual(field.get_data(), b'HELLO ')
|
||||
|
||||
def testStringShortRequired(self):
|
||||
field = TextField(max_length=6, required=True)
|
||||
with self.assertRaises(ValidationError):
|
||||
field.validate()
|
||||
field.value = 'Hello'
|
||||
field.validate()
|
||||
self.assertEqual(field.get_data(), b'HELLO ')
|
||||
|
||||
def testStringLongOptional(self):
|
||||
field = TextField(max_length=6, required=False)
|
||||
field.value = 'Hello, World!' # too long
|
||||
self.assertEqual(len(field.get_data()), field.max_length)
|
||||
|
||||
|
||||
class TestModelOutput(unittest.TestCase):
|
||||
class TestModel(Model):
|
||||
record_length = 128
|
||||
record_identifier = 'TEST' # 4 bytes
|
||||
field1 = TextField(max_length=16)
|
||||
field2 = IntegerField(max_length=16)
|
||||
blank1 = BlankField(max_length=16)
|
||||
zero1 = ZeroField(max_length=16)
|
||||
money = MoneyField(max_length=32)
|
||||
state_txt = StateField()
|
||||
state_num = StateField(use_numeric=True)
|
||||
blank2 = BlankField(max_length=24)
|
||||
|
||||
def setUp(self):
|
||||
self.model = TestModelOutput.TestModel()
|
||||
|
||||
def testModelOutput(self):
|
||||
model = self.model
|
||||
model.field1.value = 'Hello, sir!'
|
||||
model.field2.value = 12345
|
||||
model.money.value = decimal.Decimal('1234.56')
|
||||
model.state_txt.value = 'IA'
|
||||
model.state_num.value = 'IA'
|
||||
|
||||
expected = b''.join([
|
||||
b'TEST',
|
||||
b'HELLO, SIR!'.ljust(16),
|
||||
b'12345'.zfill(16),
|
||||
b' ' * 16,
|
||||
b'0' * 16,
|
||||
b'123456'.zfill(32),
|
||||
b'IA',
|
||||
b'19',
|
||||
b' ' * 24,
|
||||
])
|
||||
|
||||
self.assertEqual(model.output(), expected)
|
Loading…
Add table
Add a link
Reference in a new issue