use BytesIO to work with python3

This commit is contained in:
Mark Riedesel 2017-01-07 14:52:33 -06:00
parent 16bf2c41d0
commit 9320c68961
3 changed files with 14 additions and 13 deletions

View file

@ -37,7 +37,7 @@ def test_dump():
record.EmployerRecord(), record.EmployerRecord(),
record.EmployeeWageRecord(), record.EmployeeWageRecord(),
] ]
out = io.StringIO() out = io.BytesIO()
dump(records, out) dump(records, out)
return out return out
@ -75,7 +75,7 @@ def load(fp):
def loads(s): def loads(s):
import io import io
fp = io.StringIO(s) fp = io.BytesIO(s)
return load(fp) return load(fp)
@ -85,7 +85,7 @@ def dump(records, fp):
def dumps(records): def dumps(records):
import io import io
fp = io.StringIO() fp = io.BytesIO()
dump(records, fp) dump(records, fp)
fp.seek(0) fp.seek(0)
return fp.read() return fp.read()

View file

@ -148,7 +148,8 @@ class StateField(TextField):
def get_data(self): def get_data(self):
value = self.value or "" value = self.value or ""
if value.strip() and self.use_numeric: if value.strip() and self.use_numeric:
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length) postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
return postcode.zfill(self.max_length)
else: else:
return value.ljust(self.max_length).encode('ascii')[:self.max_length] return value.ljust(self.max_length).encode('ascii')[:self.max_length]
@ -180,8 +181,8 @@ class IntegerField(TextField):
def get_data(self): def get_data(self):
value = self.value or "" value = bytes(str(self.value), 'ascii') if self.value else b''
return str(value).zfill(self.max_length)[:self.max_length] return value.zfill(self.max_length)[:self.max_length]
def parse(self, s): def parse(self, s):
self.value = int(s) self.value = int(s)
@ -201,7 +202,7 @@ class BlankField(TextField):
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False) super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
def get_data(self): def get_data(self):
return " " * self.max_length return b' ' * self.max_length
def parse(self, s): def parse(self, s):
pass pass
@ -233,7 +234,7 @@ class BooleanField(Field):
pass pass
def get_data(self): def get_data(self):
return '1' if self._value else '0' return b'1' if self._value else b'0'
def parse(self, s): def parse(self, s):
self.value = (s == '1') self.value = (s == '1')
@ -264,8 +265,8 @@ class DateField(TextField):
def get_data(self): def get_data(self):
if self._value: if self._value:
return self._value.strftime('%m%d%Y') return bytes(self._value.strftime('%m%d%Y'), 'ascii')
return '0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:
@ -296,8 +297,8 @@ class MonthYearField(TextField):
def get_data(self): def get_data(self):
if self._value: if self._value:
return self._value.strftime("%m%Y") return bytes(self._value.strftime("%m%Y"), 'ascii')
return '0' * self.max_length return b'0' * self.max_length
def parse(self, s): def parse(self, s):
if int(s) > 0: if int(s) > 0:

View file

@ -57,7 +57,7 @@ class Model(object):
custom_validator(f) custom_validator(f)
def output(self): def output(self):
result = ''.join([field.get_data() for field in self.get_sorted_fields()]) result = b''.join([field.get_data() for field in self.get_sorted_fields()])
if hasattr(self, 'record_length') and len(result) != self.record_length: if hasattr(self, 'record_length') and len(result) != self.record_length:
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result))) raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))