Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
minor code cleanups for protocol.compact
Browse files Browse the repository at this point in the history
  • Loading branch information
lxyu committed Mar 24, 2016
1 parent b5c9649 commit fc1f91b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 65 deletions.
119 changes: 55 additions & 64 deletions thriftpy/protocol/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from thriftpy._compat import PY3


CLEAR = 0
FIELD_WRITE = 1
VALUE_WRITE = 2
Expand Down Expand Up @@ -125,15 +124,15 @@ class TCompactProtocol(object):

def __init__(self, trans):
self.trans = trans
self.__last_fid = 0
self.__bool_fid = None
self.__bool_value = None
self.__structs = []
self._last_fid = 0
self._bool_fid = None
self._bool_value = None
self._structs = []

def __getTType(self, byte):
def _get_ttype(self, byte):
return TTYPES[byte & 0x0f]

def __read_size(self):
def _read_size(self):
result = read_varint(self.trans)
if result < 0:
raise TException("Length < 0")
Expand All @@ -155,55 +154,56 @@ def read_message_begin(self):
% (version, self.VERSION))
seqid = read_varint(self.trans)
name = self.read_string()
return (name, type, seqid)
return name, type, seqid

def read_message_end(self):
assert len(self.__structs) == 0
assert len(self._structs) == 0

def read_field_begin(self):
type = self.read_ubyte()
if type & 0x0f == TType.STOP:
return (None, 0, 0)
return None, 0, 0

delta = type >> 4
if delta == 0:
fid = from_zig_zag(read_varint(self.trans))
else:
fid = self.__last_fid + delta
self.__last_fid = fid
fid = self._last_fid + delta
self._last_fid = fid

type = type & 0x0f
if type == CompactType.TRUE:
self.__bool_value = True
self._bool_value = True
elif type == CompactType.FALSE:
self.__bool_value = False
else:
pass
return (None, self.__getTType(type), fid)
self._bool_value = False

return None, self._get_ttype(type), fid

def read_field_end(self):
pass

def read_struct_begin(self):
self.__structs.append(self.__last_fid)
self.__last_fid = 0
self._structs.append(self._last_fid)
self._last_fid = 0

def read_struct_end(self):
self.__last_fid = self.__structs.pop()
self._last_fid = self._structs.pop()

def read_map_begin(self):
size = self.__read_size()
size = self._read_size()
types = 0
if size > 0:
types = self.read_ubyte()
vtype = self.__getTType(types)
ktype = self.__getTType(types >> 4)
vtype = self._get_ttype(types)
ktype = self._get_ttype(types >> 4)
return (ktype, vtype, size)

def read_collection_begin(self):
size_type = self.read_ubyte()
size = size_type >> 4
type = self.__getTType(size_type)
type = self._get_ttype(size_type)
if size == 15:
size = self.__read_size()
size = self._read_size()
return type, size

def read_collection_end(self):
Expand All @@ -226,7 +226,7 @@ def read_double(self):
return val

def read_string(self):
len = self.__read_size()
len = self._read_size()

byte_payload = self.trans.read(len)
try:
Expand All @@ -235,16 +235,16 @@ def read_string(self):
return byte_payload

def read_bool(self):
if self.__bool_value is not None:
result = self.__bool_value
self.__bool_value = None
if self._bool_value is not None:
result = self._bool_value
self._bool_value = None
return result
return self.read_byte() == CompactType.TRUE

def read_struct(self, obj):
self.read_struct_begin()
while True:
(fname, ftype, fid) = self.read_field_begin()
fname, ftype, fid = self.read_field_begin()
if ftype == TType.STOP:
break

Expand Down Expand Up @@ -275,7 +275,7 @@ def read_val(self, ttype, spec=None):
elif ttype == TType.BYTE:
return self.read_byte()

elif ttype == TType.I16 or ttype == TType.I32 or ttype == TType.I64:
elif ttype in (TType.I16, TType.I32, TType.I64):
return self.read_int()

elif ttype == TType.DOUBLE:
Expand All @@ -284,7 +284,7 @@ def read_val(self, ttype, spec=None):
elif ttype == TType.STRING:
return self.read_string()

elif ttype == TType.LIST or ttype == TType.SET:
elif ttype in (TType.LIST, TType.SET):
if isinstance(spec, tuple):
v_type, v_spec = spec[0], spec[1]
else:
Expand Down Expand Up @@ -332,17 +332,17 @@ def read_val(self, ttype, spec=None):
self.read_struct(obj)
return obj

def __write_size(self, i32):
def _write_size(self, i32):
write_varint(self.trans, i32)

