Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Performance Enhancement: Implementing expect Function to Exploit Hermitian Nature of H #301

Open
danielalcalde opened this issue Jun 22, 2023 · 3 comments

Comments

@danielalcalde
Copy link

It has come to my attention that the loss function defined in the tutorial may not be optimized for performance:

function loss(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return inner(Uψ', H, Uψ; cutoff, maxdim)
end

Currently, backpropagation through the inner function is relatively slow, primarily because the function doesn't take into account two crucial aspects: Uψ' and represent the same state, and H is Hermitian.

The function expect, defined as:

function expect(ψ, H; kwargs...)
    return real(inner', H, ψ; kwargs...))
end

Zygote.@adjoint function expect(ψ, H; kwargs...)
    function (ȳ)
        ψbar = contract(H, ψ'; kwargs...)
        return* 2 * ψbar, nothing
    end
   return expect(ψ, H; kwargs...), f̄
end

is designed to exploit these properties and can result in a considerable performance boost (in my simulation 3s->200ms) to compute gradients.

The codebase does not seem to have an equivalent function. I suggest incorporating the expect or a similarly named function into the ITensors.jl or PastaQ.jl package, which will lead to significant performance improvements.

Additionally, if an equivalent function already exists in the codebase, I recommend updating the tutorial to use this function instead of inner to make it more performance-oriented and user-friendly.

@danielalcalde danielalcalde changed the title Performance Enhancement: Implementing expect Function to Exploit Hermitian Nature of H and Similarity of States Performance Enhancement: Implementing expect Function to Exploit Hermitian Nature of H Jun 22, 2023
@danielalcalde danielalcalde changed the title Performance Enhancement: Implementing expect Function to Exploit Hermitian Nature of H Gradient Performance Enhancement: Implementing expect Function to Exploit Hermitian Nature of H Jun 22, 2023
@mtfishman
Copy link
Collaborator

That's an impressive speedup! I was trying to think if we can do this same kind of optimization automatically in our inner rrule but I don't really see how besides manually checking if the bra and ket MPS are the same. That could technically work but is a bit tricky.

The main issue I see with your proposal is if H isn't Hermitian, that derivative wouldn't be correct. We could have a flag ishermitian which you can pass to inner, but it's a bit funny having a keyword argument that's only used by the derivative rule. Alternatively we could just say that expect should be used with Hermitian MPOs, though I'm not sure I like that.

@danielalcalde
Copy link
Author

As motivation, I have an example with 2x speedup here. Note that the larger the bond dimension of the Hamiltonian the larger the speedup.

using PastaQ
using ITensors
using Random
using Printf
using OptimKit
using Zygote
using BenchmarkTools

N = 10   # number of qubits
J = 1.0  # Ising exchange interaction
h = 0.5  # transverse magnetic field

# Hilbert space
hilbert = qubits(N)

function ising_hamiltonian(N; J, h)
  os = OpSum()
  for j in 1:(N - 1)
    os += -J, "Z", j, "Z", j + 1
  end
  for j in 1:N
    os += -h, "X", j
  end
  return os
end

# define the Hamiltonian
os = ising_hamiltonian(N; J, h)

# build MPO "cost function"
H = MPO(os, hilbert)

cutoff = 1e-10

# layer of single-qubit Ry gates
Rylayer(N, θ) = [("Ry", j, (θ=θ[j],)) for j in 1:N]

# brick-layer of CX gates
function CXlayer(N, Π)
  start = isodd(Π) ? 1 : 2
  return [("CX", (j, j + 1)) for j in start:2:(N - 1)]
end

# variational ansatz
function variationalcircuit(N, depth, θ)
  circuit = Tuple[]
  for d in 1:depth
    circuit = vcat(circuit, CXlayer(N, d))
    circuit = vcat(circuit, Rylayer(N, θ[d]))
  end
  return circuit
end

depth = 20
ψ = productstate(hilbert)

cutoff = 1e-8
maxdim = 200

# cost function
function loss(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return inner(Uψ', H, Uψ; cutoff, maxdim)
end
    
function expect(ψ, H; kwargs...)
    return real(inner', H, ψ; kwargs...))
end

Zygote.@adjoint function expect(ψ, H; kwargs...)
    function (ȳ)
        ψbar = contract(H, ψ'; kwargs...)
        return* 2 * ψbar, nothing
    end
   return expect(ψ, H; kwargs...), f̄
end
    
function loss2(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return expect(Uψ, H; cutoff, maxdim)
end

Random.seed!(1234)

# initialize parameters
θ₀ = [2π .* rand(N) for _ in 1:depth];
        
gradient(loss, θ₀) .≈ gradient(loss2, θ₀)
# (true,)
        
@benchmark gradient(loss, θ₀)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min  max):  1.441 s     2.184 s  ┊ GC (min  max):  6.06%  37.79%
 Time  (median):     1.455 s               ┊ GC (median):     6.00%
 Time  (mean ± σ):   1.694 s ± 425.015 ms  ┊ GC (mean ± σ):  19.57% ± 18.46%

  ██                                                       █  
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.44 s         Histogram: frequency by time         2.18 s <

 Memory estimate: 648.61 MiB, allocs estimate: 1518320.

@benchmark gradient(loss2, θ₀)
BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range (min  max):  685.989 ms     1.110 s  ┊ GC (min  max): 10.04%  43.52%
 Time  (median):     840.434 ms               ┊ GC (median):    26.17%
 Time  (mean ± σ):   871.303 ms ± 204.527 ms  ┊ GC (mean ± σ):  28.53% ± 17.36%

  █                                                              
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▇ ▁
  686 ms           Histogram: frequency by time          1.11 s <

 Memory estimate: 334.74 MiB, allocs estimate: 1468562

@danielalcalde
Copy link
Author

danielalcalde commented Jun 23, 2023

What do you think about using this code to check if the MPO is Hermitian. And throw an error if the expect function is used with a non hermitian MPO.

using StatsBase

function repeat_inds(a::ITensor)
    for (v, i) in countmap(noprime(a.tensor.inds))
        if i==2
            return v
        end
    end
end

function LinearAlgebra.ishermitian(t::ITensor)
    s = repeat_inds(t)
    ts = swapprime(t, 0=>1; tags=[s, s'])
    return t ≈ conj(ts)
end

LinearAlgebra.ishermitian(H::MPO) = all([ishermitian(t) for t in H])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants