From 34cfd52ebdd55c6b8ac2d2ca519f7edca89965e1 Mon Sep 17 00:00:00 2001 From: marioEinsis Date: Thu, 19 Oct 2023 16:21:22 -0700 Subject: [PATCH] Add search filter for ranked outliers --- src/graphql/outlier.rs | 171 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 169 insertions(+), 2 deletions(-) diff --git a/src/graphql/outlier.rs b/src/graphql/outlier.rs index 6c0b4e14..561935fb 100644 --- a/src/graphql/outlier.rs +++ b/src/graphql/outlier.rs @@ -1,4 +1,5 @@ use super::{always_true, model::ModelDigest, Role, RoleGuard, DEFAULT_CONNECTION_SIZE}; +use crate::graphql::{earliest_key, latest_key}; use anyhow::anyhow; use async_graphql::{ connection::{query, Connection, Edge, EmptyFields}, @@ -11,6 +12,7 @@ use data_encoding::BASE64; use num_traits::ToPrimitive; use review_database::{types::FromKeyValue, Database, Direction, IterableMap}; use serde::Deserialize; +use serde::Serialize; use std::cmp; pub const TIMESTAMP_SIZE: usize = 8; @@ -53,6 +55,18 @@ impl OutlierMutation { #[derive(Default)] pub(super) struct OutlierQuery; +#[derive(InputObject, Serialize)] +pub struct OutlierTimeRange { + start: Option>, + end: Option>, +} + +#[derive(InputObject, Serialize)] +pub struct SearchFilterInput { + pub time: Option, + distance: Option, +} + #[Object] impl OutlierQuery { /// A list of outliers. @@ -120,9 +134,10 @@ impl OutlierQuery { before: Option, first: Option, last: Option, + filter: Option, ) -> Result> { - let filter = Some; - load_outliers(ctx, model_id, time, after, before, first, last, filter).await + load_ranked_outliers_with_filter(ctx, model_id, time, after, before, first, last, filter) + .await } } @@ -586,3 +601,155 @@ fn latest_outlier_key(before: &str) -> Result> { } Ok(end) } + +async fn load_ranked_outliers_with_filter( + ctx: &Context<'_>, + model_id: ID, + time: Option, + after: Option, + before: Option, + first: Option, + last: Option, + filter: Option, +) -> Result> { + let model_id: i32 = model_id.as_str().parse()?; + let timestamp = time.map(|t| t.timestamp_nanos_opt().unwrap_or_default()); + + let prefix = if let Some(timestamp) = timestamp { + bincode::DefaultOptions::new().serialize(&(model_id, timestamp))? + } else { + bincode::DefaultOptions::new().serialize(&model_id)? + }; + + let store = crate::graphql::get_store(ctx).await?; + let map = store.outlier_map().into_prefix_map(&prefix); + + let (nodes, has_previous, has_next) = + load_nodes_with_search_filter(&map, &filter, after, before, first, last)?; + + let mut connection = Connection::with_additional_fields( + has_previous, + has_next, + RankedOutlierTotalCount { + model_id, + timestamp, + }, + ); + connection + .edges + .extend(nodes.into_iter().map(|(k, ev)| Edge::new(k, ev))); + Ok(connection) +} + +fn load_nodes_with_search_filter<'m, M, I>( + map: &'m M, + filter: &Option, + after: Option, + before: Option, + first: Option, + last: Option, +) -> Result<(Vec<(String, RankedOutlier)>, bool, bool)> +where + M: IterableMap<'m, I>, + I: Iterator, Box<[u8]>)> + 'm, +{ + if let Some(last) = last { + let iter = if let Some(before) = before { + let end = latest_key(&before)?; + map.iter_from(&end, Direction::Reverse)? + } else { + map.iter_backward()? + }; + + let (nodes, has_more) = if let Some(after) = after { + let to = earliest_key(&after)?; + iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_ge, &filter, last) + } else { + iter_through_search_filter_nodes(iter, &[], always_true, &filter, last) + }?; + Ok((nodes, has_more, false)) + } else { + let first = first.unwrap_or(DEFAULT_CONNECTION_SIZE); + let iter = if let Some(after) = after { + let start = earliest_key(&after)?; + map.iter_from(&start, Direction::Forward)? + } else { + map.iter_forward()? + }; + + let (nodes, has_more) = if let Some(before) = before { + let to = latest_key(&before)?; + iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_le, &filter, first) + } else { + iter_through_search_filter_nodes(iter, &[], always_true, &filter, first) + }?; + Ok((nodes, false, has_more)) + } +} + +fn iter_through_search_filter_nodes( + iter: I, + to: &[u8], + cond: fn(cmp::Ordering) -> bool, + filter: &Option, + len: usize, +) -> Result<(Vec<(String, RankedOutlier)>, bool)> +where + I: Iterator, Box<[u8]>)>, +{ + let mut nodes = Vec::new(); + let mut exceeded = false; + for (k, v) in iter { + if !(cond)(k.as_ref().cmp(to)) { + break; + } + + let curser = BASE64.encode(&k); + let Some(node) = RankedOutlier::from_key_value(&k, &v)?.into() else { + continue; + }; + + if let Some(filter) = filter { + if let Some(time) = &filter.time { + let start = time.start; + let end = time.end; + if let Some(start) = start { + if let Some(end) = end { + if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() + || node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() + { + continue; + } + } else { + if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() { + continue; + } + } + } else { + if let Some(end) = end { + if node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() { + continue; + } + } + } + } + + if let Some(distance) = filter.distance { + if node.distance != distance { + continue; + } + } + } + + nodes.push((curser, node)); + exceeded = nodes.len() > len; + if exceeded { + break; + } + } + + if exceeded { + nodes.pop(); + } + Ok((nodes, exceeded)) +}