From ad6d50e43a8e69c9a0a7164dc37967b79fe5bb9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=ADma=2C=20Jan?= Date: Sat, 16 Sep 2023 23:09:22 +0200 Subject: [PATCH] fix plotting, solution, log --- src/ReactiveDynamics.jl | 16 +- src/compilers.jl | 8 +- src/interface/create.jl | 14 +- src/interface/plots.jl | 19 +++ src/interface/reaction_parser.jl | 2 +- src/interface/solve.jl | 50 +------ src/loadsave.jl | 6 +- src/operators/equalize.jl | 2 +- src/operators/joins.jl | 4 +- src/optim.jl | 244 ------------------------------- src/solvers.jl | 53 ++++--- src/state.jl | 67 ++++----- tutorial/basics.jl | 36 ++--- 13 files changed, 113 insertions(+), 408 deletions(-) create mode 100644 src/interface/plots.jl delete mode 100644 src/optim.jl diff --git a/src/ReactiveDynamics.jl b/src/ReactiveDynamics.jl index ed1da88..b0eef6f 100644 --- a/src/ReactiveDynamics.jl +++ b/src/ReactiveDynamics.jl @@ -74,7 +74,7 @@ end @acset_type FoldedReactionNetworkType(TheoryReactionNetwork) -const ReactionNetwork = FoldedReactionNetworkType{ +const ReactionNetworkSchema = FoldedReactionNetworkType{ Symbol, Union{String,Symbol,Missing}, SampleableValues, @@ -131,11 +131,11 @@ defargs = Dict( ) compilable_attrs = - filter(attr -> eltype(attr) == SampleableValues, propertynames(ReactionNetwork())) + filter(attr -> eltype(attr) == SampleableValues, propertynames(ReactionNetworkSchema())) species_modalities = [:nonblock, :conserved, :rate] -function assign_defaults!(acs::ReactionNetwork) +function assign_defaults!(acs::ReactionNetworkSchema) for (_, v_) in defargs, (k, v) in v_ for i = 1:length(subpart(acs, k)) isnothing(acs[i, k]) && (subpart(acs, k)[i] = v) @@ -160,12 +160,12 @@ function assign_defaults!(acs::ReactionNetwork) return acs end -function ReactionNetwork(transitions, reactants, obs, events) - return merge_acs!(ReactionNetwork(), transitions, reactants, obs, events) +function ReactionNetworkSchema(transitions, reactants, obs, events) + return merge_acs!(ReactionNetworkSchema(), transitions, reactants, obs, events) end -function ReactionNetwork(transitions, reactants, obs) - return merge_acs!(ReactionNetwork(), transitions, reactants, obs, []) +function ReactionNetworkSchema(transitions, reactants, obs) + return merge_acs!(ReactionNetworkSchema(), transitions, reactants, obs, []) end function add_obs!(acs, obs) @@ -194,7 +194,7 @@ function add_obs!(acs, obs) return acs end -function merge_acs!(acs::ReactionNetwork, transitions, reactants, obs, events) +function merge_acs!(acs::ReactionNetworkSchema, transitions, reactants, obs, events) foreach( t -> add_part!(acs, :T; trans = t[1][2], transRate = t[1][1], t[2]...), transitions, diff --git a/src/compilers.jl b/src/compilers.jl index 4588b69..9edccc6 100644 --- a/src/compilers.jl +++ b/src/compilers.jl @@ -112,12 +112,12 @@ function wrap_expr(fex, species_names, prm_names, varmap) ) push!(letex.args[2].args, fex) - # the function shall be a function of the dynamic ReactiveNetwork structure: letex -> :(state -> $letex) + # the function shall be a function of the dynamic ReactionNetworkSchema structure: letex -> :(state -> $letex) # eval the expression to a Julia function, save that function into the "compiled" acset return eval(:(state -> $letex)) end -function get_wrap_fun(acs::ReactionNetwork) +function get_wrap_fun(acs::ReactionNetworkSchema) species_names = collect(acs[:, :specName]) prm_names = collect(acs[:, :prmName]) varmap = Dict([name => :(state.u[$i]) for (i, name) in enumerate(species_names)]) @@ -133,7 +133,7 @@ function skip_compile(attr) (string(attr) == "trans") end -function compile_attrs(acs::ReactionNetwork) +function compile_attrs(acs::ReactionNetworkSchema) species_names = collect(acs[:, :specName]) prm_names = collect(acs[:, :prmName]) varmap = Dict([name => :(state.u[$i]) for (i, name) in enumerate(species_names)]) @@ -169,7 +169,7 @@ function compile_attrs(acs::ReactionNetwork) return attrs, transitions, wrap_fun end -function remove_choose(acs::ReactionNetwork) +function remove_choose(acs::ReactionNetworkSchema) acs = deepcopy(acs) pcs = [] for attr in propertynames(acs.subparts) diff --git a/src/interface/create.jl b/src/interface/create.jl index fe23701..ec49f85 100644 --- a/src/interface/create.jl +++ b/src/interface/create.jl @@ -1,6 +1,6 @@ # reaction network DSL: CREATE part; reaction line and event parsing -export @ReactionNetwork +export @ReactionNetworkSchema using MacroTools: prewalk, postwalk, striplines, isexpr using Symbolics: build_function, get_variables @@ -45,7 +45,7 @@ Custom functions and sampleable objects can be used as numeric parameters. Note # Examples ```julia -acs = @ReactionNetwork begin +acs = @ReactionNetworkSchema begin 1.0, X ⟶ Y 1.0, X ⟶ Y, priority => 6.0, prob => 0.7, capacity => 3.0 1.0, ∅ --> (Poisson(0.3γ)X, Poisson(0.5)Y) @@ -57,17 +57,17 @@ end @solve_and_plot acs ``` """ -macro ReactionNetwork end +macro ReactionNetworkSchema end -macro ReactionNetwork() +macro ReactionNetworkSchema() return make_ReactionNetwork(:()) end -macro ReactionNetwork(ex) +macro ReactionNetworkSchema(ex) return make_ReactionNetwork(ex; eval_module = __module__) end -macro ReactionNetwork(ex, args...) +macro ReactionNetworkSchema(ex, args...) return make_ReactionNetwork( generate(Expr(:braces, ex, args...); eval_module = __module__); eval_module = __module__, @@ -78,7 +78,7 @@ function make_ReactionNetwork(ex::Expr; eval_module = @__MODULE__) blockex = generate(ex; eval_module) blockex = unblock_shallow!(blockex) - return :(ReactionNetwork(get_data($(QuoteNode(blockex)))...)) + return :(ReactionNetworkSchema(get_data($(QuoteNode(blockex)))...)) end ### Functions that process the input and rephrase it as a reaction system ### diff --git a/src/interface/plots.jl b/src/interface/plots.jl new file mode 100644 index 0000000..22de29f --- /dev/null +++ b/src/interface/plots.jl @@ -0,0 +1,19 @@ +using Plots + +function plot_df(df::DataFrames.DataFrame, t_ix = 1) + data = Matrix(df) + t = @view data[:, t_ix] + data_ = @view data[:, setdiff(1:size(data, 2), (t_ix,))] + colnames = reshape(DataFrames.names(df)[setdiff(1:size(data, 2), (t_ix,))], 1, :) + + Plots.plot(t, data_, labels = colnames, xlabel = "t") +end + +# plot reduction +function AlgebraicAgents._draw(prob::ReactionNetworkProblem, vars = string.(prob.acs[:, :specName]); kwargs...) + p = plot() + for var in vars + p = plot!(p, prob.sol[!, "t"], prob.sol[!, var]; label="$var", xlabel="time", ylabel="quantity", kwargs...) + end + p +end diff --git a/src/interface/reaction_parser.jl b/src/interface/reaction_parser.jl index df69f53..2c566bb 100644 --- a/src/interface/reaction_parser.jl +++ b/src/interface/reaction_parser.jl @@ -29,7 +29,7 @@ function recursively_choose(r_line, state) end end -function extract_reactants(r_line, state::ReactiveNetwork) +function extract_reactants(r_line, state::ReactionNetworkProblem) r_line = recursively_choose(r_line, state) return recursive_find_reactants!( diff --git a/src/interface/solve.jl b/src/interface/solve.jl index 029aee3..e8a1913 100644 --- a/src/interface/solve.jl +++ b/src/interface/solve.jl @@ -1,56 +1,8 @@ -export @agentize, @solve, @plot -export @optimize, @fit, @fit_and_plot, @build_solver +export @agentize import MacroTools import Plots -""" -Convert a model to a `ReactiveNetwork`. If passed a problem instance, return the instance. - -# Examples - -```julia -@agentize acs tspan = 1:100 -``` -""" -macro agentize(acsex, args...) - args, kwargs = args_kwargs(args) - quote - if $(esc(acsex)) isa ReactiveNetwork - $(esc(acsex)) - else - ReactiveNetwork($(esc(acsex)), $(args...); $(kwargs...)) - end - end -end - -""" -Solve the problem. Solverargs passed at the calltime take precedence. - -# Examples - -```julia -@solve prob -@solve prob tspan = 1:100 -@solve prob tspan = 100 -``` -""" -macro solve(probex, args...) - args, kwargs = args_kwargs(args) - mode = find_kwargex_delete!(kwargs, :mode, nothing) - !isnothing(findfirst(el -> el.args[1] == :trajectories, kwargs)) && (mode = :ensemble) - - quote - prob = if $(esc(probex)) isa ReactiveNetwork - $(esc(probex)) - else - ReactiveNetwork($(esc(probex)), $(args...); $(kwargs...)) - end - - simulate(prob) - end -end - # auxiliary plotting functions function plot_summary(s, labels, ixs; kwargs...) isempty(ixs) && return @warn "Set of species to plot must be non-empty!" diff --git a/src/loadsave.jl b/src/loadsave.jl index da3d8b4..89571fb 100644 --- a/src/loadsave.jl +++ b/src/loadsave.jl @@ -16,7 +16,7 @@ const objects_aliases = Dict( :obs => "obs", ) -const RN_attrs = string.(propertynames(ReactionNetwork().subparts)) +const RN_attrs = string.(propertynames(ReactionNetworkSchema().subparts)) function get_attrs(object) object = object isa Symbol ? objects_aliases[object] : object @@ -24,7 +24,7 @@ function get_attrs(object) return filter(x -> occursin(object, x), RN_attrs) end -function export_network(acs::ReactionNetwork) +function export_network(acs::ReactionNetworkSchema) dict = Dict() for (key, val) in objects_aliases push!(dict, val => []) @@ -113,7 +113,7 @@ function import_network(path::AbstractString) end end -function export_network(acs::ReactionNetwork, path::AbstractString) +function export_network(acs::ReactionNetworkSchema, path::AbstractString) if splitext(path)[2] == ".csv" exported_network = export_network(acs) paths = DataFrame(; type = [], path = []) diff --git a/src/operators/equalize.jl b/src/operators/equalize.jl index cbe357d..e679ec9 100644 --- a/src/operators/equalize.jl +++ b/src/operators/equalize.jl @@ -21,7 +21,7 @@ function get_eqs_ff(eq) end end -function equalize!(acs::ReactionNetwork, eqs = []) +function equalize!(acs::ReactionNetworkSchema, eqs = []) specmap = Dict() for block in eqs block_alias = findfirst(e -> e[1] == :alias, block) diff --git a/src/operators/joins.jl b/src/operators/joins.jl index 40ebc73..f5c3276 100644 --- a/src/operators/joins.jl +++ b/src/operators/joins.jl @@ -59,7 +59,7 @@ end """ Prepend species names with a model identifier (unless a global species name). """ -function prepend!(acs::ReactionNetwork, name = gensym("acs"), eqs = []) +function prepend!(acs::ReactionNetworkSchema, name = gensym("acs"), eqs = []) specmap = Dict() for i = 1:nparts(acs, :S) new_name = normalize_name(name, i, acs[i, :specName], eqs) @@ -199,7 +199,7 @@ Model variables / parameter values and metadata are propagated; the last model t macro join(exs...) callex = :( begin - acs_new = ReactionNetwork() + acs_new = ReactionNetworkSchema() end ) exs = collect(exs) diff --git a/src/optim.jl b/src/optim.jl deleted file mode 100644 index a30e179..0000000 --- a/src/optim.jl +++ /dev/null @@ -1,244 +0,0 @@ -function build_parametrized_solver(acs, init_vec, u0, params; trajectories = 1) - prob = DiscreteProblem(acs) - vars = prob.p[:__state__][:, :specInitUncertainty] - init_vec = deepcopy(init_vec) - - function (vec) - vec = vec isa ComponentVector ? vec : (init_vec .= vec) - data = [] - for _ = 1:trajectories - prob.p[:__state__] = deepcopy(prob.p[:__state0__]) - for i in eachindex(prob.u0) - rv = randn() * vars[i] - prob.u0[i] = if (sign(rv + prob.u0[i]) == sign(prob.u0[i])) - rv + prob.u0[i] - else - prob.u0[i] - end - end - - for (i, k) in enumerate(wkeys(u0)) - prob.u0[k] = vec.species[i] - end - for k in wkeys(params) - prob.p[k] = vec[k] - end - - sync!(prob.p[:__state__], prob.u0, prob.p) - push!(data, solve(prob)) - end - - return data - end -end - -function build_parametrized_solver_(acs, init_vec, u0, params; trajectories = 1) - prob = DiscreteProblem(acs) - vars = prob.p[:__state__][:, :specInitUncertainty] - init_vec = deepcopy(init_vec) - - function (vec) - vec = vec isa ComponentVector ? vec : (init_vec .= vec; init_vec) - data = map(1:trajectories) do _ - prob.p[:__state__] = deepcopy(prob.p[:__state0__]) - for i in eachindex(prob.u0) - rv = randn() * vars[i] - prob.u0[i] = if (sign(rv + prob.u0[i]) == sign(prob.u0[i])) - rv + prob.u0[i] - else - prob.u0[i] - end - end - - for (i, k) in enumerate(wkeys(u0)) - prob.u0[k] = vec.species[i] - end - for k in wkeys(params) - prob.p[k] = vec[k] - end - - sync!(prob.p[:__state__], prob.u0, prob.p) - - return solve(prob) - end - - return trajectories == 1 ? data[1] : EnsembleSolution(data, 0.0, true) - end -end - -## optimization part - -BOUND_DEFAULT = 5000 - -function optim!(obj, init; nlopt_kwargs...) - nlopt_kwargs = Dict(nlopt_kwargs) - alg = pop!(nlopt_kwargs, :algorithm, :GN_DIRECT) - - opt = Opt(alg, length(init)) - - # match to a ComponentVector - foreach( - o -> setproperty!(opt, o...), - filter(x -> x[1] in propertynames(opt), nlopt_kwargs), - ) - if get(nlopt_kwargs, :objective, min) == min - (opt.min_objective = obj) - else - (opt.max_objective = obj) - end - - return optimize(opt, deepcopy(init)) -end - -const n_steps = 100 - -# loss objective given an objective expression -function build_loss_objective( - acs, - init_vec, - u0, - params, - obex; - loss = identity, - trajectories = 1, - min_t = -Inf, - max_t = Inf, - final_only = false, -) - ob = eval(get_wrap_fun(acs)(obex)) - obj_ = build_parametrized_solver(acs, init_vec, u0, params; trajectories) - - function (vec, _) - ls = [] - for sol in obj_(vec) - t_points = if final_only - [last(sol.t)] - else - min_t = max(min_t, sol.prob.tspan[1]) - max_t = min(max_t, sol.prob.tspan[2]) - - range(min_t, max_t; length = n_steps) - end - - push!( - ls, - mean(t -> loss(ob(as_state(sol(t), t, sol.prob.p[:__state__]))), t_points), - ) - end - - return mean(ls) - end -end - -# loss objective given empirical data -function build_loss_objective_datapoints( - acs, - init_vec, - u0, - params, - t, - data, - vars; - loss = abs2, - trajectories = 1, -) - obj_ = build_parametrized_solver(acs, init_vec, u0, params; trajectories) - - function (vec, _) - ls = [] - for sol in obj_(vec) - push!( - ls, - mean( - t -> - sum(i -> loss(sol(t[2])[i[2]] - data[i[1], t[1]]), enumerate(vars)), - enumerate(t), - ), - ) - end - - return mean(ls) - end -end - -# set initial model parameter values in an optimization problem -function prep_params!(params, prob) - for (k, v) in params - (v === NaN) && wset!(params, k, get(prob.p, k, NaN)) - end - any(p -> (p[2] === NaN) && @warn("Uninitialized prm: $p"), params) - - return params -end - -# set initial model variable values in an optimization problem -function prep_u0!(u0, prob) - for (k, v) in u0 - (v === NaN) && wset!(u0, k, get(prob.u0, k, NaN)) - end - any(u -> (u[2] === NaN) && @warn("Uninitialized prm: $(u[1])"), u0) - - return u0 -end - -""" -Extract symbolic variables referenced in `acs`, `args`. -""" -function get_free_vars(acs, args) - u0_syms = collect(acs[:, :specName]) - p_syms = collect(acs[:, :prmName]) - u0 = [] - p = [] - - for arg in args - if arg isa Symbol - (k, v) = (arg, NaN) - elseif isexpr(arg, :(=)) - (k, v) = (arg.args[1], arg.args[2]) - else - continue - end - - if ((k in u0_syms || k isa Number) && !in(k, wkeys(u0))) - push!(u0, k => v) - elseif (k in p_syms && !in(k, wkeys(p))) - push!(p, k => v) - end - end - - u0_ = [] - for (k, v) in u0 - if k isa Number - push!(u0_, Int(k) => v) - else - for i = 1:length(subpart(acs, :specName)) - (acs[i, :specName] == k) && (push!(u0_, i => v); break) - end - end - end - - return u0_, p -end - -""" -Resolve symbolic / positional model variable names to positional. -""" -function get_vars(acs, args) - (args == :()) && return args - args_ = [] - - for arg in (MacroTools.isexpr(args, :vect, :tuple) ? args.args : [args]) - arg = recursively_expand_dots(arg) - if arg isa Number - push!(args_, Int(arg)) - else - for i = 1:length(subpart(acs, :specName)) - !isnothing(acs[i, :specName]) && - (acs[i, :specName] == arg) && - (push!(args_, i); break) - end - end - end - - return args_ -end diff --git a/src/solvers.jl b/src/solvers.jl index 27cefcf..2737fe1 100644 --- a/src/solvers.jl +++ b/src/solvers.jl @@ -1,6 +1,8 @@ using Distributions using Random +export ReactionNetworkProblem + function get_sampled_transition(state, i) transition = Dict{Symbol,Any}() foreach(k -> push!(transition, k => state[i, k]), keys(state.transitions)) @@ -107,11 +109,11 @@ function get_init_satisfied(allocs, qs, state) (reqs[tok.index, i] += tok.stoich) end end - @show 2 reqs + for i in eachindex(allocs) allocs[i] = reqs[i] == 0.0 ? Inf : floor(allocs[i] / reqs[i]) end - @show allocs + foreach(i -> qs[i] = min(qs[i], minimum(allocs[:, i])), 1:size(reqs, 2)) foreach(i -> allocs[:, i] .= reqs[:, i] * qs[i], 1:size(reqs, 2)) @@ -133,7 +135,7 @@ function evolve!(state) 1:nparts(state, :T), ) qs .= ceil.(Ref(Int), qs) - @show qs + for i = 1:nparts(state, :T) new_instances = qs[i] + state[i, :transToSpawn] capacity = @@ -145,13 +147,12 @@ function evolve!(state) end reqs = get_reqs_init!(reqs, qs, state) - @show reqs + allocs = get_allocs!(reqs, state.u, state, state[:, :transPriority], state.p[:strategy]) - @show allocs + qs .= get_init_satisfied(allocs, qs, state) - @show qs - println("====") + push!( state.log, ( @@ -260,7 +261,7 @@ function finish!(state) (in(:rate, tok.modality) ? trans_[:transCycleTime] : 1) ) end - + q = if trans_.state >= trans_[:transCycleTime] rand(Distributions.Binomial(Int(trans_.q), trans_[:transProbOfSuccess])) else @@ -297,7 +298,6 @@ function free_blocked_species!(state) end end - ## resolve tspan, tstep function get_tcontrol(tspan, args) @@ -310,8 +310,8 @@ function get_tcontrol(tspan, args) return ((0.0, tspan), dt) end -function ReactiveNetwork( - acs::ReactionNetwork, +function ReactionNetworkProblem( + acs::ReactionNetworkSchema, u0 = Dict(), p = Dict(); name = "reactive_network", @@ -350,7 +350,8 @@ function ReactiveNetwork( ) ∪ [:transLHS, :transRHS, :transToSpawn, :transHash] transitions = Dict{Symbol,Vector}(a => [] for a in transitions_attrs) - network = ReactiveNetwork( + sol = DataFrame("t" => Float64[], (string(name) => Float64[] for name in acs[:, :specName])...) + network = ReactionNetworkProblem( name, acs, attrs, @@ -359,7 +360,6 @@ function ReactiveNetwork( merge( p, Dict( - :tstep => get(keywords, :tstep, 1), :strategy => get(keywords, :alloc_strategy, :weighted), ), ), @@ -370,10 +370,8 @@ function ReactiveNetwork( ongoing_transitions, log, observables, - kwargs, wrap_fun, - Vector{Float64}[], - Float64[], + sol ) save!(network) @@ -381,7 +379,18 @@ function ReactiveNetwork( return network end -function AlgebraicAgents.step!(state::ReactiveNetwork) +function AlgebraicAgents._reinit!(state::ReactionNetworkProblem) + state.u .= isempty(state.sol) ? state.u : Vector(state.sol[1, 2:end]) + state.t = state.tspan[1] + empty!(state.ongoing_transitions) + empty!(state.log) + state.observables = compile_observables(state.acs) + empty!(state.sol) + + state +end + +function AlgebraicAgents._step!(state::ReactionNetworkProblem) free_blocked_species!(state) update_observables(state) sample_transitions!(state) @@ -398,15 +407,11 @@ function AlgebraicAgents.step!(state::ReactiveNetwork) return state.t += state.dt end -function AlgebraicAgents._projected_to(state::ReactiveNetwork) - if state.t >= state.tspan[2] - true - else - state.t - end +function AlgebraicAgents._projected_to(state::ReactionNetworkProblem) + state.t > state.tspan[2] ? true : state.t end -function fetch_params(acs::ReactionNetwork) +function fetch_params(acs::ReactionNetworkSchema) return Dict{Symbol,Any}(( acs[i, :prmName] => acs[i, :prmVal] for i in Iterators.filter(i -> !isnothing(acs[i, :prmVal]), 1:nparts(acs, :P)) diff --git a/src/state.jl b/src/state.jl index d70d3bf..0aba482 100644 --- a/src/state.jl +++ b/src/state.jl @@ -1,4 +1,5 @@ -using AlgebraicAgents +@reexport using AlgebraicAgents +using DataFrames struct UnfoldedReactant index::Int @@ -29,8 +30,8 @@ Base.getindex(state::Transition, key) = state.trans[key] sampled::Any end -@aagent struct ReactiveNetwork - acs::ReactionNetwork +@aagent struct ReactionNetworkProblem + acs::ReactionNetworkSchema attrs::Dict{Symbol,Vector} transition_recipes::Dict{Symbol,Vector} @@ -47,38 +48,36 @@ end log::Vector{Tuple} observables::Dict{Symbol,Observable} - solverargs::Any wrap_fun::Any - history_u::Vector{Vector{Float64}} - history_t::Vector{Float64} + sol::DataFrame end # get value of a numeric expression # evaluate compiled numeric expression in context of (u, p, t) -function context_eval(state::ReactiveNetwork, o) +function context_eval(state::ReactionNetworkProblem, o) o = o isa Function ? Base.invokelatest(o, state) : o return o isa Sampleable ? rand(o) : o end -function Base.getindex(state::ReactiveNetwork, keys...) +function Base.getindex(state::ReactionNetworkProblem, keys...) return context_eval( state, (contains(string(keys[2]), "trans") ? state.transitions : state.attrs)[keys[2]][keys[1]], ) end -function init_u!(state::ReactiveNetwork) +function init_u!(state::ReactionNetworkProblem) return (u = fill(0.0, nparts(state, :S)); foreach(i -> u[i] = state[i, :specInitVal], 1:nparts(state, :S)); state.u = u) end -function save!(state::ReactiveNetwork) - return (push!(state.history_u, copy(state.u)); push!(state.history_t, state.t)) +function save!(state::ReactionNetworkProblem) + return push!(state.sol, (state.t, state.u[:]...)) end -function compile_observables(acs::ReactionNetwork) +function compile_observables(acs::ReactionNetworkSchema) observables = Dict{Symbol,Observable}() species_names = collect(acs[:, :specName]) prm_names = collect(acs[:, :prmName]) @@ -116,16 +115,16 @@ function sample_range(rng, state) return r isa Sampleable ? rand(r) : r end -function resample!(state::ReactiveNetwork, o::Observable) +function resample!(state::ReactionNetworkProblem, o::Observable) o.last = state.t isempty(o.range) && (return o.val = missing) return o.sampled = context_eval(state, sample_range(o.range, state)) end -resample(state::ReactiveNetwork, o::Symbol) = resample!(state, state.observables[o]) +resample(state::ReactionNetworkProblem, o::Symbol) = resample!(state, state.observables[o]) -function update_observables(state::ReactiveNetwork) +function update_observables(state::ReactionNetworkProblem) return foreach( o -> (state.t - o.last) >= o.every && resample!(state, o), values(state.observables), @@ -151,11 +150,11 @@ function prune_r_line(r_line) end end -function find_index(species::Symbol, state::ReactiveNetwork) +function find_index(species::Symbol, state::ReactionNetworkProblem) return findfirst(i -> state[i, :specName] == species, 1:nparts(state, :S)) end -function sample_transitions!(state::ReactiveNetwork) +function sample_transitions!(state::ReactionNetworkProblem) for (_, v) in state.transitions empty!(v) end @@ -194,47 +193,35 @@ function sample_transitions!(state::ReactiveNetwork) end end -## sync -update_u!(state::ReactiveNetwork, u) = (state.u .= u) -update_t!(state::ReactiveNetwork, t) = (state.t = t) -sync_p!(p, state::ReactiveNetwork) = merge!(p, state.p) - -function sync!(state::ReactiveNetwork, u, p) - state.u .= u - for k in keys(state.p) - haskey(p, k) && (state.p[k] = p[k]) - end -end - -function as_state(u, t, state::ReactiveNetwork) +function as_state(u, t, state::ReactionNetworkProblem) return (state = deepcopy(state); state.u .= u; state.t = t; state) end -function Catlab.CategoricalAlgebra.nparts(state::ReactiveNetwork, obj::Symbol) +function Catlab.CategoricalAlgebra.nparts(state::ReactionNetworkProblem, obj::Symbol) return obj == :T ? length(state.transitions[:transLHS]) : nparts(state.acs, obj) end ## query the state -t(state::ReactiveNetwork) = state.t -solverarg(state::ReactiveNetwork, arg) = state.p[arg] -take(state::ReactiveNetwork, pcs::Symbol) = state.observables[pcs].sampled -log(state::ReactiveNetwork, msg) = (println(msg); push!(state.log, (:log, msg))) -state(state::ReactiveNetwork) = state +t(state::ReactionNetworkProblem) = state.t +solverarg(state::ReactionNetworkProblem, arg) = state.p[arg] +take(state::ReactionNetworkProblem, pcs::Symbol) = state.observables[pcs].sampled +log(state::ReactionNetworkProblem, msg) = (println(msg); push!(state.log, (:log, msg))) +state(state::ReactionNetworkProblem) = state -function periodic(state::ReactiveNetwork, period) +function periodic(state::ReactionNetworkProblem, period) return period == 0.0 || ( length(state.history_t) > 1 && (fld(state.t, period) - fld(state.history_t[end-1], period) > 0) ) end -set_params(state::ReactiveNetwork, vals...) = +set_params(state::ReactionNetworkProblem, vals...) = for (p, v) in vals state.p[p] = v end -function add_to_spawn!(state::ReactiveNetwork, hash, n) +function add_to_spawn!(state::ReactionNetworkProblem, hash, n) ix = findfirst(ix -> state.transition_recipes[:transHash][ix] == hash) return !isnothing(ix) && (state.transition_recipes[:transHash][ix] += n) -end +end \ No newline at end of file diff --git a/tutorial/basics.jl b/tutorial/basics.jl index e05f7f7..de13b55 100644 --- a/tutorial/basics.jl +++ b/tutorial/basics.jl @@ -1,35 +1,21 @@ -using ReactiveDynamics, AlgebraicAgents +using ReactiveDynamics -# acs = @ReactionNetwork begin -# 1.0, X ⟺ Y -# end - -acs = @ReactionNetwork begin +# define the network +acs = @ReactionNetworkSchema begin 1.0, X --> Y, name => "transition1" end @prob_init acs X = 10 Y = 20 @prob_params acs -@prob_meta acs tspan = 250 dt = 0.11 - - -# sol = ReactiveDynamics.solve(prob) - -#sol = @solve prob - -prob = @agentize acs - -for i in 1:30 - AlgebraicAgents.step!(prob) -end -prob.history_u - +@prob_meta acs tspan = 25 dt = 0.10 +# convert network into an AlgAgents hierarchy +prob = ReactionNetworkProblem(acs) -using Plots +# simulate +simulate(prob) -@plot sol plot_type = summary -∪ +# access solution +prob.sol -prob = @problematize acs -@solve prob +draw(prob)