Skip to content

Commit 6db8235

Browse files
Merge pull request #41 from tensor4all/recycle-pivots
Implement pivots recycling and bug fixes
2 parents 104686a + 34f57c7 commit 6db8235

File tree

4 files changed

+235
-39
lines changed

4 files changed

+235
-39
lines changed

src/crossinterpolate.jl

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
122122
checkbatchevaluatable::Bool
123123
loginterval::Int
124124
initialpivots::Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
125+
recyclepivots::Bool
125126
end
126127

127128
function Base.show(io::IO, obj::TCI2PatchCreator{T}) where {T}
@@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
141142
obj.checkbatchevaluatable,
142143
obj.loginterval,
143144
obj.initialpivots,
145+
obj.recyclepivots,
144146
)
145147
end
146148

@@ -157,7 +159,8 @@ function TCI2PatchCreator(
157159
ninitialpivot=5,
158160
checkbatchevaluatable=false,
159161
loginterval=10,
160-
initialpivots=Vector{MultiIndex}[],
162+
initialpivots=MultiIndex[],
163+
recyclepivots=false,
161164
)::TCI2PatchCreator{T} where {T}
162165
#t1 = time_ns()
163166
if projector === nothing
@@ -183,6 +186,7 @@ function TCI2PatchCreator(
183186
checkbatchevaluatable,
184187
loginterval,
185188
initialpivots,
189+
recyclepivots,
186190
)
187191
end
188192

