Skip to content

Commit

Permalink
Add fix for bug with hyperindices (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmoran authored Apr 20, 2021
1 parent 3b8e539 commit e58d669
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.5'
- '1.6'
os:
- ubuntu-latest
- macOS-latest
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "QXTns"
uuid = "995e1dad-72e3-4b97-b122-7deaeb8b44f9"
authors = ["QuantEx team"]
version = "0.1.6"
version = "0.1.7"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -15,7 +15,7 @@ DataStructures = "0.18"
ITensors = "0.1"
NDTensors = "0.1"
TestSetExtensions = "2.0"
julia = "1.5"
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
43 changes: 22 additions & 21 deletions src/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,34 +163,35 @@ function contract_hyper_indices(a_indices::Array{<:Index, 1},
a_hyper_indices::Array{<:Array{<:Index, 1}, 1},
b_indices::Array{<:Index, 1},
b_hyper_indices::Array{<:Array{<:Index, 1}, 1})
common_indices = intersect(a_indices, b_indices)
remaining_indices = setdiff(union(a_indices, b_indices), common_indices)
# join hyper groups where there are overlaps. O(N^2) complexity but number of groups should be small
final_groups = Array{Array{Index, 1}, 1}()
b_found = zeros(Bool, length(b_hyper_indices))
for (i, a_group) in enumerate(a_hyper_indices)
for (j, b_group) in enumerate(b_hyper_indices)
if length(intersect(a_group, b_group)) > 0
b_found[j] = true
a_group = union(a_group, b_group)
end

ag = [a_hyper_indices..., b_hyper_indices...]
fg = Array{Array{Index, 1}, 1}()

# @show all_groups
while length(ag) > 1
overlapping = findall(x -> length(intersect(ag[1], x)) > 0, ag[2:end]) .+ 1 # add one for slice offset
if length(overlapping) == 0
push!(fg, ag[1])
popat!(ag, 1)
end
remaining = setdiff(a_group, common_indices)
if length(remaining) > 1
push!(final_groups, remaining)
for i in sort(overlapping, rev=true)
ag[1] = union(ag[1], ag[i])
popat!(ag, i)
end
end
# add any groups in b that have not been added and are still present
for i in findall(x -> !x, b_found)
b_group = b_hyper_indices[i]
if length(setdiff(b_group, common_indices)) > 1
push!(final_groups, setdiff(b_group, common_indices))
end
if length(ag) == 1
push!(fg, ag[1])
end
final_groups

# now check that final groups still exist after contraction
common_indices = intersect(a_indices, b_indices)
remaining_indices = setdiff(union(a_indices, b_indices), common_indices)
fg = map(x -> setdiff(x, common_indices), fg)
filter(x -> length(x) > 1, fg)
end



"""
contract_tensors(A::QXTensor, B::QXTensor; mock::Bool=false)
Expand Down
8 changes: 8 additions & 0 deletions test/test_tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ end
b_hyper_indices = [[bs[3], bs[4]]]

@test QXTns.contract_hyper_indices(as, a_hyper_indices, bs, b_hyper_indices) == [[as[2], as[3]]]

# next an example where the first tensor has two groups linked with group from second tensor
as = [Index(2), Index(2), Index(2), Index(2)]
a_hyper_indices = [[as[1], as[2]], [as[3], as[4]]]
bs = [as[1], as[4]]
b_hyper_indices = [[bs[1], bs[2]]]

@test QXTns.contract_hyper_indices(as, a_hyper_indices, bs, b_hyper_indices) == [[as[2], as[3]]]
end

@testset "Test tensor_data when considering hyperedges" begin
Expand Down

0 comments on commit e58d669

Please sign in to comment.