Skip to content

Commit

Permalink
[common] Add unit test for convertions
Browse files Browse the repository at this point in the history
  • Loading branch information
3vilhamster committed Nov 7, 2024
1 parent 2d4612d commit 246ded3
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 53 deletions.
27 changes: 16 additions & 11 deletions internal/common/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,55 +38,60 @@ func Int64Ceil(v float64) int64 {

// Int32Ptr makes a copy and returns the pointer to an int32.
func Int32Ptr(v int32) *int32 {
return &v
return PtrOf(v)
}

// Float64Ptr makes a copy and returns the pointer to a float64.
func Float64Ptr(v float64) *float64 {
return &v
return PtrOf(v)
}

// Int64Ptr makes a copy and returns the pointer to an int64.
func Int64Ptr(v int64) *int64 {
return &v
return PtrOf(v)
}

// StringPtr makes a copy and returns the pointer to a string.
func StringPtr(v string) *string {
return &v
return PtrOf(v)
}

// BoolPtr makes a copy and returns the pointer to a string.
func BoolPtr(v bool) *bool {
return &v
return PtrOf(v)
}

// TaskListPtr makes a copy and returns the pointer to a TaskList.
func TaskListPtr(v s.TaskList) *s.TaskList {
return &v
return PtrOf(v)
}

// DecisionTypePtr makes a copy and returns the pointer to a DecisionType.
func DecisionTypePtr(t s.DecisionType) *s.DecisionType {
return &t
return PtrOf(t)
}

// EventTypePtr makes a copy and returns the pointer to a EventType.
func EventTypePtr(t s.EventType) *s.EventType {
return &t
return PtrOf(t)
}

// QueryTaskCompletedTypePtr makes a copy and returns the pointer to a QueryTaskCompletedType.
func QueryTaskCompletedTypePtr(t s.QueryTaskCompletedType) *s.QueryTaskCompletedType {
return &t
return PtrOf(t)
}

// TaskListKindPtr makes a copy and returns the pointer to a TaskListKind.
func TaskListKindPtr(t s.TaskListKind) *s.TaskListKind {
return &t
return PtrOf(t)
}

// QueryResultTypePtr makes a copy and returns the pointer to a QueryResultType.
func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType {
return &t
return PtrOf(t)
}

// PtrOf makes a copy and returns the pointer to a value.
func PtrOf[T any](v T) *T {
return &v
}
36 changes: 36 additions & 0 deletions internal/common/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package common

import (
s "go.uber.org/cadence/.gen/go/shared"
"testing"

"github.com/stretchr/testify/assert"
)

func TestPtrOf(t *testing.T) {
assert.Equal(t, "a", *PtrOf("a"))
assert.Equal(t, 1, *PtrOf(1))
assert.Equal(t, int32(1), *PtrOf(int32(1)))
assert.Equal(t, int64(1), *PtrOf(int64(1)))
assert.Equal(t, float64(1.1), *PtrOf(float64(1.1)))
assert.Equal(t, true, *PtrOf(true))
}

func TestPtrHelpers(t *testing.T) {
assert.Equal(t, int32(1), *Int32Ptr(1))
assert.Equal(t, int64(1), *Int64Ptr(1))
assert.Equal(t, float64(1.1), *Float64Ptr(1.1))
assert.Equal(t, true, *BoolPtr(true))
assert.Equal(t, "a", *StringPtr("a"))
assert.Equal(t, s.TaskList{Name: PtrOf("a")}, *TaskListPtr(s.TaskList{Name: PtrOf("a")}))
assert.Equal(t, s.DecisionTypeScheduleActivityTask, *DecisionTypePtr(s.DecisionTypeScheduleActivityTask))
assert.Equal(t, s.EventTypeWorkflowExecutionStarted, *EventTypePtr(s.EventTypeWorkflowExecutionStarted))
assert.Equal(t, s.QueryTaskCompletedTypeCompleted, *QueryTaskCompletedTypePtr(s.QueryTaskCompletedTypeCompleted))
assert.Equal(t, s.TaskListKindNormal, *TaskListKindPtr(s.TaskListKindNormal))
assert.Equal(t, s.QueryResultTypeFailed, *QueryResultTypePtr(s.QueryResultTypeFailed))
}

func TestCeilHelpers(t *testing.T) {
assert.Equal(t, int32(2), Int32Ceil(1.1))
assert.Equal(t, int64(2), Int64Ceil(1.1))
}
54 changes: 12 additions & 42 deletions internal/common/thrift_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,13 @@ import (
"github.com/apache/thrift/lib/go/thrift"
)

// TSerialize is used to serialize thrift TStruct to []byte
func TSerialize(ctx context.Context, t thrift.TStruct) (b []byte, err error) {
return thrift.NewTSerializer().Write(ctx, t)
}

// TListSerialize is used to serialize list of thrift TStruct to []byte
func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
func TListSerialize(ts []thrift.TStruct) ([]byte, error) {
if ts == nil {
return
return nil, nil
}

t := thrift.NewTSerializer()
t.Transport.Reset()

// NOTE: we don't write any markers as thrift by design being a streaming protocol doesn't
// recommend writing length.
Expand All @@ -48,26 +42,11 @@ func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
ctx := context.Background()
for _, v := range ts {
if e := v.Write(ctx, t.Protocol); e != nil {
err = thrift.PrependError("error writing TStruct: ", e)
return
return nil, thrift.PrependError("error writing TStruct: ", e)
}
}

if err = t.Protocol.Flush(ctx); err != nil {
return
}

if err = t.Transport.Flush(ctx); err != nil {
return
}

b = t.Transport.Bytes()
return
}

// TDeserialize is used to deserialize []byte to thrift TStruct
func TDeserialize(ctx context.Context, t thrift.TStruct, b []byte) (err error) {
return thrift.NewTDeserializer().Read(ctx, t, b)
return t.Transport.Bytes(), t.Protocol.Flush(ctx)
}

// TListDeserialize is used to deserialize []byte to list of thrift TStruct
Expand All @@ -94,13 +73,8 @@ func TListDeserialize(ts []thrift.TStruct, b []byte) (err error) {
func IsUseThriftEncoding(objs []interface{}) bool {
// NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use
// thrift encoder. For everything else we default to gob.

if len(objs) == 0 {
return false
}

for i := 0; i < len(objs); i++ {
if !IsThriftType(objs[i]) {
for _, obj := range objs {
if !IsThriftType(obj) {
return false
}
}
Expand All @@ -111,14 +85,9 @@ func IsUseThriftEncoding(objs []interface{}) bool {
func IsUseThriftDecoding(objs []interface{}) bool {
// NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use
// thrift decoder. For everything else we default to gob.

if len(objs) == 0 {
return false
}

for i := 0; i < len(objs); i++ {
rVal := reflect.ValueOf(objs[i])
if rVal.Kind() != reflect.Ptr || !IsThriftType(reflect.Indirect(rVal).Interface()) {
for _, obj := range objs {
rVal := reflect.ValueOf(obj)
if rVal.Kind() != reflect.Ptr || !IsThriftType(obj) {
return false
}
}
Expand All @@ -133,6 +102,7 @@ func IsThriftType(v interface{}) bool {
if reflect.ValueOf(v).Kind() != reflect.Ptr {
return false
}
t := reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
return reflect.TypeOf(v).Implements(t)
return reflect.TypeOf(v).Implements(tStructType)
}

var tStructType = reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
96 changes: 96 additions & 0 deletions internal/common/thrift_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package common

import (
"context"
"testing"

"github.com/apache/thrift/lib/go/thrift"
"github.com/stretchr/testify/assert"
)

func TestTListSerialize(t *testing.T) {
t.Run("nil", func(t *testing.T) {
data, err := TListSerialize(nil)
assert.NoError(t, err)
assert.Nil(t, data)
})
t.Run("normal", func(t *testing.T) {
ts := []thrift.TStruct{
&mockThriftStruct{Field1: "value1", Field2: 1},
&mockThriftStruct{Field1: "value2", Field2: 2},
}

_, err := TListSerialize(ts)
assert.NoError(t, err)
})
}

func TestTListDeserialize(t *testing.T) {
ts := []thrift.TStruct{
&mockThriftStruct{},
&mockThriftStruct{},
}

data, err := TListSerialize(ts)
assert.NoError(t, err)

err = TListDeserialize(ts, data)
assert.NoError(t, err)
}

func TestIsUseThriftEncoding(t *testing.T) {
ts := []interface{}{
&mockThriftStruct{},
&mockThriftStruct{},
}

result := IsUseThriftEncoding(ts)
assert.True(t, result)

ts = []interface{}{
&mockThriftStruct{},
"string",
}

result = IsUseThriftEncoding(ts)
assert.False(t, result)
}

func TestIsUseThriftDecoding(t *testing.T) {
ts := []interface{}{
&mockThriftStruct{},
&mockThriftStruct{},
}

assert.True(t, IsUseThriftDecoding(ts))

ts = []interface{}{
&mockThriftStruct{},
"string",
}

assert.False(t, IsUseThriftDecoding(ts))
}

func TestIsThriftType(t *testing.T) {
assert.True(t, IsThriftType(&mockThriftStruct{}))

assert.False(t, IsThriftType(mockThriftStruct{}))
}

type mockThriftStruct struct {
Field1 string
Field2 int
}

func (m *mockThriftStruct) Read(ctx context.Context, iprot thrift.TProtocol) error {
return nil
}

func (m *mockThriftStruct) Write(ctx context.Context, oprot thrift.TProtocol) error {
return nil
}

func (m *mockThriftStruct) String() string {
return ""
}

0 comments on commit 246ded3

Please sign in to comment.