@@ -206,6 +210,7 @@ function _crossinterpolate2!(
206210
verbosity::Int=0,
207211
checkbatchevaluatable=false,
208212
loginterval=10,
213+
recyclepivots=false,
209214
) where {T}
210215
ncheckhistory = 3
211216
ranks, errors = TCI.optimize!(
@@ -231,13 +236,45 @@ function _crossinterpolate2!(
231236
ncheckhistory_ = min(ncheckhistory, length(errors))
232237
maxbonddim_hist = maximum(ranks[(end - ncheckhistory_ + 1):end])
233238

234-
return PatchCreatorResult{T,TensorTrain{T,3}}(
235-
TensorTrain(tci), TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim
236-
)
239+
if recyclepivots
240+
return PatchCreatorResult{T,TensorTrain{T,3}}(
241+
TensorTrain(tci),
242+
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
243+
_globalpivots(tci),
244+
)
245+
246+
else
247+
return PatchCreatorResult{T,TensorTrain{T,3}}(
248+
TensorTrain(tci),
249+
TCI.maxbonderror(tci) < tolerance && maxbonddim_hist < maxbonddim,
250+
)
251+
end
252+
end
253+
254+
# Generating global pivots from local ones
255+
function _globalpivots(
256+
tci::TCI.TensorCI2{T}; onlydiagonal=true
257+
)::Vector{MultiIndex} where {T}
258+
Isets = tci.Iset
259+
Jsets = tci.Jset
260+
L = length(Isets)
261+
p = Set{MultiIndex}()
262+
# Pivot matrices
263+
for bondindex in 1:(L - 1)
264+
if onlydiagonal
265+
for (x, y) in zip(Isets[bondindex + 1], Jsets[bondindex])
266+
push!(p, vcat(x, y))
267+
end
268+
else
269+
for x in Isets[bondindex + 1], y in Jsets[bondindex]
270+
push!(p, vcat(x, y))
271+
end
272+
end
273+
end
274+
return collect(p)
237275
end
238276

239277
function createpatch(obj::TCI2PatchCreator{T}) where {T}
240-
proj = obj.projector
241278
fsubset = _FuncAdapterTCI2Subset(obj.f)
242279

243280
tci = if isapproxttavailable(obj.f)
@@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
253290
end
254291
tci
255292
else
256-
# Random initial pivots
257293
initialpivots = MultiIndex[]
258-
let
259-
mask = [!isprojectedat(proj, n) for n in 1:length(proj)]
260-
for idx in obj.initialpivots
261-
idx_ = [[i] for i in idx]
262-
if idx_ <= proj
263-
push!(initialpivots, idx[mask])
264-
end
294+
if obj.recyclepivots
295+
# First patching iteration: random pivots
296+
if length(fsubset.localdims) == length(obj.localdims)
297+
initialpivots = union(
298+
obj.initialpivots,
299+
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
300+
)
301+
# Next iterations: recycle previously generated pivots
302+
else
303+
initialpivots = copy(obj.initialpivots)
265304
end
305+
else
306+
initialpivots = union(
307+
obj.initialpivots,
308+
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
309+
)
266310
end
267-
append!(
268-
initialpivots,
269-
findinitialpivots(fsubset, fsubset.localdims, obj.ninitialpivot),
270-
)
311+
271312
if all(fsubset.(initialpivots) .== 0)
272313
return PatchCreatorResult{T,TensorTrainState{T}}(nothing, true)
273314
end
@@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
282323
verbosity=obj.verbosity,
283324
checkbatchevaluatable=obj.checkbatchevaluatable,
284325
loginterval=obj.loginterval,
326+
recyclepivots=obj.recyclepivots,
285327
)
286328
end
287329

@@ -301,9 +343,9 @@ function adaptiveinterpolate(
301343
verbosity=0,
302344
maxbonddim=typemax(Int),
303345
tolerance=1e-8,
304-
initialpivots=Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
346+
initialpivots=MultiIndex[], # Make it to Vector{MMultiIndex}?
347+
recyclepivots=false,
305348
)::ProjTTContainer{T} where {T}
306-
t1 = time_ns()
307349
creator = TCI2PatchCreator(
308350
T,
309351
f,
@@ -313,6 +355,7 @@ function adaptiveinterpolate(
313355
verbosity,
314356
ntry=10,
315357
initialpivots=initialpivots,
358+
recyclepivots=recyclepivots,
316359
)
317360
tmp = adaptiveinterpolate(creator, pordering; verbosity)
318361
return reshape(tmp, f.sitedims)

src/patching.jl

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ abstract type AbstractPatchCreator{T,M} end
3838
mutable struct PatchCreatorResult{T,M}
3939
data::Union{M,Nothing}
4040
isconverged::Bool
41+
resultpivots::Vector{MultiIndex}
42+
43+
function PatchCreatorResult{T,M}(
44+
data::Union{M,Nothing}, isconverged::Bool, resultpivots::Vector{MultiIndex}
45+
)::PatchCreatorResult{T,M} where {T,M}
46+
return new{T,M}(data, isconverged, resultpivots)
47+
end
48+
49+
function PatchCreatorResult{T,M}(
50+
data::Union{M,Nothing}, isconverged::Bool
51+
)::PatchCreatorResult{T,M} where {T,M}
52+
return new{T,M}(data, isconverged, MultiIndex[])
53+
end
4154
end
4255

4356
function _reconst_prefix(projector::Projector, pordering::PatchOrdering)
@@ -63,10 +76,19 @@ function __taskfunc(creator::AbstractPatchCreator{T,M}, pordering; verbosity=0)
6376
for ic in 1:creator.localdims[pordering.ordering[length(prefix) + 1]]
6477
prefix_ = vcat(prefix, ic)
6578
projector_ = makeproj(pordering, prefix_, creator.localdims)
66-
#if verbosity > 0
67-
##println("Creating a task for $(prefix_) ...")
68-
#end
69-
push!(newtasks, project(creator, projector_))
79+
80+
# Pivots are shorter, pordering index is in a different position
81+
active_dims_ = findall(x -> x == [0], creator.projector.data)
82+
pos_ = findfirst(x -> x == pordering.ordering[length(prefix) + 1], active_dims_)
83+
pivots_ = [
84+
copy(piv) for piv in filter(piv -> piv[pos_] == ic, patch.resultpivots)
85+
]
86+
87+
if !isempty(pivots_)
88+
deleteat!.(pivots_, pos_)
89+
end
90+
91+
push!(newtasks, project(creator, projector_; pivots=pivots_))
7092
end
7193
return nothing, newtasks
7294
end
@@ -77,14 +99,19 @@ function _zerott(T, prefix, po::PatchOrdering, localdims::Vector{Int})
7799
return TensorTrain([zeros(T, 1, d, 1) for d in localdims_])
78100
end
79101

80-
function project(obj::AbstractPatchCreator{T,M}, projector::Projector) where {T,M}
102+
function project(
103+
obj::AbstractPatchCreator{T,M},
104+
projector::Projector;
105+
pivots::Vector{MultiIndex}=MultiIndex[],
106+
) where {T,M}
81107
projector <= obj.projector || error(
82108
"Projector $projector is not a subset of the original projector $(obj.f.projector)",
83109
)
84110

85111
obj_copy = TCI2PatchCreator{T}(obj) # shallow copy
86112
obj_copy.projector = deepcopy(projector)
87113
obj_copy.f = project(obj_copy.f, projector)
114+
obj_copy.initialpivots = deepcopy(pivots)
88115
return obj_copy
89116
end
90117

src/projectable_evaluator.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,28 +248,31 @@ function batchevaluateprj(
248248
# Some of indices might be projected
249249
NL = length(leftmmultiidxset[1])
250250
NR = length(rightmmultiidxset[1])
251+
L = length(obj)
251252

252-
NL + NR + M == length(obj) ||
253-
error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(length(obj))")
253+
NL + NR + M == L || error("Length mismatch NL: $NL, NR: $NR, M: $M, L: $(L)")
254254

255-
L = length(obj)
255+
returnshape = projectedshape(obj.projector, NL + 1, L - NR)
256256
result::Array{T,M + 2} = zeros(
257-
T,
258-
length(leftmmultiidxset),
259-
prod.(obj.sitedims[(1 + NL):(L - NR)])...,
260-
length(rightmmultiidxset),
257+
T, length(leftmmultiidxset), returnshape..., length(rightmmultiidxset)
261258
)
262-
result[lmask, .., rmask] .= begin
259+
260+
projmask = map(
261+
p -> p == 0 ? Colon() : p,
262+
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
263+
)
264+
slice = map(
265+
p -> p == 0 ? Colon() : 1,
266+
Iterators.flatten(obj.projector[n] for n in (1 + NL):(L - NR)),
267+
)
268+
269+
result[lmask, slice..., rmask] .= begin
263270
result_lrmask_multii = reshape(
264271
result_lrmask,
265272
size(result_lrmask)[1],
266273
collect(Iterators.flatten(obj.sitedims[(1 + NL):(L - NR)]))...,
267274
size(result_lrmask)[end],
268-
)
269-
projmask = map(
270-
p -> p == 0 ? Colon() : p,
271-
Iterators.flatten(obj.projector[n] for n in (1 + NL):(length(obj) - NR)),
272-
)
275+
) # Gianluca - this step might be not needed. I leave it for safety
273276
result_lrmask_multii[:, projmask..., :]
274277
end
275278
return result

0 commit comments

Comments
 (0)