Skip to content

Commit 610a1af

Browse files
authored
Variable bound check (#72)
* bound check implemented * bound check reimplemented
1 parent dc1ce41 commit 610a1af

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

src/nlp.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ struct ConstraintNull <: AbstractConstraint end
88

99
struct Variable{S,O} <: AbstractVariable
1010
size::S
11+
length::O
1112
offset::O
1213
end
1314
Base.show(io::IO, v::Variable) = print(
@@ -245,13 +246,33 @@ function ExaModel(c::C; prod = nothing) where {C<:ExaCore}
245246
)
246247
end
247248

248-
@inline Base.getindex(v::V, i) where {V<:Variable} =
249+
@inline function Base.getindex(v::V, i) where {V<:Variable}
250+
_bound_check(v.size, i)
249251
Var(i + (v.offset - _start(v.size[1]) + 1))
250-
@inline Base.getindex(v::V, i, j) where {V<:Variable} = Var(
251-
i +
252-
j * _length(v.size[1]) +
253-
(v.offset - _start(v.size[1]) + 1 - _start(v.size[2]) * _length(v.size[1])),
254-
)
252+
end
253+
@inline function Base.getindex(v::V, is...) where {V<:Variable}
254+
@assert(length(is) == length(v.size), "Variable index dimension error")
255+
_bound_check(v.size, is)
256+
Var(v.offset + idxx(is .- (_start.(v.size) .- 1), _length.(v.size)))
257+
end
258+
259+
function _bound_check(sizes, i::I) where I <: Integer
260+
__bound_check(sizes[1], i)
261+
end
262+
function _bound_check(sizes, is::NTuple{N,I}) where {I <: Integer, N}
263+
__bound_check(sizes[1], is[1])
264+
_bound_check(sizes[2:end], is[2:end])
265+
end
266+
_bound_check(sizes, is) = nothing
267+
_bound_check(sizes, is::Tuple{}) = nothing
268+
269+
function __bound_check(a::I,b::I) where I <: Integer
270+
@assert(1<= b <= a, "Variable index bound error")
271+
end
272+
function __bound_check(a::UnitRange{Int},b::I) where I <: Integer
273+
@assert(b in a, "Variable index bound error")
274+
end
275+
255276

256277
function append!(backend, a, b::Base.Generator, lb)
257278
b = _adapt_gen(b)
@@ -332,12 +353,13 @@ function variable(
332353

333354

334355
o = c.nvar
335-
c.nvar += total(ns)
356+
len = total(ns)
357+
c.nvar += len
336358
c.x0 = append!(c.backend, c.x0, start, total(ns))
337359
c.lvar = append!(c.backend, c.lvar, lvar, total(ns))
338360
c.uvar = append!(c.backend, c.uvar, uvar, total(ns))
339361

340-
return Variable(ns, o)
362+
return Variable(ns, len, o)
341363

342364
end
343365

0 commit comments

Comments
 (0)