Skip to content

Commit

Permalink
Add search filter for ranked outliers
Browse files Browse the repository at this point in the history
  • Loading branch information
marioCluml committed Nov 3, 2023
1 parent 35f3b35 commit 1533d9e
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 3 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ file is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and
this project adheres to [Semantic
Versioning](https://semver.org/spec/v2.0.0.html).

## [0.14.5] - 2023-11-02

### Changed

- Modified Ranked Outliers graphql query to take in a SearchFilter with
distance range and time range

### Added

- Added new method for Ranked Outliers `load_ranked_outliers_with_filter`,
`load_nodes_with_search_filter`, and `iter_through_search_filter_nodes`
to load Ranked Outliers depending on new Search Filter.

## [0.14.4] - 2023-10-19

### Added
Expand Down Expand Up @@ -307,6 +320,7 @@ across our system.

- An initial version.

[0.14.5]: https://github.com/aicers/review-web/compare/0.14.4...0.14.5
[0.14.4]: https://github.com/aicers/review-web/compare/0.14.3...0.14.4
[0.14.3]: https://github.com/aicers/review-web/compare/0.14.2...0.14.3
[0.14.2]: https://github.com/aicers/review-web/compare/0.14.1...0.14.2
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "review-web"
version = "0.14.4"
version = "0.14.5"
edition = "2021"

[dependencies]
Expand Down
181 changes: 179 additions & 2 deletions src/graphql/outlier.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{always_true, model::ModelDigest, Role, RoleGuard, DEFAULT_CONNECTION_SIZE};
use crate::graphql::{earliest_key, latest_key};
use anyhow::anyhow;
use async_graphql::{
connection::{query, Connection, Edge, EmptyFields},
Expand All @@ -11,6 +12,7 @@ use data_encoding::BASE64;
use num_traits::ToPrimitive;
use review_database::{types::FromKeyValue, Database, Direction, IterableMap};
use serde::Deserialize;
use serde::Serialize;
use std::cmp;

pub const TIMESTAMP_SIZE: usize = 8;
Expand Down Expand Up @@ -53,6 +55,26 @@ impl OutlierMutation {
#[derive(Default)]
pub(super) struct OutlierQuery;

#[allow(clippy::module_name_repetitions)]
#[derive(InputObject, Serialize)]
pub struct OutlierTimeRange {
start: Option<DateTime<Utc>>,
end: Option<DateTime<Utc>>,
}

#[allow(clippy::module_name_repetitions)]
#[derive(InputObject, Serialize)]
pub struct OutlierDistanceRange {
start: Option<f64>,
end: Option<f64>,
}

#[derive(InputObject, Serialize)]
pub struct SearchFilterInput {
pub time: Option<OutlierTimeRange>,
distance: Option<OutlierDistanceRange>,
}

#[Object]
impl OutlierQuery {
/// A list of outliers.
Expand Down Expand Up @@ -120,9 +142,10 @@ impl OutlierQuery {
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
filter: Option<SearchFilterInput>,
) -> Result<Connection<String, RankedOutlier, RankedOutlierTotalCount, EmptyFields>> {
let filter = Some;
load_outliers(ctx, model_id, time, after, before, first, last, filter).await
load_ranked_outliers_with_filter(ctx, model_id, time, after, before, first, last, filter)
.await
}
}

Expand Down Expand Up @@ -586,3 +609,157 @@ fn latest_outlier_key(before: &str) -> Result<Vec<u8>> {
}
Ok(end)
}

#[allow(clippy::too_many_arguments, clippy::type_complexity)] // since this is called within `load` only
async fn load_ranked_outliers_with_filter(
ctx: &Context<'_>,
model_id: ID,
time: Option<NaiveDateTime>,
after: Option<String>,
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
filter: Option<SearchFilterInput>,
) -> Result<Connection<String, RankedOutlier, RankedOutlierTotalCount, EmptyFields>> {
let model_id: i32 = model_id.as_str().parse()?;
let timestamp = time.map(|t| t.timestamp_nanos_opt().unwrap_or_default());

let prefix = if let Some(timestamp) = timestamp {
bincode::DefaultOptions::new().serialize(&(model_id, timestamp))?
} else {
bincode::DefaultOptions::new().serialize(&model_id)?
};

let store = crate::graphql::get_store(ctx).await?;
let map = store.outlier_map().into_prefix_map(&prefix);

let (nodes, has_previous, has_next) =
load_nodes_with_search_filter(&map, &filter, after, before, first, last)?;

let mut connection = Connection::with_additional_fields(
has_previous,
has_next,
RankedOutlierTotalCount {
model_id,
timestamp,
},
);
connection
.edges
.extend(nodes.into_iter().map(|(k, ev)| Edge::new(k, ev)));
Ok(connection)
}

#[allow(clippy::type_complexity)] // since this is called within `load` only
fn load_nodes_with_search_filter<'m, M, I>(
map: &'m M,
filter: &Option<SearchFilterInput>,
after: Option<String>,
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
) -> Result<(Vec<(String, RankedOutlier)>, bool, bool)>
where
M: IterableMap<'m, I>,
I: Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'm,
{
if let Some(last) = last {
let iter = if let Some(before) = before {
let end = latest_key(&before)?;
map.iter_from(&end, Direction::Reverse)?
} else {
map.iter_backward()?
};

let (nodes, has_more) = if let Some(after) = after {
let to = earliest_key(&after)?;
iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_ge, filter, last)
} else {
iter_through_search_filter_nodes(iter, &[], always_true, filter, last)
}?;
Ok((nodes, has_more, false))
} else {
let first = first.unwrap_or(DEFAULT_CONNECTION_SIZE);
let iter = if let Some(after) = after {
let start = earliest_key(&after)?;
map.iter_from(&start, Direction::Forward)?
} else {
map.iter_forward()?
};

let (nodes, has_more) = if let Some(before) = before {
let to = latest_key(&before)?;
iter_through_search_filter_nodes(iter, &to, cmp::Ordering::is_le, filter, first)
} else {
iter_through_search_filter_nodes(iter, &[], always_true, filter, first)
}?;
Ok((nodes, false, has_more))
}
}

