Skip to content

Commit

Permalink
bug: use direct casadi bspline function for 1D & 2D cubic interp (#4572)
Browse files Browse the repository at this point in the history
* bug: use direct casadi bspline function for 1D cubic interp

* bug: use direct bspline func for 2d cubic interp
  • Loading branch information
martinjrobins authored Nov 7, 2024
1 parent 9a479cf commit f05cae2
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import casadi
import numpy as np
from scipy import special
from scipy import interpolate


class CasadiConverter:
Expand Down Expand Up @@ -165,6 +166,18 @@ def _convert(self, symbol, t, y, y_dot, inputs):
# for some reason, pybamm.Interpolant always returns a column vector, so match that
test = test.T
return test
elif solver == "bspline":
bspline = interpolate.make_interp_spline(
symbol.x[0], symbol.y, k=3
)
knots = [bspline.t]
coeffs = bspline.c.flatten()
degree = [bspline.k]
m = len(coeffs) // len(symbol.x[0])
f = casadi.Function.bspline(
symbol.name, knots, coeffs, degree, m
)
return f(converted_children[0])
else:
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
Expand All @@ -176,6 +189,20 @@ def _convert(self, symbol, t, y, y_dot, inputs):
symbol.y.ravel(order="F"),
converted_children,
)
elif solver == "bspline" and len(converted_children) == 2:
bspline = interpolate.RectBivariateSpline(
symbol.x[0], symbol.x[1], symbol.y
)
[tx, ty, c] = bspline.tck
[kx, ky] = bspline.degrees
knots = [tx, ty]
coeffs = c
degree = [kx, ky]
m = 1
f = casadi.Function.bspline(
symbol.name, knots, coeffs, degree, m
)
return f(casadi.hcat(converted_children).T).T
else:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
Expand Down

0 comments on commit f05cae2

Please sign in to comment.