def __write_field_header(self, type, fid):
delta = fid - self.__last_fid
def _write_field_header(self, type, fid):
delta = fid - self._last_fid
if 0 < delta <= 15:
self.write_ubyte(delta << 4 | type)
else:
self.write_byte(type)
self.write_i16(fid)
self.__last_fid = fid
self._last_fid = fid

def write_message_begin(self, name, type, seqid):
self.write_ubyte(self.PROTOCOL_ID)
Expand All @@ -358,32 +358,32 @@ def write_field_stop(self):

def write_field_begin(self, name, type, fid):
if type == TType.BOOL:
self.__bool_fid = fid
self._bool_fid = fid
else:
self.__write_field_header(CTYPES[type], fid)
self._write_field_header(CTYPES[type], fid)

def write_field_end(self):
pass

def write_struct_begin(self, name):
self.__structs.append(self.__last_fid)
self.__last_fid = 0
def write_struct_begin(self):
self._structs.append(self._last_fid)
self._last_fid = 0

def write_struct_end(self):
self.__last_fid = self.__structs.pop()
self._last_fid = self._structs.pop()

def write_collection_begin(self, etype, size):
if size <= 14:
self.write_ubyte(size << 4 | CTYPES[etype])
else:
self.write_ubyte(0xf0 | CTYPES[etype])
self.__write_size(size)
self._write_size(size)

def write_map_begin(self, ktype, vtype, size):
if size == 0:
self.write_byte(0)
else:
self.__write_size(size)
self._write_size(size)
self.write_ubyte(CTYPES[ktype] << 4 | CTYPES[vtype])

def write_collection_end(self):
Expand All @@ -396,13 +396,13 @@ def write_byte(self, byte):
self.trans.write(pack('!b', byte))

def write_bool(self, bool):
if self.__bool_fid and self.__bool_fid > self.__last_fid \
and self.__bool_fid - self.__last_fid <= 15:
if self._bool_fid and self._bool_fid > self._last_fid \
and self._bool_fid - self._last_fid <= 15:
if bool:
ctype = CompactType.TRUE
else:
ctype = CompactType.FALSE
self.__write_field_header(ctype, self.__bool_fid)
self._write_field_header(ctype, self._bool_fid)
else:
if bool:
self.write_byte(CompactType.TRUE)
Expand All @@ -422,18 +422,15 @@ def write_double(self, dub):
self.trans.write(pack('<d', dub))

def write_string(self, s):
if PY3:
self.__write_size(len(bytearray(s, 'utf-8')))
else:
self.__write_size(len(s))
if not isinstance(s, bytes):
s = s.encode('utf-8')
self._write_size(len(s))
self.trans.write(s)

def write_struct(self, obj):
self.write_struct_begin(obj.__class__.__name__)
self.write_struct_begin()

for field in iter(obj.thrift_spec):
for field in obj.thrift_spec:
if field is None:
continue
fspec = obj.thrift_spec[field]
Expand Down Expand Up @@ -519,13 +516,7 @@ def skip(self, ttype):
elif ttype == TType.BYTE:
self.read_byte()

elif ttype == TType.I16:
from_zig_zag(read_varint(self.trans))

elif ttype == TType.I32:
from_zig_zag(read_varint(self.trans))

elif ttype == TType.I64:
elif ttype in (TType.I16, TType.I32, TType.I64):
from_zig_zag(read_varint(self.trans))

elif ttype == TType.DOUBLE:
Expand All @@ -545,25 +536,25 @@ def skip(self, ttype):
self.read_struct_end()

elif ttype == TType.MAP:
(ktype, vtype, size) = self.read_map_begin()
ktype, vtype, size = self.read_map_begin()
for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.read_collection_end()

elif ttype == TType.SET:
(etype, size) = self.read_collection_begin()
etype, size = self.read_collection_begin()
for i in range(size):
self.skip(etype)
self.read_collection_end()

elif ttype == TType.LIST:
(etype, size) = self.read_collection_begin()
etype, size = self.read_collection_begin()
for i in range(size):
self.skip(etype)
self.read_collection_end()


class TCompactProtocolFactory:
class TCompactProtocolFactory(object):
def get_protocol(self, trans):
return TCompactProtocol(trans)
4 changes: 3 additions & 1 deletion thriftpy/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def args2kwargs(thrift_spec, *args):

def parse_spec(ttype, spec=None):
name_map = TType._VALUES_TO_NAMES
_type = lambda s: parse_spec(*s) if isinstance(s, tuple) else name_map[s]

def _type(s):
return parse_spec(*s) if isinstance(s, tuple) else name_map[s]

if spec is None:
return name_map[ttype]
Expand Down

0 comments on commit fc1f91b

Please sign in to comment.