/* * Copyright (c) 2020 Alex aka mailoman * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * @author Alex aka mailoman * @copyright Copyright (c) 2020 Alex aka mailoman * @since 08.01.2020 * */ package gormgoose import ( "errors" "fmt" "log" "os" "path/filepath" "sort" "strconv" "strings" "text/template" "time" "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/mysql" _ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/sqlite" ) var ( ErrTableDoesNotExist = errors.New("table does not exist") ErrNoPreviousVersion = errors.New("no previous version found") ) type MigrationRecord struct { ID uint `gorm:"primary_key"` VersionId int64 TStamp time.Time `gorm:"default: now()"` IsApplied bool // was this a result of up() or down() } type Migration struct { Version int64 Next int64 // next version, or -1 if none Previous int64 // previous version, -1 if none Source string // path to .go or .sql script } type migrationSorter []*Migration // helpers so we can use pkg sort func (ms migrationSorter) Len() int { return len(ms) } func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] } func (ms migrationSorter) Less(i, j int) bool { return ms[i].Version < ms[j].Version } func newMigration(v int64, src string) *Migration { return &Migration{v, -1, -1, src} } func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) { db, err := OpenDBFromDBConf(conf) if err != nil { return err } defer db.Close() return RunMigrationsOnDb(conf, migrationsDir, target, db) } // Runs migration on a specific database instance. func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *gorm.DB) (err error) { current, err := EnsureDBVersion(conf, db) if err != nil { return err } migrations, err := CollectMigrations(migrationsDir, current, target) if err != nil { return err } if len(migrations) == 0 { fmt.Printf("goose: no migrations to run. current version: %d\n", current) return nil } ms := migrationSorter(migrations) direction := current < target ms.Sort(direction) fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n", conf.Env, current, target) for _, m := range ms { switch filepath.Ext(m.Source) { case ".go": err = runGoMigration(conf, m.Source, m.Version, direction) case ".sql": err = runSQLMigration(conf, db, m.Source, m.Version, direction) } if err != nil { return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err)) } fmt.Println("OK ", filepath.Base(m.Source)) } return nil } // collect all the valid looking migration scripts in the // migrations folder, and key them by version func CollectMigrations(dirpath string, current, target int64) (m []*Migration, err error) { // extract the numeric component of each migration, // filter out any uninteresting files, // and ensure we only have one file per migration version. filepath.Walk(dirpath, func(name string, info os.FileInfo, err error) error { if v, e := NumericComponent(name); e == nil { for _, g := range m { if v == g.Version { log.Fatalf("more than one file specifies the migration for version %d (%s and %s)", v, g.Source, filepath.Join(dirpath, name)) } } if versionFilter(v, current, target) { m = append(m, newMigration(v, name)) } } return nil }) return m, nil } func versionFilter(v, current, target int64) bool { if target > current { return v > current && v <= target } if target < current { return v <= current && v > target } return false } func (ms migrationSorter) Sort(direction bool) { // sort ascending or descending by version if direction { sort.Sort(ms) } else { sort.Sort(sort.Reverse(ms)) } // now that we're sorted in the appropriate direction, // populate next and previous for each migration for i, m := range ms { prev := int64(-1) if i > 0 { prev = ms[i-1].Version ms[i-1].Next = m.Version } ms[i].Previous = prev } } // look for migration scripts with names in the form: // XXX_descriptivename.ext // where XXX specifies the version number // and ext specifies the type of migration func NumericComponent(name string) (int64, error) { base := filepath.Base(name) if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { return 0, errors.New("not a recognized migration file type") } idx := strings.Index(base, "_") if idx < 0 { return 0, errors.New("no separator found") } n, e := strconv.ParseInt(base[:idx], 10, 64) if e == nil && n <= 0 { return 0, errors.New("migration IDs must be greater than zero") } return n, e } // EnsureDBVersion retrieve the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(conf *DBConf, db *gorm.DB) (int64, error) { rows := []MigrationRecord{} err := db.Order("id desc").Find(&rows).Error if err != nil { return 0, createVersionTable(conf, db) } // The most recent record for each migration specifies // whether it has been applied or rolled back. // The first version we find that has been applied is the current version. toSkip := make([]int64, 0) for _, row := range rows { // have we already marked this version to be skipped? skip := false for _, v := range toSkip { if v == row.VersionId { skip = true break } } if skip { continue } // if version has been applied we're done if row.IsApplied { return row.VersionId, nil } // latest version of migration has not been applied. toSkip = append(toSkip, row.VersionId) } panic("failure in EnsureDBVersion()") } // Create the goose_db_version table // and insert the initial 0 value into it func createVersionTable(conf *DBConf, db *gorm.DB) error { txn := db.Begin() if txn.Error != nil { return txn.Error } if err := txn.CreateTable(&MigrationRecord{}).Error; err != nil { txn.Rollback() return err } record := MigrationRecord{VersionId: 0, IsApplied: true} if err := txn.Create(&record).Error; err != nil { txn.Rollback() return err } return txn.Commit().Error } // wrapper for EnsureDBVersion for callers that don't already have // their own DB instance func GetDBVersion(conf *DBConf) (version int64, err error) { db, err := OpenDBFromDBConf(conf) if err != nil { return -1, err } defer db.Close() version, err = EnsureDBVersion(conf, db) if err != nil { return -1, err } return version, nil } func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) { previous = -1 sawGivenVersion := false filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { if !info.IsDir() { if v, e := NumericComponent(name); e == nil { if v > previous && v < version { previous = v } if v == version { sawGivenVersion = true } } } return nil }) if previous == -1 { if sawGivenVersion { // the given version is (likely) valid but we didn't find // anything before it. // 'previous' must reflect that no migrations have been applied. previous = 0 } else { err = ErrNoPreviousVersion } } return } // helper to identify the most recent possible version // within a folder of migration scripts func GetMostRecentDBVersion(dirpath string) (version int64, err error) { version = -1 filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error { if walkerr != nil { return walkerr } if !info.IsDir() { if v, e := NumericComponent(name); e == nil { if v > version { version = v } } } return nil }) if version == -1 { err = errors.New("no valid version found") } return } func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) { if migrationType != "go" && migrationType != "sql" { return "", errors.New("migration type must be 'go' or 'sql'") } timestamp := t.Format("20060102150405") filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType) fpath := filepath.Join(dir, filename) var tmpl *template.Template if migrationType == "sql" { tmpl = sqlMigrationTemplate } else { tmpl = goMigrationTemplate } path, err = writeTemplateToFile(fpath, tmpl, timestamp) return } // FinalizeMigration update the version table for the given migration, // and finalize the transaction. func FinalizeMigration(conf *DBConf, txn *gorm.DB, direction bool, v int64) error { // XXX: drop goose_db_version table on some minimum version number? record := MigrationRecord{VersionId: v, IsApplied: direction} if err := txn.Create(&record).Error; err != nil { txn.Rollback() return err } return txn.Commit().Error } var goMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(` package main import ( "github.com/jinzhu/gorm" ) // Up is executed when this migration is applied func Up_{{ . }}(txn *gorm.DB) { } // Down is executed when this migration is rolled back func Down_{{ . }}(txn *gorm.DB) { } `)) var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(` -- +goose Up -- SQL in section 'Up' is executed when this migration is applied -- +goose Down -- SQL section 'Down' is executed when this migration is rolled back `))