Skip to content

Commit a2b42ee

Browse files
committed
Remove all allocations in the in-place version of block-MINRES
1 parent 6c0b3c8 commit a2b42ee

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

src/block_krylov_workspaces.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,19 @@ mutable struct BlockMinresWorkspace{T,FC,SV,SM} <: BlockKrylovWorkspace{T,FC,SV,
2222
C :: SM
2323
D :: SM
2424
Φ :: SM
25+
Ψₖ :: SM
26+
Ωₖ :: SM
27+
Ψₖ₊₁ :: SM
28+
Πₖ₋₂ :: SM
29+
Γbarₖ₋₁ :: SM
30+
Γₖ₋₁ :: SM
31+
Λbarₖ :: SM
32+
Λₖ :: SM
2533
Vₖ₋₁ :: SM
2634
Vₖ :: SM
2735
wₖ₋₂ :: SM
2836
wₖ₋₁ :: SM
37+
wₖ :: SM
2938
Hₖ₋₂ :: SM
3039
Hₖ₋₁ :: SM
3140
τₖ₋₂ :: SV
@@ -45,10 +54,19 @@ function BlockMinresWorkspace(m::Integer, n::Integer, p::Integer, SV::Type, SM::
4554
C = SM(undef, p, p)
4655
D = SM(undef, 2p, p)
4756
Φ = SM(undef, p, p)
57+
Ψₖ = SM(undef, p, p)
58+
Ωₖ = SM(undef, p, p)
59+
Ψₖ₊₁ = SM(undef, p, p)
60+
Πₖ₋₂ = SM(undef, p, p)
61+
Γbarₖ₋₁ = SM(undef, p, p)
62+
Γₖ₋₁ = SM(undef, p, p)
63+
Λbarₖ = SM(undef, p, p)
64+
Λₖ = SM(undef, p, p)
4865
Vₖ₋₁ = SM(undef, n, p)
4966
Vₖ = SM(undef, n, p)
5067
wₖ₋₂ = SM(undef, n, p)
5168
wₖ₋₁ = SM(undef, n, p)
69+
wₖ = SM(undef, n, p)
5270
Hₖ₋₂ = SM(undef, 2p, p)
5371
Hₖ₋₁ = SM(undef, 2p, p)
5472
τₖ₋₂ = SV(undef, p)
@@ -60,7 +78,8 @@ function BlockMinresWorkspace(m::Integer, n::Integer, p::Integer, SV::Type, SM::
6078
korgqr_buffer!(Vₖ, τₖ₋₁), korgqr_buffer!(Hₖ₋₁, τₖ₋₁),
6179
kormqr_buffer!('L', FC <: AbstractFloat ? 'T' : 'C', Hₖ₋₁, τₖ₋₁, D)) : 0
6280
buffer = SV(undef, size_buffer)
63-
workspace = BlockMinresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, buffer, false, stats)
81+
workspace = BlockMinresWorkspace{T,FC,SV,SM}(m, n, p, ΔX, X, P, Q, C, D, Φ, Ψₖ, Ωₖ, Ψₖ₊₁, Πₖ₋₂, Γbarₖ₋₁, Γₖ₋₁, Λbarₖ, Λₖ,
82+
Vₖ₋₁, Vₖ, wₖ₋₂, wₖ₋₁, wₖ, Hₖ₋₂, Hₖ₋₁, τₖ₋₂, τₖ₋₁, buffer, false, stats)
6483
return workspace
6584
end
6685

src/block_minres.jl

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
116116
Vₖ₋₁, Vₖ = workspace.Vₖ₋₁, workspace.Vₖ
117117
ΔX, X, Q, C = workspace.ΔX, workspace.X, workspace.Q, workspace.C
118118
D, Φ, stats = workspace.D, workspace.Φ, workspace.stats
119-
wₖ₋₂, wₖ₋₁ = workspace.wₖ₋₂, workspace.wₖ₋₁
119+
wₖ₋₂, wₖ₋₁, wₖ = workspace.wₖ₋₂, workspace.wₖ₋₁, workspace.wₖ
120120
Hₖ₋₂, Hₖ₋₁ = workspace.Hₖ₋₂, workspace.Hₖ₋₁
121121
τₖ₋₂, τₖ₋₁ = workspace.τₖ₋₂, workspace.τₖ₋₁
122122
buffer = workspace.buffer
@@ -125,15 +125,15 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
125125
reset!(stats)
126126
R₀ = warm_start ? Q : B
127127

128-
# Temporary buffers -- should be stored in the workspace
129-
Ψₖ = similar(B, p, p)
130-
Ωₖ = similar(B, p, p)
131-
Ψₖ₊₁ = similar(B, p, p)
132-
Πₖ₋₂ = similar(B, p, p)
133-
Γbarₖ₋₁ = similar(B, p, p)
134-
Γₖ₋₁ = similar(B, p, p)
135-
Λbarₖ = similar(B, p, p)
136-
Λₖ = similar(B, p, p)
128+
# Matrices in the workspace (some of them could be removed in the future)
129+
Ψₖ = workspace.Ψₖ
130+
Ωₖ = workspace.Ωₖ
131+
Ψₖ₊₁ = workspace.Ψₖ₊₁
132+
Πₖ₋₂ = workspace.Πₖ₋₂
133+
Γbarₖ₋₁ = workspace.Γbarₖ₋₁
134+
Γₖ₋₁ = workspace.Γₖ₋₁
135+
Λbarₖ = workspace.Λbarₖ
136+
Λₖ = workspace.Λₖ
137137

138138
# Define the blocks D1 and D2
139139
D1 = view(D, 1:p, :)
@@ -242,29 +242,28 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
242242
# Compute the directions Wₖ, the last columns of Wₖ = Vₖ(Rₖ)⁻¹ ⟷ (Rₖ)ᵀ(Wₖ)ᵀ = (Vₖ)ᵀ
243243
# w₁Λ₁ = v₁
244244
if iter == 1
245-
wₖ = wₖ₋₁
246245
wₖ .= Vₖ
247246
rdiv!(wₖ, UpperTriangular(Λₖ))
248247
end
249248
# w₂Λ₂ = v₂ - w₁Γ₁
250249
if iter == 2
251-
wₖ = wₖ₋₂
252-
wₖ .= (-wₖ₋₁ * Γₖ₋₁)
253-
wₖ .+= Vₖ
250+
@kswap!(wₖ₋₁, wₖ)
251+
wₖ .= Vₖ
252+
mul!(wₖ, wₖ₋₁, Γₖ₋₁, α, β)
254253
rdiv!(wₖ, UpperTriangular(Λₖ))
255254
end
256255
# wₖΛₖ = vₖ - wₖ₋₁Γₖ₋₁ - wₖ₋₂Πₖ₋₂
257256
if iter 3
258-
wₖ = wₖ₋₂
259-
wₖ .= (-wₖ₋₂ * Πₖ₋₂)
260-
wₖ .= (wₖ - wₖ₋₁ * Γₖ₋₁)
261-
wₖ .+= Vₖ
257+
@kswap!(wₖ₋₂, wₖ₋₁)
258+
@kswap!(wₖ₋₁, wₖ)
259+
wₖ .= Vₖ
260+
mul!(wₖ, wₖ₋₂, Πₖ₋₂, α, β)
261+
mul!(wₖ, wₖ₋₁, Γₖ₋₁, α, β)
262262
rdiv!(wₖ, UpperTriangular(Λₖ))
263263
end
264264

265265
# Update Xₖ = VₖYₖ = WₖZₖ
266266
# Xₖ = Xₖ₋₁ + wₖ * Φₖ
267-
R = B - A * X
268267
mul!(X, wₖ, Φₖ, γ, β)
269268

270269
# Update residual norm estimate.
@@ -277,13 +276,11 @@ kwargs_block_minres = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose, :his
277276
copyto!(Vₖ₋₁, Vₖ) # vₖ₋₁ ← vₖ
278277
copyto!(Vₖ, Q) # vₖ ← vₖ₊₁
279278

280-
# Update directions for X and other variables...
279+
# Swap the pointers for Hᵢ and τᵢ
281280
if iter 2
282-
@kswap!(wₖ₋₂, wₖ₋₁)
283281
@kswap!(Hₖ₋₂, Hₖ₋₁)
284282
@kswap!(τₖ₋₂, τₖ₋₁)
285283
end
286-
287284
if iter == 1
288285
copyto!(Hₖ₋₁, Hₖ)
289286
copyto!(τₖ₋₁, τₖ)

test/test_allocations.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,10 +763,10 @@
763763
# @test expected_block_minres_bytes ≤ actual_block_minres_bytes ≤ 1.08 * expected_block_minres_bytes
764764
# end
765765

766-
# Workspace = BlockMinresWorkspace(A, B)
767-
# block_minres!(Workspace, A, B) # warmup
768-
# inplace_block_minres_bytes = @allocated block_minres!(Workspace, A, B)
769-
# @test inplace_block_minres_bytes == 0
766+
Workspace = BlockMinresWorkspace(A, B)
767+
block_minres!(Workspace, A, B) # warmup
768+
inplace_block_minres_bytes = @allocated block_minres!(Workspace, A, B)
769+
@test inplace_block_minres_bytes == 0
770770
end
771771
end
772772
end

0 commit comments

Comments
 (0)