Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sorting things out #2003

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ function create_callbacks(
end
end

# Update TabulatedRatingCurve Q(h) relationships
tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(
tstops,
update_tabulated_rating_curve!;
save_positions = (false, false),
)
push!(callbacks, tabulated_rating_curve_cb)

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
Expand Down Expand Up @@ -834,31 +825,6 @@ function update_allocation!(integrator)::Nothing
end
end

"Load updates from 'TabulatedRatingCurve / time' into the parameters"
function update_tabulated_rating_curve!(integrator)::Nothing
(; node_id, table, time) = integrator.p.tabulated_rating_curve
t = datetime_since(integrator.t, integrator.p.starttime)

# get groups of consecutive node_id for the current timestamp
rows = searchsorted(time.time, t)
timeblock = view(time, rows)

for group in IterTools.groupby(row -> row.node_id, timeblock)
# update the existing LinearInterpolation
id = first(group).node_id
level = [row.level for row in group]
flow_rate = [row.flow_rate for row in group]
i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id, 0))
table[i] = LinearInterpolation(
flow_rate,
level;
extrapolate = true,
cache_parameters = true,
)
end
return nothing
end

function update_subgrid_level(model::Model)::Model
update_subgrid_level!(model.integrator)
return model
Expand Down
48 changes: 21 additions & 27 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,18 @@
Base.Int32(id::NodeID) = id.value
Base.convert(::Type{Int32}, id::NodeID) = id.value
Base.broadcastable(id::NodeID) = Ref(id)
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.type == id_2.type && id_1.value == id_2.value
Base.show(io::IO, id::NodeID) = print(io, id.type, " #", id.value)
config.snake_case(id::NodeID) = config.snake_case(id.type)
Base.to_index(id::NodeID) = Int(id.value)

Check warning on line 100 in core/src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

core/src/parameter.jl#L100

Added line #L100 was not covered by tests

function Base.isless(id_1::NodeID, id_2::NodeID)::Bool
if id_1.type != id_2.type
error("Cannot compare NodeIDs of different types")
end
return id_1.value < id_2.value
end
# Compare only by value for working with a mix of integers from tables and processed NodeIDs
Base.:(==)(id_1::NodeID, id_2::NodeID) = id_1.value == id_2.value
Base.:(==)(id_1::Integer, id_2::NodeID) = id_1 == id_2.value
Base.:(==)(id_1::NodeID, id_2::Integer) = id_1.value == id_2

Check warning on line 105 in core/src/parameter.jl

View check run for this annotation

Codecov / codecov/patch

core/src/parameter.jl#L105

Added line #L105 was not covered by tests

Base.to_index(id::NodeID) = Int(id.value)
Base.isless(id_1::NodeID, id_2::NodeID)::Bool = id_1.value < id_2.value
Base.isless(id_1::Integer, id_2::NodeID)::Bool = id_1 < id_2.value
Base.isless(id_1::NodeID, id_2::Integer)::Bool = id_1.value < id_2

