Skip to content

Commit

Permalink
Merge pull request #2 from contagon/serde
Browse files Browse the repository at this point in the history
Serde Support
  • Loading branch information
contagon authored Jul 10, 2024
2 parents a4a2245 + 1454896 commit 17ac7e4
Show file tree
Hide file tree
Showing 34 changed files with 997 additions and 3,888 deletions.
3,966 changes: 314 additions & 3,652 deletions Cargo.lock

Large diffs are not rendered by default.

61 changes: 50 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,39 +1,78 @@
[package]
name = "samrs"
name = "factrs"
version = "0.1.0"
edition = "2021"

[dependencies]
# base
ahash = "0.8.11"
derive_more = "0.99.17"
paste = "1.0.15"
downcast-rs = "1.2.1"
log = "0.4.21"

faer = { version = "0.19.0", default-features = false, features = ["perf-warn", "std"] }
# numerical
faer = { version = "0.19.0", default-features = false, features = [
"perf-warn",
"std",
] }
faer-ext = { version = "0.2.0", features = ["nalgebra"] }
nalgebra = { version = "0.32.5" }
num-dual = "0.9.1"
matrixcompare-core = { version = "0.1", optional = true }
serde = {version = "1.0.203", optional = true }
log = "0.4.21"
rerun = { version = "0.16.1", optional = true}

[[example]]
name = "g2o-rerun"
required-features = ["rerun"]
# serialization
serde = { version = "1.0.203", optional = true }
typetag = { version = "0.2.16", optional = true }
serde_json = { version = "1.0.120", optional = true }

# rerun support
rerun = { version = "0.16.1", optional = true, default-features = false, features = [
"sdk",
] }

[features]
# Run everything with f32 instead of the defaul f64
f32 = []

# Use left instead of right for lie group updates
left = []

# use SO(n) x R instead of SE(n) for exponential map
fake_exp = []

# Necessary dependencies for matrix comparing
compare = ["dep:matrixcompare-core", "nalgebra/compare", "faer/matrixcompare"]
serde = ["dep:serde", "nalgebra/serde", "faer/serde", "ahash/serde"]
multithread = ["faer/rayon"]

# Add multithreaded support (may run slower on smaller problems)
rayon = ["faer/rayon"]

# Add support for serialization
serde = [
"dep:serde",
"dep:typetag",
"nalgebra/serde-serialize",
"faer/serde",
"ahash/serde",
]
# just used for examples
serde_json = ["dep:serde_json"]

# Support for conversion to rerun variable types
rerun = ["dep:rerun"]

[dev-dependencies]
matrixcompare = "0.3.0"
plotters = "0.3.6"
pretty_env_logger = "0.4"
nalgebra = { version = "0.32.5", features = ["compare"] }

[[example]]
name = "g2o"

[[example]]
name = "g2o-rerun"
required-features = ["rerun"]

[[example]]
name = "serde"
required-features = ["serde", "serde_json"]
24 changes: 12 additions & 12 deletions examples/g2o-rerun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::{
time::Instant,
};

