diff --git a/cleanbuilddir.go b/clean_build_dir.go similarity index 100% rename from cleanbuilddir.go rename to clean_build_dir.go diff --git a/connecttodb.go b/connect_to_db.go similarity index 100% rename from connecttodb.go rename to connect_to_db.go diff --git a/getcolumns.go b/get_columns.go similarity index 100% rename from getcolumns.go rename to get_columns.go diff --git a/gettables.go b/get_tables.go similarity index 100% rename from gettables.go rename to get_tables.go diff --git a/get_tables_test.go b/get_tables_test.go new file mode 100644 index 0000000..6ea5b60 --- /dev/null +++ b/get_tables_test.go @@ -0,0 +1,30 @@ +package main + +import ( + "fmt" + "log" + "os" + "testing" +) + +func TestGetTables(t *testing.T) { + dbUsername = os.Getenv("SNOWFLAKE_SANDBOX_USERNAME") + dbAccount = os.Getenv("SNOWFLAKE_SANDBOX_ACCOUNT") + dbDatabase = "ANALYTICS" + dbSchema = "DBT_WINNIE" + databaseType := "snowflake" + connStr := fmt.Sprintf("%s@%s/%s/%s?authenticator=externalbrowser", dbUsername, dbAccount, dbDatabase, dbSchema) + ctx, db, err := ConnectToDB(connStr, databaseType) + if err != nil { + log.Fatal(err) + } + tables, err := GetTables(db, ctx) + if err != nil { + t.Errorf("Error getting tables: %v", err) + } + if len(tables.SourceTables) == 0 { + t.Errorf("No tables found") + } else { + t.Logf("%v tables found.", len(tables.SourceTables)) + } +} diff --git a/main.go b/main.go index bb8ed70..c381b14 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( "fmt" "log" + "os" "strings" + "sync" "github.com/charmbracelet/huh" _ "github.com/snowflakedb/gosnowflake" @@ -26,8 +28,9 @@ type Column struct { } type SourceTable struct { - Name string `yaml:"name"` - Columns []Column `yaml:"columns"` + Name string `yaml:"name"` + Columns []Column `yaml:"columns"` + DataTypeGroups map[string][]Column `yaml:"-"` } type SourceTables struct { @@ -35,11 +38,13 @@ type SourceTables struct { } func main() { - form := huh.NewForm( - huh.NewGroup( - huh.NewNote(). - Title("Welcome to tbd!🏎️✨"). - Description(`A fast and friendly code generator for dbt. + use_form := false + if use_form { + form := huh.NewForm( + huh.NewGroup( + huh.NewNote(). + Title("Welcome to tbd!🏎️✨"). + Description(`A fast and friendly code generator for dbt. We will generate sources YAML config and SQL staging models for all the tables in the schema you specify. To prepare, make sure you have the following: ✴︎ *_Username_* (e.g. aragorn@dunedain.king) @@ -48,41 +53,49 @@ To prepare, make sure you have the following: ✴︎ *_Database_* that schema is in (e.g. gondor) Authentication will be handled via SSO in the web browser. For security, we don't currently support password-based authentication.`), - ), - huh.NewGroup( - huh.NewSelect[string](). - Title("Choose your warehouse."). - Options( - huh.NewOption("Snowflake", "snowflake"), - ). - Value(&warehouse), + ), + huh.NewGroup( + huh.NewSelect[string](). + Title("Choose your warehouse."). + Options( + huh.NewOption("Snowflake", "snowflake"), + ). + Value(&warehouse), - huh.NewInput(). - Title("What is your username?"). - Value(&dbUsername), + huh.NewInput(). + Title("What is your username?"). + Value(&dbUsername), - huh.NewInput(). - Title("What is your Snowflake account id?"). - Value(&dbAccount), + huh.NewInput(). + Title("What is your Snowflake account id?"). + Value(&dbAccount), - huh.NewInput(). - Title("What is the schema you want to generate?"). - Value(&dbSchema), + huh.NewInput(). + Title("What is the schema you want to generate?"). + Value(&dbSchema), - huh.NewInput(). - Title("What database is that schema in?"). - Value(&dbDatabase), - ), - huh.NewGroup( - huh.NewConfirm(). - Title("Are you ready to go?"). - Value(&confirm), - ), - ) - form.WithTheme(huh.ThemeCatppuccin()) - err := form.Run() - if err != nil { - log.Fatal(err) + huh.NewInput(). + Title("What database is that schema in?"). + Value(&dbDatabase), + ), + huh.NewGroup( + huh.NewConfirm(). + Title("Are you ready to go?"). + Value(&confirm), + ), + ) + form.WithTheme(huh.ThemeCatppuccin()) + err := form.Run() + if err != nil { + log.Fatal(err) + } + } else { + warehouse = "snowflake" + dbUsername = os.Getenv("SNOWFLAKE_SANDBOX_USER") + dbAccount = os.Getenv("SNOWFLAKE_SANDBOX_ACCOUNT") + dbDatabase = "ANALYTICS" + dbSchema = "JAFFLE_SHOP_RAW" + confirm = true } if confirm { databaseType := warehouse @@ -97,13 +110,25 @@ For security, we don't currently support password-based authentication.`), if err != nil { log.Fatal(err) } + tables, err := GetTables(db, ctx) if err != nil { log.Fatal(err) } + PutColumnsOnTables(db, ctx, tables) CleanBuildDir("build") - WriteYAML(tables) - WriteStagingModels(tables) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + WriteYAML(tables) + }() + go func() { + defer wg.Done() + WriteStagingModels(tables) + }() + wg.Wait() } } diff --git a/put_columns_on_tables.go b/put_columns_on_tables.go new file mode 100644 index 0000000..3bcc9e2 --- /dev/null +++ b/put_columns_on_tables.go @@ -0,0 +1,54 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "log" + "regexp" + "sync" +) + +func PutColumnsOnTables(db *sql.DB, ctx context.Context, tables SourceTables) { + mutex := sync.Mutex{} + + var wg sync.WaitGroup + wg.Add(len(tables.SourceTables)) + + dataTypeGroupMap := map[string]string{ + "(text|char)": "text", + "(float|int|num)": "numbers", + "(bool|bit)": "booleans", + "json": "json", + "date": "datetimes", + "timestamp": "timestamps", + } + + for i := range tables.SourceTables { + go func(i int) { + defer wg.Done() + + columns, err := GetColumns(db, ctx, tables.SourceTables[i]) + if err != nil { + log.Printf("Error fetching columns for table %s: %v\n", tables.SourceTables[i].Name, err) + return + } + + mutex.Lock() + tables.SourceTables[i].Columns = columns + tables.SourceTables[i].DataTypeGroups = make(map[string][]Column) + // Create a map of data types groups to hold column slices by data type + // This lets us group columns by their data type e.g. in templates + for j := range tables.SourceTables[i].Columns { + for k, v := range dataTypeGroupMap { + r, _ := regexp.Compile(fmt.Sprintf(`(?i).*%s.*`, k)) + if r.MatchString(tables.SourceTables[i].Columns[j].DataType) { + tables.SourceTables[i].DataTypeGroups[v] = append(tables.SourceTables[i].DataTypeGroups[v], tables.SourceTables[i].Columns[j]) + } + } + } + mutex.Unlock() + }(i) + } + wg.Wait() +} diff --git a/putcolumnsontables.go b/putcolumnsontables.go deleted file mode 100644 index 0ccae88..0000000 --- a/putcolumnsontables.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "context" - "database/sql" - "log" - "sync" -) - -func PutColumnsOnTables(db *sql.DB, ctx context.Context, tables SourceTables) { - mutex := sync.Mutex{} - - var wg sync.WaitGroup - wg.Add(len(tables.SourceTables)) - - for i := range tables.SourceTables { - go func(i int) { - defer wg.Done() - - columns, err := GetColumns(db, ctx, tables.SourceTables[i]) - if err != nil { - log.Printf("Error fetching columns for table %s: %v\n", tables.SourceTables[i].Name, err) - return - } - - mutex.Lock() - tables.SourceTables[i].Columns = columns - mutex.Unlock() - }(i) - } - wg.Wait() -} diff --git a/staging_template.sql b/staging_template.sql new file mode 100644 index 0000000..8b73275 --- /dev/null +++ b/staging_template.sql @@ -0,0 +1,47 @@ +with + +source as ( + + select * from {{ "{{" }} ref('{{.Name}}') {{ "}}" }} + +), + +renamed as ( + + select +{{- range $group, $columns := .DataTypeGroups -}} + {{- if eq $group "text" }} + -- text + {{ range $columns -}} + {{- .Name }} as {{ .Name | lower }}, + {{ end -}} + {{- end -}} + {{- if eq $group "numbers" }} + -- numbers + {{ range $columns -}} + {{- .Name }} as {{ .Name | lower }}, + {{ end -}} + {{- end -}} + {{- if eq $group "booleans" }} + -- booleans + {{ range $columns -}} + {{- .Name }} as {{ .Name | lower }}, + {{ end -}} + {{- end -}} + {{- if eq $group "datetimes" }} + -- datetimes + {{ range $columns -}} + {{- .Name }} as {{ .Name | lower }}, + {{ end -}} + {{- end -}} + {{- if eq $group "timestamps" }} + -- timestamps + {{ range $columns -}} + {{- .Name }} as {{ .Name | lower }}, + {{ end -}} + {{- end -}} +{{ end }} + from source +) + +select * from renamed diff --git a/template.sql b/template.sql deleted file mode 100644 index c19596b..0000000 --- a/template.sql +++ /dev/null @@ -1,20 +0,0 @@ -with - -source as ( - - select * from {{ "{{" }} ref('{{.Name}}') {{ "}}" }} - -), - -renamed as ( - - select - {{ range .Columns -}} - {{ .Name }} as {{ .Name }}, - {{ end }} - - from source - -) - -select * from renamed diff --git a/writestagingmodels.go b/write_staging_models.go similarity index 78% rename from writestagingmodels.go rename to write_staging_models.go index 93c41ff..8163e4f 100644 --- a/writestagingmodels.go +++ b/write_staging_models.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "log" "os" "strings" "sync" @@ -18,9 +17,10 @@ func WriteStagingModels(tables SourceTables) { go func(table SourceTable) { defer wg.Done() - tmpl, err := template.ParseFiles("template.sql") + tmpl := template.New("staging_template.sql").Funcs(template.FuncMap{"lower": strings.ToLower}) + tmpl, err := tmpl.ParseFiles("staging_template.sql") if err != nil { - log.Fatal(err) + panic(err) } filename := fmt.Sprintf("build/stg_" + strings.ToLower(table.Name) + ".sql") diff --git a/writeyaml.go b/write_yaml.go similarity index 100% rename from writeyaml.go rename to write_yaml.go