From 58885910e676fd82963deb55e076269f111d8dfb Mon Sep 17 00:00:00 2001 From: nociza Date: Tue, 28 Mar 2023 16:08:51 -0700 Subject: [PATCH 1/2] initial commit --- README.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..826ddfd --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# Rust DataBase Connectivity (RDBC) + +This is a Rust implementation of the Java DataBase Connectivity (JDBC) API, a continuation and reimplementation of the [rdbc](https://github.com/tokio-rs/rdbc) project. \ No newline at end of file From 5edf39c83f8d67aae84799788e1a9f6069a8833b Mon Sep 17 00:00:00 2001 From: nociza Date: Tue, 28 Mar 2023 16:10:35 -0700 Subject: [PATCH 2/2] add postgres logic --- .github/workflows/rust.yml | 31 ++++++ .gitignore | 4 + Cargo.toml | 17 +++ src/dbc.rs | 217 +++++++++++++++++++++++++++++++++++++ src/dbc/mysql.rs | 116 ++++++++++++++++++++ src/dbc/postgres.rs | 86 +++++++++++++++ src/dbc/sqlite.rs | 69 ++++++++++++ src/lib.rs | 1 + tests/mysql_test.rs | 44 ++++++++ tests/sqlite_test.rs | 33 ++++++ 10 files changed, 618 insertions(+) create mode 100644 .github/workflows/rust.yml create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/dbc.rs create mode 100644 src/dbc/mysql.rs create mode 100644 src/dbc/postgres.rs create mode 100644 src/dbc/sqlite.rs create mode 100644 src/lib.rs create mode 100644 tests/mysql_test.rs create mode 100644 tests/sqlite_test.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..a118416 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,31 @@ +name: Rust + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + runs-on: ubuntu-latest + services: + mysql: + image: mysql:5.7 + env: + MYSQL_ROOT_PASSWORD: password + MYSQL_DATABASE: test_db + ports: + - 3306:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + env: + MYSQL_DATABASE_URL: mysql://root:password@0.0.0.0:3306/test_db + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build + - name: Run tests + run: cargo test --verbose diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b154672 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +Cargo.lock +/target +/.idea +/.vscode \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..e0e3e02 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rdbc2" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +mysql = "23.0.1" +mysql_common = "0.29.2" +postgres = "0.19.4" +rusqlite = { version = "0.28.0", features = ["bundled"] } +serde = { version = "1.0.15", features = ["derive", "rc"] } +serde_json = "1.0.94" +sqlparser = "0.32.0" +tokio = { version = "1.25.0", features = ["macros", "rt", "fs"] } + diff --git a/src/dbc.rs b/src/dbc.rs new file mode 100644 index 0000000..305a9d0 --- /dev/null +++ b/src/dbc.rs @@ -0,0 +1,217 @@ +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use serde_json; + +mod mysql; +mod postgres; +mod sqlite; + +pub type Error = Box; + +pub trait Connection { + fn execute(&mut self, query: &str) -> Result; +} + +pub struct Database { + pub(crate) url: String, + pub(crate) connection: Box, +} + +impl Database { + pub fn new(url: &str) -> Result { + let connection = match url { + url if url.starts_with("mysql://") => mysql::MySQLConnection::get_connection(url)?, + url if url.starts_with("sqlite://") => sqlite::SQLiteConnection::get_connection(url)?, + _ => return Err("Unsupported dbc type".into()), + }; + + Ok(Database { + url: url.to_string(), + connection, + }) + } + + pub fn execute_query(&mut self, query: &str) -> Result { + self.connection.execute(query) + } + + pub fn execute_query_and_serialize(&mut self, query: &str) -> Result { + let result = self.execute_query(query)?; + Ok(serde_json::to_string(&result)?) + } + + pub fn execute_query_and_serialize_raw(&mut self, query: &str) -> Result, Error> { + let result = self.execute_query(query)?; + Ok(serde_json::to_vec(&result)?) + } + + pub fn execute_query_with_params( + &mut self, + query: &str, + params: &[&str], + ) -> Result { + let mut query = query.to_string(); + for param in params { + query = query.replace("?", param); + } + self.execute_query(&query) + } + + pub fn execute_query_with_params_and_serialize( + &mut self, + query: &str, + params: &[&str], + ) -> Result { + let result = self.execute_query_with_params(query, params)?; + Ok(serde_json::to_string(&result)?) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum Value { + NULL, + Bytes(Vec), + String(String), + Bool(bool), + Int(i64), + UInt(u64), + Float(f32), + Double(f64), + /// year, month, day, hour, minutes, seconds, micro seconds + Date(u16, u8, u8, u8, u8, u8, u32), + /// is negative, days, hours, minutes, seconds, micro seconds + Time(bool, u32, u8, u8, u8, u32), +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Column { + pub name: String, + pub column_type: ColumnType, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Row { + values: Vec, + columns: Arc<[Column]>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct QueryResult { + pub rows: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum ColumnType { + NULL, + NUMERIC, + DECIMAL, + DECIMAL_NEW, + INT, + INT2, + INT4, + INT8, + INT24, + TINY, + SHORT, + LONG, + LONGLONG, + FLOAT, + FLOAT4, + FLOAT8, + BIT, + BOOL, + DOUBLE, + STRING, + VARCHAR, + TEXT, + CHAR, + BYTEA, + TIMESTAMP, + TIMESTAMPTZ, + DATE, + TIME, + TIMETZ, + INTERVAL, + YEAR, + DATETIME, + JSON, + JSONB, + ENUM, + SET, + BLOB, + BLOB_TINY, + BLOB_MEDIUM, + BLOB_LONG, + GEOMETRY, + UUID, + OID, + XML, + CIDR, + INET, + MACADDR, + VARBIT, + REFCURSOR, +} + +macro_rules! add_postgres_types { + ($( $existing_variant:ident ),* ) => { + pub enum ColumnType { + $( + $existing_variant, + )* + Postgres(postgres::types::Type), + } + }; +} + +add_postgres_types!( + NULL, + NUMERIC, + DECIMAL, + DECIMAL_NEW, + INT, + INT2, + INT4, + INT8, + INT24, + TINY, + SHORT, + LONG, + LONGLONG, + FLOAT, + FLOAT4, + FLOAT8, + BIT, + BOOL, + DOUBLE, + STRING, + VARCHAR, + TEXT, + CHAR, + BYTEA, + TIMESTAMP, + TIMESTAMPTZ, + DATE, + TIME, + TIMETZ, + INTERVAL, + YEAR, + DATETIME, + JSON, + JSONB, + ENUM, + SET, + BLOB, + BLOB_TINY, + BLOB_MEDIUM, + BLOB_LONG, + GEOMETRY, + UUID, + OID, + XML, + CIDR, + INET, + MACADDR, + VARBIT, +); diff --git a/src/dbc/mysql.rs b/src/dbc/mysql.rs new file mode 100644 index 0000000..fd8a6f2 --- /dev/null +++ b/src/dbc/mysql.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use mysql; +use mysql::prelude::Queryable; +use mysql_common::constants::ColumnType; + +use crate::dbc; + +pub(crate) struct MySQLConnection { + connection: mysql::Conn, +} + +impl MySQLConnection { + pub(crate) fn get_connection(url: &str) -> Result, dbc::Error> { + Ok(Box::new(MySQLConnection { + connection: mysql::Conn::new(url)?, + }) as Box) + } +} + +impl dbc::Connection for MySQLConnection { + fn execute(&mut self, query: &str) -> Result { + let result = self.connection.query_iter(query)?; + let columns = result + .columns() + .as_ref() + .iter() + .map(|column| dbc::Column { + name: column.name_str().to_string(), + column_type: column.column_type().into(), + }) + .collect::>(); + let columns = Arc::from(columns); + + let mut rows: Vec = Vec::new(); + for row in result { + let row = row?; + let values: Vec = row + .unwrap_raw() + .iter() + .map(|value| { + if value.is_none() { + dbc::Value::NULL + } else { + value.as_ref().unwrap().into() + } + }) + .collect(); + rows.push(dbc::Row { + values, + columns: Arc::clone(&columns), + }); + } + Ok(dbc::QueryResult { rows }) + } +} + +impl From<&mysql::Value> for dbc::Value { + fn from(value: &mysql::Value) -> Self { + match value { + mysql::Value::NULL => dbc::Value::NULL, + mysql::Value::Bytes(bytes) => dbc::Value::Bytes(bytes.clone()), + mysql::Value::Int(int) => dbc::Value::Int(*int), + mysql::Value::UInt(uint) => dbc::Value::UInt(*uint), + mysql::Value::Float(float) => dbc::Value::Float(*float), + mysql::Value::Double(double) => dbc::Value::Double(*double), + mysql::Value::Date(year, month, day, hour, minute, second, microsecond) => { + dbc::Value::Date(*year, *month, *day, *hour, *minute, *second, *microsecond) + } + mysql::Value::Time(negative, days, hours, minutes, seconds, microseconds) => { + dbc::Value::Time(*negative, *days, *hours, *minutes, *seconds, *microseconds) + } + } + } +} + +impl From for dbc::ColumnType { + fn from(column_type: ColumnType) -> Self { + match column_type { + ColumnType::MYSQL_TYPE_DECIMAL => dbc::ColumnType::DECIMAL, + ColumnType::MYSQL_TYPE_TINY => dbc::ColumnType::TINY, + ColumnType::MYSQL_TYPE_SHORT => dbc::ColumnType::SHORT, + ColumnType::MYSQL_TYPE_LONG => dbc::ColumnType::LONG, + ColumnType::MYSQL_TYPE_FLOAT => dbc::ColumnType::FLOAT, + ColumnType::MYSQL_TYPE_DOUBLE => dbc::ColumnType::DOUBLE, + ColumnType::MYSQL_TYPE_NULL => dbc::ColumnType::NULL, + ColumnType::MYSQL_TYPE_TIMESTAMP => dbc::ColumnType::TIMESTAMP, + ColumnType::MYSQL_TYPE_LONGLONG => dbc::ColumnType::LONGLONG, + ColumnType::MYSQL_TYPE_INT24 => dbc::ColumnType::INT24, + ColumnType::MYSQL_TYPE_DATE => dbc::ColumnType::DATE, + ColumnType::MYSQL_TYPE_TIME => dbc::ColumnType::TIME, + ColumnType::MYSQL_TYPE_DATETIME => dbc::ColumnType::DATETIME, + ColumnType::MYSQL_TYPE_YEAR => dbc::ColumnType::YEAR, + ColumnType::MYSQL_TYPE_NEWDATE => dbc::ColumnType::DATE, // Internal? do we need this? + ColumnType::MYSQL_TYPE_VARCHAR => dbc::ColumnType::VARCHAR, + ColumnType::MYSQL_TYPE_BIT => dbc::ColumnType::BIT, + ColumnType::MYSQL_TYPE_TIMESTAMP2 => dbc::ColumnType::TIMESTAMP, // Internal? do we need this? + ColumnType::MYSQL_TYPE_DATETIME2 => dbc::ColumnType::DATETIME, // Internal? do we need this? + ColumnType::MYSQL_TYPE_TIME2 => dbc::ColumnType::TIME, // Internal? do we need this? + ColumnType::MYSQL_TYPE_JSON => dbc::ColumnType::JSON, + ColumnType::MYSQL_TYPE_NEWDECIMAL => dbc::ColumnType::DECIMAL_NEW, + ColumnType::MYSQL_TYPE_ENUM => dbc::ColumnType::ENUM, + ColumnType::MYSQL_TYPE_SET => dbc::ColumnType::SET, + ColumnType::MYSQL_TYPE_TINY_BLOB => dbc::ColumnType::BLOB_TINY, + ColumnType::MYSQL_TYPE_MEDIUM_BLOB => dbc::ColumnType::BLOB_MEDIUM, + ColumnType::MYSQL_TYPE_LONG_BLOB => dbc::ColumnType::BLOB_LONG, + ColumnType::MYSQL_TYPE_BLOB => dbc::ColumnType::BLOB, + ColumnType::MYSQL_TYPE_VAR_STRING => dbc::ColumnType::VARCHAR, + ColumnType::MYSQL_TYPE_STRING => dbc::ColumnType::STRING, + ColumnType::MYSQL_TYPE_GEOMETRY => dbc::ColumnType::GEOMETRY, + _ => { + panic!("Unknown column type: {:?}", column_type); + } + } + } +} diff --git a/src/dbc/postgres.rs b/src/dbc/postgres.rs new file mode 100644 index 0000000..0e2db4d --- /dev/null +++ b/src/dbc/postgres.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use postgres; + +use crate::dbc; + +pub struct PostgresConnection { + pub(crate) connection: postgres::Client, +} + +impl PostgresConnection { + pub(crate) fn get_connection(url: &str) -> Result, dbc::Error> { + Ok(Box::new(PostgresConnection { + connection: postgres::Client::connect(url, postgres::NoTls)?, + }) as Box) + } +} + +impl dbc::Connection for PostgresConnection { + fn execute(&mut self, query: &str) -> Result { + let result = self.connection.query(query, &[])?; + let columns = result + .columns() + .iter() + .map(|column| dbc::Column { + name: column.name().to_string(), + column_type: column.type_().into(), + }) + .collect::>(); + let columns = Arc::from(columns); + + let mut rows: Vec = Vec::new(); + for row in result { + let values: Vec = row + .iter() + .map(|value| { + if value.is_none() { + dbc::Value::NULL + } else { + value.unwrap().into() + } + }) + .collect(); + rows.push(dbc::Row { + values, + columns: Arc::clone(&columns), + }); + } + Ok(dbc::QueryResult { rows }) + } +} + +impl From for dbc::ColumnType { + fn from(value: postgres::types::Type) -> Self { + match value { + postgres::types::Type::BOOL => dbc::ColumnType::BOOL, + postgres::types::Type::INT2 => dbc::ColumnType::INT2, + postgres::types::Type::INT4 => dbc::ColumnType::INT4, + postgres::types::Type::INT8 => dbc::ColumnType::INT8, + postgres::types::Type::FLOAT4 => dbc::ColumnType::FLOAT4, + postgres::types::Type::FLOAT8 => dbc::ColumnType::FLOAT8, + postgres::types::Type::NUMERIC => dbc::ColumnType::NUMERIC, + postgres::types::Type::TIMESTAMP => dbc::ColumnType::TIMESTAMP, + postgres::types::Type::TIMESTAMPTZ => dbc::ColumnType::TIMESTAMPTZ, + postgres::types::Type::DATE => dbc::ColumnType::DATE, + postgres::types::Type::TIME => dbc::ColumnType::TIME, + postgres::types::Type::TIMETZ => dbc::ColumnType::TIMETZ, + postgres::types::Type::INTERVAL => dbc::ColumnType::INTERVAL, + postgres::types::Type::TEXT => dbc::ColumnType::TEXT, + postgres::types::Type::CHAR => dbc::ColumnType::CHAR, + postgres::types::Type::VARCHAR => dbc::ColumnType::VARCHAR, + postgres::types::Type::BYTEA => dbc::ColumnType::BYTEA, + postgres::types::Type::UUID => dbc::ColumnType::UUID, + postgres::types::Type::JSON => dbc::ColumnType::JSON, + postgres::types::Type::JSONB => dbc::ColumnType::JSONB, + postgres::types::Type::XML => dbc::ColumnType::XML, + postgres::types::Type::OID => dbc::ColumnType::OID, + postgres::types::Type::CIDR => dbc::ColumnType::CIDR, + postgres::types::Type::INET => dbc::ColumnType::INET, + postgres::types::Type::MACADDR => dbc::ColumnType::MACADDR, + postgres::types::Type::BIT => dbc::ColumnType::BIT, + postgres::types::Type::VARBIT => dbc::ColumnType::VARBIT, + _ => dbc::ColumnType::UNKNOWN, + } + } +} diff --git a/src/dbc/sqlite.rs b/src/dbc/sqlite.rs new file mode 100644 index 0000000..baddcab --- /dev/null +++ b/src/dbc/sqlite.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use rusqlite; + +use crate::dbc; + +pub(crate) struct SQLiteConnection { + connection: rusqlite::Connection, +} + +impl SQLiteConnection { + pub(crate) fn get_connection(url: &str) -> Result, dbc::Error> { + let connection; + if url == "sqlite://:memory:" { + connection = rusqlite::Connection::open_in_memory()?; + } else { + connection = rusqlite::Connection::open(url)?; + } + Ok(Box::new(SQLiteConnection { + connection, + }) as Box) + } +} + +impl dbc::Connection for SQLiteConnection { + fn execute(&mut self, query: &str) -> Result { + let mut statement = self.connection.prepare(query)?; + let columns = statement.column_names().iter().map( + |column| { + dbc::Column { + name: column.to_string(), + column_type: dbc::ColumnType::STRING, // TODO: get column type + } + } + ).collect::>(); + let columns = Arc::from(columns); + let num_columns = statement.column_count(); + + let mut rows: Vec = Vec::new(); + let mut result = statement.query([])?; + while let Some(row) = result.next()? { + let mut values: Vec = Vec::new(); + for i in 0..num_columns { + let value = row.get_ref(i).unwrap().into(); + values.push(value); + } + + rows.push(dbc::Row { + values, + columns: Arc::clone(&columns), + }); + } + Ok(dbc::QueryResult { + rows, + }) + } +} + +impl From> for dbc::Value { + fn from(value: rusqlite::types::ValueRef) -> Self { + match value { + rusqlite::types::ValueRef::Null => dbc::Value::NULL, + rusqlite::types::ValueRef::Integer(i) => dbc::Value::Int(i), + rusqlite::types::ValueRef::Real(f) => dbc::Value::Double(f), + rusqlite::types::ValueRef::Text(s) => dbc::Value::Bytes(s.to_vec()), + rusqlite::types::ValueRef::Blob(b) => dbc::Value::Bytes(b.to_vec()), + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..096714d --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod dbc; \ No newline at end of file diff --git a/tests/mysql_test.rs b/tests/mysql_test.rs new file mode 100644 index 0000000..1e7b696 --- /dev/null +++ b/tests/mysql_test.rs @@ -0,0 +1,44 @@ +use rdbc2::dbc; + +type Error = Box; + +fn _get_mysql_connection_url() -> String { + if std::env::var("MYSQL_DATABASE_URL").is_ok() { + std::env::var("MYSQL_DATABASE_URL").unwrap() + } else { + "mysql://localhost:3306/?user=nociza&password=password".to_owned() + } +} + +#[tokio::test] +async fn test_mysql_simple_query() -> Result<(), Error> { + let url = _get_mysql_connection_url(); + let mut database = dbc::Database::new(url.as_str())?; + let query = "SELECT 1"; + let result = database.execute_query(query)?; + assert_eq!(result.rows.len(), 1); + + Ok(()) +} + +#[tokio::test] +async fn test_mysql_query_with_params() -> Result<(), Error> { + let url = _get_mysql_connection_url(); + let mut database = dbc::Database::new(url.as_str())?; + let query = "SELECT ? + ?"; + let result = database.execute_query_with_params(query, &["1", "2"])?; + assert_eq!(result.rows.len(), 1); + + Ok(()) +} + +#[tokio::test] +async fn test_mysql_query_with_params_and_serialize() -> Result<(), Error> { + let url = _get_mysql_connection_url(); + let mut database = dbc::Database::new(url.as_str())?; + let query = "SELECT ? + ?"; + let result = database.execute_query_with_params_and_serialize(query, &["1", "2"])?; + assert_eq!(result, r#"{"rows":[{"values":[{"Bytes":[50]}],"columns":[{"name":"1 + 1","column_type":"LONGLONG"}]}]}"#); + + Ok(()) +} \ No newline at end of file diff --git a/tests/sqlite_test.rs b/tests/sqlite_test.rs new file mode 100644 index 0000000..2f519df --- /dev/null +++ b/tests/sqlite_test.rs @@ -0,0 +1,33 @@ +use rdbc2::dbc; + +type Error = Box; + +#[tokio::test] +async fn test_sqlite_simple_query() -> Result<(), Error> { + let mut database = dbc::Database::new("sqlite://:memory:")?; + let query = "SELECT 1"; + let result = database.execute_query(query)?; + assert_eq!(result.rows.len(), 1); + + Ok(()) +} + +#[tokio::test] +async fn test_sqlite_query_with_params() -> Result<(), Error> { + let mut database = dbc::Database::new("sqlite://:memory:")?; + let query = "SELECT ? + ?"; + let result = database.execute_query_with_params(query, &["1", "2"])?; + assert_eq!(result.rows.len(), 1); + + Ok(()) +} + +#[tokio::test] +async fn test_sqlite_query_with_params_and_serialize() -> Result<(), Error> { + let mut database = dbc::Database::new("sqlite://:memory:")?; + let query = "SELECT ? + ?"; + let result = database.execute_query_with_params_and_serialize(query, &["1", "2"])?; + assert_eq!(result, r#"{"rows":[{"values":[{"Int":2}],"columns":[{"name":"1 + 1","column_type":"STRING"}]}]}"#); // currently all columns are STRING + + Ok(()) +} \ No newline at end of file