Skip to content

Commit

Permalink
make diis more backend agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Jan 29, 2025
1 parent 768169b commit 5a24907
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions quimb/experimental/belief_propagation/diis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def __init__(self, shape, size):
self.shape = shape
self.size = size

def __repr__(self):
return f"<ArrayInfo(shape={self.shape}, size={self.size})>"


class Vectorizer:
"""Object for mapping back and forth between any nested pytree of arrays
Expand Down Expand Up @@ -134,28 +137,39 @@ def __init__(self, max_history=6, beta=1.0, rcond=1e-14):
self.errors = [None] * max_history
self.lambdas = []
self.head = self.max_history - 1

self.backend = None
self.B = None
self.y = None
self.scalar = None

def _extrapolate(self):
# TODO: make this backend agnostic
import numpy as np

# XXX: do this all on backend? (though is very small)
if self.B is None:
dtype = ar.get_dtype_name(self.guesses[0])
g0 = self.guesses[0]
self.backend = ar.infer_backend(g0)
dtype = ar.get_dtype_name(g0)
self.B = np.zeros((self.max_history + 1,) * 2, dtype=dtype)
self.y = np.zeros(self.max_history + 1, dtype=dtype)
self.B[1:, 0] = self.B[0, 1:] = self.y[0] = 1.0
# define conversion to python scalar
if "complex" in dtype:
self.scalar = complex
else:
self.scalar = float

# number of error estimates we have
d = sum(e is not None for e in self.errors)
i = self.head
error_i_conj = self.errors[i].conj()
for j in range(d):
cij = error_i_conj @ self.errors[j]
cij = self.scalar(error_i_conj @ self.errors[j])
self.B[i + 1, j + 1] = cij
if i != j:
self.B[j + 1, i + 1] = cij.conj()
self.B[j + 1, i + 1] = cij.conjugate()

# solve for coefficients, taking into account rank deficiency
Binv = np.linalg.pinv(
Expand All @@ -166,19 +180,21 @@ def _extrapolate(self):
coeffs = Binv @ self.y[: d + 1]

# first entry is -ve. lagrange multiplier -> estimated next residual
self.lambdas.append(-coeffs[0])
self.lambdas.append(abs(-coeffs[0]))
coeffs = [self.scalar(c) for c in coeffs[1:]]

# construct linear combination of previous guesses!
xnew = np.zeros_like(self.guesses[0])
for ci, xi in zip(coeffs[1:], self.guesses):
# xnew = np.zeros_like(self.guesses[0])
xnew = ar.do("zeros_like", self.guesses[0], like=self.backend)
for ci, xi in zip(coeffs, self.guesses):
xnew += ci * xi

if self.beta != 0.0:
# allow custom mix of x + xnew:
# https://prefetch.eu/know/concept/pulay-mixing/
# i.e. use not just x_i but also f(x_i) -> y_i
# original Pulay mixing is beta=1.0 == only xnews
for ci, ei in zip(coeffs[1:], self.errors):
for ci, ei in zip(coeffs, self.errors):
xnew += (self.beta * ci) * ei

return xnew
Expand Down Expand Up @@ -213,7 +229,7 @@ def update(self, y):

self.head = (self.head + 1) % self.max_history
# # TODO: make copy backend agnostic
self.guesses[self.head] = xnext.copy()
self.guesses[self.head] = ar.do("copy", xnext, like=self.backend)

# convert new extrapolated guess back to pytree
return self.vectorizer.unpack(xnext)
Expand Down

0 comments on commit 5a24907

Please sign in to comment.