diff --git a/Cargo.lock b/Cargo.lock index b10f969..892e5aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -198,6 +198,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -205,6 +220,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -213,6 +229,23 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + [[package]] name = "futures-sink" version = "0.3.30" @@ -231,10 +264,15 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -279,6 +317,12 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -681,6 +725,12 @@ dependencies = [ "mockito", "openapi", "regex", +<<<<<<< HEAD +======= + "serde_json", + "serial_test", + "snafu", +>>>>>>> origin/main "tokio", ] @@ -866,6 +916,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "scc" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76ad2bbb0ae5100a07b7a6f2ed7ab5fd0045551a4c507989b7a620046ea3efdc" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.23" @@ -881,6 +940,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b84345e4c9bd703274a082fb80caaa99b7612be48dfaa1dd9266577ec412309d" + [[package]] name = "security-framework" version = "2.11.0" @@ -947,6 +1012,31 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -971,6 +1061,27 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "snafu" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418b8136fec49956eba89be7da2847ec1909df92a9ae4178b5ff0ff092c8d95e" +dependencies = [ + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a4812a669da00d17d8266a0439eddcacbc88b17f732f927e52eeb9d196f7fb5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "socket2" version = "0.5.7" diff --git a/pinecone_sdk/Cargo.toml b/pinecone_sdk/Cargo.toml index 08f6aee..b6a6fa4 100644 --- a/pinecone_sdk/Cargo.toml +++ b/pinecone_sdk/Cargo.toml @@ -8,5 +8,15 @@ edition = "2021" [dependencies] openapi = { path = "../openapi" } tokio = { version = "1", features = ["full"] } +<<<<<<< HEAD mockito = "0.30" regex = "1.10.4" +======= +regex = "1.10.4" +serde_json = "1.0.117" +snafu = "0.8.3" + +[dev-dependencies] +mockito = "0.30" +serial_test = "3.1.1" +>>>>>>> origin/main diff --git a/pinecone_sdk/src/config.rs b/pinecone_sdk/src/config.rs index 0c80ab3..9bfea0e 100644 --- a/pinecone_sdk/src/config.rs +++ b/pinecone_sdk/src/config.rs @@ -1,16 +1,20 @@ +use std::collections::HashMap; + #[derive(Debug, Clone)] pub struct Config { pub api_key: String, pub controller_url: String, + pub additional_headers: HashMap, pub source_tag: Option, } impl Config { - pub fn new(api_key: String) -> Self { + pub fn new(api_key: String, source_tag: Option) -> Self { Config { api_key, controller_url: "https://api.pinecone.io".to_string(), - source_tag: None, + additional_headers: HashMap::new(), + source_tag, } } } diff --git a/pinecone_sdk/src/control/list_indexes.rs b/pinecone_sdk/src/control/list_indexes.rs index 10aeea2..6b9c1ab 100644 --- a/pinecone_sdk/src/control/list_indexes.rs +++ b/pinecone_sdk/src/control/list_indexes.rs @@ -15,16 +15,14 @@ impl Pinecone { #[cfg(test)] mod tests { use super::*; - use mockito::mock; - use tokio; - use crate::control::list_indexes::models::index_model::Metric; use crate::config::Config; - use openapi::models::IndexList; - use openapi::apis::configuration::Configuration; + use crate::control::list_indexes::models::index_model::Metric; + use mockito::mock; use openapi::apis::configuration::ApiKey; + use openapi::apis::configuration::Configuration; + use openapi::models::IndexList; use openapi::models::IndexModel; - - + use tokio; #[tokio::test] async fn test_list_indexes() { @@ -66,7 +64,8 @@ mod tests { // Construct Pinecone instance with the mock server URL let api_key = "test_api_key".to_string(); - let pinecone = Pinecone::new(api_key, Some(mockito::server_url())); + let pinecone = Pinecone::new(Some(api_key), Some(mockito::server_url()), None, None) + .expect("Failed to create Pinecone instance"); // Call list_indexes and verify the result let result = pinecone.list_indexes().await; diff --git a/pinecone_sdk/src/lib.rs b/pinecone_sdk/src/lib.rs index f26d99b..df1d5a8 100644 --- a/pinecone_sdk/src/lib.rs +++ b/pinecone_sdk/src/lib.rs @@ -2,4 +2,4 @@ pub mod config; pub mod control; pub mod pinecone; pub mod utils; -pub mod models; \ No newline at end of file +pub mod models; diff --git a/pinecone_sdk/src/pinecone.rs b/pinecone_sdk/src/pinecone.rs index e42dbbd..23e4751 100644 --- a/pinecone_sdk/src/pinecone.rs +++ b/pinecone_sdk/src/pinecone.rs @@ -1,7 +1,10 @@ use crate::config::Config; +use crate::utils::errors::PineconeError; use crate::utils::user_agent::get_user_agent; use openapi::apis::configuration::ApiKey; use openapi::apis::configuration::Configuration; +use serde_json; +use std::collections::HashMap; #[derive(Debug, Clone)] pub struct Pinecone { @@ -10,13 +13,54 @@ pub struct Pinecone { } impl Pinecone { - pub fn new(api_key: String, control_plane_host: Option) -> Self { - let config = Config::new(api_key.clone()); + pub fn new( + api_key: Option, + control_plane_host: Option, + additional_headers: Option>, + source_tag: Option, + ) -> Result { + // get api key + let api_key = match api_key { + Some(key) => key, + None => match std::env::var("PINECONE_API_KEY") { + Ok(key) => key, + Err(_) => { + return Err(PineconeError::APIKeyMissingError); + } + }, + }; + + let controller_host = control_plane_host.unwrap_or( + std::env::var("PINECONE_CONTROLLER_HOST") + .unwrap_or("https://api.pinecone.io".to_string()), + ); + + let additional_headers = match additional_headers { + Some(headers) => headers, + None => match std::env::var("PINECONE_ADDITIONAL_HEADERS") { + Ok(headers) => match serde_json::from_str(&headers) { + Ok(headers) => headers, + Err(json_error) => { + return Err(PineconeError::InvalidHeadersError { json_error }); + } + }, + Err(_) => HashMap::new(), + }, + }; + + let config = Config { + api_key: api_key.clone(), + controller_url: controller_host.clone(), + additional_headers, + source_tag, + }; + + let user_agent = get_user_agent(&config); let user_agent = get_user_agent(&config); let openapi_config = Configuration { - base_path: control_plane_host.unwrap_or("https://api.pinecone.io".to_string()), + base_path: controller_host, user_agent: Some(user_agent), api_key: Some(ApiKey { prefix: None, @@ -25,13 +69,281 @@ impl Pinecone { ..Default::default() }; - Pinecone { + Ok(Pinecone { config, openapi_config, - } + }) } pub fn openapi_config(&self) -> &Configuration { &self.openapi_config } } + +#[cfg(test)] +mod tests { + use std::env; + + use super::*; + use serial_test::serial; + use tokio; + + fn set_env_var(key: &str, value: &str) { + env::set_var(key, value); + assert!(env::var(key).is_ok()); + assert!(env::var(key).unwrap() == value); + } + + fn remove_env_var(key: &str) { + env::remove_var(key); + assert!(env::var(key).is_err()); + } + + #[tokio::test] + async fn test_arg_api_key() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + Some(HashMap::new()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!(pinecone.unwrap().config.api_key, mock_api_key.clone()); + } + + #[tokio::test] + #[serial] + async fn test_env_api_key() { + let mock_api_key = "mock-env-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + + set_env_var("PINECONE_API_KEY", mock_api_key.as_str()); + + let pinecone = Pinecone::new( + None, + Some(mock_controller_host.clone()), + Some(HashMap::new()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!(pinecone.unwrap().config.api_key, mock_api_key.clone()); + } + + #[tokio::test] + #[serial] + async fn test_no_api_key() { + let mock_controller_host = "mock-arg-controller-host".to_string(); + + remove_env_var("PINECONE_API_KEY"); + + let pinecone = Pinecone::new( + None, + Some(mock_controller_host.clone()), + Some(HashMap::new()), + None, + ); + + assert!(pinecone.is_err()); + assert!(matches!( + pinecone.err().unwrap(), + PineconeError::APIKeyMissingError + )); + } + + #[tokio::test] + async fn test_arg_host() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + Some(HashMap::new()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.unwrap().config.controller_url, + mock_controller_host.clone() + ); + } + + #[tokio::test] + #[serial] + async fn test_env_host() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-env-controller-host".to_string(); + + set_env_var("PINECONE_CONTROLLER_HOST", mock_controller_host.as_str()); + + let pinecone = Pinecone::new(Some(mock_api_key.clone()), None, Some(HashMap::new()), None); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.unwrap().config.controller_url, + mock_controller_host.clone() + ); + } + + #[tokio::test] + #[serial] + async fn test_default_host() { + let mock_api_key = "mock-arg-api-key".to_string(); + + remove_env_var("PINECONE_CONTROLLER_HOST"); + + let pinecone = Pinecone::new(Some(mock_api_key.clone()), None, Some(HashMap::new()), None); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.unwrap().config.controller_url, + "https://api.pinecone.io".to_string() + ); + } + + #[tokio::test] + async fn test_arg_headers() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + let mock_headers = HashMap::from([ + ("argheader1".to_string(), "value1".to_string()), + ("argheader2".to_string(), "value2".to_string()), + ]); + + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + Some(mock_headers.clone()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.unwrap().config.additional_headers, + mock_headers.clone() + ); + } + + #[tokio::test] + #[serial] + async fn test_env_headers() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + let mock_headers = HashMap::from([ + ("envheader1".to_string(), "value1".to_string()), + ("envheader2".to_string(), "value2".to_string()), + ]); + + set_env_var( + "PINECONE_ADDITIONAL_HEADERS", + serde_json::to_string(&mock_headers).unwrap().as_str(), + ); + + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + None, + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.unwrap().config.additional_headers, + mock_headers.clone() + ); + } + + #[tokio::test] + #[serial] + async fn test_invalid_env_headers() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + + set_env_var("PINECONE_ADDITIONAL_HEADERS", "invalid-json"); + + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + None, + None, + ); + + assert!(pinecone.is_err()); + assert!(matches!( + pinecone.err().unwrap(), + PineconeError::InvalidHeadersError { .. } + )); + } + + #[tokio::test] + #[serial] + async fn test_default_headers() { + let mock_api_key = "mock-arg-api-key".to_string(); + let mock_controller_host = "mock-arg-controller-host".to_string(); + + remove_env_var("PINECONE_ADDITIONAL_HEADERS"); + + let pinecone = Pinecone::new( + Some(mock_api_key.clone()), + Some(mock_controller_host.clone()), + Some(HashMap::new()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!(pinecone.unwrap().config.additional_headers, HashMap::new()); + } + + #[tokio::test] + #[serial] + async fn test_arg_overrides_env() { + let mock_arg_api_key = "mock-arg-api-key".to_string(); + let mock_arg_controller_host = "mock-arg-controller-host".to_string(); + let mock_arg_headers = HashMap::from([ + ("argheader1".to_string(), "value1".to_string()), + ("argheader2".to_string(), "value2".to_string()), + ]); + let mock_env_api_key = "mock-env-api-key".to_string(); + let mock_env_controller_host = "mock-env-controller-host".to_string(); + let mock_env_headers = HashMap::from([ + ("envheader1".to_string(), "value1".to_string()), + ("envheader2".to_string(), "value2".to_string()), + ]); + + set_env_var("PINECONE_API_KEY", mock_env_api_key.as_str()); + set_env_var( + "PINECONE_CONTROLLER_HOST", + mock_env_controller_host.as_str(), + ); + env::set_var( + "PINECONE_ADDITIONAL_HEADERS", + serde_json::to_string(&mock_env_headers).unwrap(), + ); + + let pinecone = Pinecone::new( + Some(mock_arg_api_key.clone()), + Some(mock_arg_controller_host.clone()), + Some(mock_arg_headers.clone()), + None, + ); + + assert!(pinecone.is_ok()); + assert_eq!( + pinecone.as_ref().unwrap().config.api_key, + mock_arg_api_key.clone() + ); + assert_eq!( + pinecone.as_ref().unwrap().config.controller_url, + mock_arg_controller_host.clone() + ); + assert_eq!( + pinecone.as_ref().unwrap().config.additional_headers, + mock_arg_headers.clone() + ); + } +} diff --git a/pinecone_sdk/src/utils/errors.rs b/pinecone_sdk/src/utils/errors.rs new file mode 100644 index 0000000..e97c306 --- /dev/null +++ b/pinecone_sdk/src/utils/errors.rs @@ -0,0 +1,10 @@ +use snafu::prelude::*; + +#[derive(Debug, Snafu)] +pub enum PineconeError { + #[snafu(display("API key missing."))] + APIKeyMissingError, + + #[snafu(display("Failed to parse headers: {}", json_error))] + InvalidHeadersError { json_error: serde_json::Error }, +} diff --git a/pinecone_sdk/src/utils/mod.rs b/pinecone_sdk/src/utils/mod.rs index c497f92..1eaf948 100644 --- a/pinecone_sdk/src/utils/mod.rs +++ b/pinecone_sdk/src/utils/mod.rs @@ -1 +1,2 @@ -pub mod user_agent; \ No newline at end of file +pub mod errors; +pub mod user_agent; diff --git a/pinecone_sdk/src/utils/user_agent.rs b/pinecone_sdk/src/utils/user_agent.rs index 5726c42..7fc1708 100644 --- a/pinecone_sdk/src/utils/user_agent.rs +++ b/pinecone_sdk/src/utils/user_agent.rs @@ -6,9 +6,9 @@ fn build_source_tag(source_tag: &String) -> String { // 1. Lowercase // 2. Limit charset to [a-z0-9_ ] // 3. Trim left/right empty space - // 4. Condence multiple spaces to one, and replace with underscore + // 4. Condense multiple spaces to one, and replace with underscore - let re = Regex::new(r"[^a-z0-9_ ]").unwrap(); + let re = Regex::new(r"[^a-z0-9_: ]").unwrap(); let lowercase_tag = source_tag.to_lowercase(); let tag = re.replace_all(&lowercase_tag, ""); return tag.trim() @@ -20,7 +20,7 @@ fn build_source_tag(source_tag: &String) -> String { // Gets user agent string pub fn get_user_agent(config: &Config) -> String { - let mut user_agent = format!("lang=rust/{}", "0.1.0"); + let mut user_agent = format!("lang=rust; pinecone-rust-client={}", "0.1.0"); if let Some(source_tag) = &config.source_tag { user_agent.push_str(&format!("; source_tag={}", build_source_tag(source_tag))); } @@ -38,23 +38,21 @@ mod tests { assert_eq!(build_source_tag(&source_tag), "hello_world"); } + #[tokio::test] + async fn test_build_source_tag_special_chars() { + let source_tag = " Hello World__:_!@#@# ".to_string(); + assert_eq!(build_source_tag(&source_tag), "hello_world__:_"); + } + #[tokio::test] async fn test_no_source_tag() { - let config = Config { - api_key: "api".to_string(), - controller_url: "https://api.pinecone.io".to_string(), - source_tag: None, - }; - assert_eq!(get_user_agent(&config), "lang=rust/0.1.0"); + let config = Config::new("api".to_string(), None); + assert_eq!(get_user_agent(&config), "lang=rust; pinecone-rust-client=0.1.0"); } #[tokio::test] async fn test_with_source_tag() { - let config = Config { - api_key: "api".to_string(), - controller_url: "https://api.pinecone.io".to_string(), - source_tag: Some("tag".to_string()), - }; - assert_eq!(get_user_agent(&config), "lang=rust/0.1.0; source_tag=tag"); + let config = Config::new("api".to_string(), Some("Tag".to_string())); + assert_eq!(get_user_agent(&config), "lang=rust; pinecone-rust-client=0.1.0; source_tag=tag"); } } \ No newline at end of file