Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
efritz committed Sep 17, 2024
1 parent e14c3cf commit 4cae252
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 8 deletions.
20 changes: 17 additions & 3 deletions describe_functions.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
package pgutil

import "context"
import (
"context"
"slices"

"github.com/lib/pq"
)

type FunctionDescription struct {
Namespace string
Name string
Definition string
ArgTypes []string
}

func (d FunctionDescription) Equals(other FunctionDescription) bool {
return true &&
d.Namespace == other.Namespace &&
d.Name == other.Name &&
slices.Equal(d.ArgTypes, other.ArgTypes) &&
d.Definition == other.Definition
}

var scanFunctions = NewSliceScanner(func(s Scanner) (f FunctionDescription, _ error) {
err := s.Scan(&f.Namespace, &f.Name, &f.Definition)
err := s.Scan(&f.Namespace, &f.Name, &f.Definition, pq.Array(&f.ArgTypes))
return f, err
})

Expand All @@ -25,7 +32,14 @@ func DescribeFunctions(ctx context.Context, db DB) ([]FunctionDescription, error
SELECT
n.nspname AS namespace,
p.proname AS name,
pg_get_functiondef(p.oid) AS definition
pg_get_functiondef(p.oid) AS definition,
COALESCE(
ARRAY(
SELECT typ.typname
FROM unnest(p.proargtypes) AS t(type_oid)
JOIN pg_type typ ON typ.oid = t.type_oid
),
'{}'::text[]) AS argtypes
FROM pg_catalog.pg_proc p
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
JOIN pg_language l ON l.oid = p.prolang AND l.lanname IN ('sql', 'plpgsql')
Expand Down
3 changes: 3 additions & 0 deletions drift.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func Compare(a, b SchemaDescription) (statements []string) {
//

func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) {
// TODO
// keys := map[string]struct{}{}
// for _, view := range a.Views {
// keys[fmt.Sprintf("%q.%q", view.Namespace, view.Name)] = struct{}{}
Expand All @@ -133,6 +134,7 @@ func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) {
sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName)
dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView)

// TODO
// if _, ok := keys[dependencyKey]; ok {
if _, ok := createClosure[sourceKey]; !ok {
createClosure[sourceKey] = map[string]struct{}{}
Expand All @@ -147,6 +149,7 @@ func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) {
sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName)
dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView)

