Skip to content

Commit

Permalink
make muldiv pub, remove some unnecessary copying
Browse files Browse the repository at this point in the history
  • Loading branch information
moodysalem committed Sep 17, 2024
1 parent c9fb86b commit 3dbbf89
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
24 changes: 17 additions & 7 deletions src/math/muldiv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl TryFrom<U512> for U256 {
}
}

pub(crate) fn muldiv(x: U256, y: U256, d: U256, round_up: bool) -> Result<U256, MuldivError> {
pub fn muldiv(x: U256, y: U256, d: U256, round_up: bool) -> Result<U256, MuldivError> {
if d.is_zero() {
return Err(MuldivError::DenominatorZero);
}
Expand All @@ -64,7 +64,6 @@ pub(crate) fn muldiv(x: U256, y: U256, d: U256, round_up: bool) -> Result<U256,
result.try_into().map_err(|_| MuldivError::Overflow)
}


#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -128,7 +127,9 @@ mod tests {
let d = U256::from(1);
let result = muldiv(x, y, d, false).unwrap();

let expected = U256::from_dec_str("121932631137021795226185032733622923332237463801111263526900").unwrap();
let expected =
U256::from_dec_str("121932631137021795226185032733622923332237463801111263526900")
.unwrap();
assert_eq!(result, expected);
}

Expand Down Expand Up @@ -168,7 +169,6 @@ mod tests {
assert_eq!(result, expected);
}


#[test]
fn test_muldiv_max_values_no_rounding() {
// Test with maximum U256 values that result in a valid U256 output
Expand All @@ -195,7 +195,13 @@ mod tests {
let y = U256::from(1);
let d = U256::from(2);
let result = muldiv(x, y, d, true);
assert_eq!(result.unwrap(), U256::from_dec_str("57896044618658097711785492504343953926634992332820282019728792003956564819968").unwrap());
assert_eq!(
result.unwrap(),
U256::from_dec_str(
"57896044618658097711785492504343953926634992332820282019728792003956564819968"
)
.unwrap()
);
}

#[test]
Expand Down Expand Up @@ -257,7 +263,10 @@ mod tests {
let result = muldiv(x, y, d, true);
assert_eq!(
result.unwrap(),
U256::from_dec_str("115792089237316195423570985008687907853269984665640564039457584007913129639935").unwrap()
U256::from_dec_str(
"115792089237316195423570985008687907853269984665640564039457584007913129639935"
)
.unwrap()
);
}

Expand Down Expand Up @@ -287,7 +296,8 @@ mod tests {
// Test where intermediate multiplication is large but result fits in U256
let x = U256::from_dec_str("123456789012345678901234567890").unwrap();
let y = U256::from_dec_str("98765432109876543210987654321").unwrap();
let d = U256::from_dec_str("1219326311370217952261850327336229233322374638011112635269").unwrap();
let d = U256::from_dec_str("1219326311370217952261850327336229233322374638011112635269")
.unwrap();
let result = muldiv(x, y, d, false).unwrap();
assert_eq!(result, U256::from(10));
}
Expand Down
4 changes: 2 additions & 2 deletions src/quoting/base_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ impl Pool for BasePool {
type QuoteError = BasePoolQuoteError;
type Meta = ();

fn get_key(&self) -> NodeKey {
self.key
fn get_key(&self) -> &NodeKey {
&self.key
}

fn get_state(&self) -> Self::State {
Expand Down
2 changes: 1 addition & 1 deletion src/quoting/oracle_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl Pool for OraclePool {
type QuoteError = BasePoolQuoteError;
type Meta = BlockTimestamp;

fn get_key(&self) -> NodeKey {
fn get_key(&self) -> &NodeKey {
self.base_pool.get_key()
}

Expand Down
6 changes: 3 additions & 3 deletions src/quoting/twamm_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl Pool for TwammPool {
type QuoteError = TwammPoolQuoteError;
type Meta = BlockTimestamp;

fn get_key(&self) -> NodeKey {
fn get_key(&self) -> &NodeKey {
self.base_pool.get_key()
}

Expand Down Expand Up @@ -209,7 +209,7 @@ impl Pool for TwammPool {
token0_sale_rate,
token1_sale_rate,
time_elapsed as u32,
fee,
*fee,
)
.ok_or(TwammPoolQuoteError::FailedCalculateNextSqrtRatio)?;

Expand All @@ -226,7 +226,7 @@ impl Pool for TwammPool {
amount: amount
.to_i128()
.ok_or(TwammPoolQuoteError::SaleAmountOverflow)?,
token,
token: *token,
},
sqrt_ratio_limit: Some(next_sqrt_ratio),
override_state: base_pool_state_override,
Expand Down
2 changes: 1 addition & 1 deletion src/quoting/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub trait Pool: Send + Sync {
// Any additional data that is required to compute a quote for this pool, e.g. the block timestamp
type Meta: Copy;

fn get_key(&self) -> NodeKey;
fn get_key(&self) -> &NodeKey;

fn get_state(&self) -> Self::State;

Expand Down

0 comments on commit 3dbbf89

Please sign in to comment.