Skip to content

Commit

Permalink
Fixed bool double escaping, escape more locations, diminish amount …
Browse files Browse the repository at this point in the history
…of memory allocations
  • Loading branch information
maoueh committed May 16, 2023
1 parent 13f6165 commit f9be7b4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 24 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v2.0.2

### Changed

- Diminish amount of allocations done to perform fields transformation.

### Fixed

- Fixed some places where escaping for either identifier or value was not done properly.

- Fixed double escaping of boolean values.

## v2.0.1

### Added
Expand Down
75 changes: 51 additions & 24 deletions db/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,42 +79,63 @@ func (o *Operation) mergeData(newData map[string]string) error {
}

func (o *Operation) query(typeGetter TypeGetter) (string, error) {
if o.opType == OperationTypeDelete {
return fmt.Sprintf("DELETE FROM %s.%s WHERE %s = %s", o.schemaName, o.tableName, o.primaryKeyColumnName, o.primaryKey), nil
var columns, values []string
if o.opType == OperationTypeInsert || o.opType == OperationTypeUpdate {
var err error
columns, values, err = prepareColValues(o.tableName, o.data, typeGetter)
if err != nil {
return "", fmt.Errorf("preparing column & values: %w", err)
}
}

columns, values, err := prepareColValues(o.tableName, o.data, typeGetter)
if err != nil {
return "", fmt.Errorf("preparing column-values: %w", err)
}
if o.opType == OperationTypeInsert {
switch o.opType {
case OperationTypeInsert:
return fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)",
o.schemaName,
o.tableName,
strings.Join(columns, ","),
strings.Join(values, ","),
), nil
}

var updates []string
for i := 0; i < len(columns); i++ {
// FIXME: merely using %s for key can lead to SQL injection. I
// know, you should trust the Substreams you're putting in
// front, but still.
update := fmt.Sprintf("%s=%s", columns[i], values[i])
updates = append(updates, update)
case OperationTypeUpdate:
updates := make([]string, len(columns))
for i := 0; i < len(columns); i++ {
updates[i] = fmt.Sprintf("%s=%s", columns[i], values[i])
}

updatesString := strings.Join(updates, ", ")
return fmt.Sprintf("UPDATE %s.%s SET %s WHERE %s = %s",
escapeIdentifier(o.schemaName),
escapeIdentifier(o.tableName),
updatesString,
escapeIdentifier(o.primaryKeyColumnName),
escapeStringValue(o.primaryKey),
), nil

case OperationTypeDelete:
return fmt.Sprintf("DELETE FROM %s.%s WHERE %s = %s",
escapeIdentifier(o.schemaName),
escapeIdentifier(o.tableName),
escapeIdentifier(o.primaryKeyColumnName),
escapeStringValue(o.primaryKey),
), nil

default:
panic(fmt.Errorf("unknown operation type %q", o.opType))
}

updatesString := strings.Join(updates, ", ")
return fmt.Sprintf("UPDATE %s.%s SET %s WHERE %s = '%s'", o.schemaName, o.tableName, updatesString, o.primaryKeyColumnName, o.primaryKey), nil
}

func prepareColValues(tableName string, colValues map[string]string, typeGetter TypeGetter) (columns []string, values []string, err error) {
for columnName, value := range colValues {
escapedColumn := escapeIdentifier(columnName)
if len(colValues) == 0 {
return
}

columns = append(columns, escapedColumn)
columns = make([]string, len(colValues))
values = make([]string, len(colValues))

i := 0
for columnName, value := range colValues {
valueType, err := typeGetter(tableName, columnName)
if err != nil {
return nil, nil, fmt.Errorf("get column type %s.%s: %w", tableName, columnName, err)
Expand All @@ -125,27 +146,32 @@ func prepareColValues(tableName string, colValues map[string]string, typeGetter
return nil, nil, fmt.Errorf("getting sql value from table %s for column %q raw value %q: %w", tableName, columnName, value, err)
}

escapedValue := escapeStringValue(normalizedValue)
columns[i] = escapeIdentifier(columnName)
values[i] = normalizedValue

values = append(values, escapedValue)
i++
}
return
}

// Format based on type, value returned unescaped
func normalizeValueType(value string, valueType reflect.Type) (string, error) {

switch valueType.Kind() {
case reflect.String:
return value, nil
return escapeStringValue(value), nil

case reflect.Bool:
return fmt.Sprintf("'%s'", value), nil

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return value, nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return value, nil

case reflect.Float32, reflect.Float64:
return value, nil

case reflect.Struct:
if valueType == reflect.TypeOf(time.Time{}) {
i, err := strconv.Atoi(value)
Expand All @@ -157,6 +183,7 @@ func normalizeValueType(value string, valueType reflect.Type) (string, error) {
return fmt.Sprintf("'%s'", v), nil
}
return "", fmt.Errorf("unsupported struct type %s", valueType)

default:
return "", fmt.Errorf("unsupported type %s", valueType)
}
Expand Down
39 changes: 39 additions & 0 deletions db/operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/bobg/go-generics/v2/slices"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -127,3 +129,40 @@ func TestEscapeValues(t *testing.T) {
})
}
}

func Test_prepareColValues(t *testing.T) {
boolTypeGetter := func(tableName string, columnName string) (reflect.Type, error) { return reflect.TypeOf(true), nil }

type args struct {
tableName string
colValues map[string]string
typeGetter TypeGetter
}
tests := []struct {
name string
args args
wantColumns []string
wantValues []string
assertion require.ErrorAssertionFunc
}{
{
"bool true",
args{
"test",
map[string]string{"col": "true"},
boolTypeGetter,
},
[]string{`"col"`},
[]string{`'true'`},
require.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotColumns, gotValues, err := prepareColValues(tt.args.tableName, tt.args.colValues, tt.args.typeGetter)
tt.assertion(t, err)
assert.Equal(t, tt.wantColumns, gotColumns)
assert.Equal(t, tt.wantValues, gotValues)
})
}
}

0 comments on commit f9be7b4

Please sign in to comment.