diff --git a/schema/decoding/decoding_test.go b/schema/decoding/decoding_test.go index d4308390d29c..5f6872af593a 100644 --- a/schema/decoding/decoding_test.go +++ b/schema/decoding/decoding_test.go @@ -366,8 +366,9 @@ func (e exampleBankModule) subBalance(acct, denom string, amount uint64) error { return nil } -var exampleBankSchema = schema.ModuleSchema{ - ObjectTypes: []schema.ObjectType{ +func init() { + var err error + exampleBankSchema, err = schema.NewModuleSchema([]schema.ObjectType{ { Name: "balances", KeyFields: []schema.Field{ @@ -387,9 +388,14 @@ var exampleBankSchema = schema.ModuleSchema{ }, }, }, - }, + }) + if err != nil { + panic(err) + } } +var exampleBankSchema schema.ModuleSchema + func (e exampleBankModule) ModuleCodec() (schema.ModuleCodec, error) { return schema.ModuleCodec{ Schema: exampleBankSchema, @@ -426,17 +432,25 @@ type oneValueModule struct { store *testStore } -var oneValueModSchema = schema.ModuleSchema{ - ObjectTypes: []schema.ObjectType{ - { - Name: "item", - ValueFields: []schema.Field{ - {Name: "value", Kind: schema.StringKind}, +func init() { + var err error + oneValueModSchema, err = schema.NewModuleSchema( + []schema.ObjectType{ + { + Name: "item", + ValueFields: []schema.Field{ + {Name: "value", Kind: schema.StringKind}, + }, }, }, - }, + ) + if err != nil { + panic(err) + } } +var oneValueModSchema schema.ModuleSchema + func (i oneValueModule) ModuleCodec() (schema.ModuleCodec, error) { return schema.ModuleCodec{ Schema: oneValueModSchema, diff --git a/schema/decoding/resolver_test.go b/schema/decoding/resolver_test.go index f5caf287ca5f..a5e9fa33c947 100644 --- a/schema/decoding/resolver_test.go +++ b/schema/decoding/resolver_test.go @@ -10,16 +10,24 @@ import ( type modA struct{} func (m modA) ModuleCodec() (schema.ModuleCodec, error) { + modSchema, err := schema.NewModuleSchema([]schema.ObjectType{{Name: "A", KeyFields: []schema.Field{{Name: "field1", Kind: schema.StringKind}}}}) + if err != nil { + return schema.ModuleCodec{}, err + } return schema.ModuleCodec{ - Schema: schema.ModuleSchema{ObjectTypes: []schema.ObjectType{{Name: "A"}}}, + Schema: modSchema, }, nil } type modB struct{} func (m modB) ModuleCodec() (schema.ModuleCodec, error) { + modSchema, err := schema.NewModuleSchema([]schema.ObjectType{{Name: "B", KeyFields: []schema.Field{{Name: "field2", Kind: schema.StringKind}}}}) + if err != nil { + return schema.ModuleCodec{}, err + } return schema.ModuleCodec{ - Schema: schema.ModuleSchema{ObjectTypes: []schema.ObjectType{{Name: "B"}}}, + Schema: modSchema, }, nil } @@ -36,7 +44,13 @@ var testResolver = ModuleSetDecoderResolver(moduleSet) func TestModuleSetDecoderResolver_IterateAll(t *testing.T) { objectTypes := map[string]bool{} err := testResolver.IterateAll(func(moduleName string, cdc schema.ModuleCodec) error { - objectTypes[cdc.Schema.ObjectTypes[0].Name] = true + cdc.Schema.Types(func(t schema.Type) bool { + objTyp, ok := t.(schema.ObjectType) + if ok { + objectTypes[objTyp.Name] = true + } + return true + }) return nil }) if err != nil { @@ -66,8 +80,9 @@ func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { t.Fatalf("expected to find decoder for modA") } - if decoder.Schema.ObjectTypes[0].Name != "A" { - t.Fatalf("expected object type A, got %s", decoder.Schema.ObjectTypes[0].Name) + _, ok := decoder.Schema.LookupType("A") + if !ok { + t.Fatalf("expected object type A") } decoder, found, err = testResolver.LookupDecoder("modB") @@ -79,8 +94,9 @@ func TestModuleSetDecoderResolver_LookupDecoder(t *testing.T) { t.Fatalf("expected to find decoder for modB") } - if decoder.Schema.ObjectTypes[0].Name != "B" { - t.Fatalf("expected object type B, got %s", decoder.Schema.ObjectTypes[0].Name) + _, ok = decoder.Schema.LookupType("B") + if !ok { + t.Fatalf("expected object type B") } decoder, found, err = testResolver.LookupDecoder("modC") diff --git a/schema/enum.go b/schema/enum.go index 927cc827cb3e..5afb0ecbd0b8 100644 --- a/schema/enum.go +++ b/schema/enum.go @@ -18,6 +18,8 @@ type EnumDefinition struct { Values []string } +func (EnumDefinition) isType() {} + // Validate validates the enum definition. func (e EnumDefinition) Validate() error { if !ValidateName(e.Name) { @@ -50,31 +52,3 @@ func (e EnumDefinition) ValidateValue(value string) error { } return fmt.Errorf("value %q is not a valid enum value for %s", value, e.Name) } - -// checkEnumCompatibility checks that the enum values are consistent across object types and fields. -func checkEnumCompatibility(enumValueMap map[string]map[string]bool, field Field) error { - if field.Kind != EnumKind { - return nil - } - - enum := field.EnumDefinition - - if existing, ok := enumValueMap[enum.Name]; ok { - if len(existing) != len(enum.Values) { - return fmt.Errorf("enum %q has different number of values in different object types", enum.Name) - } - - for _, value := range enum.Values { - if !existing[value] { - return fmt.Errorf("enum %q has different values in different object types", enum.Name) - } - } - } else { - valueMap := map[string]bool{} - for _, value := range enum.Values { - valueMap[value] = true - } - enumValueMap[enum.Name] = valueMap - } - return nil -} diff --git a/schema/module_schema.go b/schema/module_schema.go index 9412c4456cbe..b98744ea8142 100644 --- a/schema/module_schema.go +++ b/schema/module_schema.go @@ -1,18 +1,82 @@ package schema -import "fmt" +import ( + "fmt" + "sort" +) // ModuleSchema represents the logical schema of a module for purposes of indexing and querying. type ModuleSchema struct { - // ObjectTypes describe the types of objects that are part of the module's schema. - ObjectTypes []ObjectType + types map[string]Type +} + +// NewModuleSchema constructs a new ModuleSchema and validates it. Any module schema returned without an error +// is guaranteed to be valid. +func NewModuleSchema(objectTypes []ObjectType) (ModuleSchema, error) { + types := map[string]Type{} + + for _, objectType := range objectTypes { + types[objectType.Name] = objectType + } + + res := ModuleSchema{types: types} + + // validate adds all enum types to the type map + err := res.Validate() + if err != nil { + return ModuleSchema{}, err + } + + return res, nil +} + +func addEnumType(types map[string]Type, field Field) error { + enumDef := field.EnumDefinition + if enumDef.Name == "" { + return nil + } + + existing, ok := types[enumDef.Name] + if !ok { + types[enumDef.Name] = enumDef + return nil + } + + existingEnum, ok := existing.(EnumDefinition) + if !ok { + return fmt.Errorf("enum %q already exists as a different non-enum type", enumDef.Name) + } + + if len(existingEnum.Values) != len(enumDef.Values) { + return fmt.Errorf("enum %q has different number of values in different fields", enumDef.Name) + } + + existingValues := map[string]bool{} + for _, value := range existingEnum.Values { + existingValues[value] = true + } + + for _, value := range enumDef.Values { + _, ok := existingValues[value] + if !ok { + return fmt.Errorf("enum %q has different values in different fields", enumDef.Name) + } + } + + return nil } // Validate validates the module schema. func (s ModuleSchema) Validate() error { - enumValueMap := map[string]map[string]bool{} - for _, objType := range s.ObjectTypes { - if err := objType.validate(enumValueMap); err != nil { + for _, typ := range s.types { + objTyp, ok := typ.(ObjectType) + if !ok { + continue + } + + // all enum types get added to the type map when we call ObjectType.validate + err := objTyp.validate(s.types) + if err != nil { return err } } @@ -22,10 +86,36 @@ func (s ModuleSchema) Validate() error { // ValidateObjectUpdate validates that the update conforms to the module schema. func (s ModuleSchema) ValidateObjectUpdate(update ObjectUpdate) error { - for _, objType := range s.ObjectTypes { - if objType.Name == update.TypeName { - return objType.ValidateObjectUpdate(update) + typ, ok := s.types[update.TypeName] + if !ok { + return fmt.Errorf("object type %q not found in module schema", update.TypeName) + } + + objTyp, ok := typ.(ObjectType) + if !ok { + return fmt.Errorf("type %q is not an object type", update.TypeName) + } + + return objTyp.ValidateObjectUpdate(update) +} + +// LookupType looks up a type by name in the module schema. +func (s ModuleSchema) LookupType(name string) (Type, bool) { + typ, ok := s.types[name] + return typ, ok +} + +// Types calls the provided function for each type in the module schema and stops if the function returns false. +// The types are iterated over in sorted order by name. This function is compatible with go 1.23 iterators. +func (s ModuleSchema) Types(f func(Type) bool) { + keys := make([]string, 0, len(s.types)) + for k := range s.types { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + if !f(s.types[k]) { + break } } - return fmt.Errorf("object type %q not found in module schema", update.TypeName) } diff --git a/schema/module_schema_test.go b/schema/module_schema_test.go index d04327811be6..189c7f3b9947 100644 --- a/schema/module_schema_test.go +++ b/schema/module_schema_test.go @@ -1,27 +1,26 @@ package schema import ( + "reflect" "strings" "testing" ) func TestModuleSchema_Validate(t *testing.T) { tests := []struct { - name string - moduleSchema ModuleSchema - errContains string + name string + objectTypes []ObjectType + errContains string }{ { name: "valid module schema", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "object1", - KeyFields: []Field{ - { - Name: "field1", - Kind: StringKind, - }, + objectTypes: []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, }, }, }, @@ -30,15 +29,13 @@ func TestModuleSchema_Validate(t *testing.T) { }, { name: "invalid object type", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "", - KeyFields: []Field{ - { - Name: "field1", - Kind: StringKind, - }, + objectTypes: []ObjectType{ + { + Name: "", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, }, }, }, @@ -47,28 +44,26 @@ func TestModuleSchema_Validate(t *testing.T) { }, { name: "same enum with missing values", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "object1", - KeyFields: []Field{ - { - Name: "k", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "b"}, - }, + objectTypes: []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "k", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b"}, }, }, - ValueFields: []Field{ - { - Name: "v", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "b", "c"}, - }, + }, + ValueFields: []Field{ + { + Name: "v", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b", "c"}, }, }, }, @@ -78,31 +73,29 @@ func TestModuleSchema_Validate(t *testing.T) { }, { name: "same enum with different values", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "object1", - KeyFields: []Field{ - { - Name: "k", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "b"}, - }, + objectTypes: []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "k", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b"}, }, }, }, - { - Name: "object2", - KeyFields: []Field{ - { - Name: "k", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "c"}, - }, + }, + { + Name: "object2", + KeyFields: []Field{ + { + Name: "k", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "c"}, }, }, }, @@ -112,42 +105,58 @@ func TestModuleSchema_Validate(t *testing.T) { }, { name: "same enum", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ + objectTypes: []ObjectType{{ + Name: "object1", + KeyFields: []Field{ { - Name: "object1", - KeyFields: []Field{ - { - Name: "k", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "b"}, - }, + Name: "k", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b"}, + }, + }, + }, + }, + { + Name: "object2", + KeyFields: []Field{ + { + Name: "k", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b"}, }, }, }, - { - Name: "object2", - KeyFields: []Field{ - { - Name: "k", - Kind: EnumKind, - EnumDefinition: EnumDefinition{ - Name: "enum1", - Values: []string{"a", "b"}, - }, + }, + }, + }, + { + objectTypes: []ObjectType{ + { + Name: "type1", + ValueFields: []Field{ + { + Name: "field1", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "type1", + Values: []string{"a", "b"}, }, }, }, }, }, + errContains: "enum \"type1\" already exists as a different non-enum type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := tt.moduleSchema.Validate() + // because validate is called when calling NewModuleSchema, we just call NewModuleSchema + _, err := NewModuleSchema(tt.objectTypes) if tt.errContains == "" { if err != nil { t.Fatalf("unexpected error: %v", err) @@ -170,19 +179,18 @@ func TestModuleSchema_ValidateObjectUpdate(t *testing.T) { }{ { name: "valid object update", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "object1", - KeyFields: []Field{ - { - Name: "field1", - Kind: StringKind, - }, + moduleSchema: RequireNewModuleSchema(t, []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, }, }, }, }, + ), objectUpdate: ObjectUpdate{ TypeName: "object1", Key: "abc", @@ -191,25 +199,47 @@ func TestModuleSchema_ValidateObjectUpdate(t *testing.T) { }, { name: "object type not found", - moduleSchema: ModuleSchema{ - ObjectTypes: []ObjectType{ - { - Name: "object1", - KeyFields: []Field{ - { - Name: "field1", - Kind: StringKind, - }, + moduleSchema: RequireNewModuleSchema(t, []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, }, }, }, }, + ), objectUpdate: ObjectUpdate{ TypeName: "object2", Key: "abc", }, errContains: "object type \"object2\" not found in module schema", }, + { + name: "type name refers to an enum", + moduleSchema: RequireNewModuleSchema(t, []ObjectType{ + { + Name: "obj1", + KeyFields: []Field{ + { + Name: "field1", + Kind: EnumKind, + EnumDefinition: EnumDefinition{ + Name: "enum1", + Values: []string{"a", "b"}, + }, + }, + }, + }, + }), + objectUpdate: ObjectUpdate{ + TypeName: "enum1", + Key: "a", + }, + errContains: "type \"enum1\" is not an object type", + }, } for _, tt := range tests { @@ -227,3 +257,94 @@ func TestModuleSchema_ValidateObjectUpdate(t *testing.T) { }) } } + +func RequireNewModuleSchema(t *testing.T, objectTypes []ObjectType) ModuleSchema { + t.Helper() + moduleSchema, err := NewModuleSchema(objectTypes) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + return moduleSchema +} + +func TestModuleSchema_LookupType(t *testing.T) { + moduleSchema := RequireNewModuleSchema(t, []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, + }, + }, + }, + }) + + typ, ok := moduleSchema.LookupType("object1") + if !ok { + t.Fatalf("expected to find object type \"object1\"") + } + + objectType, ok := typ.(ObjectType) + if !ok { + t.Fatalf("expected object type, got %T", typ) + } + + if objectType.Name != "object1" { + t.Fatalf("expected object type name \"object1\", got %q", objectType.Name) + } +} + +func TestModuleSchema_ScanTypes(t *testing.T) { + moduleSchema := RequireNewModuleSchema(t, []ObjectType{ + { + Name: "object1", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, + }, + }, + }, + { + Name: "object2", + KeyFields: []Field{ + { + Name: "field1", + Kind: StringKind, + }, + }, + }, + }) + + var objectTypeNames []string + moduleSchema.Types(func(typ Type) bool { + objectType, ok := typ.(ObjectType) + if !ok { + t.Fatalf("expected object type, got %T", typ) + } + objectTypeNames = append(objectTypeNames, objectType.Name) + return true + }) + + expected := []string{"object1", "object2"} + if !reflect.DeepEqual(objectTypeNames, expected) { + t.Fatalf("expected object type names %v, got %v", expected, objectTypeNames) + } + + objectTypeNames = nil + // scan just the first type and return false + moduleSchema.Types(func(typ Type) bool { + objectType, ok := typ.(ObjectType) + if !ok { + t.Fatalf("expected object type, got %T", typ) + } + objectTypeNames = append(objectTypeNames, objectType.Name) + return false + }) + + expected = []string{"object1"} + if !reflect.DeepEqual(objectTypeNames, expected) { + t.Fatalf("expected object type names %v, got %v", expected, objectTypeNames) + } +} diff --git a/schema/object_type.go b/schema/object_type.go index a8fa432d8032..9dd3742a55d5 100644 --- a/schema/object_type.go +++ b/schema/object_type.go @@ -4,8 +4,8 @@ import "fmt" // ObjectType describes an object type a module schema. type ObjectType struct { - // Name is the name of the object type. It must be unique within the module schema - // and conform to the NameFormat regular expression. + // Name is the name of the object type. It must be unique within the module schema amongst all object and enum + // types and conform to the NameFormat regular expression. Name string // KeyFields is a list of fields that make up the primary key of the object. @@ -27,14 +27,16 @@ type ObjectType struct { RetainDeletions bool } +func (ObjectType) isType() {} + // Validate validates the object type. func (o ObjectType) Validate() error { - return o.validate(map[string]map[string]bool{}) + return o.validate(map[string]Type{}) } // validate validates the object type with an enumValueMap that can be // shared across a whole module schema. -func (o ObjectType) validate(enumValueMap map[string]map[string]bool) error { +func (o ObjectType) validate(types map[string]Type) error { if !ValidateName(o.Name) { return fmt.Errorf("invalid object type name %q", o.Name) } @@ -55,7 +57,8 @@ func (o ObjectType) validate(enumValueMap map[string]map[string]bool) error { } fieldNames[field.Name] = true - if err := checkEnumCompatibility(enumValueMap, field); err != nil { + err := addEnumType(types, field) + if err != nil { return err } } @@ -70,7 +73,8 @@ func (o ObjectType) validate(enumValueMap map[string]map[string]bool) error { } fieldNames[field.Name] = true - if err := checkEnumCompatibility(enumValueMap, field); err != nil { + err := addEnumType(types, field) + if err != nil { return err } } diff --git a/schema/type.go b/schema/type.go new file mode 100644 index 000000000000..1b3ef0657734 --- /dev/null +++ b/schema/type.go @@ -0,0 +1,7 @@ +package schema + +// Type is an interface that all types in the schema implement. +// Currently these are ObjectType and EnumDefinition. +type Type interface { + isType() +}