From f05cae2ecda531cde049ccfc69d470cdb845bf98 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Thu, 7 Nov 2024 16:28:03 +0000 Subject: [PATCH] bug: use direct casadi bspline function for 1D & 2D cubic interp (#4572) * bug: use direct casadi bspline function for 1D cubic interp * bug: use direct bspline func for 2d cubic interp --- .../operations/convert_to_casadi.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/pybamm/expression_tree/operations/convert_to_casadi.py b/src/pybamm/expression_tree/operations/convert_to_casadi.py index af53b9acd8..6b61b35263 100644 --- a/src/pybamm/expression_tree/operations/convert_to_casadi.py +++ b/src/pybamm/expression_tree/operations/convert_to_casadi.py @@ -7,6 +7,7 @@ import casadi import numpy as np from scipy import special +from scipy import interpolate class CasadiConverter: @@ -165,6 +166,18 @@ def _convert(self, symbol, t, y, y_dot, inputs): # for some reason, pybamm.Interpolant always returns a column vector, so match that test = test.T return test + elif solver == "bspline": + bspline = interpolate.make_interp_spline( + symbol.x[0], symbol.y, k=3 + ) + knots = [bspline.t] + coeffs = bspline.c.flatten() + degree = [bspline.k] + m = len(coeffs) // len(symbol.x[0]) + f = casadi.Function.bspline( + symbol.name, knots, coeffs, degree, m + ) + return f(converted_children[0]) else: return casadi.interpolant( "LUT", solver, symbol.x, symbol.y.flatten() @@ -176,6 +189,20 @@ def _convert(self, symbol, t, y, y_dot, inputs): symbol.y.ravel(order="F"), converted_children, ) + elif solver == "bspline" and len(converted_children) == 2: + bspline = interpolate.RectBivariateSpline( + symbol.x[0], symbol.x[1], symbol.y + ) + [tx, ty, c] = bspline.tck + [kx, ky] = bspline.degrees + knots = [tx, ty] + coeffs = c + degree = [kx, ky] + m = 1 + f = casadi.Function.bspline( + symbol.name, knots, coeffs, degree, m + ) + return f(casadi.hcat(converted_children).T).T else: LUT = casadi.interpolant( "LUT", solver, symbol.x, symbol.y.ravel(order="F")