Skip to content

Commit 5881253

Browse files
srohshapshinaoka
andauthored
merge 10 naive algorithm in contractmpo
* <src/add>: added naive algorithm to contract_mpo_mpo * <test/add>: naive test added * Restrict test with JET to newer Julia versions --------- Co-authored-by: Hiroshi Shinaoka <[email protected]>
1 parent ab53053 commit 5881253

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

src/contractMPO.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ function contract_mpo_mpo(M1::MPO, M2::MPO; alg::String = "densitymatrix", kwarg
33
return contract_densitymatrix(M1, M2; kwargs...)
44
elseif alg == "fit"
55
return contract_fit(M1, M2; kwargs...)
6+
elseif alg == "naive"
7+
return ITensors.contract(M1, M2; alg = "naive", kwargs...)
68
else
79
error("Unknown algorithm: $alg")
810
end

src/fitalgorithm_sum.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,14 @@ function contract_operator_state_updater(
198198
end
199199

200200

201-
function contract_fit(input_state::MPS, init::MPS; coeff::Number=1, kwargs...)::MPS
201+
function contract_fit(input_state::MPS, init::MPS; coeff::Number = 1, kwargs...)::MPS
202202
links = ITensors.sim.(linkinds(init))
203203
init = replaceinds(linkinds, init, links)
204204
reduced_operator = ReducedFitProblem(input_state)
205205
return alternating_update(
206206
reduced_operator,
207207
init;
208-
updater=contract_operator_state_updater,
208+
updater = contract_operator_state_updater,
209209
kwargs...,
210210
)
211211
end
@@ -214,7 +214,7 @@ end
214214
function fit(
215215
input_states::AbstractVector{MPS},
216216
init::MPS;
217-
coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)),
217+
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
218218
kwargs...,
219219
)::MPS
220220
links = ITensors.sim.(linkinds(init))
@@ -223,19 +223,19 @@ function fit(
223223
return alternating_update(
224224
reduced_operator,
225225
init;
226-
updater=contract_operator_state_updater,
226+
updater = contract_operator_state_updater,
227227
kwargs...,
228228
)
229229
end
230230

231231
function fit(
232232
input_states::AbstractVector{MPO},
233233
init::MPO;
234-
coeffs::AbstractVector{<:Number}=ones(Int, length(input_states)),
234+
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
235235
kwargs...,
236236
)::MPO
237237
to_mps::MPO) = MPS([x for x in Ψ])
238238

239-
res = fit(to_mps.(input_states), to_mps(init); coeffs=coeffs, kwargs...)
239+
res = fit(to_mps.(input_states), to_mps(init); coeffs = coeffs, kwargs...)
240240
return MPO([x for x in res])
241241
end

test/naive.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Test
2+
3+
import FastMPOContractions as FMPOC
4+
using ITensors
5+
6+
7+
8+
@testset "naive (x-y-z)" begin
9+
R = 3
10+
sitesx = [Index(2, "Qubit,x=$n") for n = 1:R]
11+
sitesy = [Index(2, "Qubit,y=$n") for n = 1:R]
12+
sitesz = [Index(2, "Qubit,z=$n") for n = 1:R]
13+
14+
sitesa = collect(collect.(zip(sitesx, sitesy)))
15+
sitesb = collect(collect.(zip(sitesy, sitesz)))
16+
a = _random_mpo(sitesa)
17+
b = _random_mpo(sitesb)
18+
ab_ref = contract(a, b; alg = "naive")
19+
ab = FMPOC.contract_mpo_mpo(a, b; alg = "naive")
20+
@test ab_ref ab
21+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ include("densitymatrix.jl")
1010
include("fitalgorithm.jl")
1111
include("util.jl")
1212
include("fitalgorithm_sum.jl")
13+
include("naive.jl")

test/test_with_jet.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using JET
22
import FastMPOContractions
33

44
@testset "JET" begin
5-
if VERSION v"1.9"
5+
if VERSION v"1.10"
66
JET.test_package(FastMPOContractions; target_defined_modules = true)
77
end
88
end

0 commit comments

Comments
 (0)