Skip to content

Commit

Permalink
Add semi-supervised model GraphQL api
Browse files Browse the repository at this point in the history
  • Loading branch information
kimhanbeom committed Dec 6, 2023
1 parent 60a4487 commit 0cf62fc
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/minireview.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ impl AgentManager for Manager {
bail!("Not supported")
}

async fn broadcast_semi_model_list(&self, _list: &[u8]) -> Result<(), anyhow::Error> {
bail!("Not supported")
}

async fn broadcast_trusted_user_agent_list(&self, _list: &[u8]) -> Result<(), anyhow::Error> {
bail!("Not supported")
}
Expand Down
7 changes: 7 additions & 0 deletions src/graphql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod node;
mod outlier;
mod qualifier;
mod sampling;
mod semi_model;
mod slicing;
mod statistics;
mod status;
Expand Down Expand Up @@ -81,6 +82,7 @@ pub trait AgentManager: Send + Sync {
_networks: &[u8],
) -> Result<Vec<String>, anyhow::Error>;
async fn broadcast_trusted_user_agent_list(&self, _list: &[u8]) -> Result<(), anyhow::Error>;
async fn broadcast_semi_model_list(&self, _list: &[u8]) -> Result<(), anyhow::Error>;
async fn online_apps_by_host_id(
&self,
) -> Result<HashMap<String, Vec<(String, String)>>, anyhow::Error>;
Expand Down Expand Up @@ -185,6 +187,7 @@ pub(super) struct Query(
triage::TriagePolicyQuery,
triage::TriageResponseQuery,
trusted_domain::TrustedDomainQuery,
semi_model::SemiModelQuery,
traffic_filter::TrafficFilterQuery,
allow_network::AllowNetworkQuery,
trusted_user_agent::UserAgentQuery,
Expand Down Expand Up @@ -224,6 +227,7 @@ pub(super) struct Mutation(
triage::TriageResponseMutation,
triage::TriageMutation,
trusted_domain::TrustedDomainMutation,
semi_model::SemiModelMutation,
traffic_filter::TrafficFilterMutation,
allow_network::AllowNetworkMutation,
trusted_user_agent::UserAgentMutation,
Expand Down Expand Up @@ -517,6 +521,9 @@ impl AgentManager for MockAgentManager {
async fn broadcast_trusted_user_agent_list(&self, _list: &[u8]) -> Result<(), anyhow::Error> {
unimplemented!()
}
async fn broadcast_semi_model_list(&self, _list: &[u8]) -> Result<(), anyhow::Error> {
unimplemented!()
}
async fn broadcast_internal_networks(
&self,
_networks: &[u8],
Expand Down
164 changes: 164 additions & 0 deletions src/graphql/semi_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
use super::{BoxedAgentManager, Role, RoleGuard};
use async_graphql::{
connection::{query, Connection, EmptyFields},
Context, InputObject, Object, Result, SimpleObject,
};
use bincode::Options;
use chrono::{DateTime, Utc};
use database::types::FromKeyValue;
use review_database::{self as database, IterableMap, Store};
use serde::{Deserialize, Serialize};

#[derive(Default)]
pub(super) struct SemiModelQuery;

#[Object]
impl SemiModelQuery {
/// A list of semi-supervised model list.
#[graphql(guard = "RoleGuard::new(Role::SystemAdministrator)
.or(RoleGuard::new(Role::SecurityAdministrator))
.or(RoleGuard::new(Role::SecurityManager))
.or(RoleGuard::new(Role::SecurityMonitor))")]
async fn semi_model_list(
&self,
ctx: &Context<'_>,
after: Option<String>,
before: Option<String>,
first: Option<i32>,
last: Option<i32>,
) -> Result<Connection<String, SemiModelInfo, SemiModelInfoTotalCount, EmptyFields>> {
query(
after,
before,
first,
last,
|after, before, first, last| async move { load(ctx, after, before, first, last).await },
)
.await
}
}

#[derive(Default)]
pub(super) struct SemiModelMutation;

#[Object]
impl SemiModelMutation {
/// Inserts a new semi-supervised model, Returns true if the insertion was successful.
#[graphql(guard = "RoleGuard::new(Role::SystemAdministrator)
.or(RoleGuard::new(Role::SecurityAdministrator))")]
async fn insert_semi_model(&self, ctx: &Context<'_>, input_model: SemiModel) -> Result<bool> {
let store = crate::graphql::get_store(ctx).await?;
let map = store.semi_models_map();

let key = input_model.model_name.clone();
let value = bincode::serialize::<SemiModelValue>(&(input_model, Utc::now()))?;
map.put(key.as_bytes(), &value)?;
Ok(true)
}

/// Removes a semi-supervised models using model name , Returns true if the deletion was successful.
#[graphql(guard = "RoleGuard::new(Role::SystemAdministrator)
.or(RoleGuard::new(Role::SecurityAdministrator))")]
async fn remove_semi_models(&self, ctx: &Context<'_>, models: Vec<String>) -> Result<bool> {
let store = crate::graphql::get_store(ctx).await?;
let map = store.semi_models_map();
for model in models {
map.delete(model.as_bytes())?;
}
Ok(true)
}

/// Broadcast the semi-supervised model list to all Hogs.
#[graphql(guard = "RoleGuard::new(Role::SystemAdministrator)
.or(RoleGuard::new(Role::SecurityAdministrator))")]
async fn apply_semi_model(&self, ctx: &Context<'_>) -> Result<bool> {
let store = crate::graphql::get_store(ctx).await?;
let list = get_semi_model_list(&store)?;
let serialized_semi_model = bincode::DefaultOptions::new().serialize(&list)?;
let agent_manager = ctx.data::<BoxedAgentManager>()?;
agent_manager
.broadcast_trusted_user_agent_list(&serialized_semi_model)
.await?;
Ok(true)
}
}
type SemiModelValue = (SemiModel, DateTime<Utc>);

#[derive(InputObject, Serialize, Deserialize)]
struct SemiModel {
model_type: i32,
model_name: String,
model_version: String,
model_description: String,
model_data: Vec<u8>,
}

#[derive(SimpleObject, Serialize)]
struct SemiModelInfo {
model_type: i32,
model_name: String,
model_version: String,
model_description: String,
model_data: Vec<u8>,
time: DateTime<Utc>,
}

impl SemiModelInfo {
fn new(semi_model: SemiModel, time: DateTime<Utc>) -> Self {
Self {
model_type: semi_model.model_type,
model_name: semi_model.model_name,
model_version: semi_model.model_version,
model_description: semi_model.model_description,
time: time,
model_data: semi_model.model_data,
}
}
}

impl FromKeyValue for SemiModelInfo {
fn from_key_value(_key: &[u8], value: &[u8]) -> Result<Self, anyhow::Error> {
let (semi_info, time) = bincode::deserialize::<SemiModelValue>(&value)?;
Ok(SemiModelInfo::new(semi_info, time))
}
}

struct SemiModelInfoTotalCount;

#[Object]
impl SemiModelInfoTotalCount {
/// The total number of edges.
async fn total_count(&self, ctx: &Context<'_>) -> Result<usize> {
let store = crate::graphql::get_store(ctx).await?;
let map = store.semi_models_map();
let count = map.iter_forward()?.count();
Ok(count)
}
}

/// Returns the semi supervised model list.
///
/// # Errors
///
/// Returns an error if semi supervised model database could not be retrieved.
fn get_semi_model_list(db: &Store) -> Result<Vec<SemiModelInfo>> {
let map = db.semi_models_map();
let mut semi_model_list = vec![];
for (_, value) in map.iter_forward()? {
let (semi_info, time) = bincode::deserialize::<SemiModelValue>(&value)?;
semi_model_list.push(SemiModelInfo::new(semi_info, time));
}
Ok(semi_model_list)
}

async fn load(
ctx: &Context<'_>,
after: Option<String>,
before: Option<String>,
first: Option<usize>,
last: Option<usize>,
) -> Result<Connection<String, SemiModelInfo, SemiModelInfoTotalCount, EmptyFields>> {
let store = crate::graphql::get_store(ctx).await?;
let map = store.semi_models_map();
super::load(&map, after, before, first, last, SemiModelInfoTotalCount)
}

0 comments on commit 0cf62fc

Please sign in to comment.