From 044dc52e4a8992fede883902a557ba1a2260eb40 Mon Sep 17 00:00:00 2001 From: Jesse Hoobergs Date: Tue, 9 Jul 2024 11:58:32 +0200 Subject: [PATCH] Support enums, smallintegers and add around mysql column names --- Cargo.lock | 7 ++ erd/Cargo.toml | 1 + erd/src/ast.rs | 9 +++ erd/src/erd.pest | 21 ++++- erd/src/physical.rs | 11 +++ erd/src/sql.rs | 177 +++++++++++++++++++++++++++++++++++++++++- examples/physical.erd | 1 + 7 files changed, 225 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b9a2aad..3208bce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,6 +162,7 @@ name = "erd_script" version = "0.1.0" dependencies = [ "clap", + "lazy_static", "pest", "pest_derive", "serde", @@ -204,6 +205,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.155" diff --git a/erd/Cargo.toml b/erd/Cargo.toml index b854bf5..7434c8b 100644 --- a/erd/Cargo.toml +++ b/erd/Cargo.toml @@ -13,3 +13,4 @@ pest = "2.1" pest_derive = "2.1" serde = { version = "1.0", features = ["derive"] } clap = { version = "4.5.8", features = ["derive"] } +lazy_static = "1.5.0" \ No newline at end of file diff --git a/erd/src/ast.rs b/erd/src/ast.rs index 7559503..52189f1 100644 --- a/erd/src/ast.rs +++ b/erd/src/ast.rs @@ -172,6 +172,7 @@ pub enum Expr { #[derive(Clone, Debug, Eq, PartialEq)] pub enum DataType { Integer, + SmallInteger, AutoIncrement, Float, Boolean, @@ -187,6 +188,7 @@ pub enum DataType { /// M is the precision: total number of digits (. and - not counted) /// D is the scale: the number of digits after the decimal point Decimal(usize, usize), + Enum(Vec), } impl std::convert::From for DataType { @@ -202,12 +204,17 @@ impl std::convert::From for DataType { let d = part2[..(part2.len() - 1)].trim().parse().unwrap(); Self::Decimal(m, d) + } else if s.starts_with("enum(") { + let part = &s["enum(".len()..(s.len() - 1)]; + let elements: Vec = part.split(",").map(|x| x.trim().to_string()).collect(); + Self::Enum(elements) } else { match &s[..] { "uuid" => Self::Uuid, "text" => Self::Text, "blob" => Self::Blob, "integer" => Self::Integer, + "smallinteger" => Self::SmallInteger, "autoincrement" => Self::AutoIncrement, "float" => Self::Float, "boolean" => Self::Boolean, @@ -224,6 +231,7 @@ impl DataType { pub fn foreign_key_type(&self) -> DataType { match self { Self::Integer => Self::Integer, + Self::SmallInteger => Self::SmallInteger, Self::AutoIncrement => Self::Integer, Self::Float => Self::Float, Self::Boolean => Self::Boolean, @@ -236,6 +244,7 @@ impl DataType { Self::Text => Self::Text, Self::Uuid => Self::Uuid, Self::Decimal(m, d) => Self::Decimal(*m, *d), + Self::Enum(m) => Self::Enum(m.to_owned()), } } } diff --git a/erd/src/erd.pest b/erd/src/erd.pest index 8476c15..17f65c4 100644 --- a/erd/src/erd.pest +++ b/erd/src/erd.pest @@ -4,7 +4,26 @@ COMMENT = _{ "//" ~ (!"\n" ~ ANY)* } ident = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } -datatype = { "integer" | "autoincrement" | "float" | "boolean" | (!"datetime" ~ "date") | "time" | "datetime" | "blob" | "text" | "uuid" | "varchar(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" | "varbinary(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" | "decimal(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ "," ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" } +datatype = { + "integer" | + "smallinteger" | + "autoincrement" | + "float" | + "boolean" | + (!"datetime" ~ "date") | + "time" | + "datetime" | + "blob" | + "text" | + "uuid" | + "varchar(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" | + "varbinary(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" | + "decimal(" ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ "," ~ (!"0" ~ ASCII_DIGIT) ~ ASCII_DIGIT* ~ ")" | + "enum(" ~ enum_item ~ ("," ~ enum_item)* ~ ")" +} +not_comma_or_space = { !("," | " " | ")") ~ ANY } +enum_item = { not_comma_or_space+ } + attribute_prefix = { "attribute" | "id" } attribute = { attribute_prefix ~ ident ~ ("type" ~ datatype)? } entity = { "entity" ~ ident ~ (!"\n\n" ~ "\n" ~ attribute)* } diff --git a/erd/src/physical.rs b/erd/src/physical.rs index f229e9e..abe01d3 100644 --- a/erd/src/physical.rs +++ b/erd/src/physical.rs @@ -168,10 +168,15 @@ pub struct Table { impl Table { fn write_sql_create(&self, s: &mut String, sql: SQL) { + let mut additional_definitions = Vec::new(); + write!(s, "CREATE TABLE {} (\n", self.name); for col in self.columns.iter() { col.write_sql_create_lines(s, sql); write!(s, "\n"); + if let Some(x) = sql.to_additional_definitions(&col.datatype) { + additional_definitions.push(x); + } } write!( s, @@ -183,6 +188,12 @@ impl Table { .join(","), ); write!(s, ");"); + if !additional_definitions.is_empty() { + for def in additional_definitions { + writeln!(s, ""); + write!(s, "{}", def); + } + } } } diff --git a/erd/src/sql.rs b/erd/src/sql.rs index e4de1ca..16e7cb5 100644 --- a/erd/src/sql.rs +++ b/erd/src/sql.rs @@ -24,6 +24,16 @@ impl SQL { // The methods needed for SQL creation impl SQL { + pub fn to_additional_definitions(&self, data_type: &DataType) -> Option { + match self { + Self::MSAccess => ms_access::to_additional_definitions(data_type), + Self::LibreOfficeBase => libre_office_base::to_additional_definitions(data_type), + Self::MySQL => mysql::to_additional_definitions(data_type), + Self::PostgreSQL => postgresql::to_additional_definitions(data_type), + Self::MSSQL => mssql::to_additional_definitions(data_type), + } + } + pub fn to_data_type(&self, data_type: &DataType) -> String { match self { Self::MSAccess => ms_access::to_data_type(data_type), @@ -48,9 +58,32 @@ impl SQL { mod ms_access { use crate::ast::{DataType, Ident}; + use super::MaxLength; + + pub fn to_additional_definitions(data_type: &DataType) -> Option { + match data_type { + DataType::Enum(_) => None, + DataType::Integer + | DataType::SmallInteger + | DataType::AutoIncrement + | DataType::Float + | DataType::Boolean + | DataType::Date + | DataType::Time + | DataType::DateTime + | DataType::Varchar(_) + | DataType::Varbinary(_) + | DataType::Blob + | DataType::Text + | DataType::Uuid + | DataType::Decimal(_, _) => None, + } + } + pub fn to_data_type(data_type: &DataType) -> String { match data_type { DataType::Integer => "INTEGER".to_string(), + DataType::SmallInteger => "INTEGER".to_string(), // TODO DataType::AutoIncrement => "AUTOINCREMENT".to_string(), DataType::Float => "FLOAT".to_string(), DataType::Boolean => "YESNO".to_string(), @@ -64,6 +97,8 @@ mod ms_access { DataType::Text => "TEXT".to_string(), DataType::Uuid => "GUID".to_string(), DataType::Decimal(m, d) => format!("DECIMAL({m}, {d})"), + DataType::Enum(options) => format!("VARCHAR({})", options.max_length().unwrap()), + // TODO CHECK (mycol IN('a', 'b')) } } pub fn to_column_ident(ident: &Ident) -> String { @@ -74,10 +109,33 @@ mod ms_access { mod libre_office_base { use crate::ast::{DataType, Ident}; + use super::MaxLength; + + pub fn to_additional_definitions(data_type: &DataType) -> Option { + match data_type { + DataType::Enum(_) => None, + DataType::Integer + | DataType::SmallInteger + | DataType::AutoIncrement + | DataType::Float + | DataType::Boolean + | DataType::Date + | DataType::Time + | DataType::DateTime + | DataType::Varchar(_) + | DataType::Varbinary(_) + | DataType::Blob + | DataType::Text + | DataType::Uuid + | DataType::Decimal(_, _) => None, + } + } + // See http://www.hsqldb.org/doc/1.8/guide/guide.html#datatypes-section pub fn to_data_type(data_type: &DataType) -> String { match data_type { DataType::Integer => "INTEGER".to_string(), + DataType::SmallInteger => "INTEGER".to_string(), // TODO DataType::AutoIncrement => "INTEGER GENERATED BY DEFAULT AS IDENTITY".to_string(), DataType::Float => "FLOAT".to_string(), DataType::Boolean => "BOOLEAN".to_string(), @@ -91,6 +149,8 @@ mod libre_office_base { DataType::Text => "LONGVARCHAR".to_string(), DataType::Uuid => "UUID".to_string(), DataType::Decimal(m, d) => format!("DECIMAL({m}, {d})"), + DataType::Enum(options) => format!("VARCHAR({})", options.max_length().unwrap()), + // TODO CHECK (mycol IN('a', 'b'))? } } pub fn to_column_ident(ident: &Ident) -> String { @@ -101,9 +161,30 @@ mod libre_office_base { mod mysql { use crate::ast::{DataType, Ident}; + pub fn to_additional_definitions(data_type: &DataType) -> Option { + match data_type { + DataType::Enum(_) => None, + DataType::Integer + | DataType::SmallInteger + | DataType::AutoIncrement + | DataType::Float + | DataType::Boolean + | DataType::Date + | DataType::Time + | DataType::DateTime + | DataType::Varchar(_) + | DataType::Varbinary(_) + | DataType::Blob + | DataType::Text + | DataType::Uuid + | DataType::Decimal(_, _) => None, + } + } + pub fn to_data_type(data_type: &DataType) -> String { match data_type { DataType::Integer => "INTEGER".to_string(), + DataType::SmallInteger => "SMALLINT".to_string(), DataType::AutoIncrement => "INTEGER AUTO_INCREMENT".to_string(), DataType::Float => "FLOAT".to_string(), DataType::Boolean => "BOOLEAN".to_string(), @@ -116,19 +197,76 @@ mod mysql { DataType::Text => "TEXT".to_string(), DataType::Uuid => "UUID".to_string(), DataType::Decimal(m, d) => format!("DECIMAL({m}, {d})"), + DataType::Enum(options) => format!( + "ENUM({})", + options + .iter() + .map(|x| format!("'{}'", x.replace("'", "\'"))) + .collect::>() + .join(",") + ), } } pub fn to_column_ident(ident: &Ident) -> String { - ident.to_string() + format!("`{}`", ident.to_string()) } } mod postgresql { use crate::ast::{DataType, Ident}; + use std::{collections::HashMap, sync::Mutex}; + + lazy_static::lazy_static! { + static ref ENUM_IDX: Mutex, usize>> = Mutex::new(HashMap::new()); + } + + fn get_enum_name(options: &Vec) -> String { + let mut map = ENUM_IDX.lock().unwrap(); + + let idx = if let Some(x) = map.get(options) { + *x + } else { + let new = map.len(); + map.insert(options.clone(), new); + new + }; + + format!("enum{}", idx) + } + + pub fn to_additional_definitions(data_type: &DataType) -> Option { + match data_type { + // TODO name and enum + DataType::Enum(options) => Some(format!( + "CREATE TYPE {} AS ENUM ({});", + get_enum_name(options), + options + .iter() + .map(|x| format!("'{}'", x.replace("'", "\'"))) + .collect::>() + .join(",") + )), + DataType::Integer + | DataType::SmallInteger + | DataType::AutoIncrement + | DataType::Float + | DataType::Boolean + | DataType::Date + | DataType::Time + | DataType::DateTime + | DataType::Varchar(_) + | DataType::Varbinary(_) + | DataType::Blob + | DataType::Text + | DataType::Uuid + | DataType::Decimal(_, _) => None, + } + } pub fn to_data_type(data_type: &DataType) -> String { match data_type { DataType::Integer => "INTEGER".to_string(), + DataType::SmallInteger => "SMALLINT".to_string(), DataType::AutoIncrement => "SERIAL".to_string(), DataType::Float => "FLOAT".to_string(), DataType::Boolean => "BOOLEAN".to_string(), @@ -141,6 +279,7 @@ mod postgresql { DataType::Text => "TEXT".to_string(), DataType::Uuid => "UUID".to_string(), DataType::Decimal(m, d) => format!("DECIMAL({m}, {d})"), + DataType::Enum(options) => get_enum_name(options), } } pub fn to_column_ident(ident: &Ident) -> String { @@ -151,9 +290,32 @@ mod postgresql { mod mssql { use crate::ast::{DataType, Ident}; + use super::MaxLength; + + pub fn to_additional_definitions(data_type: &DataType) -> Option { + match data_type { + DataType::Enum(_) => None, // TODO? + DataType::Integer + | DataType::SmallInteger + | DataType::AutoIncrement + | DataType::Float + | DataType::Boolean + | DataType::Date + | DataType::Time + | DataType::DateTime + | DataType::Varchar(_) + | DataType::Varbinary(_) + | DataType::Blob + | DataType::Text + | DataType::Uuid + | DataType::Decimal(_, _) => None, + } + } + pub fn to_data_type(data_type: &DataType) -> String { match data_type { DataType::Integer => "INTEGER".to_string(), + DataType::SmallInteger => "INTEGER".to_string(), // TODO DataType::AutoIncrement => "INTEGER IDENTITY(1,1)".to_string(), DataType::Float => "FLOAT".to_string(), DataType::Boolean => "BOOLEAN".to_string(), @@ -166,9 +328,22 @@ mod mssql { DataType::Text => "NVARCHAR(max)".to_string(), DataType::Uuid => "UNIQUEIDENTIFIER".to_string(), DataType::Decimal(m, d) => format!("DECIMAL({m}, {d})"), + DataType::Enum(options) => format!("VARCHAR({})", options.max_length().unwrap()), + // TODO CHECK (mycol IN('a', 'b')) + // See https://stackoverflow.com/a/1434338 } } pub fn to_column_ident(ident: &Ident) -> String { ident.to_string() } } + +trait MaxLength { + fn max_length(&self) -> Option; +} + +impl MaxLength for Vec { + fn max_length(&self) -> Option { + self.iter().map(|x| x.len()).max() + } +} diff --git a/examples/physical.erd b/examples/physical.erd index 5d3a245..0f7f9c8 100644 --- a/examples/physical.erd +++ b/examples/physical.erd @@ -17,6 +17,7 @@ relation Friends(Is friends with) entity Car id id type uuid + attribute brand type enum(volvo, BMW, ferrari, volkswagen) attribute color type varchar(20) attribute price type float attribute dec_price type decimal(10, 2)