From 27d4bf852cad5e1cfd96e34525db692fe92ad20f Mon Sep 17 00:00:00 2001 From: mattcleigh <36259408+mattcleigh@users.noreply.github.com> Date: Mon, 4 Dec 2023 15:36:29 +0100 Subject: [PATCH] Updating depricated torch.trtrs (#54) --- normflows/flows/mixing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index d35cbd5..5822836 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -505,9 +505,9 @@ def weight_inverse(self): """ lower, upper = self._create_lower_upper() identity = torch.eye(self.features, self.features) - lower_inverse, _ = torch.trtrs(identity, lower, upper=False, unitriangular=True) - weight_inverse, _ = torch.trtrs( - lower_inverse, upper, upper=True, unitriangular=False + lower_inverse = torch.linalg.solve_triangular(lower, identity, upper=False, unitriangular=True) + weight_inverse = torch.linalg.solve_triangular( + upper, lower_inverse, upper=True, unitriangular=False ) return weight_inverse