Skip to content

Commit

Permalink
solve
Browse files Browse the repository at this point in the history
  • Loading branch information
thevolatilebit committed Sep 15, 2023
1 parent 809a2df commit d2f9aba
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 91 deletions.
13 changes: 0 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,12 @@ Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMarkdown = "997ab1e6-3595-5248-9280-8efb232c3433"
GeneratedExpressions = "84d730a5-1eb9-4187-a799-27dd07f33a14"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Expand All @@ -32,26 +26,19 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CSV = "0.10"
Catlab = "0.14"
ComponentArrays = "0.14"
Crayons = "4.1"
DataFrames = "1.6"
DiffEqBase = "6.128"
DifferentialEquations = "7.9"
Distributions = "0.25"
Documenter = "0.27"
DocumenterMarkdown = "0.2"
GeneratedExpressions = "0.1"
IJulia = "1.24"
JLD2 = "0.4"
JSON = "0.21"
MacroTools = "0.5"
NLopt = "1.0"
OrdinaryDiffEq = "6.55"
Plots = "1.39"
Pluto = "0.19"
PlutoUI = "0.7"
Expand Down
1 change: 0 additions & 1 deletion src/ReactiveDynamics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module ReactiveDynamics
using Catlab, Catlab.CategoricalAlgebra, Catlab.Present
using Reexport
using MacroTools
using NLopt
using ComponentArrays

@reexport using GeneratedExpressions
Expand Down
116 changes: 46 additions & 70 deletions src/solvers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# assortment of SciML-compatible problem solvers

export DiscreteProblem

using DiffEqBase, DifferentialEquations
using Distributions
using Random

Expand Down Expand Up @@ -37,7 +32,7 @@ function get_reqs_ongoing!(reqs, qs, state)
for tok in state.ongoing_transitions[i][:transLHS]
in(:rate, tok.modality) &&
(state.ongoing_transitions[i][:transCycleTime] > 0) &&
(reqs[tok.index, i] += qs[i] * tok.stoich * state.solverargs[:tstep])
(reqs[tok.index, i] += qs[i] * tok.stoich * state.dt)
in(:nonblock, tok.modality) && (reqs[tok.index, i] += qs[i] * tok.stoich)
end
end
Expand Down Expand Up @@ -124,9 +119,8 @@ end
"""
Evolve transitions, spawn new transitions.
"""
function evolve!(u, state)
update_u!(state, u)
actual_allocs = zero(u)
function evolve!(state)
actual_allocs = zero(state.u)

## schedule new transitions
reqs = zeros(nparts(state, :S), nparts(state, :T))
Expand All @@ -137,8 +131,9 @@ function evolve!(u, state)
1:nparts(state, :T),
)
qs .= ceil.(Ref(Int), qs)
@show qs
for i = 1:nparts(state, :T)
new_instances = state.solverargs[:tstep] * qs[i] + state[i, :transToSpawn]
new_instances = state.dt * qs[i] + state[i, :transToSpawn]
capacity =
state[i, :transCapacity] -
count(t -> t[:transHash] == state[i, :transHash], state.ongoing_transitions)
Expand All @@ -148,10 +143,13 @@ function evolve!(u, state)
end

reqs = get_reqs_init!(reqs, qs, state)
@show reqs
allocs =
get_allocs!(reqs, u, state, state[:, :transPriority], state.solverargs[:strategy])
get_allocs!(reqs, state.u, state, state[:, :transPriority], state.p[:strategy])
@show allocs
qs .= get_init_satisfied(allocs, qs, state)

@show qs
println("====")
push!(
state.log,
(
Expand All @@ -160,7 +158,7 @@ function evolve!(u, state)
[(hash, q) for (hash, q) in zip(state[:, :transHash], qs)]...,
),
)
u .-= sum(allocs; dims = 2)
state.u .-= sum(allocs; dims = 2)
actual_allocs .+= sum(allocs; dims = 2)

