From 00c65539035953a79fafff90b1dbeb0f591e0fb2 Mon Sep 17 00:00:00 2001 From: Lixin Yu Date: Thu, 24 Mar 2016 13:40:52 +0800 Subject: [PATCH] add decode_response option for protocol, fix #190 --- tests/test_protocol_binary.py | 7 +++++ tests/test_protocol_compact.py | 8 ++++++ tests/test_protocol_cybinary.py | 7 +++++ thriftpy/protocol/binary.py | 36 ++++++++++++++++--------- thriftpy/protocol/compact.py | 20 +++++++++----- thriftpy/protocol/cybin/cybin.pyx | 44 +++++++++++++++++++------------ 6 files changed, 86 insertions(+), 36 deletions(-) diff --git a/tests/test_protocol_binary.py b/tests/test_protocol_binary.py index 81cdcb1..beb03cc 100644 --- a/tests/test_protocol_binary.py +++ b/tests/test_protocol_binary.py @@ -89,6 +89,13 @@ def test_unpack_string(): assert u("你好世界") == proto.read_val(b, TType.STRING) +def test_unpack_binary(): + bs = BytesIO(b"\x00\x00\x00\x0c" + b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c") + assert u("你好世界").encode("utf-8") == proto.read_val( + bs, TType.STRING, decode_response=False) + + def test_write_message_begin(): b = BytesIO() proto.TBinaryProtocol(b).write_message_begin("test", TType.STRING, 1) diff --git a/tests/test_protocol_compact.py b/tests/test_protocol_compact.py index 6ea62ea..21ae4a0 100644 --- a/tests/test_protocol_compact.py +++ b/tests/test_protocol_compact.py @@ -107,6 +107,14 @@ def test_unpack_string(): assert u('你好世界') == proto.read_val(TType.STRING) +def test_unpack_binary(): + b, proto = gen_proto(b'\x0c\xe4\xbd\xa0\xe5\xa5' + b'\xbd\xe4\xb8\x96\xe7\x95\x8c') + proto.decode_response = False + + assert u('你好世界').encode("utf-8") == proto.read_val(TType.STRING) + + def test_pack_bool(): b, proto = gen_proto() proto.write_bool(True) diff --git a/tests/test_protocol_cybinary.py b/tests/test_protocol_cybinary.py index 7d6da11..3f5a2ac 100644 --- a/tests/test_protocol_cybinary.py +++ b/tests/test_protocol_cybinary.py @@ -140,6 +140,13 @@ def test_read_string(): assert u("你好世界") == proto.read_val(b, TType.STRING) +def test_read_binary(): + b = TCyMemoryBuffer(b"\x00\x00\x00\x0c" + b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c") + assert u("你好世界").encode("utf-8") == proto.read_val( + b, TType.STRING, decode_response=False) + + def test_write_message_begin(): trans = TCyMemoryBuffer() b = proto.TCyBinaryProtocol(trans) diff --git a/thriftpy/protocol/binary.py b/thriftpy/protocol/binary.py index 5efbb47..406cb0b 100644 --- a/thriftpy/protocol/binary.py +++ b/thriftpy/protocol/binary.py @@ -205,7 +205,7 @@ def read_map_begin(inbuf): return k_type, v_type, sz -def read_val(inbuf, ttype, spec=None): +def read_val(inbuf, ttype, spec=None, decode_response=True): if ttype == TType.BOOL: return bool(unpack_i8(inbuf.read(1))) @@ -227,11 +227,15 @@ def read_val(inbuf, ttype, spec=None): elif ttype == TType.STRING: sz = unpack_i32(inbuf.read(4)) byte_payload = inbuf.read(sz) - # Since we cannot tell if we're getting STRING or BINARY, try both - try: - return byte_payload.decode('utf-8') - except UnicodeDecodeError: - return byte_payload + + # Since we cannot tell if we're getting STRING or BINARY + # if not asked not to decode, try both + if decode_response: + try: + return byte_payload.decode('utf-8') + except UnicodeDecodeError: + pass + return byte_payload elif ttype == TType.SET or ttype == TType.LIST: if isinstance(spec, tuple): @@ -285,7 +289,7 @@ def read_val(inbuf, ttype, spec=None): return obj -def read_struct(inbuf, obj): +def read_struct(inbuf, obj, decode_response=True): while True: f_type, fid = read_field_begin(inbuf) if f_type == TType.STOP: @@ -307,7 +311,8 @@ def read_struct(inbuf, obj): skip(inbuf, f_type) continue - setattr(obj, f_name, read_val(inbuf, f_type, f_container_spec)) + setattr(obj, f_name, + read_val(inbuf, f_type, f_container_spec, decode_response)) def skip(inbuf, ftype): @@ -351,10 +356,13 @@ def skip(inbuf, ftype): class TBinaryProtocol(object): """Binary implementation of the Thrift protocol driver.""" - def __init__(self, trans, strict_read=True, strict_write=True): + def __init__(self, trans, + strict_read=True, strict_write=True, + decode_response=True): self.trans = trans self.strict_read = strict_read self.strict_write = strict_write + self.decode_response = decode_response def skip(self, ttype): skip(self.trans, ttype) @@ -375,16 +383,20 @@ def write_message_end(self): pass def read_struct(self, obj): - return read_struct(self.trans, obj) + return read_struct(self.trans, obj, self.decode_response) def write_struct(self, obj): write_val(self.trans, TType.STRUCT, obj) class TBinaryProtocolFactory(object): - def __init__(self, strict_read=True, strict_write=True): + def __init__(self,strict_read=True, strict_write=True, + decode_response=True): self.strict_read = strict_read self.strict_write = strict_write + self.decode_response = decode_response def get_protocol(self, trans): - return TBinaryProtocol(trans, self.strict_read, self.strict_write) + return TBinaryProtocol(trans, + self.strict_read, self.strict_write, + self.decode_response) diff --git a/thriftpy/protocol/compact.py b/thriftpy/protocol/compact.py index b4f0056..9f49e05 100644 --- a/thriftpy/protocol/compact.py +++ b/thriftpy/protocol/compact.py @@ -122,12 +122,13 @@ class TCompactProtocol(object): TYPE_BITS = 0x07 TYPE_SHIFT_AMOUNT = 5 - def __init__(self, trans): + def __init__(self, trans, decode_response=True): self.trans = trans self._last_fid = 0 self._bool_fid = None self._bool_value = None self._structs = [] + self.decode_response = decode_response def _get_ttype(self, byte): return TTYPES[byte & 0x0f] @@ -227,12 +228,14 @@ def read_double(self): def read_string(self): len = self._read_size() - byte_payload = self.trans.read(len) - try: - return byte_payload.decode('utf-8') - except UnicodeDecodeError: - return byte_payload + + if self.decode_response: + try: + byte_payload = byte_payload.decode('utf-8') + except UnicodeDecodeError: + pass + return byte_payload def read_bool(self): if self._bool_value is not None: @@ -556,5 +559,8 @@ def skip(self, ttype): class TCompactProtocolFactory(object): + def __init__(self, decode_response=True): + self.decode_response = decode_response + def get_protocol(self, trans): - return TCompactProtocol(trans) + return TCompactProtocol(trans, decode_response=self.decode_response) diff --git a/thriftpy/protocol/cybin/cybin.pyx b/thriftpy/protocol/cybin/cybin.pyx index c171dea..b239bd2 100644 --- a/thriftpy/protocol/cybin/cybin.pyx +++ b/thriftpy/protocol/cybin/cybin.pyx @@ -153,7 +153,7 @@ cdef inline write_dict(CyTransportBase buf, object val, spec): c_write_val(buf, v_type, v, v_spec) -cdef inline read_struct(CyTransportBase buf, obj): +cdef inline read_struct(CyTransportBase buf, obj, decode_response=True): cdef dict field_specs = obj.thrift_spec cdef int fid cdef TType field_type, ttype @@ -182,7 +182,7 @@ cdef inline read_struct(CyTransportBase buf, obj): else: spec = field_spec[2] - setattr(obj, name, c_read_val(buf, ttype, spec)) + setattr(obj, name, c_read_val(buf, ttype, spec, decode_response)) return obj @@ -217,7 +217,7 @@ cdef inline write_struct(CyTransportBase buf, obj): write_i08(buf, T_STOP) -cdef inline c_read_string(CyTransportBase buf, int32_t size): +cdef inline c_read_binary(CyTransportBase buf, int32_t size): cdef char string_val[STACK_STRING_LEN] if size > STACK_STRING_LEN: @@ -229,13 +229,15 @@ cdef inline c_read_string(CyTransportBase buf, int32_t size): buf.c_read(size, string_val) py_data = string_val[:size] - try: - return py_data.decode("utf-8") - except UnicodeDecodeError: - return py_data + return py_data + +cdef inline c_read_string(CyTransportBase buf, int32_t size): + return c_read_binary(buf, size).decode("utf-8") -cdef c_read_val(CyTransportBase buf, TType ttype, spec=None): + +cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, + decode_response=True): cdef int size cdef int64_t n cdef TType v_type, k_type, orig_type, orig_key_type @@ -261,7 +263,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None): elif ttype == T_STRING: size = read_i32(buf) - return c_read_string(buf, size) + if decode_response: + return c_read_string(buf, size) + else: + return c_read_binary(buf, size) elif ttype == T_SET or ttype == T_LIST: if isinstance(spec, int): @@ -343,7 +348,6 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None): write_string(buf, val) elif ttype == T_SET or ttype == T_LIST: - assert not isinstance(val, basestring) write_list(buf, val, spec) elif ttype == T_MAP: @@ -367,7 +371,7 @@ cpdef skip(CyTransportBase buf, TType ttype): read_i64(buf) elif ttype == T_STRING: size = read_i32(buf) - c_read_string(buf, size) + c_read_binary(buf, size) elif ttype == T_SET or ttype == T_LIST: v_type = read_i08(buf) size = read_i32(buf) @@ -389,8 +393,8 @@ cpdef skip(CyTransportBase buf, TType ttype): skip(buf, f_type) -def read_val(CyTransportBase buf, TType ttype): - return c_read_val(buf, ttype) +def read_val(CyTransportBase buf, TType ttype, decode_response=True): + return c_read_val(buf, ttype, None, decode_response) def write_val(CyTransportBase buf, TType ttype, val, spec=None): @@ -401,11 +405,14 @@ cdef class TCyBinaryProtocol(object): cdef public CyTransportBase trans cdef public bool strict_read cdef public bool strict_write + cdef public bool decode_response - def __init__(self, trans, strict_read=True, strict_write=True): + def __init__(self, trans, strict_read=True, strict_write=True, + decode_response=True): self.trans = trans self.strict_read = strict_read self.strict_write = strict_write + self.decode_response = decode_response def skip(self, ttype): skip(self.trans, (ttype)) @@ -452,7 +459,7 @@ cdef class TCyBinaryProtocol(object): def read_struct(self, obj): try: - return read_struct(self.trans, obj) + return read_struct(self.trans, obj, self.decode_response) except Exception: self.trans.clean() raise @@ -466,9 +473,12 @@ cdef class TCyBinaryProtocol(object): class TCyBinaryProtocolFactory(object): - def __init__(self, strict_read=True, strict_write=True): + def __init__(self, strict_read=True, strict_write=True, + decode_response=True): self.strict_read = strict_read self.strict_write = strict_write + self.decode_response = decode_response def get_protocol(self, trans): - return TCyBinaryProtocol(trans, self.strict_read, self.strict_write) + return TCyBinaryProtocol( + trans, self.strict_read, self.strict_write, self.decode_response)