use rerun::{Arrows2D, Arrows3D, Points2D, Points3D};
use samrs::{
optimizers::{GaussNewton, Optimizer},
rerun::RerunSender,
use factrs::{
optimizers::{GaussNewton, GraphOptimizer, Optimizer},
rerun::RerunObserver,
utils::load_g20,
variables::*,
};
use rerun::{Arrows2D, Arrows3D, Points2D, Points3D};

fn main() -> Result<(), Box<dyn std::error::Error>> {
// ------------------------- Parse Arguments & Load data ------------------------- //
Expand Down Expand Up @@ -46,20 +46,20 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let topic = "base/solution";
match (dim, obj) {
("se2", "points") => {
let callback = RerunSender::<SE2, Points2D>::new(rec, topic);
optimizer.params.add_callback(callback)
let callback = RerunObserver::<SE2, Points2D>::new(rec, topic);
optimizer.observers.add(callback)
}
("se2", "arrows") => {
let callback = RerunSender::<SE2, Arrows2D>::new(rec, topic);
optimizer.params.add_callback(callback)
let callback = RerunObserver::<SE2, Arrows2D>::new(rec, topic);
optimizer.observers.add(callback)
}
("se3", "points") => {
let callback = RerunSender::<SE3, Points3D>::new(rec, topic);
optimizer.params.add_callback(callback)
let callback = RerunObserver::<SE3, Points3D>::new(rec, topic);
optimizer.observers.add(callback)
}
("se3", "arrows") => {
let callback = RerunSender::<SE3, Arrows3D>::new(rec, topic);
optimizer.params.add_callback(callback)
let callback = RerunObserver::<SE3, Arrows3D>::new(rec, topic);
optimizer.observers.add(callback)
}
_ => panic!("Invalid arguments"),
};
Expand Down
4 changes: 2 additions & 2 deletions examples/g2o.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{env, time::Instant};

use samrs::{
optimizers::{GaussNewton, Optimizer},
use factrs::{
optimizers::{GaussNewton, GraphOptimizer, Optimizer},
utils::load_g20,
};
fn main() {
Expand Down
45 changes: 45 additions & 0 deletions examples/serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use factrs::{
containers::{Graph, Values, X},
factors::Factor,
noise::GaussianNoise,
residuals::{BetweenResidual, PriorResidual},
robust::{GemanMcClure, L2},
variables::{SE2, SO2},
};

fn main() {
// ------------------------- Try with values ------------------------- //
let x = SO2::from_theta(0.6);
let y = SE2::new(1.0, 2.0, 0.3);
let mut values = Values::new();
values.insert(X(0), x.clone());
values.insert(X(1), y.clone());

let serialized = serde_json::to_string_pretty(&values).unwrap();
println!("serialized = {}", serialized);

// Convert the JSON string back to a Point.
let deserialized: Values = serde_json::from_str(&serialized).unwrap();
println!("deserialized = {}", deserialized);

// ------------------------- Try with graph ------------------------- //
let prior = PriorResidual::new(x);
let bet = BetweenResidual::new(y);

let prior = Factor::new_full(
&[X(0)],
prior,
GaussianNoise::from_scalar_cov(0.1),
GemanMcClure::default(),
);
let bet = Factor::new_full(&[X(0), X(1)], bet, GaussianNoise::from_scalar_cov(10.0), L2);
let mut graph = Graph::new();
graph.add_factor(prior);
graph.add_factor(bet);

let serialized = serde_json::to_string_pretty(&graph).unwrap();
println!("serialized = {}", serialized);

let deserialized: Graph = serde_json::from_str(&serialized).unwrap();
println!("deserialized = {:?}", deserialized);
}
3 changes: 2 additions & 1 deletion src/containers/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use faer::sparse::SymbolicSparseColMat;
use super::{Idx, Values, ValuesOrder};
use crate::{dtype, factors::Factor, linear::LinearGraph};

#[derive(Default)]
#[derive(Default, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Graph {
factors: Vec<Factor>,
}
Expand Down
5 changes: 5 additions & 0 deletions src/containers/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@ const CHR_MASK: u64 = (char::MAX as u64) << IDX_BITS;
const IDX_MASK: u64 = !CHR_MASK;

#[derive(Clone, Eq, Hash, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Symbol(u64);

impl Symbol {
pub fn new_raw(key: u64) -> Self {
Symbol(key)
}

pub fn chr(&self) -> char {
((self.0 & CHR_MASK) >> IDX_BITS) as u8 as char
}
Expand Down
11 changes: 9 additions & 2 deletions src/containers/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{linear::LinearValues, variables::VariableSafe};
// we can just use dtype rather than using generics with Numeric

#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Values {
values: AHashMap<Symbol, Box<dyn VariableSafe>>,
}
Expand All @@ -35,8 +36,7 @@ impl Values {
key: Symbol,
value: impl VariableSafe,
) -> Option<Box<dyn VariableSafe>> {
// TODO: Avoid cloning here?
self.values.insert(key, value.clone_box())
self.values.insert(key, Box::new(value))
}

pub fn get(&self, key: &Symbol) -> Option<&Box<dyn VariableSafe>> {
Expand All @@ -62,6 +62,13 @@ impl Values {
self.values.get_mut(key)
}

// TODO: This should be some kind of error
pub fn get_mut_cast<T: VariableSafe>(&mut self, key: &Symbol) -> Option<&mut T> {
self.values
.get_mut(key)
.and_then(|value| value.downcast_mut::<T>())
}

pub fn remove(&mut self, key: &Symbol) -> Option<Box<dyn VariableSafe>> {
self.values.remove(key)
}
Expand Down
27 changes: 15 additions & 12 deletions src/factors/factor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use crate::{
dtype,
linalg::{AllocatorBuffer, Const, DefaultAllocator, DiffResult, DualAllocator, MatrixBlock},
linear::LinearFactor,
noise::{GaussianNoise, NoiseModel, NoiseModelSafe},
noise::{NoiseModel, NoiseModelSafe, UnitNoise},
residuals::{Residual, ResidualSafe},
robust::{RobustCost, RobustCostSafe, L2},
robust::{RobustCostSafe, L2},
};

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Factor {
pub keys: Vec<Symbol>,
residual: Box<dyn ResidualSafe>,
Expand All @@ -21,14 +23,15 @@ impl Factor {
residual: R,
) -> Self
where
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>>,
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>> + ResidualSafe,
AllocatorBuffer<R::DimIn>: Sync + Send,
DefaultAllocator: DualAllocator<R::DimIn>,
UnitNoise<DIM_OUT>: NoiseModelSafe,
{
Self {
keys: keys.to_vec(),
residual: Box::new(residual),
noise: Box::new(GaussianNoise::<DIM_OUT>::identity()),
noise: Box::new(UnitNoise::<DIM_OUT>),
robust: Box::new(L2),
}
}
Expand All @@ -39,8 +42,8 @@ impl Factor {
noise: N,
) -> Self
where
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>>,
N: 'static + NoiseModel<Dim = Const<DIM_OUT>>,
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>> + ResidualSafe,
N: 'static + NoiseModel<Dim = Const<DIM_OUT>> + NoiseModelSafe,
AllocatorBuffer<R::DimIn>: Sync + Send,
DefaultAllocator: DualAllocator<R::DimIn>,
{
Expand All @@ -59,11 +62,11 @@ impl Factor {
robust: C,
) -> Self
where
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>>,
R: 'static + Residual<NumVars = Const<NUM_VARS>, DimOut = Const<DIM_OUT>> + ResidualSafe,
AllocatorBuffer<R::DimIn>: Sync + Send,
DefaultAllocator: DualAllocator<R::DimIn>,
N: 'static + NoiseModel<Dim = Const<DIM_OUT>>,
C: 'static + RobustCost,
N: 'static + NoiseModel<Dim = Const<DIM_OUT>> + NoiseModelSafe,
C: 'static + RobustCostSafe,
{
Self {
keys: keys.to_vec(),
Expand All @@ -75,7 +78,7 @@ impl Factor {

pub fn error(&self, values: &Values) -> dtype {
let r = self.residual.residual(values, &self.keys);
let r = self.noise.whiten_vec(r.as_view());
let r = self.noise.whiten_vec(r);
let norm2 = r.norm_squared();
self.robust.loss(norm2)
}
Expand All @@ -89,8 +92,8 @@ impl Factor {
let DiffResult { value: r, diff: a } = self.residual.residual_jacobian(values, &self.keys);

// Whiten residual and jacobian
let r = self.noise.whiten_vec(r.as_view());
let a = self.noise.whiten_mat(a.as_view());
let r = self.noise.whiten_vec(r);
let a = self.noise.whiten_mat(a);

// Weight according to robust cost
let norm2 = r.norm_squared();
Expand Down
38 changes: 38 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,41 @@ pub mod variables;

#[cfg(feature = "rerun")]
pub mod rerun;

#[cfg(feature = "serde")]
pub mod serde {
pub trait Tagged: serde::Serialize {
const TAG: &'static str;
}

#[macro_export]
macro_rules! register_typetag {
($trait:path, $ty:ty) => {
// TODO: It'd be great if this was a blanket implementation, but
// I had problems getting it to run over const generics
impl $crate::serde::Tagged for $ty {
const TAG: &'static str = stringify!($ty);
}

typetag::__private::inventory::submit! {
<dyn $trait>::typetag_register(
<$ty as $crate::serde::Tagged>::TAG, // Tag of the type
(|deserializer| typetag::__private::Result::Ok(
typetag::__private::Box::new(
typetag::__private::erased_serde::deserialize::<$ty>(deserializer)?
),
)) as typetag::__private::DeserializeFn<<dyn $trait as typetag::__private::Strictest>::Object>
)
}
};
}
}

// Dummy implementation so things don't break when the serde feature is disabled
#[cfg(not(feature = "serde"))]
pub mod serde {
#[macro_export]
macro_rules! register_typetag {
($trait:path, $ty:ty) => {};
}
}
3 changes: 2 additions & 1 deletion src/linalg/forward_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::{
MatrixDim,
};
use crate::{
dtype,
linalg::{Const, DefaultAllocator, DiffResult, DimName, Dyn, MatrixX, VectorDim, VectorX},
variables::Variable,
};
Expand All @@ -19,7 +20,7 @@ macro_rules! forward_maker {
($num:expr, $( ($name:ident: $var:ident) ),*) => {
paste! {
#[allow(unused_assignments)]
fn [<jacobian_ $num>]<$( $var: Variable<Alias<f64> = $var>, )* F: Fn($($var::Alias<Self::D>,)*) -> VectorX<Self::D>>
fn [<jacobian_ $num>]<$( $var: Variable<Alias<dtype> = $var>, )* F: Fn($($var::Alias<Self::D>,)*) -> VectorX<Self::D>>
(f: F, $($name: &$var,)*) -> DiffResult<VectorX, MatrixX>{
// Prepare variables
let mut curr_dim = 0;
Expand Down
Loading

0 comments on commit 17ac7e4

Please sign in to comment.