Skip to content

Commit 80c9487

Browse files
authored
Merge pull request #8 from tensor4all/7-support-fit-algorithm-for-sum-of-mpssmpos
Implement fit algorithm for sum of MPSs
2 parents 4eed68b + 3bea7e4 commit 80c9487

File tree

9 files changed

+329
-24
lines changed

9 files changed

+329
-24
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2020
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2324

2425
[targets]
25-
test = ["Test", "Random", "Aqua", "JET"]
26+
test = ["Test", "Random", "Aqua", "JET", "StableRNGs"]

docs/make.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
using FastMPOContractions
22
using Documenter
33

4-
DocMeta.setdocmeta!(FastMPOContractions, :DocTestSetup, :(using FastMPOContractions); recursive=true)
4+
DocMeta.setdocmeta!(
5+
FastMPOContractions,
6+
:DocTestSetup,
7+
:(using FastMPOContractions);
8+
recursive = true,
9+
)
510

611
makedocs(;
7-
modules=[FastMPOContractions],
8-
authors="Hiroshi Shinaoka <[email protected]> and contributors",
9-
sitename="FastMPOContractions.jl",
10-
format=Documenter.HTML(;
11-
canonical="https://github.com/tensor4all/FastMPOContractions.jl",
12-
edit_link="main",
13-
assets=String[]),
14-
pages=[
15-
"Home" => "index.md",
16-
])
17-
18-
deploydocs(;
19-
repo="github.com/tensor4all/FastMPOContractions.jl.git",
20-
devbranch="main",
12+
modules = [FastMPOContractions],
13+
authors = "Hiroshi Shinaoka <[email protected]> and contributors",
14+
sitename = "FastMPOContractions.jl",
15+
format = Documenter.HTML(;
16+
canonical = "https://github.com/tensor4all/FastMPOContractions.jl",
17+
edit_link = "main",
18+
assets = String[],
19+
),
20+
pages = ["Home" => "index.md"],
2121
)
22+
23+
deploydocs(; repo = "github.com/tensor4all/FastMPOContractions.jl.git", devbranch = "main")

src/FastMPOContractions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module FastMPOContractions
33
using StaticArrays
44

55
using ITensors
6-
import ITensors.ITensorMPS: AbstractMPS, sim!, setleftlim!, setrightlim!, check_hascommoninds
6+
import ITensors.ITensorMPS:
7+
AbstractMPS, sim!, setleftlim!, setrightlim!, check_hascommoninds
78

89
using ITensorTDVP
910

1011
include("densitymatrix.jl")
1112
include("fitalgorithm.jl")
1213
include("util.jl")
1314
include("contractMPO.jl")
15+
include("fitalgorithm_sum.jl")
1416

1517
end

