diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 48b4a693b..87abd2e66 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -12,6 +12,7 @@ import ( "github.com/gocql/gocql" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" ) func init() { @@ -240,13 +241,29 @@ func (c *Cassandra) Drop() error { return err } } - // Re-create the version table - return c.ensureVersionTable() + + return nil } -// Ensure version table exists -func (c *Cassandra) ensureVersionTable() error { - err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Cassandra type. +func (c *Cassandra) ensureVersionTable() (err error) { + if err = c.Lock(); err != nil { + return err + } + + defer func() { + if e := c.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + + err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() if err != nil { return err } diff --git a/database/cassandra/cassandra_test.go b/database/cassandra/cassandra_test.go index 3e60c489c..3389f109e 100644 --- a/database/cassandra/cassandra_test.go +++ b/database/cassandra/cassandra_test.go @@ -3,6 +3,7 @@ package cassandra import ( "context" "fmt" + "github.com/golang-migrate/migrate/v4" "strconv" "testing" ) @@ -15,6 +16,7 @@ import ( import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( @@ -72,3 +74,25 @@ func Test(t *testing.T) { dt.Test(t, d, []byte("SELECT table_name from system_schema.tables")) }) } + +func TestMigrate(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(9042) + if err != nil { + t.Fatal("Unable to get mapped port:", err) + } + addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port) + p := &Cassandra{} + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT table_name from system_schema.tables")) + }) +} diff --git a/database/cassandra/examples/migrations/1_simple_select.down.sql b/database/cassandra/examples/migrations/1_simple_select.down.sql new file mode 100644 index 000000000..29787f084 --- /dev/null +++ b/database/cassandra/examples/migrations/1_simple_select.down.sql @@ -0,0 +1 @@ +SELECT table_name from system_schema.tables \ No newline at end of file diff --git a/database/cassandra/examples/migrations/1_simple_select.up.sql b/database/cassandra/examples/migrations/1_simple_select.up.sql new file mode 100644 index 000000000..29787f084 --- /dev/null +++ b/database/cassandra/examples/migrations/1_simple_select.up.sql @@ -0,0 +1 @@ +SELECT table_name from system_schema.tables \ No newline at end of file diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index 6f98bd181..ebf5b17d6 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -11,6 +11,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" ) var DefaultMigrationsTable = "schema_migrations" @@ -159,7 +160,25 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error { return tx.Commit() } -func (ch *ClickHouse) ensureVersionTable() error { + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the ClickHouse type. +func (ch *ClickHouse) ensureVersionTable() (err error) { + if err = ch.Lock(); err != nil { + return err + } + + defer func() { + if e := ch.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + var ( table string query = "SHOW TABLES FROM " + ch.config.DatabaseName + " LIKE '" + ch.config.MigrationsTable + "'" @@ -207,7 +226,7 @@ func (ch *ClickHouse) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - return ch.ensureVersionTable() + return nil } func (ch *ClickHouse) Lock() error { return nil } diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index df32db0d8..c18edcd7a 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -13,6 +13,7 @@ import ( import ( "github.com/cockroachdb/cockroach-go/crdb" + "github.com/hashicorp/go-multierror" "github.com/lib/pq" ) @@ -85,11 +86,12 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config: config, } - if err := px.ensureVersionTable(); err != nil { + // ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable. + if err := px.ensureLockTable(); err != nil { return nil, err } - if err := px.ensureLockTable(); err != nil { + if err := px.ensureVersionTable(); err != nil { return nil, err } @@ -294,15 +296,29 @@ func (c *CockroachDb) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := c.ensureVersionTable(); err != nil { - return err - } } return nil } -func (c *CockroachDb) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the CockroachDb type. +func (c *CockroachDb) ensureVersionTable() (err error) { + if err = c.Lock(); err != nil { + return err + } + + defer func() { + if e := c.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/cockroachdb/cockroachdb_test.go b/database/cockroachdb/cockroachdb_test.go index 157a6eb4a..4f5570e0b 100644 --- a/database/cockroachdb/cockroachdb_test.go +++ b/database/cockroachdb/cockroachdb_test.go @@ -6,6 +6,7 @@ import ( "context" "database/sql" "fmt" + "github.com/golang-migrate/migrate/v4" "strings" "testing" ) @@ -18,6 +19,7 @@ import ( import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) const defaultPort = 26257 @@ -92,6 +94,30 @@ func Test(t *testing.T) { }) } +func TestMigrate(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { + createDB(t, ci) + + ip, port, err := ci.Port(26257) + if err != nil { + t.Fatal(err) + } + + addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port) + c := &CockroachDb{} + d, err := c.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "migrate", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT 1")) + }) +} + func TestMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) { createDB(t, ci) diff --git a/database/driver.go b/database/driver.go index fa914c5d0..901e5dd66 100644 --- a/database/driver.go +++ b/database/driver.go @@ -74,6 +74,8 @@ type Driver interface { Version() (version int, dirty bool, err error) // Drop deletes everything in the database. + // Note that this is a breaking action, a new call to Open() is necessary to + // ensure subsequent calls work as expected. Drop() error } diff --git a/database/mongodb/examples/001_create_user.down.json b/database/mongodb/examples/migrations/001_create_user.down.json similarity index 100% rename from database/mongodb/examples/001_create_user.down.json rename to database/mongodb/examples/migrations/001_create_user.down.json diff --git a/database/mongodb/examples/001_create_user.up.json b/database/mongodb/examples/migrations/001_create_user.up.json similarity index 100% rename from database/mongodb/examples/001_create_user.up.json rename to database/mongodb/examples/migrations/001_create_user.up.json diff --git a/database/mongodb/examples/002_create_indexes.down.json b/database/mongodb/examples/migrations/002_create_indexes.down.json similarity index 100% rename from database/mongodb/examples/002_create_indexes.down.json rename to database/mongodb/examples/migrations/002_create_indexes.down.json diff --git a/database/mongodb/examples/002_create_indexes.up.json b/database/mongodb/examples/migrations/002_create_indexes.up.json similarity index 100% rename from database/mongodb/examples/002_create_indexes.up.json rename to database/mongodb/examples/migrations/002_create_indexes.up.json diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 58e152e58..d6986ee23 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -59,6 +59,7 @@ func WithInstance(instance *mongo.Client, config *Config) (database.Driver, erro db: instance.Database(config.DatabaseName), config: config, } + return mc, nil } @@ -77,9 +78,6 @@ func (m *Mongo) Open(dsn string) (database.Driver, error) { return nil, err } migrationsCollection := purl.Query().Get("x-migrations-collection") - if len(migrationsCollection) == 0 { - migrationsCollection = DefaultMigrationsCollection - } transactionMode, _ := strconv.ParseBool(purl.Query().Get("x-transaction-mode")) diff --git a/database/mongodb/mongodb_test.go b/database/mongodb/mongodb_test.go index 6acbde3b7..260fd1c18 100644 --- a/database/mongodb/mongodb_test.go +++ b/database/mongodb/mongodb_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/golang-migrate/migrate/v4" "io" "os" "strconv" @@ -20,6 +21,7 @@ import ( import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( @@ -83,6 +85,28 @@ func Test(t *testing.T) { }) } +func TestMigrate(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := mongoConnectionString(ip, port) + p := &Mongo{} + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)) + }) +} + func TestWithAuth(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/mysql/examples/migrations/1_init.down.sql b/database/mysql/examples/migrations/1_init.down.sql new file mode 100644 index 000000000..1b10e6fc0 --- /dev/null +++ b/database/mysql/examples/migrations/1_init.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS test; \ No newline at end of file diff --git a/database/mysql/examples/migrations/1_init.up.sql b/database/mysql/examples/migrations/1_init.up.sql new file mode 100644 index 000000000..2c3d7a1f2 --- /dev/null +++ b/database/mysql/examples/migrations/1_init.up.sql @@ -0,0 +1,3 @@ +CREATE TABLE IF NOT EXISTS test ( + firstname VARCHAR(16) +); \ No newline at end of file diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 20c840e02..ccd6f3703 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -17,6 +17,7 @@ import ( import ( "github.com/go-sql-driver/mysql" + "github.com/hashicorp/go-multierror" ) import ( @@ -127,9 +128,6 @@ func (m *Mysql) Open(url string) (database.Driver, error) { purl.RawQuery = q.Encode() migrationsTable := purl.Query().Get("x-migrations-table") - if len(migrationsTable) == 0 { - migrationsTable = DefaultMigrationsTable - } // use custom TLS? ctls := purl.Query().Get("tls") @@ -342,15 +340,29 @@ func (m *Mysql) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } } return nil } -func (m *Mysql) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Mysql type. +func (m *Mysql) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var result string query := `SHOW TABLES LIKE "` + m.config.MigrationsTable + `"` diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index aa68a3754..a71383484 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -5,6 +5,7 @@ import ( "database/sql" sqldriver "database/sql/driver" "fmt" + "github.com/golang-migrate/migrate/v4" "net/url" "testing" ) @@ -17,6 +18,7 @@ import ( import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) const defaultPort = 3306 @@ -88,6 +90,40 @@ func Test(t *testing.T) { }) } +func TestMigrate(t *testing.T) { + // mysql.SetLogger(mysql.Logger(log.New(ioutil.Discard, "", log.Ltime))) + + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port) + p := &Mysql{} + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT 1")) + + // check ensureVersionTable + if err := d.(*Mysql).ensureVersionTable(); err != nil { + t.Fatal(err) + } + // check again + if err := d.(*Mysql).ensureVersionTable(); err != nil { + t.Fatal(err) + } + }) +} + func TestLockWorks(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.Port(defaultPort) diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index b840df84b..a9570bc72 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -117,14 +117,12 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } migrationsTable := purl.Query().Get("x-migrations-table") - if len(migrationsTable) == 0 { - migrationsTable = DefaultMigrationsTable - } px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, }) + if err != nil { return nil, err } @@ -325,14 +323,14 @@ func (p *Postgres) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := p.ensureVersionTable(); err != nil { - return err - } } return nil } +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Postgres type. func (p *Postgres) ensureVersionTable() (err error) { if err = p.Lock(); err != nil { return err diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 6b9390a8b..ad2d91f6b 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -7,6 +7,7 @@ import ( "database/sql" sqldriver "database/sql/driver" "fmt" + "github.com/golang-migrate/migrate/v4" "io" "strconv" "strings" @@ -17,6 +18,7 @@ import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( @@ -77,6 +79,28 @@ func Test(t *testing.T) { }) } +func TestMigrate(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT 1")) + }) +} + func TestMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/ql/migration/33_create_table.down.sql b/database/ql/examples/migrations/33_create_table.down.sql similarity index 100% rename from database/ql/migration/33_create_table.down.sql rename to database/ql/examples/migrations/33_create_table.down.sql diff --git a/database/ql/migration/33_create_table.up.sql b/database/ql/examples/migrations/33_create_table.up.sql similarity index 100% rename from database/ql/migration/33_create_table.up.sql rename to database/ql/examples/migrations/33_create_table.up.sql diff --git a/database/ql/migration/44_alter_table.down.sql b/database/ql/examples/migrations/44_alter_table.down.sql similarity index 100% rename from database/ql/migration/44_alter_table.down.sql rename to database/ql/examples/migrations/44_alter_table.down.sql diff --git a/database/ql/migration/44_alter_table.up.sql b/database/ql/examples/migrations/44_alter_table.up.sql similarity index 100% rename from database/ql/migration/44_alter_table.up.sql rename to database/ql/examples/migrations/44_alter_table.up.sql diff --git a/database/ql/ql.go b/database/ql/ql.go index 86b2364dd..6dfd202a1 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -3,6 +3,7 @@ package ql import ( "database/sql" "fmt" + "github.com/hashicorp/go-multierror" "io" "io/ioutil" "strings" @@ -46,6 +47,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if err := instance.Ping(); err != nil { return nil, err } + if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } @@ -59,7 +61,24 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { } return mx, nil } -func (m *Ql) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Ql type. +func (m *Ql) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + tx, err := m.db.Begin() if err != nil { return err @@ -132,9 +151,6 @@ func (m *Ql) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } } return nil diff --git a/database/ql/ql_test.go b/database/ql/ql_test.go index 5a05e355e..49e629cce 100644 --- a/database/ql/ql_test.go +++ b/database/ql/ql_test.go @@ -22,7 +22,7 @@ func Test(t *testing.T) { defer func() { os.RemoveAll(dir) }() - fmt.Printf("DB path : %s\n", filepath.Join(dir, "ql.db")) + t.Logf("DB path : %s\n", filepath.Join(dir, "ql.db")) p := &Ql{} addr := fmt.Sprintf("ql://%s", filepath.Join(dir, "ql.db")) d, err := p.Open(addr) @@ -40,23 +40,38 @@ func Test(t *testing.T) { } }() dt.Test(t, d, []byte("CREATE TABLE t (Qty int, Name string);")) - driver, err := WithInstance(db, &Config{}) +} + +func TestMigrate(t *testing.T) { + dir, err := ioutil.TempDir("", "ql-driver-test") if err != nil { - t.Fatalf("%v", err) + return } - if err := d.Drop(); err != nil { - t.Fatal(err) + defer func() { + os.RemoveAll(dir) + }() + t.Logf("DB path : %s\n", filepath.Join(dir, "ql.db")) + + db, err := sql.Open("ql", filepath.Join(dir, "ql.db")) + if err != nil { + return } + defer func() { + if err := db.Close(); err != nil { + return + } + }() - m, err := migrate.NewWithDatabaseInstance( - "file://./migration", - "ql", driver) + driver, err := WithInstance(db, &Config{}) if err != nil { t.Fatalf("%v", err) } - fmt.Println("UP") - err = m.Up() + + m, err := migrate.NewWithDatabaseInstance( + "file://./examples/migrations", + "ql", driver) if err != nil { t.Fatalf("%v", err) } + dt.TestMigrate(t, m, []byte("CREATE TABLE t (Qty int, Name string);")) } diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 19f1b9f78..ee4dcf624 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -14,6 +14,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" "github.com/lib/pq" ) @@ -100,9 +101,6 @@ func (p *Redshift) Open(url string) (database.Driver, error) { } migrationsTable := purl.Query().Get("x-migrations-table") - if len(migrationsTable) == 0 { - migrationsTable = DefaultMigrationsTable - } px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, @@ -282,15 +280,29 @@ func (p *Redshift) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := p.ensureVersionTable(); err != nil { - return err - } } return nil } -func (p *Redshift) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Redshift type. +func (p *Redshift) ensureVersionTable() (err error) { + if err = p.Lock(); err != nil { + return err + } + + defer func() { + if e := p.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() + // check if migration table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/redshift/redshift_test.go b/database/redshift/redshift_test.go index 58c4ddd90..0fa58651d 100644 --- a/database/redshift/redshift_test.go +++ b/database/redshift/redshift_test.go @@ -8,6 +8,7 @@ import ( "database/sql" sqldriver "database/sql/driver" "fmt" + "github.com/golang-migrate/migrate/v4" "io" "strconv" "strings" @@ -21,6 +22,7 @@ import ( import ( dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" + _ "github.com/golang-migrate/migrate/v4/source/file" ) var ( @@ -84,6 +86,28 @@ func Test(t *testing.T) { }) } +func TestMigrate(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := redshiftConnectionString(ip, port) + p := &Redshift{} + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "postgres", d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT 1")) + }) +} + func TestMultiStatement(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index 84ff25224..7b5a42b8d 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -17,6 +17,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" "google.golang.org/api/iterator" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" ) @@ -108,9 +109,6 @@ func (s *Spanner) Open(url string) (database.Driver, error) { } migrationsTable := purl.Query().Get("x-migrations-table") - if len(migrationsTable) == 0 { - migrationsTable = DefaultMigrationsTable - } db := &DB{admin: adminClient, data: dataClient} return WithInstance(db, &Config{ @@ -255,14 +253,27 @@ func (s *Spanner) Drop() error { return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} } - if err := s.ensureVersionTable(); err != nil { + return nil +} + +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Spanner type. +func (s *Spanner) ensureVersionTable() (err error) { + if err = s.Lock(); err != nil { return err } - return nil -} + defer func() { + if e := s.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() -func (s *Spanner) ensureVersionTable() error { ctx := context.Background() tbl := s.config.MigrationsTable iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) diff --git a/database/spanner/spanner_test.go b/database/spanner/spanner_test.go index bd90d6530..b9cb52b13 100644 --- a/database/spanner/spanner_test.go +++ b/database/spanner/spanner_test.go @@ -2,10 +2,12 @@ package spanner import ( "fmt" + "github.com/golang-migrate/migrate/v4" "os" "testing" dt "github.com/golang-migrate/migrate/v4/database/testing" + _ "github.com/golang-migrate/migrate/v4/source/file" ) func Test(t *testing.T) { @@ -19,10 +21,33 @@ func Test(t *testing.T) { } s := &Spanner{} - addr := fmt.Sprintf("spanner://%v", db) + addr := fmt.Sprintf("spanner://%s", db) d, err := s.Open(addr) if err != nil { t.Fatalf("%v", err) } dt.Test(t, d, []byte("SELECT 1")) } + +func TestMigrate(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + db, ok := os.LookupEnv("SPANNER_DATABASE") + if !ok { + t.Skip("SPANNER_DATABASE not set, skipping test.") + } + + s := &Spanner{} + addr := fmt.Sprintf("spanner://%s", db) + d, err := s.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", db, d) + if err != nil { + t.Fatalf("%v", err) + } + dt.TestMigrate(t, m, []byte("SELECT 1")) +} diff --git a/database/sqlite3/migration/33_create_table.down.sql b/database/sqlite3/examples/migrations/33_create_table.down.sql similarity index 100% rename from database/sqlite3/migration/33_create_table.down.sql rename to database/sqlite3/examples/migrations/33_create_table.down.sql diff --git a/database/sqlite3/migration/33_create_table.up.sql b/database/sqlite3/examples/migrations/33_create_table.up.sql similarity index 100% rename from database/sqlite3/migration/33_create_table.up.sql rename to database/sqlite3/examples/migrations/33_create_table.up.sql diff --git a/database/sqlite3/migration/44_alter_table.down.sql b/database/sqlite3/examples/migrations/44_alter_table.down.sql similarity index 100% rename from database/sqlite3/migration/44_alter_table.down.sql rename to database/sqlite3/examples/migrations/44_alter_table.down.sql diff --git a/database/sqlite3/migration/44_alter_table.up.sql b/database/sqlite3/examples/migrations/44_alter_table.up.sql similarity index 100% rename from database/sqlite3/migration/44_alter_table.up.sql rename to database/sqlite3/examples/migrations/44_alter_table.up.sql diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index d65fe8070..4826448a5 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -10,6 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/hashicorp/go-multierror" _ "github.com/mattn/go-sqlite3" ) @@ -44,6 +45,7 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if err := instance.Ping(); err != nil { return nil, err } + if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } @@ -58,7 +60,23 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return mx, nil } -func (m *Sqlite) ensureVersionTable() error { +// ensureVersionTable checks if versions table exists and, if not, creates it. +// Note that this function locks the database, which deviates from the usual +// convention of "caller locks" in the Sqlite type. +func (m *Sqlite) ensureVersionTable() (err error) { + if err = m.Lock(); err != nil { + return err + } + + defer func() { + if e := m.Unlock(); e != nil { + if err == nil { + err = e + } else { + err = multierror.Append(err, e) + } + } + }() query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); @@ -125,9 +143,6 @@ func (m *Sqlite) Drop() error { return &database.Error{OrigErr: err, Query: []byte(query)} } } - if err := m.ensureVersionTable(); err != nil { - return err - } query := "VACUUM" _, err = m.db.Query(query) if err != nil { diff --git a/database/sqlite3/sqlite3_test.go b/database/sqlite3/sqlite3_test.go index af2020c03..648141c92 100644 --- a/database/sqlite3/sqlite3_test.go +++ b/database/sqlite3/sqlite3_test.go @@ -29,6 +29,24 @@ func Test(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + dt.Test(t, d, []byte("CREATE TABLE t (Qty int, Name string);")) +} + +func TestMigrate(t *testing.T) { + dir, err := ioutil.TempDir("", "sqlite3-driver-test") + if err != nil { + return + } + defer func() { + os.RemoveAll(dir) + }() + t.Logf("DB path : %s\n", filepath.Join(dir, "sqlite3.db")) + p := &Sqlite{} + addr := fmt.Sprintf("sqlite3://%s", filepath.Join(dir, "sqlite3.db")) + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } db, err := sql.Open("sqlite3", filepath.Join(dir, "sqlite3.db")) if err != nil { @@ -39,7 +57,6 @@ func Test(t *testing.T) { return } }() - dt.Test(t, d, []byte("CREATE TABLE t (Qty int, Name string);")) driver, err := WithInstance(db, &Config{}) if err != nil { t.Fatalf("%v", err) @@ -49,16 +66,12 @@ func Test(t *testing.T) { } m, err := migrate.NewWithDatabaseInstance( - "file://./migration", + "file://./examples/migrations", "ql", driver) if err != nil { t.Fatalf("%v", err) } - t.Log("UP") - err = m.Up() - if err != nil { - t.Fatalf("%v", err) - } + dt.TestMigrate(t, m, []byte("CREATE TABLE t (Qty int, Name string);")) } func TestMigrationTable(t *testing.T) { @@ -90,7 +103,7 @@ func TestMigrationTable(t *testing.T) { t.Fatalf("%v", err) } m, err := migrate.NewWithDatabaseInstance( - "file://./migration", + "file://./examples/migrations", "ql", driver) if err != nil { t.Fatalf("%v", err) diff --git a/database/stub/stub_test.go b/database/stub/stub_test.go index 2b966daa1..f755be65c 100644 --- a/database/stub/stub_test.go +++ b/database/stub/stub_test.go @@ -1,6 +1,9 @@ package stub import ( + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/source" + "github.com/golang-migrate/migrate/v4/source/stub" "testing" dt "github.com/golang-migrate/migrate/v4/database/testing" @@ -14,3 +17,24 @@ func Test(t *testing.T) { } dt.Test(t, d, []byte("/* foobar migration */")) } + +func TestMigrate(t *testing.T) { + s := &Stub{} + d, err := s.Open("") + if err != nil { + t.Fatal(err) + } + + stubMigrations := source.NewMigrations() + stubMigrations.Append(&source.Migration{Version: 1, Direction: source.Up, Identifier: "CREATE 1"}) + stubMigrations.Append(&source.Migration{Version: 1, Direction: source.Down, Identifier: "DROP 1"}) + src := &stub.Stub{} + srcDrv, err := src.Open("") + srcDrv.(*stub.Stub).Migrations = stubMigrations + m, err := migrate.NewWithInstance("stub", srcDrv, "", d) + if err != nil { + t.Fatalf("%v", err) + } + + dt.TestMigrate(t, m, []byte("/* foobar migration */")) +} diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go new file mode 100644 index 000000000..5b328a715 --- /dev/null +++ b/database/testing/migrate_testing.go @@ -0,0 +1,39 @@ +// Package testing has the database tests. +// All database drivers must pass the Test function. +// This lives in it's own package so it stays a test dependency. +package testing + +import ( + "testing" +) + +import ( + "github.com/golang-migrate/migrate/v4" +) + +// TestMigrate runs integration-tests between the Migrate layer and database implementations. +// +func TestMigrate(t *testing.T, m *migrate.Migrate, migration []byte) { + if migration == nil { + panic("test must provide migration reader") + } + + TestMigrateUp(t, m) + TestMigrateDrop(t, m) +} + +// Regression test for preventing a regression for #164 https://github.com/golang-migrate/migrate/pull/173 +// Similar to TestDrop(), but tests the dropping mechanism through the Migrate logic instead, to check for +// double-locking during the Drop logic. +func TestMigrateDrop(t *testing.T, m *migrate.Migrate) { + if err := m.Drop(); err != nil { + t.Fatal(err) + } +} + +func TestMigrateUp(t *testing.T, m *migrate.Migrate) { + t.Log("UP") + if err := m.Up(); err != nil { + t.Fatalf("%v", err) + } +} \ No newline at end of file diff --git a/database/testing/testing.go b/database/testing/testing.go index 6d561d48e..b227c6088 100644 --- a/database/testing/testing.go +++ b/database/testing/testing.go @@ -22,8 +22,9 @@ func Test(t *testing.T, d database.Driver, migration []byte) { TestNilVersion(t, d) // test first TestLockAndUnlock(t, d) TestRun(t, d, bytes.NewReader(migration)) - TestDrop(t, d) TestSetVersion(t, d) // also tests Version() + // Drop breaks the driver, so test it last. + TestDrop(t, d) } func TestNilVersion(t *testing.T, d database.Driver) {