use BytesIO to work with python3
This commit is contained in:
parent
16bf2c41d0
commit
9320c68961
3 changed files with 14 additions and 13 deletions
|
@ -37,7 +37,7 @@ def test_dump():
|
||||||
record.EmployerRecord(),
|
record.EmployerRecord(),
|
||||||
record.EmployeeWageRecord(),
|
record.EmployeeWageRecord(),
|
||||||
]
|
]
|
||||||
out = io.StringIO()
|
out = io.BytesIO()
|
||||||
dump(records, out)
|
dump(records, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ def load(fp):
|
||||||
|
|
||||||
def loads(s):
|
def loads(s):
|
||||||
import io
|
import io
|
||||||
fp = io.StringIO(s)
|
fp = io.BytesIO(s)
|
||||||
return load(fp)
|
return load(fp)
|
||||||
|
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ def dump(records, fp):
|
||||||
|
|
||||||
def dumps(records):
|
def dumps(records):
|
||||||
import io
|
import io
|
||||||
fp = io.StringIO()
|
fp = io.BytesIO()
|
||||||
dump(records, fp)
|
dump(records, fp)
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
return fp.read()
|
return fp.read()
|
||||||
|
|
|
@ -148,7 +148,8 @@ class StateField(TextField):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
value = self.value or ""
|
value = self.value or ""
|
||||||
if value.strip() and self.use_numeric:
|
if value.strip() and self.use_numeric:
|
||||||
return str(enums.state_postal_numeric[value.upper()]).zfill(self.max_length)
|
postcode = bytes(str(enums.state_postal_numeric[value.upper()]), 'ascii')
|
||||||
|
return postcode.zfill(self.max_length)
|
||||||
else:
|
else:
|
||||||
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
return value.ljust(self.max_length).encode('ascii')[:self.max_length]
|
||||||
|
|
||||||
|
@ -180,8 +181,8 @@ class IntegerField(TextField):
|
||||||
|
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
value = self.value or ""
|
value = bytes(str(self.value), 'ascii') if self.value else b''
|
||||||
return str(value).zfill(self.max_length)[:self.max_length]
|
return value.zfill(self.max_length)[:self.max_length]
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
self.value = int(s)
|
self.value = int(s)
|
||||||
|
@ -201,7 +202,7 @@ class BlankField(TextField):
|
||||||
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
|
super(TextField, self).__init__(name=name, max_length=max_length, required=required, uppercase=False)
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return " " * self.max_length
|
return b' ' * self.max_length
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
pass
|
pass
|
||||||
|
@ -233,7 +234,7 @@ class BooleanField(Field):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
return '1' if self._value else '0'
|
return b'1' if self._value else b'0'
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
self.value = (s == '1')
|
self.value = (s == '1')
|
||||||
|
@ -264,8 +265,8 @@ class DateField(TextField):
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
if self._value:
|
if self._value:
|
||||||
return self._value.strftime('%m%d%Y')
|
return bytes(self._value.strftime('%m%d%Y'), 'ascii')
|
||||||
return '0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
if int(s) > 0:
|
if int(s) > 0:
|
||||||
|
@ -296,8 +297,8 @@ class MonthYearField(TextField):
|
||||||
|
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
if self._value:
|
if self._value:
|
||||||
return self._value.strftime("%m%Y")
|
return bytes(self._value.strftime("%m%Y"), 'ascii')
|
||||||
return '0' * self.max_length
|
return b'0' * self.max_length
|
||||||
|
|
||||||
def parse(self, s):
|
def parse(self, s):
|
||||||
if int(s) > 0:
|
if int(s) > 0:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Model(object):
|
||||||
custom_validator(f)
|
custom_validator(f)
|
||||||
|
|
||||||
def output(self):
|
def output(self):
|
||||||
result = ''.join([field.get_data() for field in self.get_sorted_fields()])
|
result = b''.join([field.get_data() for field in self.get_sorted_fields()])
|
||||||
|
|
||||||
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
if hasattr(self, 'record_length') and len(result) != self.record_length:
|
||||||
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
raise ValidationError("Record result length not equal to %d bytes (%d)" % (self.record_length, len(result)))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue