use BytesIO to work with python3
This commit is contained in:
parent
16bf2c41d0
commit
9320c68961
3 changed files with 14 additions and 13 deletions
|
@ -37,7 +37,7 @@ def test_dump():
|
|||
record.EmployerRecord(),
|
||||
record.EmployeeWageRecord(),
|
||||
]
|
||||
out = io.StringIO()
|
||||
out = io.BytesIO()
|
||||
dump(records, out)
|
||||
return out
|
||||
|
||||
|
@ -75,7 +75,7 @@ def load(fp):
|
|||
|
||||
def loads(s):
|
||||
import io
|
||||
fp = io.StringIO(s)
|
||||
fp = io.BytesIO(s)
|
||||
return load(fp)
|
||||
|
||||
|
||||
|
@ -85,7 +85,7 @@ def dump(records, fp):
|
|||
|
||||
def dumps(records):
|
||||
import io
|
||||
fp = io.StringIO()
|
||||
fp = io.BytesIO()
|
||||
dump(records, fp)
|
||||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
|
|
@ -148,7 +148,8 @@ class StateField(TextField):
|
|||
def get_data(self):
|
||||
value = self.value or ""
|
||||
if value.strip() and self.use_numeric:
|
||||
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
|
||||
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
|
||||
return postcode.zfill(self.max_length)
|
||||
else:
|
||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
||||
|
||||
|
@ -180,8 +181,8 @@ class IntegerField(TextField):
|
|||
|
||||
|
||||
def get_data(self):
|
||||
value = self.value or ""
|
||||
return str(value).zfill(self.max_length)[:self.max_length]
|
||||
value = bytes(str(self.value), 'ascii') if self.value else b''
|
||||
return value.zfill(self.max_length)[:self.max_length]
|
||||
|
||||
def parse(self, s):
|
||||
self.value = int(s)
|
||||
|
@ -201,7 +202,7 @@ class BlankField(TextField):
|
|||
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
|
||||
|
||||
def get_data(self):
|
||||
return " " * self.max_length
|
||||
return b' ' * self.max_length
|
||||
|
||||
def parse(self, s):
|
||||
pass
|
||||
|
@ -233,7 +234,7 @@ class BooleanField(Field):
|
|||
pass
|
||||
|
||||
def get_data(self):
|
||||
return '1' if self._value else '0'
|
||||
return b'1' if self._value else b'0'
|
||||
|
||||
def parse(self, s):
|
||||
self.value = (s == '1')
|
||||
|
@ -264,8 +265,8 @@ class DateField(TextField):
|
|||
|
||||
def get_data(self):
|
||||
if self._value:
|
||||
return self._value.strftime('%m%d%Y')
|
||||
return '0' * self.max_length
|
||||
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
|
||||
return b'0' * self.max_length
|
||||
|
||||
def parse(self, s):
|
||||
if int(s) > 0:
|
||||
|
@ -296,8 +297,8 @@ class MonthYearField(TextField):
|
|||
|
||||
def get_data(self):
|
||||
if self._value:
|
||||
return self._value.strftime("%m%Y")
|
||||
return '0' * self.max_length
|
||||
return bytes(self._value.strftime("%m%Y"), 'ascii')
|
||||
return b'0' * self.max_length
|
||||
|
||||
def parse(self, s):
|
||||
if int(s) > 0:
|
||||
|
|
|
@ -57,7 +57,7 @@ class Model(object):
|
|||
custom_validator(f)
|
||||
|
||||
def output(self):
|
||||
result = ''.join([field.get_data() for field in self.get_sorted_fields()])
|
||||
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
|
||||
|
||||
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
||||
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue