mirror of
https://github.com/fnproject/fn.git
synced 2022-10-28 21:29:17 +03:00
* initial Db helper split - make SQL and datastore packages optional * abstracting log store * break out DB, MQ and log drivers as extensions * cleanup * fewer deps * fixing docker test * hmm dbness * updating db startup * Consolidate all your extensions into one convenient package * cleanup * clean up dep constraints
344 lines
9.2 KiB
Go
344 lines
9.2 KiB
Go
package migratex
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/fnproject/fn/api/datastore/sql/dbhelper"
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var (
|
|
// use same migration table name as mattes/migrate, so that we don't have to
|
|
// migrate that.
|
|
// TODO doesn't have to be a glob
|
|
MigrationsTable = "schema_migrations"
|
|
|
|
ErrLocked = errors.New("database is locked")
|
|
)
|
|
|
|
func migrateErr(version int64, up bool, err error) ErrMigration {
|
|
dir := "up"
|
|
if !up {
|
|
dir = "down"
|
|
}
|
|
return ErrMigration(fmt.Sprintf("error running migration. version: %v direction: %v err: %v", version, dir, err))
|
|
}
|
|
|
|
// ErrMigration represents an error running a specific migration in a specific direction
|
|
type ErrMigration string
|
|
|
|
func (e ErrMigration) Error() string { return string(e) }
|
|
|
|
func dirtyErr(version int64) ErrDirty {
|
|
return ErrDirty(fmt.Sprintf("database is dirty. version: %v", version))
|
|
}
|
|
|
|
// ErrDirty is an error that is returned when a db is dirty.
|
|
type ErrDirty string
|
|
|
|
func (e ErrDirty) Error() string { return string(e) }
|
|
|
|
const (
|
|
NilVersion = -1
|
|
)
|
|
|
|
type Migration interface {
|
|
Up(context.Context, *sqlx.Tx) error
|
|
Down(context.Context, *sqlx.Tx) error
|
|
Version() int64
|
|
}
|
|
|
|
type sorted []Migration
|
|
|
|
func (s sorted) Len() int { return len(s) }
|
|
func (s sorted) Less(i, j int) bool { return s[i].Version() < s[j].Version() }
|
|
func (s sorted) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
|
|
var _ Migration = new(MigFields)
|
|
|
|
// MigFields implements Migration and can be used for convenience.
|
|
type MigFields struct {
|
|
UpFunc func(context.Context, *sqlx.Tx) error
|
|
DownFunc func(context.Context, *sqlx.Tx) error
|
|
VersionFunc func() int64
|
|
}
|
|
|
|
func (m MigFields) Up(ctx context.Context, tx *sqlx.Tx) error { return m.UpFunc(ctx, tx) }
|
|
func (m MigFields) Down(ctx context.Context, tx *sqlx.Tx) error { return m.DownFunc(ctx, tx) }
|
|
func (m MigFields) Version() int64 { return m.VersionFunc() }
|
|
|
|
// TODO instance must have `multiStatements` set to true ?
|
|
|
|
func Up(ctx context.Context, tx *sqlx.Tx, migs []Migration) error {
|
|
return migrate(ctx, tx, migs, true)
|
|
}
|
|
|
|
func Down(ctx context.Context, tx *sqlx.Tx, migs []Migration) error {
|
|
return migrate(ctx, tx, migs, false)
|
|
}
|
|
|
|
func migrate(ctx context.Context, tx *sqlx.Tx, migs []Migration, up bool) error {
|
|
curVersion, dirty, err := Version(ctx, tx)
|
|
if dirty {
|
|
return dirtyErr(curVersion)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO we could grab the lock here and hold it over all the migrations,
|
|
// in methodology we want each migration to run in its own tx envelope
|
|
// so that we can make as much progress as possible if we hit an error.
|
|
// not sure it makes much difference either way where we lock.
|
|
|
|
if up {
|
|
sort.Sort(sorted(migs))
|
|
} else {
|
|
sort.Sort(sort.Reverse(sorted(migs)))
|
|
}
|
|
for _, m := range migs {
|
|
// skip over migrations we have run
|
|
mVersion := m.Version()
|
|
if (up && curVersion < mVersion) || (!up && curVersion >= mVersion) {
|
|
|
|
// do each individually, for large migrations it's better to checkpoint
|
|
// than to try to do them all in one big go.
|
|
// XXX(reed): we could more gracefully handle concurrent databases trying to
|
|
// run migrations here by handling error and feeding back the version.
|
|
// get something working mode for now...
|
|
err := run(ctx, tx, m, up)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func withLock(ctx context.Context, tx *sqlx.Tx, f func(*sqlx.Tx) error) error {
|
|
err := lock(ctx, tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = f(tx)
|
|
|
|
// NOTE: migrations happen on init and if they fail we should close our session with the db
|
|
// which will release the lock, thus, we don't need to futz with the context here to unlock.
|
|
errU := unlock(ctx, tx)
|
|
|
|
if errU != nil {
|
|
if err == nil {
|
|
err = errU
|
|
} else {
|
|
err = multiError(err, errU)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
var _ error = multiError()
|
|
|
|
// MultiError holds multiple errors. If you have to handle one of these... I am so sorry.
|
|
type MultiError struct {
|
|
Errs []error
|
|
}
|
|
|
|
func multiError(errs ...error) MultiError {
|
|
compactErrs := make([]error, 0)
|
|
for _, e := range errs {
|
|
if e != nil {
|
|
compactErrs = append(compactErrs, e)
|
|
}
|
|
}
|
|
return MultiError{compactErrs}
|
|
}
|
|
|
|
func (m MultiError) Error() string {
|
|
var strs = make([]string, 0)
|
|
for _, e := range m.Errs {
|
|
strs = append(strs, e.Error())
|
|
}
|
|
return strings.Join(strs, "\n")
|
|
}
|
|
|
|
func run(ctx context.Context, tx *sqlx.Tx, m Migration, up bool) error {
|
|
return withLock(ctx, tx, func(tx *sqlx.Tx) error {
|
|
// within the transaction, we need to check the version and ensure this
|
|
// migration has not already been applied.
|
|
curVersion, dirty, err := Version(ctx, tx)
|
|
if dirty {
|
|
return dirtyErr(curVersion)
|
|
}
|
|
|
|
// enforce monotonicity
|
|
if up && curVersion != NilVersion && m.Version() != curVersion+1 {
|
|
return fmt.Errorf("non-contiguous migration attempted up: %v != %v", m.Version(), curVersion+1)
|
|
} else if !up && m.Version() != curVersion { // down is always unraveling
|
|
return fmt.Errorf("non-contiguous migration attempted down: %v != %v", m.Version(), curVersion)
|
|
}
|
|
|
|
// TODO is this robust enough? we could check
|
|
version := m.Version()
|
|
if !up {
|
|
version = m.Version() - 1
|
|
}
|
|
|
|
// TODO we don't need the dirty bit anymore since we're using transactions?
|
|
err = SetVersion(ctx, tx, version, true)
|
|
|
|
if up {
|
|
err = m.Up(ctx, tx)
|
|
} else {
|
|
err = m.Down(ctx, tx)
|
|
}
|
|
|
|
if err != nil {
|
|
return migrateErr(version, up, err)
|
|
}
|
|
|
|
err = SetVersion(ctx, tx, version, false)
|
|
return err
|
|
})
|
|
}
|
|
|
|
const advisoryLockIdSalt uint = 1486364155
|
|
|
|
// inspired by rails migrations, see https://goo.gl/8o9bCT
|
|
// NOTE that this means if the db server has multiple databases that use this
|
|
// library then this can cause contention... it seems a far cry.
|
|
func generateAdvisoryLockId(name string) string {
|
|
sum := crc32.ChecksumIEEE([]byte(name))
|
|
sum = sum * uint32(advisoryLockIdSalt)
|
|
return fmt.Sprintf("%v", sum)
|
|
}
|
|
|
|
func lock(ctx context.Context, tx *sqlx.Tx) error {
|
|
aid := generateAdvisoryLockId(MigrationsTable)
|
|
|
|
// pg has special locking & sqlite3 needs no locking
|
|
var query string
|
|
switch tx.DriverName() {
|
|
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
|
|
query = `SELECT pg_try_advisory_lock(?)`
|
|
case "mysql", "oci8", "ora", "goracle":
|
|
query = "SELECT GET_LOCK(?, -1)"
|
|
case "sqlite3":
|
|
// sqlite3 doesn't have a lock. as long as migrate isn't called concurrently it'll be ok
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unsupported database, please add this or fix: %v", tx.DriverName())
|
|
}
|
|
|
|
query = tx.Rebind(query)
|
|
|
|
var success bool
|
|
if err := tx.QueryRowContext(ctx, query, aid).Scan(&success); err != nil {
|
|
return err
|
|
}
|
|
|
|
if success {
|
|
return nil
|
|
}
|
|
|
|
return ErrLocked
|
|
}
|
|
|
|
func unlock(ctx context.Context, tx *sqlx.Tx) error {
|
|
aid := generateAdvisoryLockId(MigrationsTable)
|
|
|
|
var query string
|
|
switch tx.DriverName() {
|
|
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
|
|
query = `SELECT pg_advisory_unlock(?)`
|
|
case "mysql", "oci8", "ora", "goracle":
|
|
query = `SELECT RELEASE_LOCK(?)`
|
|
case "sqlite3":
|
|
// sqlite3 doesn't have a lock. as long as migrate isn't called concurrently it'll be ok
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unsupported database, please add this or fix: %v", tx.DriverName())
|
|
}
|
|
|
|
query = tx.Rebind(query)
|
|
|
|
_, err := tx.ExecContext(ctx, query, aid)
|
|
return err
|
|
}
|
|
|
|
func SetVersion(ctx context.Context, tx *sqlx.Tx, version int64, dirty bool) error {
|
|
err := ensureVersionTable(ctx, tx)
|
|
if err != nil {
|
|
logrus.WithError(err).Error("error ensuring version table")
|
|
return err
|
|
}
|
|
|
|
// TODO need to handle down migration better
|
|
// ideally, we have a record of each up/down migration with a timestamp for auditing,
|
|
// this just nukes the whole table which is kinda lame.
|
|
query := tx.Rebind("DELETE FROM " + MigrationsTable)
|
|
if _, err := tx.Exec(query); err != nil {
|
|
logrus.WithError(err).Error("error deleting version table")
|
|
return err
|
|
}
|
|
|
|
if version >= 0 {
|
|
query = tx.Rebind(`INSERT INTO ` + MigrationsTable + ` (version, dirty) VALUES (?, ?)`)
|
|
if _, err := tx.ExecContext(ctx, query, version, dirty); err != nil {
|
|
logrus.WithError(err).Error("error updating version table")
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func Version(ctx context.Context, tx *sqlx.Tx) (version int64, dirty bool, err error) {
|
|
helper, ok := dbhelper.GetHelper(tx.DriverName())
|
|
if !ok {
|
|
return 0, false, fmt.Errorf("no db helper registered for for %s", tx.DriverName())
|
|
}
|
|
|
|
tableExists, err := helper.CheckTableExists(tx, MigrationsTable)
|
|
|
|
if err != nil {
|
|
return 0, false, err
|
|
}
|
|
|
|
if !tableExists {
|
|
return NilVersion, false, nil
|
|
}
|
|
|
|
query := tx.Rebind(`SELECT version, dirty FROM ` + MigrationsTable + ` LIMIT 1`)
|
|
|
|
err = tx.QueryRowContext(ctx, query).Scan(&version, &dirty)
|
|
switch {
|
|
case err == sql.ErrNoRows:
|
|
return NilVersion, false, nil
|
|
|
|
case err != nil:
|
|
return 0, false, err
|
|
|
|
default:
|
|
return version, dirty, nil
|
|
}
|
|
}
|
|
|
|
func ensureVersionTable(ctx context.Context, tx *sqlx.Tx) error {
|
|
// TODO it would sure be nice to have timestamps for auditing
|
|
// TODO sqlite3 uses uint64 type? ugha, test.
|
|
query := tx.Rebind(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (
|
|
version bigint NOT NULL PRIMARY KEY,
|
|
dirty boolean NOT NULL
|
|
)`, MigrationsTable))
|
|
_, err := tx.ExecContext(ctx, query)
|
|
return err
|
|
}
|