Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cmd/schemagen #262

Merged
merged 2 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions cmd/schemagen/schemagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"log"
"os"
"path"
"regexp"
"strings"

"github.com/gocql/gocql"
Expand All @@ -19,13 +20,15 @@ import (
)

var (
cmd = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
flagCluster = cmd.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagKeyspace = cmd.String("keyspace", "", "keyspace to inspect")
flagPkgname = cmd.String("pkgname", "models", "the name you wish to assign to your generated package")
flagOutput = cmd.String("output", "models", "the name of the folder to output to")
flagUser = cmd.String("user", "", "user for password authentication")
flagPassword = cmd.String("password", "", "password for password authentication")
cmd = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
flagCluster = cmd.String("cluster", "127.0.0.1", "a comma-separated list of host:port tuples")
flagKeyspace = cmd.String("keyspace", "", "keyspace to inspect")
flagPkgname = cmd.String("pkgname", "models", "the name you wish to assign to your generated package")
flagOutput = cmd.String("output", "models", "the name of the folder to output to")
flagUser = cmd.String("user", "", "user for password authentication")
flagPassword = cmd.String("password", "", "password for password authentication")
flagIgnoreNames = cmd.String("ignore-names", "", "a comma-separated list of table, view or index names to ignore")
flagIgnoreIndexes = cmd.Bool("ignore-indexes", false, "don't generate types for indexes")
)

var (
Expand Down Expand Up @@ -82,6 +85,31 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) {
log.Fatalln("unable to parse models template:", err)
}

ignoredNames := make(map[string]struct{})
for _, ignoredName := range strings.Split(*flagIgnoreNames, ",") {
ignoredNames[ignoredName] = struct{}{}
}
if *flagIgnoreIndexes {
for name := range md.Tables {
if strings.HasSuffix(name, "_index") {
ignoredNames[name] = struct{}{}
}
}
}
for name := range ignoredNames {
delete(md.Tables, name)
}

orphanedTypes := make(map[string]struct{})
for userTypeName := range md.UserTypes {
if !usedInTables(userTypeName, md.Tables) {
orphanedTypes[userTypeName] = struct{}{}
}
}
for typeName := range orphanedTypes {
delete(md.UserTypes, typeName)
}

imports := make([]string, 0)
for _, t := range md.Tables {
for _, c := range t.Columns {
Expand Down Expand Up @@ -135,3 +163,28 @@ func existsInSlice(s []string, v string) bool {

return false
}

// userTypes finds Cassandra schema types enclosed in angle brackets.
// Calling FindAllStringSubmatch on it will return a slice of string slices containing two elements.
// The second element contains the name of the type.
//
// [["<my_type,", "my_type"] ["my_other_type>", "my_other_type"]]
var userTypes = regexp.MustCompile(`(?:<|\s)(\w+)(?:>|,)`) // match all types contained in set<X>, list<X>, tuple<A, B> etc.

// usedInTables reports whether the typeName is used in any of columns of the provided tables.
func usedInTables(typeName string, tables map[string]*gocql.TableMetadata) bool {
for _, table := range tables {
for _, column := range table.Columns {
if typeName == column.Validator {
return true
}
matches := userTypes.FindAllStringSubmatch(column.Validator, -1)
for _, s := range matches {
if s[1] == typeName {
return true
}
}
}
}
return false
}
109 changes: 109 additions & 0 deletions cmd/schemagen/schemagen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"testing"

"github.com/gocql/gocql"
"github.com/google/go-cmp/cmp"
"github.com/scylladb/gocqlx/v2/gocqlxtest"
)
Expand All @@ -16,6 +18,15 @@ var flagUpdate = flag.Bool("update", false, "update golden file")
func TestSchemagen(t *testing.T) {
flag.Parse()
createTestSchema(t)

// add ignored types and table
*flagIgnoreNames = strings.Join([]string{
"composers",
"composers_by_name",
"label",
}, ",")
*flagIgnoreIndexes = true

b := runSchemagen(t, "foobar")

const goldenFile = "testdata/models.go.txt"
Expand All @@ -34,6 +45,76 @@ func TestSchemagen(t *testing.T) {
}
}

func Test_usedInTables(t *testing.T) {
tests := map[string]struct {
columnValidator string
typeName string
}{
"matches given a frozen collection": {
columnValidator: "frozen<album>",
typeName: "album",
},
"matches given a set": {
columnValidator: "set<artist>",
typeName: "artist",
},
"matches given a list": {
columnValidator: "list<song>",
typeName: "song",
},
"matches given a tuple: first of two elements": {
columnValidator: "tuple<first, second>",
typeName: "first",
},
"matches given a tuple: second of two elements": {
columnValidator: "tuple<first, second>",
typeName: "second",
},
"matches given a tuple: first of three elements": {
columnValidator: "tuple<first, second, third>",
typeName: "first",
},
"matches given a tuple: second of three elements": {
columnValidator: "tuple<first, second, third>",
typeName: "second",
},
"matches given a tuple: third of three elements": {
columnValidator: "tuple<first, second, third>",
typeName: "third",
},
"matches given a frozen set": {
columnValidator: "set<frozen<album>>",
typeName: "album",
},
"matches snake_case names given a nested map": {
columnValidator: "map<album, tuple<first, map<map_key, map-value>, third>>",
typeName: "map_key",
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tables := map[string]*gocql.TableMetadata{
"table": {Columns: map[string]*gocql.ColumnMetadata{
"column": {Validator: tt.columnValidator},
}},
}
if !usedInTables(tt.typeName, tables) {
t.Fatal()
}
})
}

t.Run("doesn't panic with empty type name", func(t *testing.T) {
tables := map[string]*gocql.TableMetadata{
"table": {Columns: map[string]*gocql.ColumnMetadata{
"column": {Validator: "map<text, album>"},
}},
}
usedInTables("", tables)
})
}

func createTestSchema(t *testing.T) {
t.Helper()

Expand Down Expand Up @@ -73,6 +154,34 @@ func createTestSchema(t *testing.T) {
if err != nil {
t.Fatal("create table:", err)
}

err = session.ExecStmt(`CREATE INDEX IF NOT EXISTS songs_title ON schemagen.songs (title)`)
if err != nil {
t.Fatal("create index:", err)
}

err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS schemagen.composers (
id uuid PRIMARY KEY,
name text)`)
if err != nil {
t.Fatal("create table:", err)
}

err = session.ExecStmt(`CREATE MATERIALIZED VIEW IF NOT EXISTS schemagen.composers_by_name AS
SELECT id, name
FROM composers
WHERE id IS NOT NULL AND name IS NOT NULL
PRIMARY KEY (id, name)`)
if err != nil {
t.Fatal("create view:", err)
}

err = session.ExecStmt(`CREATE TYPE IF NOT EXISTS schemagen.label (
name text,
artists set<text>)`)
if err != nil {
t.Fatal("create type:", err)
}
}

func runSchemagen(t *testing.T, pkgname string) []byte {
Expand Down
Loading