diff --git a/indexer/postgres/delete.go b/indexer/postgres/delete.go new file mode 100644 index 000000000000..08bdb155dc62 --- /dev/null +++ b/indexer/postgres/delete.go @@ -0,0 +1,61 @@ +package postgres + +import ( + "context" + "fmt" + "io" + "strings" +) + +// delete deletes the row with the provided key from the table. +func (tm *objectIndexer) delete(ctx context.Context, conn dbConn, key interface{}) error { + buf := new(strings.Builder) + var params []interface{} + var err error + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + params, err = tm.retainDeleteSqlAndParams(buf, key) + } else { + params, err = tm.deleteSqlAndParams(buf, key) + } + if err != nil { + return err + } + + sqlStr := buf.String() + tm.options.logger.Info("Delete", "sql", sqlStr, "params", params) + _, err = conn.ExecContext(ctx, sqlStr, params...) + return err +} + +// deleteSqlAndParams generates a DELETE statement and binding parameters for the provided key. +func (tm *objectIndexer) deleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "DELETE FROM %q", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +// retainDeleteSqlAndParams generates an UPDATE statement to set the _deleted column to true for the provided key +// which is used when the table is set to retain deletions mode. +func (tm *objectIndexer) retainDeleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "UPDATE %q SET _deleted = TRUE", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} diff --git a/indexer/postgres/indexer.go b/indexer/postgres/indexer.go index 2c37e9a79b11..bfaac25842e5 100644 --- a/indexer/postgres/indexer.go +++ b/indexer/postgres/indexer.go @@ -72,6 +72,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) { opts := options{ disableRetainDeletions: config.DisableRetainDeletions, logger: params.Logger, + addressCodec: params.AddressCodec, } idx := &indexerImpl{ @@ -85,6 +86,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) { return indexer.InitResult{ Listener: idx.listener(), + View: idx, }, nil } diff --git a/indexer/postgres/insert_update.go b/indexer/postgres/insert_update.go new file mode 100644 index 000000000000..fb246a84b61c --- /dev/null +++ b/indexer/postgres/insert_update.go @@ -0,0 +1,116 @@ +package postgres + +import ( + "context" + "fmt" + "io" + "strings" +) + +// insertUpdate inserts or updates the row with the provided key and value. +func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error { + exists, err := tm.exists(ctx, conn, key) + if err != nil { + return err + } + + buf := new(strings.Builder) + var params []interface{} + if exists { + if len(tm.typ.ValueFields) == 0 { + // special case where there are no value fields, so we can't update anything + return nil + } + + params, err = tm.updateSql(buf, key, value) + } else { + params, err = tm.insertSql(buf, key, value) + } + if err != nil { + return err + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params) + } + _, err = conn.ExecContext(ctx, sqlStr, params...) + return err +} + +// insertSql generates an INSERT statement and binding parameters for the provided key and value. +func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) { + keyParams, keyCols, err := tm.bindKeyParams(key) + if err != nil { + return nil, err + } + + valueParams, valueCols, err := tm.bindValueParams(value) + if err != nil { + return nil, err + } + + var allParams []interface{} + allParams = append(allParams, keyParams...) + allParams = append(allParams, valueParams...) + + allCols := make([]string, 0, len(keyCols)+len(valueCols)) + allCols = append(allCols, keyCols...) + allCols = append(allCols, valueCols...) + + var paramBindings []string + for i := 1; i <= len(allCols); i++ { + paramBindings = append(paramBindings, fmt.Sprintf("$%d", i)) + } + + _, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(), + strings.Join(allCols, ", "), + strings.Join(paramBindings, ", "), + ) + return allParams, err +} + +// updateSql generates an UPDATE statement and binding parameters for the provided key and value. +func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName()) + if err != nil { + return nil, err + } + + valueParams, valueCols, err := tm.bindValueParams(value) + if err != nil { + return nil, err + } + + paramIdx := 1 + for i, col := range valueCols { + if i > 0 { + _, err = fmt.Fprintf(w, ", ") + if err != nil { + return nil, err + } + } + _, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx) + if err != nil { + return nil, err + } + + paramIdx++ + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + _, err = fmt.Fprintf(w, ", _deleted = FALSE") + if err != nil { + return nil, err + } + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx) + if err != nil { + return nil, err + } + + allParams := append(valueParams, keyParams...) + _, err = fmt.Fprintf(w, ";") + return allParams, err +} diff --git a/indexer/postgres/listener.go b/indexer/postgres/listener.go index 1f46c1c7c55c..21d08a736af7 100644 --- a/indexer/postgres/listener.go +++ b/indexer/postgres/listener.go @@ -25,6 +25,34 @@ func (i *indexerImpl) listener() appdata.Listener { _, err := i.tx.Exec("INSERT INTO block (number) VALUES ($1)", data.Height) return err }, + OnObjectUpdate: func(data appdata.ObjectUpdateData) error { + module := data.ModuleName + mod, ok := i.modules[module] + if !ok { + return fmt.Errorf("module %s not initialized", module) + } + + for _, update := range data.Updates { + if i.logger != nil { + i.logger.Debug("OnObjectUpdate", "module", module, "type", update.TypeName, "key", update.Key, "delete", update.Delete, "value", update.Value) + } + tm, ok := mod.tables[update.TypeName] + if !ok { + return fmt.Errorf("object type %s not found in schema for module %s", update.TypeName, module) + } + + var err error + if update.Delete { + err = tm.delete(i.ctx, i.tx, update.Key) + } else { + err = tm.insertUpdate(i.ctx, i.tx, update.Key, update.Value) + } + if err != nil { + return err + } + } + return nil + }, Commit: func(data appdata.CommitData) (func() error, error) { err := i.tx.Commit() if err != nil { diff --git a/indexer/postgres/options.go b/indexer/postgres/options.go index d18a4c4d7f2c..db905f9dbaa7 100644 --- a/indexer/postgres/options.go +++ b/indexer/postgres/options.go @@ -1,6 +1,9 @@ package postgres -import "cosmossdk.io/schema/logutil" +import ( + "cosmossdk.io/schema/addressutil" + "cosmossdk.io/schema/logutil" +) // options are the options for module and object indexers. type options struct { @@ -9,4 +12,7 @@ type options struct { // logger is the logger for the indexer to use. It may be nil. logger logutil.Logger + + // addressCodec is the codec for encoding and decoding addresses. It is expected to be non-nil. + addressCodec addressutil.AddressCodec } diff --git a/indexer/postgres/params.go b/indexer/postgres/params.go new file mode 100644 index 000000000000..b2af8f6f174a --- /dev/null +++ b/indexer/postgres/params.go @@ -0,0 +1,116 @@ +package postgres + +import ( + "fmt" + "time" + + "cosmossdk.io/schema" +) + +// bindKeyParams binds the key to the key columns. +func (tm *objectIndexer) bindKeyParams(key interface{}) ([]interface{}, []string, error) { + n := len(tm.typ.KeyFields) + if n == 0 { + // singleton, set _id = 1 + return []interface{}{1}, []string{"_id"}, nil + } else if n == 1 { + return tm.bindParams(tm.typ.KeyFields, []interface{}{key}) + } else { + key, ok := key.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("expected key to be a slice") + } + + return tm.bindParams(tm.typ.KeyFields, key) + } +} + +func (tm *objectIndexer) bindValueParams(value interface{}) (params []interface{}, valueCols []string, err error) { + n := len(tm.typ.ValueFields) + if n == 0 { + return nil, nil, nil + } else if valueUpdates, ok := value.(schema.ValueUpdates); ok { + var e error + var fields []schema.Field + var params []interface{} + if err := valueUpdates.Iterate(func(name string, value interface{}) bool { + field, ok := tm.valueFields[name] + if !ok { + e = fmt.Errorf("unknown column %q", name) + return false + } + fields = append(fields, field) + params = append(params, value) + return true + }); err != nil { + return nil, nil, err + } + if e != nil { + return nil, nil, e + } + + return tm.bindParams(fields, params) + } else if n == 1 { + return tm.bindParams(tm.typ.ValueFields, []interface{}{value}) + } else { + values, ok := value.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("expected values to be a slice") + } + + return tm.bindParams(tm.typ.ValueFields, values) + } +} + +func (tm *objectIndexer) bindParams(fields []schema.Field, values []interface{}) ([]interface{}, []string, error) { + names := make([]string, 0, len(fields)) + params := make([]interface{}, 0, len(fields)) + for i, field := range fields { + if i >= len(values) { + return nil, nil, fmt.Errorf("missing value for field %q", field.Name) + } + + param, err := tm.bindParam(field, values[i]) + if err != nil { + return nil, nil, err + } + + name, err := tm.updatableColumnName(field) + if err != nil { + return nil, nil, err + } + + names = append(names, name) + params = append(params, param) + } + return params, names, nil +} + +func (tm *objectIndexer) bindParam(field schema.Field, value interface{}) (param interface{}, err error) { + param = value + if value == nil { + if !field.Nullable { + return nil, fmt.Errorf("expected non-null value for field %q", field.Name) + } + } else if field.Kind == schema.TimeKind { + t, ok := value.(time.Time) + if !ok { + return nil, fmt.Errorf("expected time.Time value for field %q, got %T", field.Name, value) + } + + param = t.UnixNano() + } else if field.Kind == schema.DurationKind { + t, ok := value.(time.Duration) + if !ok { + return nil, fmt.Errorf("expected time.Duration value for field %q, got %T", field.Name, value) + } + + param = int64(t) + } else if field.Kind == schema.AddressKind { + param, err = tm.options.addressCodec.BytesToString(value.([]byte)) + if err != nil { + return nil, fmt.Errorf("address encoding failed for field %q: %w", field.Name, err) + } + } + return +} diff --git a/indexer/postgres/select.go b/indexer/postgres/select.go new file mode 100644 index 000000000000..46ef12d3f15c --- /dev/null +++ b/indexer/postgres/select.go @@ -0,0 +1,299 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "strings" + "time" + + "cosmossdk.io/schema" +) + +// Count returns the number of rows in the table. +func (tm *objectIndexer) count(ctx context.Context, conn dbConn) (int, error) { + sqlStr := fmt.Sprintf("SELECT COUNT(*) FROM %q;", tm.tableName()) + if tm.options.logger != nil { + tm.options.logger.Debug("Count", "sql", sqlStr) + } + row := conn.QueryRowContext(ctx, sqlStr) + var count int + err := row.Scan(&count) + return count, err +} + +// exists checks if a row with the provided key exists in the table. +func (tm *objectIndexer) exists(ctx context.Context, conn dbConn, key interface{}) (bool, error) { + buf := new(strings.Builder) + params, err := tm.existsSqlAndParams(buf, key) + if err != nil { + return false, err + } + + return tm.checkExists(ctx, conn, buf.String(), params) +} + +// checkExists checks if a row exists in the table. +func (tm *objectIndexer) checkExists(ctx context.Context, conn dbConn, sqlStr string, params []interface{}) (bool, error) { + if tm.options.logger != nil { + tm.options.logger.Debug("Check exists", "sql", sqlStr, "params", params) + } + var res interface{} + err := conn.QueryRowContext(ctx, sqlStr, params...).Scan(&res) + switch err { + case nil: + return true, nil + case sql.ErrNoRows: + return false, nil + default: + return false, err + } +} + +// existsSqlAndParams generates a SELECT statement to check if a row with the provided key exists in the table. +func (tm *objectIndexer) existsSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "SELECT 1 FROM %q", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +func (tm *objectIndexer) get(ctx context.Context, conn dbConn, key interface{}) (schema.ObjectUpdate, bool, error) { + buf := new(strings.Builder) + params, err := tm.getSqlAndParams(buf, key) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Get", "sql", sqlStr, "params", params) + } + + row := conn.QueryRowContext(ctx, sqlStr, params...) + return tm.readRow(row) +} + +func (tm *objectIndexer) selectAllSql(w io.Writer) error { + err := tm.selectAllClause(w) + if err != nil { + return err + } + + _, err = fmt.Fprintf(w, ";") + return err +} + +func (tm *objectIndexer) getSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + err := tm.selectAllClause(w) + if err != nil { + return nil, err + } + + keyParams, keyCols, err := tm.bindKeyParams(key) + if err != nil { + return nil, err + } + + _, keyParams, err = tm.whereSql(w, keyParams, keyCols, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +func (tm *objectIndexer) selectAllClause(w io.Writer) error { + allFields := make([]string, 0, len(tm.typ.KeyFields)+len(tm.typ.ValueFields)) + + for _, field := range tm.typ.KeyFields { + colName, err := tm.updatableColumnName(field) + if err != nil { + return err + } + allFields = append(allFields, colName) + } + + for _, field := range tm.typ.ValueFields { + colName, err := tm.updatableColumnName(field) + if err != nil { + return err + } + allFields = append(allFields, colName) + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + allFields = append(allFields, "_deleted") + } + + _, err := fmt.Fprintf(w, "SELECT %s FROM %q", strings.Join(allFields, ", "), tm.tableName()) + if err != nil { + return err + } + + return nil +} + +func (tm *objectIndexer) readRow(row interface{ Scan(...interface{}) error }) (schema.ObjectUpdate, bool, error) { + var res []interface{} + for _, f := range tm.typ.KeyFields { + res = append(res, tm.colBindValue(f)) + } + + for _, f := range tm.typ.ValueFields { + res = append(res, tm.colBindValue(f)) + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + res = append(res, new(bool)) + } + + err := row.Scan(res...) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return schema.ObjectUpdate{}, false, err + } + return schema.ObjectUpdate{}, false, err + } + + var keys []interface{} + for _, field := range tm.typ.KeyFields { + x, err := tm.readCol(field, res[0]) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + keys = append(keys, x) + res = res[1:] + } + + var key interface{} = keys + if len(keys) == 1 { + key = keys[0] + } + + var values []interface{} + for _, field := range tm.typ.ValueFields { + x, err := tm.readCol(field, res[0]) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + values = append(values, x) + res = res[1:] + } + + var value interface{} = values + if len(values) == 1 { + value = values[0] + } + + update := schema.ObjectUpdate{ + TypeName: tm.typ.Name, + Key: key, + Value: value, + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + deleted := res[0].(*bool) + if *deleted { + update.Delete = true + } + } + + return update, true, nil +} + +func (tm *objectIndexer) colBindValue(field schema.Field) interface{} { + switch field.Kind { + case schema.BytesKind: + return new(interface{}) + default: + return new(sql.NullString) + } +} + +func (tm *objectIndexer) readCol(field schema.Field, value interface{}) (interface{}, error) { + switch field.Kind { + case schema.BytesKind: + // for bytes types we either get []byte or nil + value = *value.(*interface{}) + return value, nil + default: + } + + nullStr := *value.(*sql.NullString) + if field.Nullable { + if !nullStr.Valid { + return nil, nil + } + } + str := nullStr.String + + switch field.Kind { + case schema.StringKind, schema.EnumKind, schema.IntegerStringKind, schema.DecimalStringKind: + return str, nil + case schema.Uint8Kind: + value, err := strconv.ParseUint(str, 10, 8) + return uint8(value), err + case schema.Uint16Kind: + value, err := strconv.ParseUint(str, 10, 16) + return uint16(value), err + case schema.Uint32Kind: + value, err := strconv.ParseUint(str, 10, 32) + return uint32(value), err + case schema.Uint64Kind: + value, err := strconv.ParseUint(str, 10, 64) + return value, err + case schema.Int8Kind: + value, err := strconv.ParseInt(str, 10, 8) + return int8(value), err + case schema.Int16Kind: + value, err := strconv.ParseInt(str, 10, 16) + return int16(value), err + case schema.Int32Kind: + value, err := strconv.ParseInt(str, 10, 32) + return int32(value), err + case schema.Int64Kind: + value, err := strconv.ParseInt(str, 10, 64) + return value, err + case schema.Float32Kind: + value, err := strconv.ParseFloat(str, 32) + return float32(value), err + case schema.Float64Kind: + value, err := strconv.ParseFloat(str, 64) + return value, err + case schema.BoolKind: + value, err := strconv.ParseBool(str) + return value, err + case schema.JSONKind: + return json.RawMessage(str), nil + case schema.TimeKind: + value, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return nil, err + } + return time.Unix(0, value), nil + case schema.DurationKind: + value, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return nil, err + } + return time.Duration(value), nil + case schema.AddressKind: + return tm.options.addressCodec.StringToBytes(str) + default: + return value, nil + } +} diff --git a/indexer/postgres/tests/go.mod b/indexer/postgres/tests/go.mod index a72a1dc7fc07..cb642a3e7c6f 100644 --- a/indexer/postgres/tests/go.mod +++ b/indexer/postgres/tests/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( cosmossdk.io/indexer/postgres v0.0.0-00010101000000-000000000000 cosmossdk.io/schema v0.1.1 + cosmossdk.io/schema/testing v0.0.0 github.com/fergusstrange/embedded-postgres v1.29.0 github.com/hashicorp/consul/sdk v0.16.1 github.com/jackc/pgx/v5 v5.6.0 @@ -13,6 +14,7 @@ require ( ) require ( + github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -22,14 +24,18 @@ require ( github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/tidwall/btree v1.7.0 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + pgregory.net/rapid v1.1.0 // indirect ) replace cosmossdk.io/indexer/postgres => ../. replace cosmossdk.io/schema => ../../../schema + +replace cosmossdk.io/schema/testing => ../../../schema/testing diff --git a/indexer/postgres/tests/go.sum b/indexer/postgres/tests/go.sum index 809d5040b848..f310c988deaa 100644 --- a/indexer/postgres/tests/go.sum +++ b/indexer/postgres/tests/go.sum @@ -1,3 +1,5 @@ +github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= +github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -32,6 +34,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= @@ -52,3 +56,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/indexer/postgres/tests/postgres_test.go b/indexer/postgres/tests/postgres_test.go new file mode 100644 index 000000000000..fc725f9cc1cf --- /dev/null +++ b/indexer/postgres/tests/postgres_test.go @@ -0,0 +1,99 @@ +package tests + +import ( + "context" + "os" + "strings" + "testing" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/hashicorp/consul/sdk/freeport" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/require" + + "cosmossdk.io/indexer/postgres" + "cosmossdk.io/schema/addressutil" + "cosmossdk.io/schema/indexer" + indexertesting "cosmossdk.io/schema/testing" + "cosmossdk.io/schema/testing/appdatasim" + "cosmossdk.io/schema/testing/statesim" +) + +func TestPostgresIndexer(t *testing.T) { + t.Run("RetainDeletions", func(t *testing.T) { + testPostgresIndexer(t, true) + }) + t.Run("NoRetainDeletions", func(t *testing.T) { + testPostgresIndexer(t, false) + }) +} + +func testPostgresIndexer(t *testing.T, retainDeletions bool) { + tempDir, err := os.MkdirTemp("", "postgres-indexer-test") + require.NoError(t, err) + + dbPort := freeport.GetOne(t) + pgConfig := embeddedpostgres.DefaultConfig(). + Port(uint32(dbPort)). + DataPath(tempDir) + + dbUrl := pgConfig.GetConnectionURL() + pg := embeddedpostgres.NewDatabase(pgConfig) + require.NoError(t, pg.Start()) + + ctx, cancel := context.WithCancel(context.Background()) + + t.Cleanup(func() { + cancel() + require.NoError(t, pg.Stop()) + err := os.RemoveAll(tempDir) + require.NoError(t, err) + }) + + cfg, err := postgresConfigToIndexerConfig(postgres.Config{ + DatabaseURL: dbUrl, + DisableRetainDeletions: !retainDeletions, + }) + require.NoError(t, err) + + debugLog := &strings.Builder{} + + pgIndexer, err := postgres.StartIndexer(indexer.InitParams{ + Config: cfg, + Context: ctx, + Logger: &prettyLogger{debugLog}, + AddressCodec: addressutil.HexAddressCodec{}, + }) + require.NoError(t, err) + + sim, err := appdatasim.NewSimulator(appdatasim.Options{ + Listener: pgIndexer.Listener, + AppSchema: indexertesting.ExampleAppSchema, + StateSimOptions: statesim.Options{ + CanRetainDeletions: retainDeletions, + }, + }) + require.NoError(t, err) + + blockDataGen := sim.BlockDataGenN(10, 100) + numBlocks := 200 + if testing.Short() { + numBlocks = 10 + } + for i := 0; i < numBlocks; i++ { + // using Example generates a deterministic data set based + // on a seed so that regression tests can be created OR rapid.Check can + // be used for fully random property-based testing + blockData := blockDataGen.Example(i) + + // process the generated block data with the simulator which will also + // send it to the indexer + require.NoError(t, sim.ProcessBlockData(blockData), debugLog.String()) + + // compare the expected state in the simulator to the actual state in the indexer and expect the diff to be empty + require.Empty(t, appdatasim.DiffAppData(sim, pgIndexer.View), debugLog.String()) + + // reset the debug log after each successful block so that it doesn't get too long when debugging + debugLog.Reset() + } +} diff --git a/indexer/postgres/view.go b/indexer/postgres/view.go new file mode 100644 index 000000000000..eac2c52f8a8b --- /dev/null +++ b/indexer/postgres/view.go @@ -0,0 +1,151 @@ +package postgres + +import ( + "context" + "database/sql" + "strings" + + "cosmossdk.io/schema" + "cosmossdk.io/schema/view" +) + +var _ view.AppData = &indexerImpl{} + +func (i *indexerImpl) AppState() view.AppState { + return i +} + +func (i *indexerImpl) BlockNum() (uint64, error) { + var blockNum int64 + err := i.tx.QueryRow("SELECT coalesce(max(number), 0) FROM block").Scan(&blockNum) + if err != nil { + return 0, err + } + return uint64(blockNum), nil +} + +type moduleView struct { + moduleIndexer + ctx context.Context + conn dbConn +} + +func (i *indexerImpl) GetModule(moduleName string) (view.ModuleState, error) { + mod, ok := i.modules[moduleName] + if !ok { + return nil, nil + } + return &moduleView{ + moduleIndexer: *mod, + ctx: i.ctx, + conn: i.tx, + }, nil +} + +func (i *indexerImpl) Modules(f func(modState view.ModuleState, err error) bool) { + for _, mod := range i.modules { + if !f(&moduleView{ + moduleIndexer: *mod, + ctx: i.ctx, + conn: i.tx, + }, nil) { + return + } + } +} + +func (i *indexerImpl) NumModules() (int, error) { + return len(i.modules), nil +} + +func (m *moduleView) ModuleName() string { + return m.moduleName +} + +func (m *moduleView) ModuleSchema() schema.ModuleSchema { + return m.schema +} + +func (m *moduleView) GetObjectCollection(objectType string) (view.ObjectCollection, error) { + obj, ok := m.tables[objectType] + if !ok { + return nil, nil + } + return &objectView{ + objectIndexer: *obj, + ctx: m.ctx, + conn: m.conn, + }, nil +} + +func (m *moduleView) ObjectCollections(f func(value view.ObjectCollection, err error) bool) { + for _, obj := range m.tables { + if !f(&objectView{ + objectIndexer: *obj, + ctx: m.ctx, + conn: m.conn, + }, nil) { + return + } + } +} + +func (m *moduleView) NumObjectCollections() (int, error) { + return len(m.tables), nil +} + +type objectView struct { + objectIndexer + ctx context.Context + conn dbConn +} + +func (tm *objectView) ObjectType() schema.ObjectType { + return tm.typ +} + +func (tm *objectView) GetObject(key interface{}) (update schema.ObjectUpdate, found bool, err error) { + return tm.get(tm.ctx, tm.conn, key) +} + +func (tm *objectView) AllState(f func(schema.ObjectUpdate, error) bool) { + buf := new(strings.Builder) + err := tm.selectAllSql(buf) + if err != nil { + panic(err) + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Select", "sql", sqlStr) + } + + rows, err := tm.conn.QueryContext(tm.ctx, sqlStr) + if err != nil { + panic(err) + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + panic(err) + } + }(rows) + + for rows.Next() { + update, found, err := tm.readRow(rows) + if err == nil && !found { + err = sql.ErrNoRows + } + if !f(update, err) { + return + } + } +} + +func (tm *objectView) Len() (int, error) { + n, err := tm.count(tm.ctx, tm.conn) + if err != nil { + return 0, err + } + return n, nil +} diff --git a/indexer/postgres/where.go b/indexer/postgres/where.go new file mode 100644 index 000000000000..745092781734 --- /dev/null +++ b/indexer/postgres/where.go @@ -0,0 +1,60 @@ +package postgres + +import ( + "fmt" + "io" +) + +// whereSqlAndParams generates a WHERE clause for the provided key and returns the parameters. +func (tm *objectIndexer) whereSqlAndParams(w io.Writer, key interface{}, startParamIdx int) (endParamIdx int, keyParams []interface{}, err error) { + var keyCols []string + keyParams, keyCols, err = tm.bindKeyParams(key) + if err != nil { + return + } + + endParamIdx, keyParams, err = tm.whereSql(w, keyParams, keyCols, startParamIdx) + return +} + +// whereSql generates a WHERE clause for the provided columns and returns the parameters. +func (tm *objectIndexer) whereSql(w io.Writer, params []interface{}, cols []string, startParamIdx int) (endParamIdx int, resParams []interface{}, err error) { + _, err = fmt.Fprintf(w, " WHERE ") + if err != nil { + return 0, nil, err + } + + endParamIdx = startParamIdx + for i, col := range cols { + if i > 0 { + _, err = fmt.Fprintf(w, " AND ") + if err != nil { + return 0, nil, err + } + } + + _, err = fmt.Fprintf(w, "%s ", col) + if err != nil { + return 0, nil, err + } + + if params[i] == nil { + _, err = fmt.Fprintf(w, "IS NULL") + if err != nil { + return 0, nil, err + } + + } else { + _, err = fmt.Fprintf(w, "= $%d", endParamIdx) + if err != nil { + return 0, nil, err + } + + resParams = append(resParams, params[i]) + + endParamIdx++ + } + } + + return endParamIdx, resParams, nil +}