// TODO
// if _, ok := keys[sourceKey]; ok {
if _, ok := dropClosure[dependencyKey]; !ok {
dropClosure[dependencyKey] = map[string]struct{}{}
Expand Down
6 changes: 3 additions & 3 deletions drift_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgutil

import (
"fmt"
"strings"
)

type FunctionModifier struct {
Expand All @@ -15,7 +16,7 @@ func NewFunctionModifier(_ SchemaDescription, d FunctionDescription) FunctionMod
}

func (m FunctionModifier) Key() string {
return fmt.Sprintf("%q.%q", m.d.Namespace, m.d.Name)
return fmt.Sprintf("%q.%q(%s)", m.d.Namespace, m.d.Name, strings.Join(m.d.ArgTypes, ", "))
}

func (m FunctionModifier) ObjectType() string {
Expand All @@ -30,7 +31,7 @@ func (m FunctionModifier) Create() string {
return fmt.Sprintf("%s;", m.d.Definition)
}

func (m FunctionModifier) AlterExisting(_ SchemaDescription, _ FunctionModifier) ([]ddlStatement, bool) {
func (m FunctionModifier) AlterExisting(_ SchemaDescription, _ FunctionDescription) ([]ddlStatement, bool) {
return []ddlStatement{
newStatement(
m.Key(),
Expand All @@ -42,6 +43,5 @@ func (m FunctionModifier) AlterExisting(_ SchemaDescription, _ FunctionModifier)
}

func (m FunctionModifier) Drop() string {
// TODO - capture args
return fmt.Sprintf("DROP FUNCTION IF EXISTS %s;", m.Key())
}
2 changes: 1 addition & 1 deletion drift_sequences.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (m SequenceModifier) Drop() string {
return fmt.Sprintf("DROP SEQUENCE IF EXISTS %s;", m.Key())
}

func (m SequenceModifier) AlterExisting(existingSchema SchemaDescription, existingObject EnumDescription) ([]ddlStatement, bool) {
func (m SequenceModifier) AlterExisting(existingSchema SchemaDescription, existingObject SequenceDescription) ([]ddlStatement, bool) {
// TODO - implement these
// Type string
// StartValue int
Expand Down
160 changes: 159 additions & 1 deletion drift_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,161 @@
package pgutil

// TODO
import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDrift_Extensions(t *testing.T) {
t.Run("create", func(t *testing.T) {
testDrift(t,
// setup
`CREATE EXTENSION IF NOT EXISTS hstore;`,
// alter
`DROP EXTENSION IF EXISTS hstore;`,
// expected
`CREATE EXTENSION IF NOT EXISTS "public"."hstore";`,
)
})

t.Run("drop", func(t *testing.T) {
testDrift(t,
// setup
``,
// alter
`CREATE EXTENSION IF NOT EXISTS pg_trgm;`,
// expected
`DROP EXTENSION IF EXISTS "public"."pg_trgm";`,
)
})
}

func TestDrift_Enums(t *testing.T) {
t.Run("create", func(t *testing.T) {
testDrift(t,
// setup
`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`,
// alter
`DROP TYPE IF EXISTS mood;`,
// expected
`CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`,
)
})

t.Run("drop", func(t *testing.T) {
testDrift(t,
// setup
``,
// alter
`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`,
// expected
`DROP TYPE IF EXISTS "public"."mood";`,
)
})

t.Run("alter", func(t *testing.T) {
// TODO - test view closures (fix impl first)
})
}

var postgresAddFunctionDefinition = `CREATE OR REPLACE FUNCTION public.add(integer, integer)
RETURNS integer
LANGUAGE sql
AS $function$SELECT $1 + $2;$function$
;`

func TestDrift_Functions(t *testing.T) {
t.Run("create", func(t *testing.T) {
t.Skip() // TODO
testDrift(t,
// setup
`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`,
// alter
`DROP FUNCTION IF EXISTS add(integer, integer);`,
// expected
postgresAddFunctionDefinition,
)
})

t.Run("drop", func(t *testing.T) {
t.Skip() // TODO
testDrift(t,
// setup
``,
// alter
`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`,
// expected
`DROP FUNCTION IF EXISTS "public"."add"(int4, int4);`,
)
})

t.Run("alter mismatched definition", func(t *testing.T) {
testDrift(t,
// setup
`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`,
// alter
`CREATE OR REPLACE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 - $2;' LANGUAGE SQL;`,
// expected
postgresAddFunctionDefinition,
)
})

t.Run("does not alter differing argument types", func(t *testing.T) {
testDrift(t,
// setup
`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`,
// alter
`CREATE FUNCTION add(integer, integer, integer) RETURNS integer AS 'SELECT $1 + $2 + $3;' LANGUAGE SQL;`,
// expected (drops extra function)
`DROP FUNCTION IF EXISTS "public"."add"(int4, int4, int4);`,
)
})
}

func TestDrift_Tables(t *testing.T) {
// TODO
}

func TestDrift_Sequences(t *testing.T) {
// TODO
}

func TestDrift_Views(t *testing.T) {
// TODO
}

func TestDrift_Triggers(t *testing.T) {
// TODO
}

func TestDrift_EnumDependencies(t *testing.T) {
// TODO
}

func TestDrift_ColumnDependencies(t *testing.T) {
// TODO
}

//
//

func testDrift(t *testing.T, setupQuery string, alterQuery string, expectedQueries ...string) {
t.Helper()
db := NewTestDB(t)
ctx := context.Background()

require.NoError(t, db.Exec(ctx, RawQuery(setupQuery)))

before, err := DescribeSchema(ctx, db)
if err != nil {
t.Fatalf("Failed to describe schema: %v", err)
}

require.NoError(t, db.Exec(ctx, RawQuery(alterQuery)))

after, err := DescribeSchema(ctx, db)
require.NoError(t, err)
assert.Equal(t, expectedQueries, Compare(before, after))
}
27 changes: 27 additions & 0 deletions testdata/golden/TestDescribeSchema.golden
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,32 @@ BEGIN
END;
$function$
`,
ArgTypes: []string{},
},
{
Namespace: "public",
Name: "get_weather_description",
Definition: `CREATE OR REPLACE FUNCTION public.get_weather_description(w weather)
RETURNS text
LANGUAGE plpgsql
AS $function$
BEGIN
CASE
WHEN w = 'sunny' THEN
RETURN 'Pack some SPF!';
WHEN w = 'rainy' THEN
RETURN 'Bring an umbrella!';
WHEN w = 'cloudy' THEN
RETURN 'Wear a jacket!';
WHEN w = 'snowy' THEN
RETURN 'Bundle up!';
ELSE
RETURN 'Unknown weather';
END CASE;
END;
$function$
`,
ArgTypes: []string{"weather"},
},
{
Namespace: "public",
Expand All @@ -57,6 +83,7 @@ BEGIN
END;
$function$
`,
ArgTypes: []string{},
},
},
Tables: []pgutil.TableDescription{
Expand Down
17 changes: 17 additions & 0 deletions testdata/schemas/describe.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@ BEGIN
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION get_weather_description(w weather) RETURNS TEXT AS $$
BEGIN
CASE
WHEN w = 'sunny' THEN
RETURN 'Pack some SPF!';
WHEN w = 'rainy' THEN
RETURN 'Bring an umbrella!';
WHEN w = 'cloudy' THEN
RETURN 'Wear a jacket!';
WHEN w = 'snowy' THEN
RETURN 'Bundle up!';
ELSE
RETURN 'Unknown weather';
END CASE;
END;
$$ LANGUAGE plpgsql;

CREATE OR REPLACE FUNCTION update_last_modified() RETURNS TRIGGER AS $$
BEGIN
NEW.last_modified = NOW();
Expand Down

0 comments on commit 4cae252

Please sign in to comment.