Skip to content

Commit

Permalink
[common] Add unit test for convertions (#1395)
Browse files Browse the repository at this point in the history
* [common] Add unit test for convertions
  • Loading branch information
3vilhamster authored Nov 8, 2024
1 parent 2d4612d commit 7a3beaa
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 50 deletions.
27 changes: 16 additions & 11 deletions internal/common/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,55 +38,60 @@ func Int64Ceil(v float64) int64 {

// Int32Ptr makes a copy and returns the pointer to an int32.
func Int32Ptr(v int32) *int32 {
return &v
return PtrOf(v)
}

// Float64Ptr makes a copy and returns the pointer to a float64.
func Float64Ptr(v float64) *float64 {
return &v
return PtrOf(v)
}

// Int64Ptr makes a copy and returns the pointer to an int64.
func Int64Ptr(v int64) *int64 {
return &v
return PtrOf(v)
}

// StringPtr makes a copy and returns the pointer to a string.
func StringPtr(v string) *string {
return &v
return PtrOf(v)
}

// BoolPtr makes a copy and returns the pointer to a string.
func BoolPtr(v bool) *bool {
return &v
return PtrOf(v)
}

// TaskListPtr makes a copy and returns the pointer to a TaskList.
func TaskListPtr(v s.TaskList) *s.TaskList {
return &v
return PtrOf(v)
}

// DecisionTypePtr makes a copy and returns the pointer to a DecisionType.
func DecisionTypePtr(t s.DecisionType) *s.DecisionType {
return &t
return PtrOf(t)
}

// EventTypePtr makes a copy and returns the pointer to a EventType.
func EventTypePtr(t s.EventType) *s.EventType {
return &t
return PtrOf(t)
}

// QueryTaskCompletedTypePtr makes a copy and returns the pointer to a QueryTaskCompletedType.
func QueryTaskCompletedTypePtr(t s.QueryTaskCompletedType) *s.QueryTaskCompletedType {
return &t
return PtrOf(t)
}

// TaskListKindPtr makes a copy and returns the pointer to a TaskListKind.
func TaskListKindPtr(t s.TaskListKind) *s.TaskListKind {
return &t
return PtrOf(t)
}

// QueryResultTypePtr makes a copy and returns the pointer to a QueryResultType.
func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType {
return &t
return PtrOf(t)
}

// PtrOf makes a copy and returns the pointer to a value.
func PtrOf[T any](v T) *T {
return &v
}
57 changes: 57 additions & 0 deletions internal/common/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2017-2021 Uber Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package common

import (
"testing"

s "go.uber.org/cadence/.gen/go/shared"

"github.com/stretchr/testify/assert"
)

func TestPtrOf(t *testing.T) {
assert.Equal(t, "a", *PtrOf("a"))
assert.Equal(t, 1, *PtrOf(1))
assert.Equal(t, int32(1), *PtrOf(int32(1)))
assert.Equal(t, int64(1), *PtrOf(int64(1)))
assert.Equal(t, float64(1.1), *PtrOf(float64(1.1)))
assert.Equal(t, true, *PtrOf(true))
}

func TestPtrHelpers(t *testing.T) {
assert.Equal(t, int32(1), *Int32Ptr(1))
assert.Equal(t, int64(1), *Int64Ptr(1))
assert.Equal(t, 1.1, *Float64Ptr(1.1))
assert.Equal(t, true, *BoolPtr(true))
assert.Equal(t, "a", *StringPtr("a"))
assert.Equal(t, s.TaskList{Name: PtrOf("a")}, *TaskListPtr(s.TaskList{Name: PtrOf("a")}))
assert.Equal(t, s.DecisionTypeScheduleActivityTask, *DecisionTypePtr(s.DecisionTypeScheduleActivityTask))
assert.Equal(t, s.EventTypeWorkflowExecutionStarted, *EventTypePtr(s.EventTypeWorkflowExecutionStarted))
assert.Equal(t, s.QueryTaskCompletedTypeCompleted, *QueryTaskCompletedTypePtr(s.QueryTaskCompletedTypeCompleted))
assert.Equal(t, s.TaskListKindNormal, *TaskListKindPtr(s.TaskListKindNormal))
assert.Equal(t, s.QueryResultTypeFailed, *QueryResultTypePtr(s.QueryResultTypeFailed))
}

func TestCeilHelpers(t *testing.T) {
assert.Equal(t, int32(2), Int32Ceil(1.1))
assert.Equal(t, int64(2), Int64Ceil(1.1))
}
54 changes: 15 additions & 39 deletions internal/common/thrift_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,13 @@ import (
"github.com/apache/thrift/lib/go/thrift"
)

// TSerialize is used to serialize thrift TStruct to []byte
func TSerialize(ctx context.Context, t thrift.TStruct) (b []byte, err error) {
return thrift.NewTSerializer().Write(ctx, t)
}

// TListSerialize is used to serialize list of thrift TStruct to []byte
func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
func TListSerialize(ts []thrift.TStruct) ([]byte, error) {
if ts == nil {
return
return nil, nil
}

t := thrift.NewTSerializer()
t.Transport.Reset()

// NOTE: we don't write any markers as thrift by design being a streaming protocol doesn't
// recommend writing length.
Expand All @@ -48,26 +42,11 @@ func TListSerialize(ts []thrift.TStruct) (b []byte, err error) {
ctx := context.Background()
for _, v := range ts {
if e := v.Write(ctx, t.Protocol); e != nil {
err = thrift.PrependError("error writing TStruct: ", e)
return
return nil, thrift.PrependError("error writing TStruct: ", e)
}
}

if err = t.Protocol.Flush(ctx); err != nil {
return
}

if err = t.Transport.Flush(ctx); err != nil {
return
}

b = t.Transport.Bytes()
return
}

// TDeserialize is used to deserialize []byte to thrift TStruct
func TDeserialize(ctx context.Context, t thrift.TStruct, b []byte) (err error) {
return thrift.NewTDeserializer().Read(ctx, t, b)
return t.Transport.Bytes(), t.Protocol.Flush(ctx)
}

// TListDeserialize is used to deserialize []byte to list of thrift TStruct
Expand All @@ -92,15 +71,13 @@ func TListDeserialize(ts []thrift.TStruct, b []byte) (err error) {

// IsUseThriftEncoding checks if the objects passed in are all encoded using thrift.
func IsUseThriftEncoding(objs []interface{}) bool {
// NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use
// thrift encoder. For everything else we default to gob.

if len(objs) == 0 {
return false
}

for i := 0; i < len(objs); i++ {
if !IsThriftType(objs[i]) {
// NOTE: our criteria to use which encoder is simple if all the types are serializable using thrift then we use
// thrift encoder. For everything else we default to gob.
for _, obj := range objs {
if !IsThriftType(obj) {
return false
}
}
Expand All @@ -109,15 +86,13 @@ func IsUseThriftEncoding(objs []interface{}) bool {

// IsUseThriftDecoding checks if the objects passed in are all de-serializable using thrift.
func IsUseThriftDecoding(objs []interface{}) bool {
// NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use
// thrift decoder. For everything else we default to gob.

if len(objs) == 0 {
return false
}

for i := 0; i < len(objs); i++ {
rVal := reflect.ValueOf(objs[i])
// NOTE: our criteria to use which encoder is simple if all the types are de-serializable using thrift then we use
// thrift decoder. For everything else we default to gob.
for _, obj := range objs {
rVal := reflect.ValueOf(obj)
if rVal.Kind() != reflect.Ptr || !IsThriftType(reflect.Indirect(rVal).Interface()) {
return false
}
Expand All @@ -133,6 +108,7 @@ func IsThriftType(v interface{}) bool {
if reflect.ValueOf(v).Kind() != reflect.Ptr {
return false
}
t := reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
return reflect.TypeOf(v).Implements(t)
return reflect.TypeOf(v).Implements(tStructType)
}

var tStructType = reflect.TypeOf((*thrift.TStruct)(nil)).Elem()
150 changes: 150 additions & 0 deletions internal/common/thrift_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) 2017-2021 Uber Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package common

import (
"context"
"testing"

"github.com/apache/thrift/lib/go/thrift"
"github.com/stretchr/testify/assert"
)

func TestTListSerialize(t *testing.T) {
t.Run("nil", func(t *testing.T) {
data, err := TListSerialize(nil)
assert.NoError(t, err)
assert.Nil(t, data)
})
t.Run("normal", func(t *testing.T) {
ts := []thrift.TStruct{
&mockThriftStruct{Field1: "value1", Field2: 1},
&mockThriftStruct{Field1: "value2", Field2: 2},
}

_, err := TListSerialize(ts)
assert.NoError(t, err)
})
}

func TestTListDeserialize(t *testing.T) {
ts := []thrift.TStruct{
&mockThriftStruct{},
&mockThriftStruct{},
}

data, err := TListSerialize(ts)
assert.NoError(t, err)

err = TListDeserialize(ts, data)
assert.NoError(t, err)
}

func TestIsUseThriftEncoding(t *testing.T) {
for _, tc := range []struct {
name string
input []interface{}
expected bool
}{
{
name: "nil",
input: nil,
expected: false,
},
{
name: "success",
input: []interface{}{
&mockThriftStruct{},
&mockThriftStruct{},
},
expected: true,
},
{
name: "fail",
input: []interface{}{
&mockThriftStruct{},
PtrOf("string"),
},
expected: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, IsUseThriftEncoding(tc.input))
})
}
}

func TestIsUseThriftDecoding(t *testing.T) {
for _, tc := range []struct {
name string
input []interface{}
expected bool
}{
{
name: "nil",
input: nil,
expected: false,
},
{
name: "success",
input: []interface{}{
PtrOf(&mockThriftStruct{}),
PtrOf(&mockThriftStruct{}),
},
expected: true,
},
{
name: "fail",
input: []interface{}{
PtrOf(&mockThriftStruct{}),
PtrOf("string"),
},
expected: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, IsUseThriftDecoding(tc.input))
})
}
}

func TestIsThriftType(t *testing.T) {
assert.True(t, IsThriftType(&mockThriftStruct{}))

assert.False(t, IsThriftType(mockThriftStruct{}))
}

type mockThriftStruct struct {
Field1 string
Field2 int
}

func (m *mockThriftStruct) Read(ctx context.Context, iprot thrift.TProtocol) error {
return nil
}

func (m *mockThriftStruct) Write(ctx context.Context, oprot thrift.TProtocol) error {
return nil
}

func (m *mockThriftStruct) String() string {
return ""
}

0 comments on commit 7a3beaa

Please sign in to comment.