forked from pressly/goose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
migration.go
398 lines (360 loc) · 11.4 KB
/
migration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
package goose
import (
"context"
"database/sql"
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/pressly/goose/v3/internal/sqlparser"
)
// NewGoMigration creates a new Go migration.
//
// Both up and down functions may be nil, in which case the migration will be recorded in the
// versions table but no functions will be run. This is useful for recording (up) or deleting (down)
// a version without running any functions. See [GoFunc] for more details.
func NewGoMigration(version int64, up, down *GoFunc) *Migration {
m := &Migration{
Type: TypeGo,
Registered: true,
Version: version,
Next: -1, Previous: -1,
goUp: &GoFunc{Mode: TransactionEnabled},
goDown: &GoFunc{Mode: TransactionEnabled},
construct: true,
}
updateMode := func(f *GoFunc) *GoFunc {
// infer mode from function
if f.Mode == 0 {
if f.RunTx != nil && f.RunDB == nil {
f.Mode = TransactionEnabled
}
if f.RunTx == nil && f.RunDB != nil {
f.Mode = TransactionDisabled
}
// Always default to TransactionEnabled if both functions are nil. This is the most
// common use case.
if f.RunDB == nil && f.RunTx == nil {
f.Mode = TransactionEnabled
}
}
return f
}
// To maintain backwards compatibility, we set ALL legacy functions. In a future major version,
// we will remove these fields in favor of [GoFunc].
//
// Note, this function does not do any validation. Validation is lazily done when the migration
// is registered.
if up != nil {
m.goUp = updateMode(up)
if up.RunDB != nil {
m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error
m.UpFnNoTx = withoutContext(up.RunDB) // func(*sql.DB) error
}
if up.RunTx != nil {
m.UseTx = true
m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error
m.UpFn = withoutContext(up.RunTx) // func(*sql.Tx) error
}
}
if down != nil {
m.goDown = updateMode(down)
if down.RunDB != nil {
m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error
m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error
}
if down.RunTx != nil {
m.UseTx = true
m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error
m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error
}
}
return m
}
// Migration struct represents either a SQL or Go migration.
//
// Avoid constructing migrations manually, use [NewGoMigration] function.
type Migration struct {
Type MigrationType
Version int64
// Source is the path to the .sql script or .go file. It may be empty for Go migrations that
// have been registered globally and don't have a source file.
Source string
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
// These fields will be removed in a future major version. They are here for backwards
// compatibility and are an implementation detail.
Registered bool
UseTx bool
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
// We still save the non-context versions in the struct in case someone is using them. Goose
// does not use these internally anymore in favor of the context-aware versions. These fields
// will be removed in a future major version.
UpFn GoMigration // Deprecated: use UpFnContext instead.
DownFn GoMigration // Deprecated: use DownFnContext instead.
UpFnNoTx GoMigrationNoTx // Deprecated: use UpFnNoTxContext instead.
DownFnNoTx GoMigrationNoTx // Deprecated: use DownFnNoTxContext instead.
noVersioning bool
// These fields are used internally by goose and users are not expected to set them. Instead,
// use [NewGoMigration] to create a new go migration.
construct bool
goUp, goDown *GoFunc
sql sqlMigration
}
type sqlMigration struct {
// The Parsed field is used to track whether the SQL migration has been parsed. It serves as an
// optimization to avoid parsing migrations that may never be needed. Typically, migrations are
// incremental, and users often run only the most recent ones, making parsing of prior
// migrations unnecessary in most cases.
Parsed bool
// Parsed must be set to true before the following fields are used.
UseTx bool
Up []string
Down []string
}
// GoFunc represents a Go migration function.
type GoFunc struct {
// Exactly one of these must be set, or both must be nil.
RunTx func(ctx context.Context, tx *sql.Tx) error
// -- OR --
RunDB func(ctx context.Context, db *sql.DB) error
// Mode is the transaction mode for the migration. When one of the run functions is set, the
// mode will be inferred from the function and the field is ignored. Users do not need to set
// this field when supplying a run function.
//
// If both run functions are nil, the mode defaults to TransactionEnabled. The use case for nil
// functions is to record a version in the version table without invoking a Go migration
// function.
//
// The only time this field is required is if BOTH run functions are nil AND you want to
// override the default transaction mode.
Mode TransactionMode
}
// TransactionMode represents the possible transaction modes for a migration.
type TransactionMode int
const (
TransactionEnabled TransactionMode = iota + 1
TransactionDisabled
)
func (m TransactionMode) String() string {
switch m {
case TransactionEnabled:
return "transaction_enabled"
case TransactionDisabled:
return "transaction_disabled"
default:
return fmt.Sprintf("unknown transaction mode (%d)", m)
}
}
// MigrationRecord struct.
//
// Deprecated: unused and will be removed in a future major version.
type MigrationRecord struct {
VersionID int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
func (m *Migration) String() string {
return fmt.Sprint(m.Source)
}
// Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error {
ctx := context.Background()
return m.UpContext(ctx, db)
}
// UpContext runs an up migration.
func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
return nil
}
// Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error {
ctx := context.Background()
return m.DownContext(ctx, db)
}
// DownContext runs a down migration.
func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
return nil
}
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
f, err := baseFS.Open(m.Source)
if err != nil {
return fmt.Errorf("ERROR %v: failed to open SQL migration file: %w", filepath.Base(m.Source), err)
}
defer f.Close()
statements, useTx, err := sqlparser.ParseSQLMigration(f, sqlparser.FromBool(direction), verbose)
if err != nil {
return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
}
start := time.Now()
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
}
finish := truncateDuration(time.Since(start))
if len(statements) > 0 {
log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish)
} else {
log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish)
}
case ".go":
if !m.Registered {
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
start := time.Now()
var empty bool
if m.UseTx {
// Run go-based migration inside a tx.
fn := m.DownFnContext
if direction {
fn = m.UpFnContext
}
empty = (fn == nil)
if err := runGoMigration(
ctx,
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration: %q: %w", filepath.Base(m.Source), err)
}
} else {
// Run go-based migration outside a tx.
fn := m.DownFnNoTxContext
if direction {
fn = m.UpFnNoTxContext
}
empty = (fn == nil)
if err := runGoMigrationNoTx(
ctx,
db,
fn,
m.Version,
direction,
!m.noVersioning,
); err != nil {
return fmt.Errorf("ERROR go migration no tx: %q: %w", filepath.Base(m.Source), err)
}
}
finish := truncateDuration(time.Since(start))
if !empty {
log.Printf("OK %s (%s)\n", filepath.Base(m.Source), finish)
} else {
log.Printf("EMPTY %s (%s)\n", filepath.Base(m.Source), finish)
}
}
return nil
}
func runGoMigrationNoTx(
ctx context.Context,
db *sql.DB,
fn GoMigrationNoTxContext,
version int64,
direction bool,
recordVersion bool,
) error {
if fn != nil {
// Run go migration function.
if err := fn(ctx, db); err != nil {
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
return insertOrDeleteVersionNoTx(ctx, db, version, direction)
}
return nil
}
func runGoMigration(
ctx context.Context,
db *sql.DB,
fn GoMigrationContext,
version int64,
direction bool,
recordVersion bool,
) error {
if fn == nil && !recordVersion {
return nil
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
if fn != nil {
// Run go migration function.
if err := fn(ctx, tx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to run go migration: %w", err)
}
}
if recordVersion {
if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to update version: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
return nil
}
func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error {
if direction {
return store.InsertVersion(ctx, tx, TableName(), version)
}
return store.DeleteVersion(ctx, tx, TableName(), version)
}
func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error {
if direction {
return store.InsertVersionNoTx(ctx, db, TableName(), version)
}
return store.DeleteVersionNoTx(ctx, db, TableName(), version)
}
// NumericComponent parses the version from the migration file name.
//
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
// migration, either .sql or .go.
func NumericComponent(filename string) (int64, error) {
base := filepath.Base(filename)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("migration file does not have .sql or .go file extension")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse version from migration file: %s: %w", base, err)
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
}
return n, nil
}
func truncateDuration(d time.Duration) time.Duration {
for _, v := range []time.Duration{
time.Second,
time.Millisecond,
time.Microsecond,
} {
if d > v {
return d.Round(v / time.Duration(100))
}
}
return d
}
// ref returns a string that identifies the migration. This is used for logging and error messages.
func (m *Migration) ref() string {
return fmt.Sprintf("(type:%s,version:%d)", m.Type, m.Version)
}