@@ -122,6 +122,7 @@ mutable struct TCI2PatchCreator{T} <: AbstractPatchCreator{T,TensorTrainState{T}
122
122
checkbatchevaluatable:: Bool
123
123
loginterval:: Int
124
124
initialpivots:: Vector{MultiIndex} # Make it to Vector{MMultiIndex}?
125
+ recyclepivots:: Bool
125
126
end
126
127
127
128
function Base. show (io:: IO , obj:: TCI2PatchCreator{T} ) where {T}
@@ -141,6 +142,7 @@ function TCI2PatchCreator{T}(obj::TCI2PatchCreator{T})::TCI2PatchCreator{T} wher
141
142
obj. checkbatchevaluatable,
142
143
obj. loginterval,
143
144
obj. initialpivots,
145
+ obj. recyclepivots,
144
146
)
145
147
end
146
148
@@ -157,7 +159,8 @@ function TCI2PatchCreator(
157
159
ninitialpivot= 5 ,
158
160
checkbatchevaluatable= false ,
159
161
loginterval= 10 ,
160
- initialpivots= Vector{MultiIndex}[],
162
+ initialpivots= MultiIndex[],
163
+ recyclepivots= false ,
161
164
):: TCI2PatchCreator{T} where {T}
162
165
# t1 = time_ns()
163
166
if projector === nothing
@@ -183,6 +186,7 @@ function TCI2PatchCreator(
183
186
checkbatchevaluatable,
184
187
loginterval,
185
188
initialpivots,
189
+ recyclepivots,
186
190
)
187
191
end
188
192
@@ -206,6 +210,7 @@ function _crossinterpolate2!(
206
210
verbosity:: Int = 0 ,
207
211
checkbatchevaluatable= false ,
208
212
loginterval= 10 ,
213
+ recyclepivots= false ,
209
214
) where {T}
210
215
ncheckhistory = 3
211
216
ranks, errors = TCI. optimize! (
@@ -231,13 +236,45 @@ function _crossinterpolate2!(
231
236
ncheckhistory_ = min (ncheckhistory, length (errors))
232
237
maxbonddim_hist = maximum (ranks[(end - ncheckhistory_ + 1 ): end ])
233
238
234
- return PatchCreatorResult {T,TensorTrain{T,3}} (
235
- TensorTrain (tci), TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim
236
- )
239
+ if recyclepivots
240
+ return PatchCreatorResult {T,TensorTrain{T,3}} (
241
+ TensorTrain (tci),
242
+ TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim,
243
+ _globalpivots (tci),
244
+ )
245
+
246
+ else
247
+ return PatchCreatorResult {T,TensorTrain{T,3}} (
248
+ TensorTrain (tci),
249
+ TCI. maxbonderror (tci) < tolerance && maxbonddim_hist < maxbonddim,
250
+ )
251
+ end
252
+ end
253
+
254
+ # Generating global pivots from local ones
255
+ function _globalpivots (
256
+ tci:: TCI.TensorCI2{T} ; onlydiagonal= true
257
+ ):: Vector{MultiIndex} where {T}
258
+ Isets = tci. Iset
259
+ Jsets = tci. Jset
260
+ L = length (Isets)
261
+ p = Set {MultiIndex} ()
262
+ # Pivot matrices
263
+ for bondindex in 1 : (L - 1 )
264
+ if onlydiagonal
265
+ for (x, y) in zip (Isets[bondindex + 1 ], Jsets[bondindex])
266
+ push! (p, vcat (x, y))
267
+ end
268
+ else
269
+ for x in Isets[bondindex + 1 ], y in Jsets[bondindex]
270
+ push! (p, vcat (x, y))
271
+ end
272
+ end
273
+ end
274
+ return collect (p)
237
275
end
238
276
239
277
function createpatch (obj:: TCI2PatchCreator{T} ) where {T}
240
- proj = obj. projector
241
278
fsubset = _FuncAdapterTCI2Subset (obj. f)
242
279
243
280
tci = if isapproxttavailable (obj. f)
@@ -253,21 +290,25 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
253
290
end
254
291
tci
255
292
else
256
- # Random initial pivots
257
293
initialpivots = MultiIndex[]
258
- let
259
- mask = [! isprojectedat (proj, n) for n in 1 : length (proj)]
260
- for idx in obj. initialpivots
261
- idx_ = [[i] for i in idx]
262
- if idx_ <= proj
263
- push! (initialpivots, idx[mask])
264
- end
294
+ if obj. recyclepivots
295
+ # First patching iteration: random pivots
296
+ if length (fsubset. localdims) == length (obj. localdims)
297
+ initialpivots = union (
298
+ obj. initialpivots,
299
+ findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
300
+ )
301
+ # Next iterations: recycle previously generated pivots
302
+ else
303
+ initialpivots = copy (obj. initialpivots)
265
304
end
305
+ else
306
+ initialpivots = union (
307
+ obj. initialpivots,
308
+ findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
309
+ )
266
310
end
267
- append! (
268
- initialpivots,
269
- findinitialpivots (fsubset, fsubset. localdims, obj. ninitialpivot),
270
- )
311
+
271
312
if all (fsubset .(initialpivots) .== 0 )
272
313
return PatchCreatorResult {T,TensorTrainState{T}} (nothing , true )
273
314
end
@@ -282,6 +323,7 @@ function createpatch(obj::TCI2PatchCreator{T}) where {T}
282
323
verbosity= obj. verbosity,
283
324
checkbatchevaluatable= obj. checkbatchevaluatable,
284
325
loginterval= obj. loginterval,
326
+ recyclepivots= obj. recyclepivots,
285
327
)
286
328
end
287
329
@@ -301,9 +343,9 @@ function adaptiveinterpolate(
301
343
verbosity= 0 ,
302
344
maxbonddim= typemax (Int),
303
345
tolerance= 1e-8 ,
304
- initialpivots= Vector{MultiIndex}[], # Make it to Vector{MMultiIndex}?
346
+ initialpivots= MultiIndex[], # Make it to Vector{MMultiIndex}?
347
+ recyclepivots= false ,
305
348
):: ProjTTContainer{T} where {T}
306
- t1 = time_ns ()
307
349
creator = TCI2PatchCreator (
308
350
T,
309
351
f,
@@ -313,6 +355,7 @@ function adaptiveinterpolate(
313
355
verbosity,
314
356
ntry= 10 ,
315
357
initialpivots= initialpivots,
358
+ recyclepivots= recyclepivots,
316
359
)
317
360
tmp = adaptiveinterpolate (creator, pordering; verbosity)
318
361
return reshape (tmp, f. sitedims)
0 commit comments