fn iter_through_search_filter_nodes<I>(
iter: I,
to: &[u8],
cond: fn(cmp::Ordering) -> bool,
filter: &Option<SearchFilterInput>,
len: usize,
) -> Result<(Vec<(String, RankedOutlier)>, bool)>
where
I: Iterator<Item = (Box<[u8]>, Box<[u8]>)>,
{
let mut nodes = Vec::new();
let mut exceeded = false;
for (k, v) in iter {
if !(cond)(k.as_ref().cmp(to)) {
break;
}

let curser = BASE64.encode(&k);
let Some(node) = RankedOutlier::from_key_value(&k, &v)?.into() else {
continue;
};

if let Some(filter) = filter {
if let Some(time) = &filter.time {
if let Some(start) = time.start {
if let Some(end) = time.end {
if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default()
|| node.timestamp > end.timestamp_nanos_opt().unwrap_or_default()
{
continue;
}
} else if node.timestamp < start.timestamp_nanos_opt().unwrap_or_default() {
continue;
}
} else if let Some(end) = time.end {
if node.timestamp > end.timestamp_nanos_opt().unwrap_or_default() {
continue;
}
}
}

if let Some(distance) = &filter.distance {
if let Some(start) = distance.start {
if let Some(end) = distance.end {
if node.distance < start || node.distance > end {
continue;
}
} else if node.distance < start {
continue;
}
}
}
}

nodes.push((curser, node));
exceeded = nodes.len() > len;
if exceeded {
break;
}
}

if exceeded {
nodes.pop();
}
Ok((nodes, exceeded))
}

0 comments on commit 1533d9e

Please sign in to comment.