@@ -8,6 +8,7 @@ struct ConstraintNull <: AbstractConstraint end
88
99struct Variable{S,O} <: AbstractVariable
1010 size:: S
11+ length:: O
1112 offset:: O
1213end
1314Base. show (io:: IO , v:: Variable ) = print (
@@ -245,13 +246,33 @@ function ExaModel(c::C; prod = nothing) where {C<:ExaCore}
245246 )
246247end
247248
248- @inline Base. getindex (v:: V , i) where {V<: Variable } =
249+ @inline function Base. getindex (v:: V , i) where {V<: Variable }
250+ _bound_check (v. size, i)
249251 Var (i + (v. offset - _start (v. size[1 ]) + 1 ))
250- @inline Base. getindex (v:: V , i, j) where {V<: Variable } = Var (
251- i +
252- j * _length (v. size[1 ]) +
253- (v. offset - _start (v. size[1 ]) + 1 - _start (v. size[2 ]) * _length (v. size[1 ])),
254- )
252+ end
253+ @inline function Base. getindex (v:: V , is... ) where {V<: Variable }
254+ @assert (length (is) == length (v. size), " Variable index dimension error" )
255+ _bound_check (v. size, is)
256+ Var (v. offset + idxx (is .- (_start .(v. size) .- 1 ), _length .(v. size)))
257+ end
258+
259+ function _bound_check (sizes, i:: I ) where I <: Integer
260+ __bound_check (sizes[1 ], i)
261+ end
262+ function _bound_check (sizes, is:: NTuple{N,I} ) where {I <: Integer , N}
263+ __bound_check (sizes[1 ], is[1 ])
264+ _bound_check (sizes[2 : end ], is[2 : end ])
265+ end
266+ _bound_check (sizes, is) = nothing
267+ _bound_check (sizes, is:: Tuple{} ) = nothing
268+
269+ function __bound_check (a:: I ,b:: I ) where I <: Integer
270+ @assert (1 <= b <= a, " Variable index bound error" )
271+ end
272+ function __bound_check (a:: UnitRange{Int} ,b:: I ) where I <: Integer
273+ @assert (b in a, " Variable index bound error" )
274+ end
275+
255276
256277function append! (backend, a, b:: Base.Generator , lb)
257278 b = _adapt_gen (b)
@@ -332,12 +353,13 @@ function variable(
332353
333354
334355 o = c. nvar
335- c. nvar += total (ns)
356+ len = total (ns)
357+ c. nvar += len
336358 c. x0 = append! (c. backend, c. x0, start, total (ns))
337359 c. lvar = append! (c. backend, c. lvar, lvar, total (ns))
338360 c. uvar = append! (c. backend, c. uvar, uvar, total (ns))
339361
340- return Variable (ns, o)
362+ return Variable (ns, len, o)
341363
342364end
343365
0 commit comments