"LinearInterpolation from a Float64 to a Float64"
const ScalarInterpolation = LinearInterpolation{
Expand Down Expand Up @@ -275,7 +275,7 @@
Base.length(::EdgeMetadata) = 1

"""
The update of an parameter given by a value and a reference to the target
The update of a parameter given by a value and a reference to the target
location of the variable in memory
"""
struct ParameterUpdate{T}
Expand All @@ -289,13 +289,12 @@
end

"""
The parameter update associated with a certain control state
for discrete control
The parameter update associated with a certain control state for discrete control
"""
@kwdef struct ControlStateUpdate
@kwdef struct ControlStateUpdate{T <: AbstractInterpolation}
active::ParameterUpdate{Bool}
scalar_update::Vector{ParameterUpdate{Float64}} = []
itp_update::Vector{ParameterUpdate{ScalarInterpolation}} = []
itp_update::Vector{ParameterUpdate{T}} = ParameterUpdate{ScalarInterpolation}[]
end

"""
Expand Down Expand Up @@ -434,15 +433,10 @@
end

"""
struct TabulatedRatingCurve{C}
struct TabulatedRatingCurve

Rating curve from level to flow rate. The rating curve is a lookup table with linear
interpolation in between. Relation can be updated in time, which is done by moving data from
the `time` field into the `tables`, which is done in the `update_tabulated_rating_curve`
callback.

Type parameter C indicates the content backing the StructVector, which can be a NamedTuple
of Vectors or Arrow Primitives, and is added to avoid type instabilities.
interpolation in between. Relations can be updated in time.

node_id: node ID of the TabulatedRatingCurve node
inflow_edge: incoming flow edge metadata
Expand All @@ -451,18 +445,18 @@
The ID of the source node is always the ID of the TabulatedRatingCurve node
active: whether this node is active and thus contributes flows
max_downstream_level: The downstream level above which the TabulatedRatingCurve flow goes to zero
table: The current Q(h) relationships
time: The time table used for updating the tables
interpolations: All Q(h) relationships for the nodes over time
current_interpolation_index: Per node 1 lookup from t to an index in `interpolations`
control_mapping: dictionary from (node_id, control_state) to Q(h) and/or active state
"""
@kwdef struct TabulatedRatingCurve{C} <: AbstractParameterNode
@kwdef struct TabulatedRatingCurve <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_edge::Vector{EdgeMetadata}
outflow_edge::Vector{EdgeMetadata}
active::Vector{Bool}
max_downstream_level::Vector{Float64} = fill(Inf, length(node_id))
table::Vector{ScalarInterpolation}
time::StructVector{TabulatedRatingCurveTimeV1, C, Int}
interpolations::Vector{ScalarInterpolation}
current_interpolation_index::Vector{IndexLookup}
control_mapping::Dict{Tuple{NodeID, String}, ControlStateUpdate}
end

Expand Down Expand Up @@ -935,14 +929,14 @@
Float64,
}

@kwdef mutable struct Parameters{C1, C2, C3, C4, C5, C6, C7, C8, C9, C10, C11}
@kwdef mutable struct Parameters{C1, C2, C3, C4, C6, C7, C8, C9, C10, C11}
const starttime::DateTime
const graph::ModelGraph
const allocation::Allocation
const basin::Basin{C1, C2, C3, C4}
const linear_resistance::LinearResistance
const manning_resistance::ManningResistance
const tabulated_rating_curve::TabulatedRatingCurve{C5}
const tabulated_rating_curve::TabulatedRatingCurve
const level_boundary::LevelBoundary{C6}
const flow_boundary::FlowBoundary{C7}
const pump::Pump
Expand Down
143 changes: 89 additions & 54 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,14 @@
)
end
elseif node_id in time_node_ids
# TODO replace (time, node_id) order by (node_id, time)
# this fits our access pattern better, so we can use views
idx = findall(==(node_id), time_node_id_vec)
time_subset = time[idx]

time_first_idx = searchsortedfirst(time_node_id_vec[idx], node_id)

time_first_idx = searchsortedfirst(time.node_id, node_id)
for parameter_name in parameter_names
# If the parameter is interpolatable, create an interpolation object
if parameter_name in time_interpolatables
val, is_valid = get_scalar_interpolation(
config.starttime,
t_end,
time_subset,
time,
node_id,
parameter_name;
default_value = hasproperty(defaults, parameter_name) ?
Expand All @@ -174,7 +168,7 @@
val = true
else
# If the parameter is not interpolatable, get the instance in the first row
val = getfield(time_subset[time_first_idx], parameter_name)
val = getfield(time[time_first_idx], parameter_name)
end
end
getfield(out, parameter_name)[node_id.idx] = val
Expand Down Expand Up @@ -300,75 +294,98 @@
static_node_ids, time_node_ids, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve)

if !valid
error(
"Problems encountered when parsing TabulatedRatingcurve static and time node IDs.",
)
end
valid || error(
"Problems encountered when parsing TabulatedRatingcurve static and time node IDs.",
)

interpolations = ScalarInterpolation[]
current_interpolation_index = IndexLookup[]
interpolation_index = 0
control_mapping = Dict{Tuple{NodeID, String}, ControlStateUpdate}()
active = Bool[]
max_downstream_level = Float64[]
errors = false

local is_active, interpolation, max_level

Check warning on line 309 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L309

Added line #L309 was not covered by tests

qh_iterator = IterTools.groupby(row -> (row.node_id, row.time), time)
state = nothing # initial iterator state

for node_id in node_ids
if node_id in static_node_ids
# Loop over all static rating curves (groups) with this node_id.
# If it has a control_state add it to control_mapping.
# The last rating curve forms the initial condition and activity.
source = "static"
# For the static case the interpolation index does not depend on time,
# but it can be changed by DiscreteControl. For simplicity we do create an
# index lookup that doesn't change with time just like the dynamic case.
# DiscreteControl will then change this lookup object.
rows = searchsorted(
NodeID.(NodeType.TabulatedRatingCurve, static.node_id, node_id.idx),
node_id,
)
static_id = view(static, rows)
local is_active, interpolation
# coalesce control_state to nothing to avoid boolean groupby logic on missing
for group in
for qh_group in
IterTools.groupby(row -> coalesce(row.control_state, nothing), static_id)
control_state = first(group).control_state
is_active = coalesce(first(group).active, true)
max_level = coalesce(first(group).max_downstream_level, Inf)
table = StructVector(group)
rowrange =
findlastgroup(node_id, NodeID.(node_id.type, table.node_id, Ref(0)))
if !valid_tabulated_rating_curve(node_id, table, rowrange)
errors = true
end
interpolation = try
qh_interpolation(table, rowrange)
catch
LinearInterpolation(Float64[], Float64[])
end
interpolation_index += 1
first_row = first(qh_group)
control_state = first_row.control_state
is_active = coalesce(first_row.active, true)
max_level = coalesce(first_row.max_downstream_level, Inf)
qh_table = StructVector(qh_group)
interpolation =
qh_interpolation(node_id, qh_table.level, qh_table.flow_rate)
if !ismissing(control_state)
control_mapping[(
NodeID(NodeType.TabulatedRatingCurve, node_id, node_id.idx),
control_state,
)] = ControlStateUpdate(
# let control swap out the static lookup object
index_lookup = static_lookup(interpolation_index)
control_mapping[(node_id, control_state)] = ControlStateUpdate(
ParameterUpdate(:active, is_active),
ParameterUpdate{Float64}[],
[ParameterUpdate(:table, interpolation)],
[ParameterUpdate(:current_interpolation_index, index_lookup)],
)
end
push!(interpolations, interpolation)
end
push!(interpolations, interpolation)
push_lookup!(current_interpolation_index, interpolation_index)
push!(active, is_active)
push!(max_downstream_level, max_level)
elseif node_id in time_node_ids
source = "time"
# get the timestamp that applies to the model starttime
idx_starttime = searchsortedlast(time.time, config.starttime)
pre_table = view(time, 1:idx_starttime)
rowrange =
findlastgroup(node_id, NodeID.(node_id.type, pre_table.node_id, Ref(0)))

if !valid_tabulated_rating_curve(node_id, pre_table, rowrange)
errors = true
lookup_time = Float64[]
lookup_index = Int[]
while true
val_state = iterate(qh_iterator, state)
if val_state === nothing
# end of table
break
end
qh_group, new_state = val_state

first_row = first(qh_group)
group_node_id = first_row.node_id
# max_level just document that it doesn't work and use the first or last
max_level = coalesce(first_row.max_downstream_level, Inf)
t = seconds_since(first_row.time, config.starttime)

qh_table = StructVector(qh_group)
if group_node_id == node_id
# continue iterator
state = new_state

interpolation =
qh_interpolation(node_id, qh_table.level, qh_table.flow_rate)

interpolation_index += 1
push!(interpolations, interpolation)
push!(lookup_index, interpolation_index)
push!(lookup_time, t)
else
# end of group, new timeseries for different node has started,
# don't accept the new state
break

Check warning on line 385 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L385

Added line #L385 was not covered by tests
end
end
interpolation = qh_interpolation(pre_table, rowrange)
max_level = coalesce(pre_table.max_downstream_level[rowrange][begin], Inf)
push!(interpolations, interpolation)
push_lookup!(current_interpolation_index, lookup_index, lookup_time)
push!(active, true)
push!(max_downstream_level, max_level)
else
Expand All @@ -377,17 +394,16 @@
end
end

if errors
error("Errors occurred when parsing TabulatedRatingCurve data.")
end
errors && error("Errors occurred when parsing TabulatedRatingCurve data.")

return TabulatedRatingCurve(;
node_id = node_ids,
inflow_edge = inflow_edge.(Ref(graph), node_ids),
outflow_edge = outflow_edge.(Ref(graph), node_ids),
active,
max_downstream_level,
table = interpolations,
time,
interpolations,
current_interpolation_index,
control_mapping,
)
end
Expand Down Expand Up @@ -1238,6 +1254,7 @@
)
end

"Create and push a ConstantInterpolation to the current_interpolation_index."
function push_lookup!(
current_interpolation_index::Vector{IndexLookup},
lookup_index::Vector{Int},
Expand All @@ -1252,6 +1269,24 @@
push!(current_interpolation_index, index_lookup)
end

"Create and push a static ConstantInterpolation to the current_interpolation_index."
function push_lookup!(current_interpolation_index::Vector{IndexLookup}, lookup_index::Int)
index_lookup = static_lookup(lookup_index)
push!(current_interpolation_index, index_lookup)
end

"Create an interpolation object that always returns `lookup_index`."
function static_lookup(lookup_index::Int)::IndexLookup
# TODO if https://github.com/SciML/DataInterpolations.jl/issues/373 is fixed,
# make these size 1 vectors, and remove `unique` from `valid_tabulated_curve_level`
return ConstantInterpolation(
[lookup_index, lookup_index],
[0.0, 0.0];
extrapolate = true,
cache_parameters = true,
)
end

function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
time = load_structvector(db, config, BasinSubgridTimeV1)
static = load_structvector(db, config, BasinSubgridV1)
Expand Down
Loading
Loading