Skip to content

Commit

Permalink
Merge pull request #7466 from roc-lang/pi-test-updates
Browse files Browse the repository at this point in the history
update benchmark platform to PI
  • Loading branch information
rtfeldman authored Jan 5, 2025
2 parents c85c864 + 07f930c commit 8b641dd
Show file tree
Hide file tree
Showing 21 changed files with 660 additions and 698 deletions.
128 changes: 64 additions & 64 deletions crates/cli/tests/benchmarks/AStar.roc
Original file line number Diff line number Diff line change
@@ -1,107 +1,107 @@
module [findPath, Model, initialModel, cheapestOpen, reconstructPath]
module [find_path, Model, initial_model, cheapest_open, reconstruct_path]

import Quicksort

findPath = \costFn, moveFn, start, end ->
astar costFn moveFn end (initialModel start)
find_path = \cost_fn, move_fn, start, end ->
astar(cost_fn, move_fn, end, initial_model(start))

Model position : {
evaluated : Set position,
openSet : Set position,
open_set : Set position,
costs : Dict position F64,
cameFrom : Dict position position,
came_from : Dict position position,
} where position implements Hash & Eq

initialModel : position -> Model position where position implements Hash & Eq
initialModel = \start -> {
evaluated: Set.empty {},
openSet: Set.single start,
costs: Dict.single start 0,
cameFrom: Dict.empty {},
initial_model : position -> Model position where position implements Hash & Eq
initial_model = \start -> {
evaluated: Set.empty({}),
open_set: Set.single(start),
costs: Dict.single(start, 0),
came_from: Dict.empty({}),
}

cheapestOpen : (position -> F64), Model position -> Result position {} where position implements Hash & Eq
cheapestOpen = \costFn, model ->
model.openSet
cheapest_open : (position -> F64), Model position -> Result position {} where position implements Hash & Eq
cheapest_open = \cost_fn, model ->
model.open_set
|> Set.toList
|> List.keepOks
(\position ->
when Dict.get model.costs position is
Err _ -> Err {}
Ok cost -> Ok { cost: cost + costFn position, position }
)
|> Quicksort.sortBy .cost
|> List.keepOks(
\position ->
when Dict.get(model.costs, position) is
Err(_) -> Err({})
Ok(cost) -> Ok({ cost: cost + cost_fn(position), position }),
)
|> Quicksort.sort_by(.cost)
|> List.first
|> Result.map .position
|> Result.mapErr (\_ -> {})
|> Result.map(.position)
|> Result.mapErr(\_ -> {})

reconstructPath : Dict position position, position -> List position where position implements Hash & Eq
reconstructPath = \cameFrom, goal ->
when Dict.get cameFrom goal is
Err _ -> []
Ok next -> List.append (reconstructPath cameFrom next) goal
reconstruct_path : Dict position position, position -> List position where position implements Hash & Eq
reconstruct_path = \came_from, goal ->
when Dict.get(came_from, goal) is
Err(_) -> []
Ok(next) -> List.append(reconstruct_path(came_from, next), goal)

updateCost : position, position, Model position -> Model position where position implements Hash & Eq
updateCost = \current, neighbor, model ->
newCameFrom =
Dict.insert model.cameFrom neighbor current
update_cost : position, position, Model position -> Model position where position implements Hash & Eq
update_cost = \current, neighbor, model ->
new_came_from =
Dict.insert(model.came_from, neighbor, current)

newCosts =
Dict.insert model.costs neighbor distanceTo
new_costs =
Dict.insert(model.costs, neighbor, distance_to)

distanceTo =
reconstructPath newCameFrom neighbor
distance_to =
reconstruct_path(new_came_from, neighbor)
|> List.len
|> Num.toFrac

newModel =
new_model =
{ model &
costs: newCosts,
cameFrom: newCameFrom,
costs: new_costs,
came_from: new_came_from,
}

when Dict.get model.costs neighbor is
Err _ ->
newModel
when Dict.get(model.costs, neighbor) is
Err(_) ->
new_model

Ok previousDistance ->
if distanceTo < previousDistance then
newModel
Ok(previous_distance) ->
if distance_to < previous_distance then
new_model
else
model

astar : (position, position -> F64), (position -> Set position), position, Model position -> Result (List position) {} where position implements Hash & Eq
astar = \costFn, moveFn, goal, model ->
when cheapestOpen (\source -> costFn source goal) model is
Err {} -> Err {}
Ok current ->
astar = \cost_fn, move_fn, goal, model ->
when cheapest_open(\source -> cost_fn(source, goal), model) is
Err({}) -> Err({})
Ok(current) ->
if current == goal then
Ok (reconstructPath model.cameFrom goal)
Ok(reconstruct_path(model.came_from, goal))
else
modelPopped =
model_popped =
{ model &
openSet: Set.remove model.openSet current,
evaluated: Set.insert model.evaluated current,
open_set: Set.remove(model.open_set, current),
evaluated: Set.insert(model.evaluated, current),
}

neighbors =
moveFn current
move_fn(current)

newNeighbors =
Set.difference neighbors modelPopped.evaluated
new_neighbors =
Set.difference(neighbors, model_popped.evaluated)

modelWithNeighbors : Model position
modelWithNeighbors =
modelPopped
|> &openSet (Set.union modelPopped.openSet newNeighbors)
model_with_neighbors : Model position
model_with_neighbors =
model_popped
|> &open_set(Set.union(model_popped.open_set, new_neighbors))

walker : Model position, position -> Model position
walker = \amodel, n -> updateCost current n amodel
walker = \amodel, n -> update_cost(current, n, amodel)

modelWithCosts =
Set.walk newNeighbors modelWithNeighbors walker
model_with_costs =
Set.walk(new_neighbors, model_with_neighbors, walker)

astar costFn moveFn goal modelWithCosts
astar(cost_fn, move_fn, goal, model_with_costs)

# takeStep = \moveFn, _goal, model, current ->
# modelPopped =
Expand Down
56 changes: 28 additions & 28 deletions crates/cli/tests/benchmarks/Base64.roc
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
module [fromBytes, fromStr, toBytes, toStr]
module [from_bytes, from_str, to_bytes, to_str]

import Base64.Decode
import Base64.Encode

# base 64 encoding from a sequence of bytes
fromBytes : List U8 -> Result Str [InvalidInput]
fromBytes = \bytes ->
when Base64.Decode.fromBytes bytes is
Ok v ->
Ok v
from_bytes : List U8 -> Result Str [InvalidInput]
from_bytes = \bytes ->
when Base64.Decode.from_bytes(bytes) is
Ok(v) ->
Ok(v)

Err _ ->
Err InvalidInput
Err(_) ->
Err(InvalidInput)

# base 64 encoding from a string
fromStr : Str -> Result Str [InvalidInput]
fromStr = \str ->
fromBytes (Str.toUtf8 str)
from_str : Str -> Result Str [InvalidInput]
from_str = \str ->
from_bytes(Str.toUtf8(str))

# base64-encode bytes to the original
toBytes : Str -> Result (List U8) [InvalidInput]
toBytes = \str ->
Ok (Base64.Encode.toBytes str)

toStr : Str -> Result Str [InvalidInput]
toStr = \str ->
when toBytes str is
Ok bytes ->
when Str.fromUtf8 bytes is
Ok v ->
Ok v

Err _ ->
Err InvalidInput

Err _ ->
Err InvalidInput
to_bytes : Str -> Result (List U8) [InvalidInput]
to_bytes = \str ->
Ok(Base64.Encode.to_bytes(str))

to_str : Str -> Result Str [InvalidInput]
to_str = \str ->
when to_bytes(str) is
Ok(bytes) ->
when Str.fromUtf8(bytes) is
Ok(v) ->
Ok(v)

Err(_) ->
Err(InvalidInput)

Err(_) ->
Err(InvalidInput)
88 changes: 44 additions & 44 deletions crates/cli/tests/benchmarks/Base64/Decode.roc
Original file line number Diff line number Diff line change
@@ -1,86 +1,86 @@
module [fromBytes]
module [from_bytes]

import Bytes.Decode exposing [ByteDecoder, DecodeProblem]

fromBytes : List U8 -> Result Str DecodeProblem
fromBytes = \bytes ->
Bytes.Decode.decode bytes (decodeBase64 (List.len bytes))
from_bytes : List U8 -> Result Str DecodeProblem
from_bytes = \bytes ->
Bytes.Decode.decode(bytes, decode_base64(List.len(bytes)))

decodeBase64 : U64 -> ByteDecoder Str
decodeBase64 = \width -> Bytes.Decode.loop loopHelp { remaining: width, string: "" }
decode_base64 : U64 -> ByteDecoder Str
decode_base64 = \width -> Bytes.Decode.loop(loop_help, { remaining: width, string: "" })

loopHelp : { remaining : U64, string : Str } -> ByteDecoder (Bytes.Decode.Step { remaining : U64, string : Str } Str)
loopHelp = \{ remaining, string } ->
loop_help : { remaining : U64, string : Str } -> ByteDecoder (Bytes.Decode.Step { remaining : U64, string : Str } Str)
loop_help = \{ remaining, string } ->
if remaining >= 3 then
Bytes.Decode.map3 Bytes.Decode.u8 Bytes.Decode.u8 Bytes.Decode.u8 \x, y, z ->
Bytes.Decode.map3(Bytes.Decode.u8, Bytes.Decode.u8, Bytes.Decode.u8, \x, y, z ->
a : U32
a = Num.intCast x
a = Num.intCast(x)
b : U32
b = Num.intCast y
b = Num.intCast(y)
c : U32
c = Num.intCast z
combined = Num.bitwiseOr (Num.bitwiseOr (Num.shiftLeftBy a 16) (Num.shiftLeftBy b 8)) c
c = Num.intCast(z)
combined = Num.bitwiseOr(Num.bitwiseOr(Num.shiftLeftBy(a, 16), Num.shiftLeftBy(b, 8)), c)

Loop {
Loop({
remaining: remaining - 3,
string: Str.concat string (bitsToChars combined 0),
}
string: Str.concat(string, bits_to_chars(combined, 0)),
}))
else if remaining == 0 then
Bytes.Decode.succeed (Done string)
Bytes.Decode.succeed(Done(string))
else if remaining == 2 then
Bytes.Decode.map2 Bytes.Decode.u8 Bytes.Decode.u8 \x, y ->
Bytes.Decode.map2(Bytes.Decode.u8, Bytes.Decode.u8, \x, y ->

a : U32
a = Num.intCast x
a = Num.intCast(x)
b : U32
b = Num.intCast y
combined = Num.bitwiseOr (Num.shiftLeftBy a 16) (Num.shiftLeftBy b 8)
b = Num.intCast(y)
combined = Num.bitwiseOr(Num.shiftLeftBy(a, 16), Num.shiftLeftBy(b, 8))

Done (Str.concat string (bitsToChars combined 1))
Done(Str.concat(string, bits_to_chars(combined, 1))))
else
# remaining = 1
Bytes.Decode.map Bytes.Decode.u8 \x ->
Bytes.Decode.map(Bytes.Decode.u8, \x ->

a : U32
a = Num.intCast x
a = Num.intCast(x)

Done (Str.concat string (bitsToChars (Num.shiftLeftBy a 16) 2))
Done(Str.concat(string, bits_to_chars(Num.shiftLeftBy(a, 16), 2))))

bitsToChars : U32, Int * -> Str
bitsToChars = \bits, missing ->
when Str.fromUtf8 (bitsToCharsHelp bits missing) is
Ok str -> str
Err _ -> ""
bits_to_chars : U32, Int * -> Str
bits_to_chars = \bits, missing ->
when Str.fromUtf8(bits_to_chars_help(bits, missing)) is
Ok(str) -> str
Err(_) -> ""

# Mask that can be used to get the lowest 6 bits of a binary number
lowest6BitsMask : Int *
lowest6BitsMask = 63
lowest6_bits_mask : Int *
lowest6_bits_mask = 63

bitsToCharsHelp : U32, Int * -> List U8
bitsToCharsHelp = \bits, missing ->
bits_to_chars_help : U32, Int * -> List U8
bits_to_chars_help = \bits, missing ->
# The input is 24 bits, which we have to partition into 4 6-bit segments. We achieve this by
# shifting to the right by (a multiple of) 6 to remove unwanted bits on the right, then `Num.bitwiseAnd`
# with `0b111111` (which is 2^6 - 1 or 63) (so, 6 1s) to remove unwanted bits on the left.
# any 6-bit number is a valid base64 digit, so this is actually safe
p =
Num.shiftRightZfBy bits 18
Num.shiftRightZfBy(bits, 18)
|> Num.intCast
|> unsafeToChar
|> unsafe_to_char

q =
Num.bitwiseAnd (Num.shiftRightZfBy bits 12) lowest6BitsMask
Num.bitwiseAnd(Num.shiftRightZfBy(bits, 12), lowest6_bits_mask)
|> Num.intCast
|> unsafeToChar
|> unsafe_to_char

r =
Num.bitwiseAnd (Num.shiftRightZfBy bits 6) lowest6BitsMask
Num.bitwiseAnd(Num.shiftRightZfBy(bits, 6), lowest6_bits_mask)
|> Num.intCast
|> unsafeToChar
|> unsafe_to_char

s =
Num.bitwiseAnd bits lowest6BitsMask
Num.bitwiseAnd(bits, lowest6_bits_mask)
|> Num.intCast
|> unsafeToChar
|> unsafe_to_char

equals : U8
equals = 61
Expand All @@ -94,8 +94,8 @@ bitsToCharsHelp = \bits, missing ->
[]

# Base64 index to character/digit
unsafeToChar : U8 -> U8
unsafeToChar = \n ->
unsafe_to_char : U8 -> U8
unsafe_to_char = \n ->
if n <= 25 then
# uppercase characters
65 + n
Expand Down
Loading

0 comments on commit 8b641dd

Please sign in to comment.