From c8422a912f398558ee16bfcc707719c8aad7aea7 Mon Sep 17 00:00:00 2001 From: Guillaume Ballet <3272758+gballet@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:12:37 +0200 Subject: [PATCH] Add support for uint8 en/decoding --- cmd/sszgen/opset.go | 10 ++++++++++ codec.go | 13 +++++++++++++ decoder.go | 19 +++++++++++++++++++ encoder.go | 14 ++++++++++++++ hasher.go | 7 +++++++ 5 files changed, 63 insertions(+) diff --git a/cmd/sszgen/opset.go b/cmd/sszgen/opset.go index 3985491..aee9139 100644 --- a/cmd/sszgen/opset.go +++ b/cmd/sszgen/opset.go @@ -63,6 +63,16 @@ func (p *parseContext) resolveBasicOpset(typ *types.Basic, tags *sizeTag) (opset "DecodeBool({{.Codec}}, &{{.Field}})", []int{1}, }, nil + case types.Uint8: + if tags != nil && tags.size[0] != 1 { + return nil, fmt.Errorf("byte basic type requires ssz-size=1: have %d", tags.size[0]) + } + return &opsetStatic{ + "DefineUint8({{.Codec}}, &{{.Field}})", + "EncodeUint8({{.Codec}}, &{{.Field}})", + "DecodeUint8({{.Codec}}, &{{.Field}})", + []int{1}, + }, nil case types.Uint64: if tags != nil && tags.size[0] != 8 { return nil, fmt.Errorf("uint64 basic type requires ssz-size=8: have %d", tags.size[0]) diff --git a/codec.go b/codec.go index f087888..c4dff9b 100644 --- a/codec.go +++ b/codec.go @@ -66,6 +66,19 @@ func DefineBool[T ~bool](c *Codec, v *T) { HashBool(c.has, *v) } +// DefineUint8 defines the next field as a uint64. +func DefineUint8[T ~uint8](c *Codec, n *T) { + if c.enc != nil { + EncodeUint8(c.enc, *n) + return + } + if c.dec != nil { + DecodeUint8(c.dec, n) + return + } + HashUint8(c.has, *n) +} + // DefineUint64 defines the next field as a uint64. func DefineUint64[T ~uint64](c *Codec, n *T) { if c.enc != nil { diff --git a/decoder.go b/decoder.go index aa21291..94257e2 100644 --- a/decoder.go +++ b/decoder.go @@ -109,6 +109,25 @@ func DecodeBool[T ~bool](dec *Decoder, v *T) { } } +// DecodeUint8 parses a uint8. +func DecodeUint8[T ~uint8](dec *Decoder, n *T) { + if dec.err != nil { + return + } + if dec.inReader != nil { + _, dec.err = io.ReadFull(dec.inReader, dec.buf[:1]) + *n = T(dec.buf[0]) + dec.inRead += 1 + } else { + if len(dec.inBuffer) < 8 { + dec.err = io.ErrUnexpectedEOF + return + } + *n = T(dec.inBuffer[0]) + dec.inBuffer = dec.inBuffer[1:] + } +} + // DecodeUint64 parses a uint64. func DecodeUint64[T ~uint64](dec *Decoder, n *T) { if dec.err != nil { diff --git a/encoder.go b/encoder.go index 2fadebb..ff48eea 100644 --- a/encoder.go +++ b/encoder.go @@ -99,6 +99,20 @@ func EncodeBool[T ~bool](enc *Encoder, v T) { } } +// EncodeUint8 serializes a uint8. +func EncodeUint8[T ~uint8](enc *Encoder, n T) { + if enc.outWriter != nil { + if enc.err != nil { + return + } + enc.buf[0] = byte(n) + _, enc.err = enc.outWriter.Write(enc.buf[:1]) + } else { + enc.outBuffer[0] = byte(n) + enc.outBuffer = enc.outBuffer[1:] + } +} + // EncodeUint64 serializes a uint64. func EncodeUint64[T ~uint64](enc *Encoder, n T) { // Nope, dive into actual encoding diff --git a/hasher.go b/hasher.go index 5225dec..945fdab 100644 --- a/hasher.go +++ b/hasher.go @@ -73,6 +73,13 @@ func HashBool[T ~bool](h *Hasher, v T) { } } +// HashUint8 hashes a uint8. +func HashUint8[T ~uint8](h *Hasher, n T) { + var buffer [32]byte + buffer[0] = uint8(n) + h.insertChunk(buffer, 0) +} + // HashUint64 hashes a uint64. func HashUint64[T ~uint64](h *Hasher, n T) { var buffer [32]byte