diff --git a/cmd/schemagen/schemagen.go b/cmd/schemagen/schemagen.go index 9cff543..85875c7 100644 --- a/cmd/schemagen/schemagen.go +++ b/cmd/schemagen/schemagen.go @@ -11,6 +11,7 @@ import ( "log" "os" "path" + "regexp" "strings" "github.com/gocql/gocql" @@ -19,13 +20,15 @@ import ( ) var ( - cmd = flag.NewFlagSet(os.Args[0], flag.ExitOnError) - flagCluster = cmd.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") - flagKeyspace = cmd.String("keyspace", "", "keyspace to inspect") - flagPkgname = cmd.String("pkgname", "models", "the name you wish to assign to your generated package") - flagOutput = cmd.String("output", "models", "the name of the folder to output to") - flagUser = cmd.String("user", "", "user for password authentication") - flagPassword = cmd.String("password", "", "password for password authentication") + cmd = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flagCluster = cmd.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples") + flagKeyspace = cmd.String("keyspace", "", "keyspace to inspect") + flagPkgname = cmd.String("pkgname", "models", "the name you wish to assign to your generated package") + flagOutput = cmd.String("output", "models", "the name of the folder to output to") + flagUser = cmd.String("user", "", "user for password authentication") + flagPassword = cmd.String("password", "", "password for password authentication") + flagIgnoreNames = cmd.String("ignore-names", "", "a comma-separated list of table, view or index names to ignore") + flagIgnoreIndexes = cmd.Bool("ignore-indexes", false, "don't generate types for indexes") ) var ( @@ -82,6 +85,31 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) { log.Fatalln("unable to parse models template:", err) } + ignoredNames := make(map[string]struct{}) + for _, ignoredName := range strings.Split(*flagIgnoreNames, ",") { + ignoredNames[ignoredName] = struct{}{} + } + if *flagIgnoreIndexes { + for name := range md.Tables { + if strings.HasSuffix(name, "_index") { + ignoredNames[name] = struct{}{} + } + } + } + for name := range ignoredNames { + delete(md.Tables, name) + } + + orphanedTypes := make(map[string]struct{}) + for userTypeName := range md.UserTypes { + if !usedInTables(userTypeName, md.Tables) { + orphanedTypes[userTypeName] = struct{}{} + } + } + for typeName := range orphanedTypes { + delete(md.UserTypes, typeName) + } + imports := make([]string, 0) for _, t := range md.Tables { for _, c := range t.Columns { @@ -135,3 +163,28 @@ func existsInSlice(s []string, v string) bool { return false } + +// userTypes finds Cassandra schema types enclosed in angle brackets. +// Calling FindAllStringSubmatch on it will return a slice of string slices containing two elements. +// The second element contains the name of the type. +// +// [["", "my_other_type"]] +var userTypes = regexp.MustCompile(`(?:<|\s)(\w+)(?:>|,)`) // match all types contained in set, list, tuple etc. + +// usedInTables reports whether the typeName is used in any of columns of the provided tables. +func usedInTables(typeName string, tables map[string]*gocql.TableMetadata) bool { + for _, table := range tables { + for _, column := range table.Columns { + if typeName == column.Validator { + return true + } + matches := userTypes.FindAllStringSubmatch(column.Validator, -1) + for _, s := range matches { + if s[1] == typeName { + return true + } + } + } + } + return false +} diff --git a/cmd/schemagen/schemagen_test.go b/cmd/schemagen/schemagen_test.go index 382d21f..44c5925 100644 --- a/cmd/schemagen/schemagen_test.go +++ b/cmd/schemagen/schemagen_test.go @@ -5,8 +5,10 @@ import ( "fmt" "io/ioutil" "os" + "strings" "testing" + "github.com/gocql/gocql" "github.com/google/go-cmp/cmp" "github.com/scylladb/gocqlx/v2/gocqlxtest" ) @@ -16,6 +18,15 @@ var flagUpdate = flag.Bool("update", false, "update golden file") func TestSchemagen(t *testing.T) { flag.Parse() createTestSchema(t) + + // add ignored types and table + *flagIgnoreNames = strings.Join([]string{ + "composers", + "composers_by_name", + "label", + }, ",") + *flagIgnoreIndexes = true + b := runSchemagen(t, "foobar") const goldenFile = "testdata/models.go.txt" @@ -34,6 +45,76 @@ func TestSchemagen(t *testing.T) { } } +func Test_usedInTables(t *testing.T) { + tests := map[string]struct { + columnValidator string + typeName string + }{ + "matches given a frozen collection": { + columnValidator: "frozen", + typeName: "album", + }, + "matches given a set": { + columnValidator: "set", + typeName: "artist", + }, + "matches given a list": { + columnValidator: "list", + typeName: "song", + }, + "matches given a tuple: first of two elements": { + columnValidator: "tuple", + typeName: "first", + }, + "matches given a tuple: second of two elements": { + columnValidator: "tuple", + typeName: "second", + }, + "matches given a tuple: first of three elements": { + columnValidator: "tuple", + typeName: "first", + }, + "matches given a tuple: second of three elements": { + columnValidator: "tuple", + typeName: "second", + }, + "matches given a tuple: third of three elements": { + columnValidator: "tuple", + typeName: "third", + }, + "matches given a frozen set": { + columnValidator: "set>", + typeName: "album", + }, + "matches snake_case names given a nested map": { + columnValidator: "map, third>>", + typeName: "map_key", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + tables := map[string]*gocql.TableMetadata{ + "table": {Columns: map[string]*gocql.ColumnMetadata{ + "column": {Validator: tt.columnValidator}, + }}, + } + if !usedInTables(tt.typeName, tables) { + t.Fatal() + } + }) + } + + t.Run("doesn't panic with empty type name", func(t *testing.T) { + tables := map[string]*gocql.TableMetadata{ + "table": {Columns: map[string]*gocql.ColumnMetadata{ + "column": {Validator: "map"}, + }}, + } + usedInTables("", tables) + }) +} + func createTestSchema(t *testing.T) { t.Helper() @@ -73,6 +154,34 @@ func createTestSchema(t *testing.T) { if err != nil { t.Fatal("create table:", err) } + + err = session.ExecStmt(`CREATE INDEX IF NOT EXISTS songs_title ON schemagen.songs (title)`) + if err != nil { + t.Fatal("create index:", err) + } + + err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS schemagen.composers ( + id uuid PRIMARY KEY, + name text)`) + if err != nil { + t.Fatal("create table:", err) + } + + err = session.ExecStmt(`CREATE MATERIALIZED VIEW IF NOT EXISTS schemagen.composers_by_name AS + SELECT id, name + FROM composers + WHERE id IS NOT NULL AND name IS NOT NULL + PRIMARY KEY (id, name)`) + if err != nil { + t.Fatal("create view:", err) + } + + err = session.ExecStmt(`CREATE TYPE IF NOT EXISTS schemagen.label ( + name text, + artists set)`) + if err != nil { + t.Fatal("create type:", err) + } } func runSchemagen(t *testing.T, pkgname string) []byte {