# add spawned transitions to the heap
Expand All @@ -171,18 +169,17 @@ function evolve!(u, state)
)
end

update_u!(state, u)
## evolve ongoing transitions
reqs = zeros(nparts(state, :S), length(state.ongoing_transitions))
qs = map(t -> t.q, state.ongoing_transitions)

get_reqs_ongoing!(reqs, qs, state)
allocs = get_allocs!(
reqs,
u,
state.u,
state,
map(t -> t[:transPriority], state.ongoing_transitions),
state.solverargs[:strategy],
state.p[:strategy],
)
qs .= get_frac_satisfied(allocs, reqs, state)
push!(
Expand All @@ -196,11 +193,11 @@ function evolve!(u, state)
]...,
),
)
u .-= sum(allocs; dims = 2)
state.u .-= sum(allocs; dims = 2)
actual_allocs .+= sum(allocs; dims = 2)

foreach(
i -> state.ongoing_transitions[i].state += qs[i] * state.solverargs[:tstep],
i -> state.ongoing_transitions[i].state += qs[i] * state.dt,
eachindex(state.ongoing_transitions),
)

Expand Down Expand Up @@ -229,8 +226,7 @@ function event_action!(state)
end

# collect terminated transitions
function finish!(u, state)
update_u!(state, u)
function finish!(state)
val_reward = 0
terminated_all = Dict{Symbol,Float64}()
terminated_success = Dict{Symbol,Float64}()
Expand All @@ -256,7 +252,7 @@ function finish!(u, state)
end
for tok in trans_[:transLHS]
in(:conserved, tok.modality) && (
u[tok.index] +=
state.u[tok.index] +=
trans_.q *
tok.stoich *
(in(:rate, tok.modality) ? trans_[:transCycleTime] : 1)
Expand All @@ -268,12 +264,11 @@ function finish!(u, state)
0
end
foreach(
tok -> (u[tok.index] += q * tok.stoich;
tok -> (state.u[tok.index] += q * tok.stoich;
val_reward += state[tok.index, :specReward] * q * tok.stoich),
toks_rhs,
)

update_u!(state, u)
context_eval(state, trans_.trans[:transPostAction])
terminated_all[trans_[:transHash]] =
get(terminated_all, trans_[:transHash], 0) + trans_.q
Expand All @@ -289,7 +284,7 @@ function finish!(u, state)
push!(state.log, (:terminated_success, state.t, terminated_success...))
push!(state.log, (:valuation_reward, state.t, val_reward))

return u
return state.u
end

function free_blocked_species!(state)
Expand All @@ -306,15 +301,15 @@ function get_tcontrol(tspan, args)
tunit = get(args, :tunit, oneunit(tspan))
tspan = tspan / tunit

tstep = get(args, :tstep, haskey(args, :tstops) ? tspan / args[:tstops] : tunit) / tunit
dt = get(args, :dt, haskey(args, :tstops) ? tspan / args[:tstops] : tunit) / tunit

return ((0.0, tspan), tstep)
return ((0.0, tspan), dt)
end

function ReactiveNetwork(
acs::ReactionNetwork,
u0 = Dict(),
p = DiffEqBase.NullParameters();
p = Dict();
name = "reactive_network",
kwargs...,
)
Expand All @@ -325,39 +320,21 @@ function ReactiveNetwork(
])
merge!(keywords, Dict(collect(kwargs)))
merge!(keywords, Dict(:strategy => get(keywords, :alloc_strategy, :weighted)))

keywords[:tspan], keywords[:tstep] = get_tcontrol(keywords[:tspan], keywords)

acs = remove_choose(acs)
attrs, transitions, wrap_fun = compile_attrs(acs)

init_u!(state)
save!(state)
transition_recipes = transitions
u0_init = zeros(nparts(acs, :S))

u0_init = zeros(nparts(state, :S))

u0 isa Dict && foreach(
i ->
if !isnothing(acs[i, :specName]) && haskey(u0, acs[i, :specName])
u0_init[i] = u0[acs[i, :specName]]
end,
1:nparts(state, :S),
)

p_ = p == DiffEqBase.NullParameters() ? Dict() : Dict(k => v for (k, v) in p)
prob = remake(
prob;
u0 = prob.u0,
tspan = keywords[:tspan],
dt = get(keywords, :tstep, 1),
p = merge(
prob.p,
p_,
Dict(
:tstep => get(keywords, :tstep, 1),
:strategy => get(keywords, :alloc_strategy, :weighted),
),
),
)
for i in 1:nparts(acs, :S)
if !isnothing(acs[i, :specName]) && haskey(u0, acs[i, :specName])
u0_init[i] = u0[acs[i, :specName]]
else
u0_init[i] = acs[i, :specInitVal]
end
end

ongoing_transitions = Transition[]
log = NamedTuple[]
Expand All @@ -369,21 +346,19 @@ function ReactiveNetwork(
) [:transLHS, :transRHS, :transToSpawn, :transHash]
transitions = Dict{Symbol,Vector}(a => [] for a in transitions_attrs)

return ReactiveNetwork(
network = ReactiveNetwork(
name,
acs,
attrs,
transition_recipes,
u0_init,
merge(
prob.p,
p_,
p,
Dict(
:tstep => get(keywords, :tstep, 1),
:strategy => get(keywords, :alloc_strategy, :weighted),
),
),
t,
keywords[:tspan][1],
keywords[:tspan],
get(keywords, :tstep, 1),
Expand All @@ -396,31 +371,32 @@ function ReactiveNetwork(
Vector{Float64}[],
Float64[],
)

save!(network)

return network
end

function AlgebraicAgents.step!(state::ReactiveNetwork)
du = copy(state.u)
#du = copy(state.u)
free_blocked_species!(state)
du .= state.u
#du .= state.u
update_observables(state)
sample_transitions!(state)
evolve!(du, state)
finish!(du, state)
update_u!(state, du)
evolve!(state)
finish!(state)
#update_u!(state, u)
event_action!(state)

du .= state.u
push!(
state.log,
(:valuation, t, du' * [state[i, :specValuation] for i = 1:nparts(state, :S)]),
(:valuation, state.t, state.u' * [state[i, :specValuation] for i = 1:nparts(state, :S)]),
)

t = (state.t += state.solverargs[:tstep])
update_u!(state, du)
#update_u!(state, du)
save!(state)
sync_p!(p, state)

state.u .= du
#state.u .= du
state.t += state.dt
end

Expand Down
4 changes: 1 addition & 3 deletions src/state.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using DiffEqBase: NullParameters

using AlgebraicAgents

struct UnfoldedReactant
Expand Down Expand Up @@ -77,7 +75,7 @@ function init_u!(state::ReactiveNetwork)
state.u = u)
end
function save!(state::ReactiveNetwork)
return (push!(state.history_u, state.u); push!(state.history_t, state.t))
return (push!(state.history_u, copy(state.u)); push!(state.history_t, state.t))
end

function compile_observables(acs::ReactionNetwork)
Expand Down
File renamed without changes.
17 changes: 13 additions & 4 deletions tutorial/basics.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
using ReactiveDynamics
using ReactiveDynamics, AlgebraicAgents

# acs = @ReactionNetwork begin
# 1.0, X ⟺ Y
# end

acs = @ReactionNetwork begin
1.0, X Y, name => "transition1"
1.0, X --> Y, name => "transition1"
end

@prob_init acs X = 10 Y = 20
@prob_params acs
@prob_meta acs tspan = 250 dt = 0.1

prob = @problematize acs

# sol = ReactiveDynamics.solve(prob)

sol = @solve prob
#sol = @solve prob

prob = @agentize acs

for i in 1:30
AlgebraicAgents.step!(prob)
end
prob.history_u
using Plots

@plot sol plot_type = summary

prob = @problematize acs
@solve prob

0 comments on commit d2f9aba

Please sign in to comment.