diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bff4de0..3a9c4bc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ file is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.14.5] - 2023-11-02 + +### Changed + +- Modified Ranked Outliers graphql query to take in a SearchFilter with + distance range and time range + +### Added + +- Added new method for Ranked Outliers `load_ranked_outliers_with_filter`, + `load_nodes_with_search_filter`, and `iter_through_search_filter_nodes` + to load Ranked Outliers depending on new Search Filter. + ## [0.14.4] - 2023-10-19 ### Added @@ -307,6 +320,7 @@ across our system. - An initial version. +[0.14.5]: https://github.com/aicers/review-web/compare/0.14.4...0.14.5 [0.14.4]: https://github.com/aicers/review-web/compare/0.14.3...0.14.4 [0.14.3]: https://github.com/aicers/review-web/compare/0.14.2...0.14.3 [0.14.2]: https://github.com/aicers/review-web/compare/0.14.1...0.14.2 diff --git a/Cargo.toml b/Cargo.toml index 586731c8..1725ee55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "review-web" -version = "0.14.4" +version = "0.14.5" edition = "2021" [dependencies] diff --git a/src/graphql/outlier.rs b/src/graphql/outlier.rs index 6c0b4e14..ac087bad 100644 --- a/src/graphql/outlier.rs +++ b/src/graphql/outlier.rs @@ -1,4 +1,5 @@ use super::{always_true, model::ModelDigest, Role, RoleGuard, DEFAULT_CONNECTION_SIZE}; +use crate::graphql::{earliest_key, latest_key}; use anyhow::anyhow; use async_graphql::{ connection::{query, Connection, Edge, EmptyFields}, @@ -11,6 +12,7 @@ use data_encoding::BASE64; use num_traits::ToPrimitive; use review_database::{types::FromKeyValue, Database, Direction, IterableMap}; use serde::Deserialize; +use serde::Serialize; use std::cmp; pub const TIMESTAMP_SIZE: usize = 8; @@ -53,6 +55,26 @@ impl OutlierMutation { #[derive(Default)] pub(super) struct OutlierQuery; +#[allow(clippy::module_name_repetitions)] +#[derive(InputObject, Serialize)] +pub struct OutlierTimeRange { + start: Option>, + end: Option>, +} + +#[allow(clippy::module_name_repetitions)] +#[derive(InputObject, Serialize)] +pub struct OutlierDistanceRange { + start: Option, + end: Option, +} + +#[derive(InputObject, Serialize)] +pub struct SearchFilterInput { + pub time: Option, + distance: Option, +} + #[Object] impl OutlierQuery { /// A list of outliers. @@ -120,9 +142,10 @@ impl OutlierQuery { before: Option, first: Option, last: Option, + filter: Option, ) -> Result> { - let filter = Some; - load_outliers(ctx, model_id, time, after, before, first, last, filter).await + load_ranked_outliers_with_filter(ctx, model_id, time, after, before, first, last, filter) + .await } } @@ -586,3 +609,157 @@ fn latest_outlier_key(before: &str) -> Result> { } Ok(end) } + +#[allow(clippy::too_many_arguments, clippy::type_complexity)] // since this is called within `load` only +async fn load_ranked_outliers_with_filter( + ctx: &Context<'_>, + model_id: ID, + time: Option, + after: Option, + before: Option, + first: Option, + last: Option, + filter: Option, +) -> Result> { + let model_id: i32 = model_id.as_str().parse()?; + let timestamp = time.map(|t| t.timestamp_nanos_opt().unwrap_or_default()); + + let prefix = if let Some(timestamp) = timestamp { + bincode::DefaultOptions::new().serialize(&(model_id, timestamp))? + } else { + bincode::DefaultOptions::new().serialize(&model_id)? + }; + + let store = crate::graphql::get_store(ctx).await?; + let map = store.outlier_map().into_prefix_map(&prefix); + + let (nodes, has_previous, has_next) = + load_nodes_with_search_filter(&map, &filter, after, before, first, last)?; + + let mut connection = Connection::with_additional_fields( + has_previous, + has_next, + RankedOutlierTotalCount { + model_id, + timestamp, + }, + ); + connection + .edges + .extend(nodes.into_iter().map(|(k, ev)| Edge::new(k, ev))); + Ok(connection) +} + +#[allow(clippy::type_complexity)] // since this is called within `load` only +fn load_nodes_with_search_filter<'m, M, I>( + map: &'m M, + filter: &Option, + after: Option, + before: Option, + first: Option, + last: Option, +) -> Result<(Vec<(String, RankedOutlier)>, bool, bool)> +where + M: IterableMap<'m, I>, + I: Iterator, Box<[u8]>)> + 'm, +{ + if let Some(last) = last { + let iter = if let Some(before) = before { + let end = latest_key(&before)?; + map.iter_from(&end, Direction::Reverse)? + } else { + map.iter_backward()? + }; + + let (nodes, has_more) = if let Some(after) = after { + let to = earliest_key(&after)?; + iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_ge, filter, last) + } else { + iter_through_search_filter_nodes(iter, &[], always_true, filter, last) + }?; + Ok((nodes, has_more, false)) + } else { + let first = first.unwrap_or(DEFAULT_CONNECTION_SIZE); + let iter = if let Some(after) = after { + let start = earliest_key(&after)?; + map.iter_from(&start, Direction::Forward)? + } else { + map.iter_forward()? + }; + + let (nodes, has_more) = if let Some(before) = before { + let to = latest_key(&before)?; + iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_le, filter, first) + } else { + iter_through_search_filter_nodes(iter, &[], always_true, filter, first) + }?; + Ok((nodes, false, has_more)) + } +} + +fn iter_through_search_filter_nodes( + iter: I, + to: &[u8], + cond: fn(cmp::Ordering) -> bool, + filter: &Option, + len: usize, +) -> Result<(Vec<(String, RankedOutlier)>, bool)> +where + I: Iterator, Box<[u8]>)>, +{ + let mut nodes = Vec::new(); + let mut exceeded = false; + for (k, v) in iter { + if !(cond)(k.as_ref().cmp(to)) { + break; + } + + let curser = BASE64.encode(&k); + let Some(node) = RankedOutlier::from_key_value(&k, &v)?.into() else { + continue; + }; + + if let Some(filter) = filter { + if let Some(time) = &filter.time { + if let Some(start) = time.start { + if let Some(end) = time.end { + if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() + || node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() + { + continue; + } + } else if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() { + continue; + } + } else if let Some(end) = time.end { + if node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() { + continue; + } + } + } + + if let Some(distance) = &filter.distance { + if let Some(start) = distance.start { + if let Some(end) = distance.end { + if node.distance < start || node.distance > end { + continue; + } + } else if node.distance < start { + continue; + } + } + } + } + + nodes.push((curser, node)); + exceeded = nodes.len() > len; + if exceeded { + break; + } + } + + if exceeded { + nodes.pop(); + } + Ok((nodes, exceeded)) +}