use BytesIO to work with python3

This commit is contained in:
Mark Riedesel 2017-01-07 14:52:33 -06:00
parent 16bf2c41d0
commit 9320c68961
3 changed files with 14 additions and 13 deletions

View file

@ -37,7 +37,7 @@ def test_dump():
record.EmployerRecord(),
record.EmployeeWageRecord(),
]
out = io.StringIO()
out = io.BytesIO()
dump(records, out)
return out
@ -75,7 +75,7 @@ def load(fp):
def loads(s):
import io
fp = io.StringIO(s)
fp = io.BytesIO(s)
return load(fp)
@ -85,7 +85,7 @@ def dump(records, fp):
def dumps(records):
import io
fp = io.StringIO()
fp = io.BytesIO()
dump(records, fp)
fp.seek(0)
return fp.read()

View file

@ -148,7 +148,8 @@ class StateField(TextField):
def get_data(self):
value = self.value or ""
if value.strip() and self.use_numeric:
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
return postcode.zfill(self.max_length)
else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
@ -180,8 +181,8 @@ class IntegerField(TextField):
def get_data(self):
value = self.value or ""
return str(value).zfill(self.max_length)[:self.max_length]
value = bytes(str(self.value), 'ascii') if self.value else b''
return value.zfill(self.max_length)[:self.max_length]
def parse(self, s):
self.value = int(s)
@ -201,7 +202,7 @@ class BlankField(TextField):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self):
return " " * self.max_length
return b' ' * self.max_length
def parse(self, s):
pass
@ -233,7 +234,7 @@ class BooleanField(Field):
pass
def get_data(self):
return '1' if self._value else '0'
return b'1' if self._value else b'0'
def parse(self, s):
self.value = (s == '1')
@ -264,8 +265,8 @@ class DateField(TextField):
def get_data(self):
if self._value:
return self._value.strftime('%m%d%Y')
return '0' * self.max_length
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
return b'0' * self.max_length
def parse(self, s):
if int(s) > 0:
@ -296,8 +297,8 @@ class MonthYearField(TextField):
def get_data(self):
if self._value:
return self._value.strftime("%m%Y")
return '0' * self.max_length
return bytes(self._value.strftime("%m%Y"), 'ascii')
return b'0' * self.max_length
def parse(self, s):
if int(s) > 0:

View file

@ -57,7 +57,7 @@ class Model(object):
custom_validator(f)
def output(self):
result = ''.join([field.get_data() for field in self.get_sorted_fields()])
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))