From aa993499ba68602ae2e24c2ae425847309d5d757 Mon Sep 17 00:00:00 2001 From: henry0715-dev Date: Wed, 4 Dec 2024 11:04:06 +0900 Subject: [PATCH] Added `IpAddress` GraphQL scalar for IP addresses Close #353 `Nic` and `NicInput` that are not used as additional content have been removed. --- CHANGELOG.md | 10 ++++ src/graphql.rs | 29 +++++++++- src/graphql/account.rs | 111 +++++++++++++++++++++++++++++-------- src/graphql/ip_location.rs | 41 +++++++------- src/graphql/node.rs | 72 +----------------------- src/graphql/sampling.rs | 100 ++++++++++++++++++++++++++++----- 6 files changed, 232 insertions(+), 131 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1eea852..c6da21c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,16 @@ Versioning](https://semver.org/spec/v2.0.0.html). - Added the `updateTrustedDomain` GraphQL API, allowing users to modify a trusted domain. +- Added `IpAddress` GraphQL custom scalar for IP addresses. + - Applied it to the GraphQL APIs `ipLocation`, `ipLocationList`, + `insertAccount`, `updateAccount`, `insertSamplingPolicy`, and + `updateSamplingPolicy`. + - The API returns the following error message when a value cannot be parsed as + an `IpAddr` (e.g., when "abc" is given): + ```text + Failed to parse "IpAddress": Invalid IP address: abc (occurred while + parsing "[IpAddress!]") + ``` ### Changed diff --git a/src/graphql.rs b/src/graphql.rs index 6f023300..6e6027fa 100644 --- a/src/graphql.rs +++ b/src/graphql.rs @@ -36,14 +36,17 @@ mod trusted_user_agent; use std::fmt; use std::future::Future; +use std::net::IpAddr; #[cfg(test)] use std::net::SocketAddr; use std::sync::{Arc, Mutex}; -use async_graphql::connection::{ConnectionNameType, CursorType, EdgeNameType, OpaqueCursor}; +use async_graphql::connection::{ + Connection, ConnectionNameType, CursorType, Edge, EdgeNameType, EmptyFields, OpaqueCursor, +}; use async_graphql::{ - connection::{Connection, Edge, EmptyFields}, - Context, Guard, MergedObject, MergedSubscription, ObjectType, OutputType, Result, + Context, Guard, InputValueError, InputValueResult, MergedObject, MergedSubscription, + ObjectType, OutputType, Result, Scalar, ScalarType, Value, }; use chrono::TimeDelta; use num_traits::ToPrimitive; @@ -504,6 +507,26 @@ impl Guard for RoleGuard { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct IpAddress(pub IpAddr); + +#[Scalar] +impl ScalarType for IpAddress { + fn parse(value: Value) -> InputValueResult { + match value { + Value::String(s) => s + .parse::() + .map(IpAddress) + .map_err(|_| InputValueError::custom(format!("Invalid IP address: {s}"))), + _ => Err(InputValueError::expected_type(value)), + } + } + + fn to_value(&self) -> Value { + Value::String(self.0.to_string()) + } +} + fn fill_vacant_time_slots(series: &[database::TimeCount]) -> Vec { let mut filled_series: Vec = Vec::new(); diff --git a/src/graphql/account.rs b/src/graphql/account.rs index 2e6941ae..be912309 100644 --- a/src/graphql/account.rs +++ b/src/graphql/account.rs @@ -1,6 +1,6 @@ use std::{ env, - net::{AddrParseError, IpAddr, SocketAddr}, + net::{IpAddr, SocketAddr}, }; use anyhow::anyhow; @@ -18,7 +18,7 @@ use review_database::{ use serde::Serialize; use tracing::info; -use super::RoleGuard; +use super::{IpAddress, RoleGuard}; use crate::auth::{create_token, decode_token, insert_token, revoke_token, update_jwt_expires_in}; use crate::graphql::query_with_constraints; @@ -157,7 +157,7 @@ impl AccountMutation { name: String, department: String, language: Option, - allow_access_from: Option>, + allow_access_from: Option>, max_parallel_sessions: Option, ) -> Result { let store = crate::graphql::get_store(ctx).await?; @@ -166,7 +166,7 @@ impl AccountMutation { return Err("account already exists".into()); } let allow_access_from = if let Some(ip_addrs) = allow_access_from { - let ip_addrs = strings_to_ip_addrs(&ip_addrs)?; + let ip_addrs = to_ip_addr(&ip_addrs); Some(ip_addrs) } else { None @@ -272,16 +272,8 @@ impl AccountMutation { let dept = department.map(|d| (d.old, d.new)); let language = language.map(|d| (d.old, d.new)); let allow_access_from = if let Some(ip_addrs) = allow_access_from { - let old = if let Some(old) = ip_addrs.old { - Some(strings_to_ip_addrs(&old)?) - } else { - None - }; - let new = if let Some(new) = ip_addrs.new { - Some(strings_to_ip_addrs(&new)?) - } else { - None - }; + let old = ip_addrs.old.map(|old| to_ip_addr(&old)); + let new = ip_addrs.new.map(|new| to_ip_addr(&new)); Some((old, new)) } else { None @@ -639,13 +631,14 @@ impl From for Account { } } -fn strings_to_ip_addrs(ip_addrs: &[String]) -> Result, AddrParseError> { +fn to_ip_addr(ip_addrs: &[IpAddress]) -> Vec { let mut ip_addrs = ip_addrs .iter() - .map(|ip_addr| ip_addr.parse::()) - .collect::, _>>()?; - ip_addrs.sort(); - Ok(ip_addrs) + .map(|ip_addr| ip_addr.0) + .collect::>(); + ip_addrs.sort_unstable(); + ip_addrs.dedup(); + ip_addrs } #[derive(SimpleObject)] @@ -693,8 +686,8 @@ struct UpdateLanguage { /// The old and new values of `allowAccessFrom` to update. #[derive(InputObject)] struct UpdateAllowAccessFrom { - old: Option>, - new: Option>, + old: Option>, + new: Option>, } /// The old and new values of `maxParallelSessions` to update. @@ -1360,7 +1353,8 @@ mod tests { role: "SECURITY_ADMINISTRATOR", name: "John Doe", department: "Security", - language: "en-US" + language: "en-US", + allowAccessFrom: ["127.0.0.1"] ) }"#, ) @@ -1410,6 +1404,10 @@ mod tests { language: { old: "en-US", new: "ko-KR" + }, + allowAccessFrom: { + old: "127.0.0.1", + new: "127.0.0.2" } ) }"#, @@ -1428,6 +1426,7 @@ mod tests { name department language + allowAccessFrom } }"#, ) @@ -1435,7 +1434,45 @@ mod tests { assert_eq!( res.data.to_string(), - r#"{account: {username: "username", role: SYSTEM_ADMINISTRATOR, name: "Loren Ipsum", department: "Admin", language: "ko-KR"}}"# + r#"{account: {username: "username", role: SYSTEM_ADMINISTRATOR, name: "Loren Ipsum", department: "Admin", language: "ko-KR", allowAccessFrom: ["127.0.0.2"]}}"# + ); + + let res = schema + .execute( + r#" + mutation { + updateAccount( + username: "username", + password: "password", + role: { + old: "SECURITY_ADMINISTRATOR", + new: "SYSTEM_ADMINISTRATOR" + }, + name: { + old: "John Doe", + new: "Loren Ipsum" + }, + department: { + old: "Security", + new: "Admin" + }, + language: { + old: "en-US", + new: "ko-KR" + }, + allowAccessFrom: { + old: "127.0.0.2", + new: "127.0.0.x" + } + ) + }"#, + ) + .await; + assert_eq!( + res.errors.first().unwrap().message.to_string(), + "Failed to parse \"IpAddress\": Invalid IP address: 127.0.0.x (occurred while \ + parsing \"[IpAddress!]\") (occurred while parsing \"UpdateAllowAccessFrom\")" + .to_string() ); } @@ -1587,6 +1624,34 @@ mod tests { assert!(res.is_err()); } + #[tokio::test] + async fn invalid_ip_allow_access_from() { + let agent_manager: BoxedAgentManager = Box::new(MockAgentManager {}); + let test_addr: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + + let schema = TestSchema::new_with(agent_manager, Some(test_addr)).await; + let res = schema + .execute( + r#"mutation { + insertAccount( + username: "u1", + password: "pw1", + role: "SECURITY_ADMINISTRATOR", + name: "User One", + department: "Test", + allowAccessFrom: ["127.0.0.x"] + ) + }"#, + ) + .await; + assert_eq!( + res.errors.first().unwrap().message.to_string(), + "Failed to parse \"IpAddress\": Invalid IP address: 127.0.0.x (occurred while \ + parsing \"[IpAddress!]\")" + .to_string() + ); + } + #[tokio::test] async fn language() { let schema = TestSchema::new().await; diff --git a/src/graphql/ip_location.rs b/src/graphql/ip_location.rs index 1033ae6b..8c306a2b 100644 --- a/src/graphql/ip_location.rs +++ b/src/graphql/ip_location.rs @@ -1,11 +1,8 @@ -use std::{ - net::IpAddr, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; use async_graphql::{Context, Object, Result, SimpleObject}; -use super::{Role, RoleGuard}; +use super::{IpAddress, Role, RoleGuard}; const MAX_NUM_IP_LOCATION_LIST: usize = 200; #[derive(Default)] @@ -18,10 +15,12 @@ impl IpLocationQuery { .or(RoleGuard::new(Role::SecurityAdministrator)) .or(RoleGuard::new(Role::SecurityManager)) .or(RoleGuard::new(Role::SecurityMonitor))")] - async fn ip_location(&self, ctx: &Context<'_>, address: String) -> Result> { - let Ok(addr) = address.parse::() else { - return Err("invalid IP address".into()); - }; + async fn ip_location( + &self, + ctx: &Context<'_>, + address: IpAddress, + ) -> Result> { + let addr = address.0; let Ok(mutex) = ctx.data::>>() else { return Err("IP location database unavailable".into()); }; @@ -47,7 +46,7 @@ impl IpLocationQuery { async fn ip_location_list( &self, ctx: &Context<'_>, - mut addresses: Vec, + mut addresses: Vec, ) -> Result> { let Ok(mutex) = ctx.data::>>() else { return Err("IP location database unavailable".into()); @@ -59,19 +58,17 @@ impl IpLocationQuery { .map_err(|_| "Failed to lock IP location database")?; let records = addresses .iter() - .filter_map(|address| { - address.parse::().ok().and_then(|addr| { - locator - .ip_lookup(addr) - .ok() - .map(std::convert::TryInto::try_into) - .and_then(|r| { - r.ok().map(|location| IpLocationItem { - address: address.clone(), - location, - }) + .filter_map(|addr| { + locator + .ip_lookup(addr.0) + .ok() + .map(std::convert::TryInto::try_into) + .and_then(|r| { + r.ok().map(|location| IpLocationItem { + address: addr.0.to_string(), + location, }) - }) + }) }) .collect(); diff --git a/src/graphql/node.rs b/src/graphql/node.rs index d66a0fea..101ab8cf 100644 --- a/src/graphql/node.rs +++ b/src/graphql/node.rs @@ -4,11 +4,10 @@ mod input; mod process; mod status; -use std::{borrow::Cow, net::IpAddr, time::Duration}; +use std::{borrow::Cow, time::Duration}; use async_graphql::{ - types::ID, ComplexObject, Context, Enum, InputObject, Object, Result, SimpleObject, - StringNumber, + types::ID, ComplexObject, Context, Enum, Object, Result, SimpleObject, StringNumber, }; use bincode::Options; use chrono::{DateTime, TimeZone, Utc}; @@ -16,7 +15,6 @@ use chrono::{DateTime, TimeZone, Utc}; pub use crud::get_customer_id_of_node; use database::Indexable; use input::NodeInput; -use ipnet::Ipv4Net; use review_database as database; use roxy::Process as RoxyProcess; use serde::{Deserialize, Serialize}; @@ -36,72 +34,6 @@ pub(super) struct NodeControlMutation; #[derive(Default)] pub(super) struct ProcessListQuery; -#[derive(Clone, Deserialize, Serialize, SimpleObject, PartialEq)] -#[graphql(complex)] -struct Nic { - name: String, - #[graphql(skip)] - interface: Ipv4Net, - #[graphql(skip)] - gateway: IpAddr, -} - -#[ComplexObject] -impl Nic { - async fn interface(&self) -> String { - self.interface.to_string() - } - - async fn gateway(&self) -> String { - self.gateway.to_string() - } -} - -#[derive(Clone, InputObject)] -struct NicInput { - name: String, - interface: String, - gateway: String, -} - -impl PartialEq for NicInput { - fn eq(&self, rhs: &Nic) -> bool { - self.name == rhs.name - && self - .interface - .as_str() - .parse::() - .map_or(false, |ip| ip == rhs.interface) - && self - .gateway - .as_str() - .parse::() - .map_or(false, |ip| ip == rhs.gateway) - } -} - -impl TryFrom for Nic { - type Error = anyhow::Error; - - fn try_from(input: NicInput) -> Result { - (&input).try_into() - } -} - -impl TryFrom<&NicInput> for Nic { - type Error = anyhow::Error; - - fn try_from(input: &NicInput) -> Result { - let interface = input.interface.as_str().parse::()?; - let gateway = input.gateway.as_str().parse::()?; - Ok(Self { - name: input.name.clone(), - interface, - gateway, - }) - } -} - #[derive(Clone, Deserialize, PartialEq, Serialize, Copy, Eq, Enum)] #[graphql(remote = "database::AgentKind")] pub enum AgentKind { diff --git a/src/graphql/sampling.rs b/src/graphql/sampling.rs index 140f2565..8fa802c0 100644 --- a/src/graphql/sampling.rs +++ b/src/graphql/sampling.rs @@ -10,7 +10,7 @@ use chrono::{DateTime, Utc}; use review_database::{Direction, Iterable}; use serde::{Deserialize, Serialize}; -use super::{BoxedAgentManager, Role, RoleGuard}; +use super::{BoxedAgentManager, IpAddress, Role, RoleGuard}; use crate::graphql::query_with_constraints; #[derive(Default)] @@ -206,8 +206,8 @@ pub(super) struct SamplingPolicyInput { pub interval: Interval, pub period: Period, pub offset: i32, - pub src_ip: Option, - pub dst_ip: Option, + pub src_ip: Option, + pub dst_ip: Option, pub node: Option, // hostname pub column: Option, pub immutable: bool, @@ -223,8 +223,8 @@ impl TryFrom for review_database::SamplingPolicyUpdate { interval: input.interval.into(), period: input.period.into(), offset: input.offset, - src_ip: input.src_ip.map(|ip| ip.parse::()).transpose()?, - dst_ip: input.dst_ip.map(|ip| ip.parse::()).transpose()?, + src_ip: input.src_ip.map(|ip| ip.0), + dst_ip: input.dst_ip.map(|ip| ip.0), node: input.node, column: input.column, immutable: input.immutable, @@ -344,8 +344,8 @@ impl SamplingPolicyMutation { interval: Interval, period: Period, offset: i32, - src_ip: Option, - dst_ip: Option, + src_ip: Option, + dst_ip: Option, node: Option, column: Option, immutable: bool, @@ -357,8 +357,8 @@ impl SamplingPolicyMutation { interval: interval.into(), period: period.into(), offset, - src_ip: src_ip.map(|ip| ip.parse::()).transpose()?, - dst_ip: dst_ip.map(|ip| ip.parse::()).transpose()?, + src_ip: src_ip.map(|ip| ip.0), + dst_ip: dst_ip.map(|ip| ip.0), node, column, immutable, @@ -470,7 +470,9 @@ mod tests { period: ONE_DAY, offset: 0, node: "sensor", - immutable: false + immutable: false, + srcIp: "127.0.0.1", + dstIp: "127.0.0.2" ) } "#, @@ -478,6 +480,30 @@ mod tests { .await; assert_eq!(res.data.to_string(), r#"{insertSamplingPolicy: "0"}"#); + let res = schema + .execute( + r#" + mutation { + insertSamplingPolicy( + name: "Policy 2", + kind: CONN, + interval: FIFTEEN_MINUTES, + period: ONE_DAY, + offset: 0, + node: "sensor", + immutable: false, + srcIp: "127.0.0.1", + dstIp: "127.0.0.x" + ) + } + "#, + ) + .await; + assert_eq!( + res.errors.first().unwrap().message.to_string(), + "Failed to parse \"IpAddress\": Invalid IP address: 127.0.0.x".to_string() + ); + let res = schema .execute( r#" @@ -491,7 +517,9 @@ mod tests { period: ONE_DAY, offset: 0, node: "sensor", - immutable: false + immutable: false, + srcIp: "127.0.0.1", + dstIp: "127.0.0.2" }, new:{ name: "Policy 2", @@ -500,7 +528,9 @@ mod tests { period: ONE_DAY, offset: 0, node: "manager", - immutable: true + immutable: true, + srcIp: "127.0.0.1", + dstIp: "127.0.0.2" } ) } @@ -522,6 +552,8 @@ mod tests { offset node immutable + srcIp + dstIp } } } @@ -540,12 +572,54 @@ mod tests { "period": "ONE_DAY", "offset": 0, "node": "manager", - "immutable": true + "immutable": true, + "srcIp": "127.0.0.1", + "dstIp": "127.0.0.2", }] } }) ); + let res = schema + .execute( + r#" + mutation { + updateSamplingPolicy( + id: "0", + old: { + name: "Policy 2", + kind: CONN, + interval: FIFTEEN_MINUTES, + period: ONE_DAY, + offset: 0, + node: "manager", + immutable: true, + srcIp: "127.0.0.1", + dstIp: "127.0.0.2" + }, + new:{ + name: "Policy 3", + kind: CONN, + interval: FIFTEEN_MINUTES, + period: ONE_DAY, + offset: 0, + node: "manager", + immutable: true, + srcIp: "127.0.0.x", + dstIp: "127.0.0.2" + } + ) + } + "#, + ) + .await; + assert_eq!( + res.errors.first().unwrap().message.to_string(), + "Failed to parse \"IpAddress\": Invalid IP address: 127.0.0.x \ + (occurred while parsing \"SamplingPolicyInput\")" + .to_string() + ); + let res = schema .execute( r#"mutation {