Skip to content

Commit 4d0ac40

Browse files
feat: add discrete saving to callback structs
1 parent a6feb53 commit 4d0ac40

File tree

2 files changed

+92
-33
lines changed

2 files changed

+92
-33
lines changed

src/callbacks.jl

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ Contains a single callback whose `condition` is a continuous function. The callb
110110
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
111111
used as that will lead to an unstable step following initialization. This warning can be
112112
ignored for non-DAE ODEs.
113+
114+
# Extended help
115+
116+
- `discrete_save_idxs`: An iterable of timeseries indexes to save after the callback triggers. MTK-only
117+
API
113118
"""
114-
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
119+
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI} <:
115120
AbstractContinuousCallback
116121
condition::F1
117122
affect!::F2
@@ -127,20 +132,20 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
127132
reltol::T2
128133
repeat_nudge::T3
129134
initializealg::T4
135+
discrete_save_idxs::DSI
130136
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
131137
initialize::F4, finalize::F5, idxs::I, rootfind,
132138
interp_points, save_positions, dtrelax::R, abstol::T,
133-
reltol::T2,
134-
repeat_nudge::T3,
135-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
139+
reltol::T2, repeat_nudge::T3, initializealg::T4 = nothing,
140+
discrete_save_idxs::DSI = ()) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI
136141
}
137142
_condition = prepare_function(condition)
138-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
143+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI}(_condition,
139144
affect!, affect_neg!,
140145
initialize, finalize, idxs, rootfind,
141146
interp_points,
142147
BitArray(collect(save_positions)),
143-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
148+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
144149
end
145150
end
146151

@@ -154,12 +159,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
154159
dtrelax = 1,
155160
abstol = 10eps(), reltol = 0,
156161
repeat_nudge = 1 // 100,
157-
initializealg = nothing)
162+
initializealg = nothing,
163+
discrete_save_idxs = ())
158164
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
159165
idxs,
160166
rootfind, interp_points,
161167
save_positions,
162-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
168+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
163169
end
164170

165171
function ContinuousCallback(condition, affect!;
@@ -172,11 +178,11 @@ function ContinuousCallback(condition, affect!;
172178
interp_points = 10,
173179
dtrelax = 1,
174180
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
175-
initializealg = nothing)
181+
initializealg = nothing, discrete_save_idxs = ())
176182
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
177183
rootfind, interp_points,
178184
collect(save_positions),
179-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
185+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
180186
end
181187

182188
"""
@@ -219,8 +225,13 @@ multiple events.
219225
- `len`: Number of callbacks chained. This is compulsory to be specified.
220226
221227
Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
228+
229+
# Extended help
230+
231+
- `discrete_save_idxs`: An iterable of `len` elements, where the `i`th element is an iterable of timeseries
232+
indexes to save when the `i`th event triggers. MTK-only API.
222233
"""
223-
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
234+
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI} <:
224235
AbstractContinuousCallback
225236
condition::F1
226237
affect!::F2
@@ -237,21 +248,24 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
237248
reltol::T2
238249
repeat_nudge::T3
239250
initializealg::T4
251+
discrete_save_idxs::DSI
240252
function VectorContinuousCallback(
241253
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
242254
initialize::F4, finalize::F5, idxs::I, rootfind,
243255
interp_points, save_positions, dtrelax::R,
244-
abstol::T, reltol::T2,
245-
repeat_nudge::T3,
246-
initializealg::T4 = nothing) where {F1, F2, F3, F4, F5, T, T2,
247-
T3, T4, I, R}
256+
abstol::T, reltol::T2, repeat_nudge::T3,
257+
initializealg::T4 = nothing,
258+
discrete_save_idxs::DSI = ()) where {F1, F2, F3, F4, F5, T, T2,
259+
T3, T4, I, R, DSI}
248260
_condition = prepare_function(condition)
249-
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
261+
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R, DSI}(
262+
_condition,
250263
affect!, affect_neg!, len,
251264
initialize, finalize, idxs, rootfind,
252265
interp_points,
253266
BitArray(collect(save_positions)),
254-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
267+
dtrelax, abstol, reltol, repeat_nudge, initializealg,
268+
discrete_save_idxs)
255269
end
256270
end
257271

@@ -264,13 +278,13 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
264278
interp_points = 10,
265279
dtrelax = 1,
266280
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
267-
initializealg = nothing)
281+
initializealg = nothing, discrete_save_idxs = ())
268282
VectorContinuousCallback(condition, affect!, affect_neg!, len,
269283
initialize, finalize,
270284
idxs,
271285
rootfind, interp_points,
272286
save_positions, dtrelax,
273-
abstol, reltol, repeat_nudge, initializealg)
287+
abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
274288
end
275289

276290
function VectorContinuousCallback(condition, affect!, len;
@@ -283,12 +297,12 @@ function VectorContinuousCallback(condition, affect!, len;
283297
interp_points = 10,
284298
dtrelax = 1,
285299
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
286-
initializealg = nothing)
300+
initializealg = nothing, discrete_save_idxs = ())
287301
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
288302
idxs,
289303
rootfind, interp_points,
290304
collect(save_positions),
291-
dtrelax, abstol, reltol, repeat_nudge, initializealg)
305+
dtrelax, abstol, reltol, repeat_nudge, initializealg, discrete_save_idxs)
292306
end
293307

294308
"""
@@ -339,31 +353,39 @@ DiscreteCallback(condition, affect!;
339353
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
340354
used as that will lead to an unstable step following initialization. This warning can be
341355
ignored for non-DAE ODEs.
356+
357+
# Extended help
358+
359+
- `discrete_save_idxs`: An iterable of timeseries indexes to save after the callback triggers. MTK-only
360+
API
342361
"""
343-
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
362+
struct DiscreteCallback{F1, F2, F3, F4, F5, DSI} <: AbstractDiscreteCallback
344363
condition::F1
345364
affect!::F2
346365
initialize::F3
347366
finalize::F4
348367
save_positions::BitArray{1}
349368
initializealg::F5
369+
discrete_save_idxs::DSI
350370
function DiscreteCallback(condition::F1, affect!::F2,
351371
initialize::F3, finalize::F4,
352372
save_positions,
353-
initializealg::F5 = nothing) where {F1, F2, F3, F4, F5}
373+
initializealg::F5 = nothing,
374+
discrete_save_idxs::DSI = ()) where {F1, F2, F3, F4, F5, DSI}
354375
_condition = prepare_function(condition)
355-
new{typeof(_condition), F2, F3, F4, F5}(_condition,
376+
new{typeof(_condition), F2, F3, F4, F5, DSI}(_condition,
356377
affect!, initialize, finalize,
357378
BitArray(collect(save_positions)),
358-
initializealg)
379+
initializealg, discrete_save_idxs)
359380
end
360381
end
361382
function DiscreteCallback(condition, affect!;
362383
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
363384
save_positions = (true, true),
364-
initializealg = nothing)
385+
initializealg = nothing, discrete_save_idxs = ())
365386
DiscreteCallback(
366-
condition, affect!, initialize, finalize, save_positions, initializealg)
387+
condition, affect!, initialize, finalize, save_positions, initializealg,
388+
discrete_save_idxs)
367389
end
368390

369391
"""
@@ -420,3 +442,34 @@ end
420442
split_callbacks((cs..., d.continuous_callbacks...), (ds..., d.discrete_callbacks...),
421443
args...)
422444
end
445+
446+
function save_discretes!(integrator::DEIntegrator, cb::Union{ContinuousCallback, DiscreteCallback}; skip_duplicates = false)
447+
for idx in cb.discrete_save_idxs
448+
save_discretes!(integrator, idx; skip_duplicates)
449+
end
450+
end
451+
452+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback)
453+
isempty(cb.discrete_save_idxs) && return
454+
for idx in eachindex(cb.discrete_save_idxs)
455+
save_discretes!(integrator, cb, idx; skip_duplicates = true)
456+
end
457+
end
458+
459+
function save_discretes!(integrator::DEIntegrator, cb::VectorContinuousCallback, i; skip_duplicates = false)
460+
isempty(cb.discrete_save_idxs) && return
461+
for idx in cb.discrete_save_idxs[i]
462+
save_discretes!(integrator, idx; skip_duplicates)
463+
end
464+
end
465+
466+
function _save_all_discretes!(integrator::DEIntegrator, cb::DECallback, cbs::DECallback...)
467+
save_discretes!(integrator, cb; skip_duplicates = true)
468+
_save_all_discretes!(integrator, cbs...)
469+
end
470+
471+
_save_all_discretes!(::DEIntegrator) = nothing
472+
473+
function save_discretes!(integrator::DEIntegrator, cb::CallbackSet; kw...)
474+
_save_all_discretes!(integrator, cb.continuous_callbacks..., cb.discrete_callbacks...)
475+
end