src/contractMPO.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ function contract_mpo_mpo(M1::MPO, M2::MPO; alg::String = "densitymatrix", kwarg
77
error("Unknown algorithm: $alg")
88
end
99

10-
end
10+
end

src/fitalgorithm.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Contract M1 and M2, and return the result as an MPO.
88
function contract_fit(M1::MPO, M2::MPO; init = nothing, kwargs...)::MPO
99
M2_ = MPS([M2[v] for v in eachindex(M2)])
1010
if init === nothing
11-
init_MPO::MPO = ITensors.contract(M1, M2; alg="zipup", kwargs...)
11+
init_MPO::MPO = ITensors.contract(M1, M2; alg = "zipup", kwargs...)
1212
init = MPS([init_MPO[v] for v in eachindex(init_MPO)])
1313
else
1414
init = MPS([init[v] for v in eachindex(M2)])
@@ -58,6 +58,10 @@ function contract_fit(A::MPO, psi0::MPS; init_mps = psi0, nsweeps = 1, kwargs...
5858

5959
reduced_operator = ITensorTDVP.ReducedContractProblem(psi0, A)
6060
return ITensorTDVP.alternating_update(
61-
reduced_operator, init_mps; updater=ITensorTDVP.contract_operator_state_updater, nsweeps=nsweeps, kwargs...
62-
)
61+
reduced_operator,
62+
init_mps;
63+
updater = ITensorTDVP.contract_operator_state_updater,
64+
nsweeps = nsweeps,
65+
kwargs...,
66+
)
6367
end

src/fitalgorithm_sum.jl

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
using ITensors.ITensorMPS: ITensorMPS, AbstractProjMPO, MPO, MPS
2+
using ITensors.ITensorMPS: linkinds, replaceinds
3+
using ITensors: ITensors, OneITensor
4+
import ITensorTDVP: alternating_update, rproj, lproj
5+
6+
"""
7+
A ReducedFitProblem represents the projection
8+
of an MPS `input_state` onto the basis of a different MPS `state`.
9+
`state` may be an approximation of `input_state`.
10+
```
11+
*--*--*- -*--*--*--*--*--* <state|
12+
| | | | | | | | | | |
13+
o--o--o- -o--o--o--o--o--o |input_state>
14+
```
15+
"""
16+
mutable struct ReducedFitProblem <: AbstractProjMPO
17+
lpos::Int
18+
rpos::Int
19+
nsite::Int
20+
input_state::MPS
21+
environments::Vector{ITensor}
22+
end
23+
24+
function ReducedFitProblem(input_state::MPS)
25+
lpos = 0
26+
rpos = length(input_state) + 1
27+
nsite = 2
28+
environments = Vector{ITensor}(undef, length(input_state))
29+
return ReducedFitProblem(lpos, rpos, nsite, input_state, environments)
30+
end
31+
32+
function lproj(P::ReducedFitProblem)::Union{ITensor,OneITensor}
33+
(P.lpos <= 0) && return OneITensor()
34+
return P.environments[P.lpos]
35+
end
36+
37+
function rproj(P::ReducedFitProblem)::Union{ITensor,OneITensor}
38+
(P.rpos >= length(P) + 1) && return OneITensor()
39+
return P.environments[P.rpos]
40+
end
41+
42+
43+
function Base.copy(reduced_operator::ReducedFitProblem)
44+
return ReducedFitProblem(
45+
reduced_operator.lpos,
46+
reduced_operator.rpos,
47+
reduced_operator.nsite,
48+
copy(reduced_operator.input_state),
49+
copy(reduced_operator.environments),
50+
)
51+
end
52+
53+
Base.length(reduced_operator::ReducedFitProblem) = length(reduced_operator.input_state)
54+
55+
function ITensorMPS.set_nsite!(reduced_operator::ReducedFitProblem, nsite)
56+
reduced_operator.nsite = nsite
57+
return reduced_operator
58+
end
59+
60+
function ITensorMPS.makeL!(reduced_operator::ReducedFitProblem, state::MPS, k::Int)
61+
# Save the last `L` that is made to help with caching
62+
# for DiskProjMPO
63+
ll = reduced_operator.lpos
64+
if ll k
65+
# Special case when nothing has to be done.
66+
# Still need to change the position if lproj is
67+
# being moved backward.
68+
reduced_operator.lpos = k
69+
return nothing
70+
end
71+
# Make sure ll is at least 0 for the generic logic below
72+
ll = max(ll, 0)
73+
L = lproj(reduced_operator)
74+
while ll < k
75+
L = L * reduced_operator.input_state[ll+1] * dag(state[ll+1])
76+
reduced_operator.environments[ll+1] = L
77+
ll += 1
78+
end
79+
# Needed when moving lproj backward.
80+
reduced_operator.lpos = k
81+
return reduced_operator
82+
end
83+
84+
function ITensorMPS.makeR!(reduced_operator::ReducedFitProblem, state::MPS, k::Int)
85+
# Save the last `R` that is made to help with caching
86+
# for DiskProjMPO
87+
rl = reduced_operator.rpos
88+
if rl k
89+
# Special case when nothing has to be done.
90+
# Still need to change the position if rproj is
91+
# being moved backward.
92+
reduced_operator.rpos = k
93+
return nothing
94+
end
95+
N = length(state)
96+
# Make sure rl is no bigger than `N + 1` for the generic logic below
97+
rl = min(rl, N + 1)
98+
R = rproj(reduced_operator)
99+
while rl > k
100+
R = R * reduced_operator.input_state[rl-1] * dag(state[rl-1])
101+
reduced_operator.environments[rl-1] = R
102+
rl -= 1
103+
end
104+
reduced_operator.rpos = k
105+
return reduced_operator
106+
end
107+
108+
109+
struct ReducedFitMPSsProblem <: AbstractProjMPO
110+
problems::Vector{ReducedFitProblem}
111+
coeffs::Vector{<:Number}
112+
end
113+
114+
function ReducedFitMPSsProblem(
115+
input_states::AbstractVector{MPS},
116+
coeffs::AbstractVector{<:Number},
117+
)
118+
ReducedFitMPSsProblem(ReducedFitProblem.(input_states), coeffs)
119+
end
120+
121+
function Base.copy(reduced_operator::ReducedFitMPSsProblem)
122+
return ReducedFitMPSsProblem(reduced_operator.problems, reduced_operator.coeffs)
123+
end
124+
125+
function Base.getproperty(reduced_operator::ReducedFitMPSsProblem, sym::Symbol)
126+
if sym === :nsite
127+
return getfield(reduced_operator, :problems)[1].nsite
128+
end
129+
return getfield(reduced_operator, sym)
130+
end
131+
132+
133+
Base.length(reduced_operator::ReducedFitMPSsProblem) = length(reduced_operator.problems[1])
134+
135+
function ITensorMPS.set_nsite!(reduced_operator::ReducedFitMPSsProblem, nsite)
136+
for p in reduced_operator.problems
137+
ITensorMPS.set_nsite!(p, nsite)
138+
end
139+
return reduced_operator
140+
end
141+
142+
function ITensorMPS.makeL!(reduced_operator::ReducedFitMPSsProblem, state::MPS, k::Int)
143+
for p in reduced_operator.problems
144+
ITensorMPS.makeL!(p, state, k)
145+
end
146+
return reduced_operator
147+
end
148+
149+
150+
function ITensorMPS.makeR!(reduced_operator::ReducedFitMPSsProblem, state::MPS, k::Int)
151+
for p in reduced_operator.problems
152+
ITensorMPS.makeR!(p, state, k)
153+
end
154+
return reduced_operator
155+
end
156+
157+
158+
159+
function _contract(P::ReducedFitProblem, v::ITensor)::ITensor
160+
itensor_map = Union{ITensor,OneITensor}[lproj(P)]
161+
push!(itensor_map, rproj(P))
162+
163+
# Reverse the contraction order of the map if
164+
# the first tensor is a scalar (for example we
165+
# are at the left edge of the system)
166+
if dim(first(itensor_map)) == 1
167+
reverse!(itensor_map)
168+
end
169+
170+
# Apply the map
171+
Hv = v
172+
for it in itensor_map
173+
Hv *= it
174+
end
175+
return Hv
176+
end
177+
178+
function contract_operator_state_updater(operator::ReducedFitProblem, init; internal_kwargs)
179+
state = ITensor(true)
180+
for j = (operator.lpos+1):(operator.rpos-1)
181+
state *= operator.input_state[j]
182+
end
183+
state = _contract(operator, state)
184+
return state, (;)
185+
end
186+
187+
function contract_operator_state_updater(
188+
operator::ReducedFitMPSsProblem,
189+
init;
190+
internal_kwargs,
191+
)
192+
states = ITensor[]
193+
for (p, coeff) in zip(operator.problems, operator.coeffs)
194+
res = contract_operator_state_updater(p, init; internal_kwargs)
195+
push!(states, coeff * res[1])
196+
end
197+
return sum(states), (;)
198+
end
199+
200+
201+
function contract_fit(input_state::MPS, init::MPS; coeff::Number = 1, kwargs...)
202+
links = ITensors.sim.(linkinds(init))
203+
init = replaceinds(linkinds, init, links)
204+
reduced_operator = ReducedFitProblem(input_state)
205+
return alternating_update(
206+
reduced_operator,
207+
init;
208+
updater = contract_operator_state_updater,
209+
kwargs...,
210+
)
211+
end
212+
213+
214+
function fit(
215+
input_states::AbstractVector{MPS},
216+
init::MPS;
217+
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
218+
kwargs...,
219+
)
220+
links = ITensors.sim.(linkinds(init))
221+
init = replaceinds(linkinds, init, links)
222+
reduced_operator = ReducedFitMPSsProblem(input_states, coeffs)
223+
return alternating_update(
224+
reduced_operator,
225+
init;
226+
updater = contract_operator_state_updater,
227+
kwargs...,
228+
)
229+
end
230+
231+
function fit(
232+
input_states::AbstractVector{MPO},
233+
init::MPO;
234+
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
235+
kwargs...,
236+
)
237+
:MPO
238+
to_mps::MPO) = MPS([x for x in Ψ])
239+
240+
res = fit(to_mps.(input_states), to_mps(init); coeffs = coeffs, kwargs...)
241+
return MPO([x for x in res])
242+
end

test/_util.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
using ITensors
2+
using Random
23

34
function _random_mpo(sites::Vector{Vector{Index{T}}}; linkdims = 1) where {T}
5+
_random_mpo(Random.GLOBAL_RNG, sites; linkdims = linkdims)
6+
end
7+
8+
function _random_mpo(rng, sites::Vector{Vector{Index{T}}}; linkdims = 1) where {T}
49
N = length(sites)
510
links = [Index(linkdims, "Link,n=$n") for n = 1:N-1]
611
M = MPO(N)
7-
M[1] = random_itensor(sites[1]..., links[1])
8-
M[N] = random_itensor(links[N-1], sites[N]...)
12+
M[1] = random_itensor(rng, sites[1]..., links[1])
13+
M[N] = random_itensor(rng, links[N-1], sites[N]...)
914
for n = 2:N-1
10-
M[n] = random_itensor(links[n-1], sites[n]..., links[n])
15+
M[n] = random_itensor(rng, links[n-1], sites[n]..., links[n])
1116
end
1217
return M
1318
end

0 commit comments

Comments
 (0)