Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
efritz committed Sep 13, 2024
1 parent 6ebada5 commit c8a00a6
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 52 deletions.
12 changes: 6 additions & 6 deletions .envrc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash

export TEST_PGHOST=localhost
export TEST_PGPORT=5432
export TEST_PGUSER=efritz
export TEST_PGPASSWORD=
export TEST_PGDATABASE=efritz
export TEST_TEMPLATEDB=template0
export PGHOST=localhost
export PGPORT=5432
export PGUSER=postgres
export PGPASSWORD=
export PGDATABASE=postgres
export TEMPLATEDB=template0
37 changes: 35 additions & 2 deletions cmd/migrate/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"fmt"
"net/url"
"os"

"github.com/go-nacelle/nacelle/v2"
Expand All @@ -21,20 +23,36 @@ var rootCmd = &cobra.Command{

var (
migrationDirectory string
logger = nacelle.NewNilLogger() // TODO
databaseURL = "postgres://efritz@localhost:5432/efritz?sslmode=disable" // TODO
databaseURL string
defaultDatabaseURL = pgutil.BuildDatabaseURL()
logger = nacelle.NewNilLogger() // TODO
)

func init() {
masked, err := maskDatabasePassword(defaultDatabaseURL)
if err != nil {
panic(err)
}

rootCmd.PersistentFlags().StringVarP(
&migrationDirectory,
"dir", "d",
"migrations",
"The directory where schema migrations are defined",
)

rootCmd.PersistentFlags().StringVarP(
&databaseURL,
"url", "u",
"",
fmt.Sprintf("The database connection URL (default %s)", masked),
)
}

func dial() (pgutil.DB, error) {
if databaseURL == "" {
databaseURL = defaultDatabaseURL
}
return pgutil.Dial(databaseURL, logger)
}

Expand All @@ -52,3 +70,18 @@ func runner() (*pgutil.Runner, error) {

return runner, nil
}

func maskDatabasePassword(databaseURL string) (string, error) {
parsedURL, err := url.Parse(databaseURL)
if err != nil {
return "", fmt.Errorf("failed to parse database URL: %w", err)
}

if parsedURL.User != nil {
if _, ok := parsedURL.User.Password(); ok {
parsedURL.User = url.UserPassword(parsedURL.User.Username(), "xxxxx")
}
}

return parsedURL.String(), nil
}
77 changes: 33 additions & 44 deletions testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"os"
"testing"

Expand All @@ -12,68 +13,56 @@ import (
"github.com/stretchr/testify/require"
)

var (
testHost = os.Getenv("TEST_PGHOST")
testPort = os.Getenv("TEST_PGPORT")
testUser = os.Getenv("TEST_PGUSER")
testPassword = os.Getenv("TEST_PGPASSWORD")
testDatabase = os.Getenv("TEST_PGDATABASE")
testTemplateDatabase = os.Getenv("TEST_TEMPLATEDB")
)

func NewTestDB(t testing.TB) DB {
return NewTestDBWithLogger(t, log.NewNilLogger())
}

func NewTestDBWithLogger(t testing.TB, logger log.Logger) DB {
t.Helper()

rawDB, err := sql.Open("postgres", fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=disable",
testUser,
testPassword,
testHost,
testPort,
testDatabase,
))
id, err := randomHexString(16)
require.NoError(t, err)
rawLoggingDB := newLoggingDB(rawDB, log.NewNilLogger())

id, err := randomHexString(16)
var (
testDatabaseName = fmt.Sprintf("pgutil-test-%s", id)
quotedTestDatabaseName = pq.QuoteIdentifier(testDatabaseName)
quotedTemplateDatabaseName = pq.QuoteIdentifier(os.Getenv("TEMPLATEDB"))

// NOTE: Must interpolate identifiers here as placeholders aren't valid in this position.
createDatabaseQuery = Query(fmt.Sprintf("CREATE DATABASE %s TEMPLATE %s", quotedTestDatabaseName, quotedTemplateDatabaseName), Args{})
dropDatabaseQuery = Query(fmt.Sprintf("DROP DATABASE %s", quotedTestDatabaseName), Args{})
terminateConnectionsQuery = Query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {:name}", Args{"name": testDatabaseName})
)

// Resolve "control" database URL
baseURL := BuildDatabaseURL()
parsedURL, err := url.Parse(baseURL)
require.NoError(t, err)
testDatabaseName := fmt.Sprintf("pgutil-test-%s", id)

require.NoError(t, rawLoggingDB.Exec(context.Background(), Query(
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
fmt.Sprintf("CREATE DATABASE %s TEMPLATE %s", pq.QuoteIdentifier(testDatabaseName), pq.QuoteIdentifier(testTemplateDatabase)),
Args{},
)))
// Resolve "test" database URL
testDBURL := parsedURL.ResolveReference(&url.URL{
Path: "/" + testDatabaseName,
RawQuery: parsedURL.RawQuery,
})

testDB, err := sql.Open("postgres", fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=disable",
testUser,
testPassword,
testHost,
testPort,
testDatabaseName,
))
// Open "control" database
rawDB, err := sql.Open("postgres", baseURL)
require.NoError(t, err)
rawLoggingDB := newLoggingDB(rawDB, log.NewNilLogger())

// Create "test" database
require.NoError(t, rawLoggingDB.Exec(context.Background(), createDatabaseQuery))

// Open "test" database
testDB, err := sql.Open("postgres", testDBURL.String())
require.NoError(t, err)

t.Cleanup(func() {
defer rawDB.Close()

require.NoError(t, testDB.Close())

require.NoError(t, rawLoggingDB.Exec(context.Background(), Query(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {:name}",
Args{"name": testDatabaseName},
)))

require.NoError(t, rawLoggingDB.Exec(context.Background(), Query(
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
fmt.Sprintf("DROP DATABASE %s", pq.QuoteIdentifier(testDatabaseName)),
Args{},
)))
require.NoError(t, rawLoggingDB.Exec(context.Background(), terminateConnectionsQuery))
require.NoError(t, rawLoggingDB.Exec(context.Background(), dropDatabaseQuery))
})

return newLoggingDB(testDB, logger)
Expand Down
35 changes: 35 additions & 0 deletions url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package pgutil

import (
"fmt"
"net/url"
"os"
)

func BuildDatabaseURL() string {
var (
host = getEnvOrDefault("PGHOST", "localhost")
port = getEnvOrDefault("PGPORT", "5432")
user = getEnvOrDefault("PGUSER", "")
password = getEnvOrDefault("PGPASSWORD", "")
database = getEnvOrDefault("PGDATABASE", "")
sslmode = getEnvOrDefault("PGSSLMODE", "disable")
)

u := &url.URL{
Scheme: "postgres",
Host: fmt.Sprintf("%s:%s", host, port),
User: url.UserPassword(user, password),
Path: database,
RawQuery: url.Values{"sslmode": []string{sslmode}}.Encode(),
}
return (u).String()
}

func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}

return defaultValue
}

0 comments on commit c8a00a6

Please sign in to comment.