Skip to content

Commit

Permalink
Add search filter for ranked outliers
Browse files Browse the repository at this point in the history
  • Loading branch information
marioCluml committed Oct 19, 2023
1 parent 4056cb9 commit df80e1f
Showing 1 changed file with 169 additions and 2 deletions.
171 changes: 169 additions & 2 deletions src/graphql/outlier.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{always_true, model::ModelDigest, Role, RoleGuard, DEFAULT_CONNECTION_SIZE};
use crate::graphql::{earliest_key, latest_key};
use anyhow::anyhow;
use async_graphql::{
connection::{query, Connection, Edge, EmptyFields},
Expand All @@ -11,6 +12,7 @@ use data_encoding::BASE64;
use num_traits::ToPrimitive;
use review_database::{types::FromKeyValue, Database, Direction, IterableMap};
use serde::Deserialize;
use serde::Serialize;
use std::cmp;

pub const TIMESTAMP_SIZE: usize = 8;
Expand Down Expand Up @@ -53,6 +55,18 @@ impl OutlierMutation {
#[derive(Default)]
pub(super) struct OutlierQuery;

#[derive(InputObject, Serialize)]
pub struct OutlierTimeRange {
start: Option<DateTime<Utc>>,
end: Option<DateTime<Utc>>,
}

#[derive(InputObject, Serialize)]
pub struct SearchFilterInput {
pub time: Option<OutlierTimeRange>,
distance: Option<f64>,
}

#[Object]
impl OutlierQuery {
/// A list of outliers.
Expand Down Expand Up @@ -120,9 +134,10 @@ impl OutlierQuery {
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
filter: Option<SearchFilterInput>,
) -> Result<Connection<String, RankedOutlier, RankedOutlierTotalCount, EmptyFields>> {
let filter = Some;
load_outliers(ctx, model_id, time, after, before, first, last, filter).await
load_ranked_outliers_with_filter(ctx, model_id, time, after, before, first, last, filter)
.await
}
}

Expand Down Expand Up @@ -586,3 +601,155 @@ fn latest_outlier_key(before: &str) -> Result<Vec<u8>> {
}
Ok(end)
}

async fn load_ranked_outliers_with_filter(
ctx: &Context<'_>,
model_id: ID,
time: Option<NaiveDateTime>,
after: Option<String>,
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
filter: Option<SearchFilterInput>,
) -> Result<Connection<String, RankedOutlier, RankedOutlierTotalCount, EmptyFields>> {
let model_id: i32 = model_id.as_str().parse()?;
let timestamp = time.map(|t| t.timestamp_nanos_opt().unwrap_or_default());

let prefix = if let Some(timestamp) = timestamp {
bincode::DefaultOptions::new().serialize(&(model_id, timestamp))?
} else {
bincode::DefaultOptions::new().serialize(&model_id)?
};

let store = crate::graphql::get_store(ctx).await?;
let map = store.outlier_map().into_prefix_map(&prefix);

let (nodes, has_previous, has_next) =
load_nodes_with_search_filter(&map, &filter, after, before, first, last)?;

let mut connection = Connection::with_additional_fields(
has_previous,
has_next,
RankedOutlierTotalCount {
model_id,
timestamp,
},
);
connection
.edges
.extend(nodes.into_iter().map(|(k, ev)| Edge::new(k, ev)));
Ok(connection)
}

fn load_nodes_with_search_filter<'m, M, I>(
map: &'m M,
filter: &Option<SearchFilterInput>,
after: Option<String>,
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(Vec<(String, RankedOutlier)>, bool, bool)>
where
M: IterableMap<'m, I>,
I: Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'm,
{
if let Some(last) = last {
let iter = if let Some(before) = before {
let end = latest_key(&before)?;
map.iter_from(&end, Direction::Reverse)?
} else {
map.iter_backward()?
};

let (nodes, has_more) = if let Some(after) = after {
let to = earliest_key(&after)?;
iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_ge, &filter, last)
} else {
iter_through_search_filter_nodes(iter, &[], always_true, &filter, last)
}?;
Ok((nodes, has_more, false))
} else {
let first = first.unwrap_or(DEFAULT_CONNECTION_SIZE);
let iter = if let Some(after) = after {
let start = earliest_key(&after)?;
map.iter_from(&start, Direction::Forward)?
} else {
map.iter_forward()?
};

let (nodes, has_more) = if let Some(before) = before {
let to = latest_key(&before)?;
iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_le, &filter, first)
} else {
iter_through_search_filter_nodes(iter, &[], always_true, &filter, first)
}?;
Ok((nodes, false, has_more))
}
}

fn iter_through_search_filter_nodes<I>(
iter: I,
to: &[u8],
cond: fn(cmp::Ordering) -> bool,
filter: &Option<SearchFilterInput>,
len: usize,
) -> Result<(Vec<(String, RankedOutlier)>, bool)>
where
I: Iterator<Item = (Box<[u8]>, Box<[u8]>)>,
{
let mut nodes = Vec::new();
let mut exceeded = false;
for (k, v) in iter {
if !(cond)(k.as_ref().cmp(to)) {
break;
}

let curser = BASE64.encode(&k);
let Some(node) = RankedOutlier::from_key_value(&k, &v)?.into() else {
continue;
};

if let Some(filter) = filter {
if let Some(time) = &filter.time {
let start = time.start;
let end = time.end;
if let Some(start) = start {
if let Some(end) = end {
if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default()
|| node.timestamp > end.timestamp_nanos_opt().unwrap_or_default()
{
continue;
}
} else {
if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() {
continue;
}
}
} else {
if let Some(end) = end {
if node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() {
continue;
}
}
}
}

if let Some(distance) = filter.distance {
if node.distance != distance {
continue;
}
}
}

nodes.push((curser, node));
exceeded = nodes.len() > len;
if exceeded {
break;
}
}

if exceeded {
nodes.pop();
}
Ok((nodes, exceeded))
}

0 comments on commit df80e1f

Please sign in to comment.