diff --git a/pyaccuwage/__init__.py b/pyaccuwage/__init__.py index 3db24fc..e9c3878 100644 --- a/pyaccuwage/__init__.py +++ b/pyaccuwage/__init__.py @@ -66,9 +66,12 @@ def dump(fp, records, delim=None): fp.write(delim) -def dumps(records, delim=None): +def dumps(records, delim=None, skip_validation=False): import io fp = io.BytesIO() + if not skip_validation: + for record in records: + record.validate() dump(fp, records, delim=delim) fp.seek(0) return fp.read() diff --git a/pyaccuwage/fields.py b/pyaccuwage/fields.py index 2a78fd5..18aaabf 100644 --- a/pyaccuwage/fields.py +++ b/pyaccuwage/fields.py @@ -23,11 +23,12 @@ class Field(object): is_read_only = False _value = None - def __init__(self, name=None, max_length=0, required=True, uppercase=True, creation_counter=None): + def __init__(self, name=None, max_length=0, blank=False, required=True, uppercase=True, creation_counter=None): self.name = name self._value = None self._orig_value = None self.max_length = max_length + self.blank = blank self.required = required self.uppercase = uppercase self.creation_counter = creation_counter or Field.creation_counter @@ -97,9 +98,9 @@ class Field(object): wrapper.width = 100 value = wrapper.wrap(value) value = list([(" " * 9) + ('"' + x + '"') for x in value]) - value.append(" " * 10 + ('_' * 10) * (wrapper.width / 10)) - value.append(" " * 10 + ('0123456789') * (wrapper.width / 10)) - value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(wrapper.width / 10 )]))) + value.append(" " * 10 + ('_' * 10) * int(wrapper.width / 10)) + value.append(" " * 10 + ('0123456789') * int(wrapper.width / 10)) + value.append(" " * 10 + ''.join(([str(x) + (' ' * 9) for x in range(int(wrapper.width / 10))]))) start = counter['c'] counter['c'] += len(self._orig_value or self.value) @@ -121,6 +122,9 @@ class TextField(Field): raise ValidationError("value required", field=self) if len(self.get_data()) > self.max_length: raise ValidationError("value is too long", field=self) + if len(self.get_data().strip()) == 0 and (not self.blank and self.required): + print(self.name, 'blank', self.blank, self.required) + raise ValidationError("field cannot be blank", field=self) def get_data(self): value = str(self.value or '').encode('ascii') or b'' @@ -144,7 +148,7 @@ class TextField(Field): class StateField(TextField): def __init__(self, name=None, required=True, use_numeric=False, max_length=2): - super(StateField, self).__init__(name=name, max_length=2, required=required) + super(StateField, self).__init__(name=name, max_length=max_length, required=required) self.use_numeric = use_numeric def get_data(self): @@ -219,6 +223,10 @@ class BlankField(TextField): def parse(self, s): pass + def validate(self): + if len(self.get_data()) != self.max_length: + raise ValidationError("blank field did not match expected length", field=self) + class ZeroField(BlankField): is_read_only = True diff --git a/pyaccuwage/model.py b/pyaccuwage/model.py index 1d7e4ed..c950055 100644 --- a/pyaccuwage/model.py +++ b/pyaccuwage/model.py @@ -39,7 +39,11 @@ class Model(object): getattr(self, field_name).value = value def get_fields(self): - identifier = TextField("record_identifier", max_length=len(self.record_identifier), creation_counter=-1) + identifier = TextField( + "record_identifier", + max_length = len(self.record_identifier), + blank = len(self.record_identifier) == 0, + creation_counter=-1) identifier.value = self.record_identifier fields = [identifier] diff --git a/setup.py b/setup.py index 078072e..c3d830a 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ def pyaccuwage_tests(): return test_suite setup(name='pyaccuwage', - version='0.2024.0', + version='0.2024.1', packages=['pyaccuwage'], scripts=[ 'scripts/pyaccuwage-checkseq', diff --git a/tests/test_records.py b/tests/test_records.py index 166b6d5..67ce3ce 100644 --- a/tests/test_records.py +++ b/tests/test_records.py @@ -8,6 +8,7 @@ from pyaccuwage.fields import StateField from pyaccuwage.fields import TextField from pyaccuwage.fields import ZeroField from pyaccuwage.fields import StaticField +from pyaccuwage.fields import ValidationError from pyaccuwage.model import Model class TestModelOutput(unittest.TestCase): @@ -90,7 +91,7 @@ class TestFileFormats(unittest.TestCase): record_identifier = 'B' # 1 byte zero1 = ZeroField(max_length=32) text1 = TextField(max_length=71) - text2 = TextField(max_length=20) + text2 = TextField(max_length=20, required=False) blank2 = BlankField(max_length=4) record_types = [TestModelA, TestModelB] @@ -130,3 +131,49 @@ class TestFileFormats(unittest.TestCase): original_bytes = pyaccuwage.dumps(records) reloaded_bytes = pyaccuwage.dumps(records_loaded) self.assertEqual(original_bytes, reloaded_bytes) + + +class TestRequiredFields(unittest.TestCase): + def createTestRecord(self, required=False, blank=False): + class Record(pyaccuwage.model.Model): + record_length = 16 + record_identifier = '' + test_field = TextField(max_length=16, required=required, blank=blank) + record = Record() + def dump(): + return pyaccuwage.dumps([record]) + return (record, dump) + + def testRequiredBlankField(self): + (record, dump) = self.createTestRecord(required=True, blank=True) + record.test_field.value # if nothing is ever assigned, raise error + self.assertRaises(ValidationError, dump) + record.test_field.value = '' # value may be empty string + dump() + + def testRequiredNonblankField(self): + (record, dump) = self.createTestRecord(required=True, blank=False) + record.test_field.value # if nothing is ever assigned, raise error + self.assertRaises(ValidationError, dump) + record.test_field.value = '' # value must not be empty string + self.assertRaises(ValidationError, dump) + record.test_field.value = 'hello' + dump() + + def testOptionalBlankField(self): + (record, dump) = self.createTestRecord(required=False, blank=True) + record.test_field.value # OK if nothing is ever assigned + dump() + record.test_field.value = '' # OK if empty string is assigned + dump() + record.test_field.value = 'hello' + dump() + + def testOptionalNonBlankField(self): + (record, dump) = self.createTestRecord(required=False, blank=False) + record.test_field.value # OK if nothing is ever assigned + dump() + record.test_field.value = '' # OK if empty string is assigned + dump() + record.test_field.value = 'hello' + dump()