Skip to content

Commit

Permalink
avro: logical type encoding in unions
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Mar 2, 2021
1 parent 4bbeb2d commit 6f942d7
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 33 deletions.
58 changes: 57 additions & 1 deletion codec_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder {
default:
break
}
case reflect.Ptr:
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()

ls := getLogicalSchema(schema)
if ls == nil {
break
}
if elemType.RType() != ratRType || schema.Type() != Bytes || ls.Type() != Decimal {
break
}
dec := ls.(*DecimalLogicalSchema)

return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
}

return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())}
Expand Down Expand Up @@ -203,6 +217,12 @@ func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder {
case typ.RType() == timeRType && st == Long && lt == TimestampMicros:
return &timestampMicrosCodec{}

case typ.RType() == ratRType && st != Bytes || lt == Decimal:
ls := getLogicalSchema(schema)
dec := ls.(*DecimalLogicalSchema)

return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}

default:
break
}
Expand All @@ -220,7 +240,7 @@ func createEncoderOfNative(schema Schema, typ reflect2.Type) ValEncoder {
}
dec := ls.(*DecimalLogicalSchema)

return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()}
return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()}
}

if schema.Type() == Null {
Expand Down Expand Up @@ -455,6 +475,42 @@ func ratFromBytes(b []byte, scale int) *big.Rat {
}

func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) {
r := (*big.Rat)(ptr)
i := (&big.Int{}).Mul(r.Num(), big.NewInt(int64(math.Pow10(c.scale))))
i = i.Div(i, r.Denom())

var b []byte
switch i.Sign() {
case 0:
b = []byte{0}

case 1:
b = i.Bytes()
if b[0]&0x80 > 0 {
b = append([]byte{0}, b...)
}

case -1:
length := uint(i.BitLen()/8+1) * 8
b = i.Add(i, (&big.Int{}).Lsh(one, length)).Bytes()
}
w.WriteBytes(b)
}

type bytesDecimalPtrCodec struct {
prec int
scale int
}

func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) {
b := r.ReadBytes()
if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 {
i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8))
}
*((**big.Rat)(ptr)) = ratFromBytes(b, c.scale)
}

func (c *bytesDecimalPtrCodec) Encode(ptr unsafe.Pointer, w *Writer) {
r := *((**big.Rat)(ptr))
i := (&big.Int{}).Mul(r.Num(), big.NewInt(int64(math.Pow10(c.scale))))
i = i.Div(i, r.Denom())
Expand Down
18 changes: 12 additions & 6 deletions codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,24 @@ func unionResolutionName(schema Schema) string {
func encoderOfResolverUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
union := schema.(*UnionSchema)

name, err := cfg.resolver.Name(typ)
names, err := cfg.resolver.Name(typ)
if err != nil {
return &errorEncoder{err: err}
}

if idx := strings.Index(name, ":"); idx > 0 {
name = name[:idx]
}
var pos int
for _, name := range names {
if idx := strings.Index(name, ":"); idx > 0 {
name = name[:idx]
}

schema, pos := union.Types().Get(name)
schema, pos = union.Types().Get(name)
if schema != nil {
break
}
}
if schema == nil {
return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", name)}
return &errorEncoder{err: fmt.Errorf("avro: unknown union type %s", names[0])}
}

encoder := encoderOfType(cfg, schema, typ)
Expand Down
75 changes: 65 additions & 10 deletions decoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,48 @@ func TestDecoder_UnionMapNull(t *testing.T) {
assert.Equal(t, map[string]interface{}(nil), got)
}

func TestDecoder_UnionMapWithTime(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05}
schema := `["null", {"type": "long", "logicalType": "timestamp-micros"}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), got["long.timestamp-micros"])
}

func TestDecoder_UnionMapWithDuration(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0xAA, 0xB4, 0xDE, 0x75}
schema := `["null", {"type": "int", "logicalType": "time-millis"}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, 123456789*time.Millisecond, got["int.time-millis"])
}

func TestDecoder_UnionMapWithDecimal(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x6, 0x00, 0x87, 0x78}
schema := `["null", {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, big.NewRat(1734, 5), got["bytes.decimal"])
}

func TestDecoder_UnionMapInvalidSchema(t *testing.T) {
defer ConfigTeardown()

Expand Down Expand Up @@ -412,43 +454,56 @@ func TestDecoder_UnionInterfaceWithTime(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x80, 0xCD, 0xB7, 0xA2, 0xEE, 0xC7, 0xCD, 0x05}
schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": ["null", {"type": "long", "logicalType": "timestamp-micros"}]}]}`
schema := `["null", {"type": "long", "logicalType": "timestamp-micros"}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
var got interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), got["a"])
assert.Equal(t, time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC), got)
}

func TestDecoder_UnionInterfaceWithDuration(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0xAA, 0xB4, 0xDE, 0x75}
schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": ["null", {"type": "int", "logicalType": "time-millis"}]}]}`
schema := `["null", {"type": "int", "logicalType": "time-millis"}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
var got interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, 123456789*time.Millisecond, got["a"])
assert.Equal(t, 123456789*time.Millisecond, got)
}

func TestDecoder_UnionInterfaceWithDecimal(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x6, 0x00, 0x87, 0x78}
schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": ["null", {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}]}]}`
schema := `["null", {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got map[string]interface{}
var got interface{}
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, big.NewRat(1734, 5), got)
}

func TestDecoder_UnionInterfaceWithDecimal_Negative(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x02, 0x6, 0xFF, 0x78, 0x88}
schema := `["null", {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got interface{}
err := dec.Decode(&got)
expected := big.NewRat(1734, 5)

assert.NoError(t, err)
assert.Equal(t, *expected, got["a"])
assert.Equal(t, big.NewRat(-1734, 5), got)
}

func TestDecoder_UnionInterfaceUnresolvableTypeWithError(t *testing.T) {
Expand Down
42 changes: 42 additions & 0 deletions encoder_native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,48 @@ func TestEncoder_BytesRat_Zero(t *testing.T) {
assert.Equal(t, []byte{0x02, 0x00}, buf.Bytes())
}

func TestEncoder_BytesRatNonPtr_Positive(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(*big.NewRat(1734, 5))

assert.NoError(t, err)
assert.Equal(t, []byte{0x6, 0x00, 0x87, 0x78}, buf.Bytes())
}

func TestEncoder_BytesRatNonPtr_Negative(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(*big.NewRat(-1734, 5))

assert.NoError(t, err)
assert.Equal(t, []byte{0x6, 0xFF, 0x78, 0x88}, buf.Bytes())
}

func TestEncoder_BytesRatNonPtr_Zero(t *testing.T) {
defer ConfigTeardown()

schema := `{"type":"bytes","logicalType":"decimal","precision":4,"scale":2}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(*big.NewRat(0, 1))

assert.NoError(t, err)
assert.Equal(t, []byte{0x02, 0x00}, buf.Bytes())
}

func TestEncoder_BytesRatInvalidSchema(t *testing.T) {
defer ConfigTeardown()

Expand Down
Loading

0 comments on commit 6f942d7

Please sign in to comment.