From e58d6697f271fa10d65f82940d2531ae0fecf727 Mon Sep 17 00:00:00 2001 From: Niall Date: Tue, 20 Apr 2021 16:34:26 +0100 Subject: [PATCH] Add fix for bug with hyperindices (#8) --- .github/workflows/ci.yml | 2 +- Project.toml | 4 ++-- src/tensor.jl | 43 ++++++++++++++++++++-------------------- test/test_tensor.jl | 8 ++++++++ 4 files changed, 33 insertions(+), 24 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3613067..9d831bf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: version: - - '1.5' + - '1.6' os: - ubuntu-latest - macOS-latest diff --git a/Project.toml b/Project.toml index 26c2ce8..253d85a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QXTns" uuid = "995e1dad-72e3-4b97-b122-7deaeb8b44f9" authors = ["QuantEx team"] -version = "0.1.6" +version = "0.1.7" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -15,7 +15,7 @@ DataStructures = "0.18" ITensors = "0.1" NDTensors = "0.1" TestSetExtensions = "2.0" -julia = "1.5" +julia = "1.6" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/tensor.jl b/src/tensor.jl index a221b80..d6c8a72 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -163,34 +163,35 @@ function contract_hyper_indices(a_indices::Array{<:Index, 1}, a_hyper_indices::Array{<:Array{<:Index, 1}, 1}, b_indices::Array{<:Index, 1}, b_hyper_indices::Array{<:Array{<:Index, 1}, 1}) - common_indices = intersect(a_indices, b_indices) - remaining_indices = setdiff(union(a_indices, b_indices), common_indices) - # join hyper groups where there are overlaps. O(N^2) complexity but number of groups should be small - final_groups = Array{Array{Index, 1}, 1}() - b_found = zeros(Bool, length(b_hyper_indices)) - for (i, a_group) in enumerate(a_hyper_indices) - for (j, b_group) in enumerate(b_hyper_indices) - if length(intersect(a_group, b_group)) > 0 - b_found[j] = true - a_group = union(a_group, b_group) - end + + ag = [a_hyper_indices..., b_hyper_indices...] + fg = Array{Array{Index, 1}, 1}() + + # @show all_groups + while length(ag) > 1 + overlapping = findall(x -> length(intersect(ag[1], x)) > 0, ag[2:end]) .+ 1 # add one for slice offset + if length(overlapping) == 0 + push!(fg, ag[1]) + popat!(ag, 1) end - remaining = setdiff(a_group, common_indices) - if length(remaining) > 1 - push!(final_groups, remaining) + for i in sort(overlapping, rev=true) + ag[1] = union(ag[1], ag[i]) + popat!(ag, i) end end - # add any groups in b that have not been added and are still present - for i in findall(x -> !x, b_found) - b_group = b_hyper_indices[i] - if length(setdiff(b_group, common_indices)) > 1 - push!(final_groups, setdiff(b_group, common_indices)) - end + if length(ag) == 1 + push!(fg, ag[1]) end - final_groups + + # now check that final groups still exist after contraction + common_indices = intersect(a_indices, b_indices) + remaining_indices = setdiff(union(a_indices, b_indices), common_indices) + fg = map(x -> setdiff(x, common_indices), fg) + filter(x -> length(x) > 1, fg) end + """ contract_tensors(A::QXTensor, B::QXTensor; mock::Bool=false) diff --git a/test/test_tensor.jl b/test/test_tensor.jl index 371c447..4299669 100644 --- a/test/test_tensor.jl +++ b/test/test_tensor.jl @@ -49,6 +49,14 @@ end b_hyper_indices = [[bs[3], bs[4]]] @test QXTns.contract_hyper_indices(as, a_hyper_indices, bs, b_hyper_indices) == [[as[2], as[3]]] + + # next an example where the first tensor has two groups linked with group from second tensor + as = [Index(2), Index(2), Index(2), Index(2)] + a_hyper_indices = [[as[1], as[2]], [as[3], as[4]]] + bs = [as[1], as[4]] + b_hyper_indices = [[bs[1], bs[2]]] + + @test QXTns.contract_hyper_indices(as, a_hyper_indices, bs, b_hyper_indices) == [[as[2], as[3]]] end @testset "Test tensor_data when considering hyperedges" begin