src/solutions/ode_solutions.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,28 +422,34 @@ end
422422
Save the parameter timeseries with index `timeseries_idx`. Calls `get_saveable_values` to
423423
get the values to save. If it returns `nothing`, then the save does not happen.
424424
"""
425-
function save_discretes!(integ::DEIntegrator, timeseries_idx)
425+
function save_discretes!(integ::DEIntegrator, timeseries_idx; skip_duplicate = false)
426426
inner_sol = get_sol(integ)
427427
vals = get_saveable_values(inner_sol, parameter_values(integ), timeseries_idx)
428428
vals === nothing && return
429-
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx)
429+
save_discretes!(integ.sol, current_time(integ), vals, timeseries_idx; skip_duplicate)
430430
end
431431

432432
save_discretes!(args...) = nothing
433433

434434
# public API, used by MTK
435-
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx)
435+
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx; skip_duplicate = false)
436436
RecursiveArrayTools.has_discretes(sol) || return
437437
disc = RecursiveArrayTools.get_discretes(sol)
438-
_save_discretes_internal!(disc[timeseries_idx], t, vals)
438+
_save_discretes_internal!(disc[timeseries_idx], t, vals; skip_duplicate)
439439
end
440440

441-
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
441+
function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals; skip_duplicate = false)
442+
if skip_duplicate && isequal(t, A.t[end])
443+
return
444+
end
442445
push!(A.t, t)
443446
push!(A.u, vals)
444447
end
445448

446-
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
449+
function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals; skip_duplicate = false)
450+
if skip_duplicate && !isempty(A.u) && isequal(A.t[length(A.u)], t)
451+
return
452+
end
447453
idx = length(A.u) + 1
448454
if A.t[idx] t
449455
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")

0 commit comments

Comments
 (0)