Skip to content

Commit

Permalink
Merge pull request #2 from rustodeans/main
Browse files Browse the repository at this point in the history
Cargo format
  • Loading branch information
martinjrobins authored Mar 21, 2024
2 parents 8e09490 + a815637 commit 8d45e9c
Show file tree
Hide file tree
Showing 30 changed files with 894 additions and 649 deletions.
90 changes: 49 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
#[cfg(feature = "diffsl-llvm4")]
pub extern crate diffsl4_0 as diffsl;
#[cfg(feature = "diffsl-llvm5")]
pub extern crate diffsl5_0 as diffsl;
#[cfg(feature = "diffsl-llvm6")]
pub extern crate diffsl6_0 as diffsl;
#[cfg(feature = "diffsl-llvm7")]
pub extern crate diffsl7_0 as diffsl;
#[cfg(feature = "diffsl-llvm8")]
pub extern crate diffsl8_0 as diffsl;
#[cfg(feature = "diffsl-llvm9")]
pub extern crate diffsl9_0 as diffsl;
#[cfg(feature = "diffsl-llvm10")]
pub extern crate diffsl10_0 as diffsl;
#[cfg(feature = "diffsl-llvm11")]
Expand All @@ -26,41 +14,68 @@ pub extern crate diffsl15_0 as diffsl;
pub extern crate diffsl16_0 as diffsl;
#[cfg(feature = "diffsl-llvm17")]
pub extern crate diffsl17_0 as diffsl;
#[cfg(feature = "diffsl-llvm4")]
pub extern crate diffsl4_0 as diffsl;
#[cfg(feature = "diffsl-llvm5")]
pub extern crate diffsl5_0 as diffsl;
#[cfg(feature = "diffsl-llvm6")]
pub extern crate diffsl6_0 as diffsl;
#[cfg(feature = "diffsl-llvm7")]
pub extern crate diffsl7_0 as diffsl;
#[cfg(feature = "diffsl-llvm8")]
pub extern crate diffsl8_0 as diffsl;
#[cfg(feature = "diffsl-llvm9")]
pub extern crate diffsl9_0 as diffsl;

pub trait Scalar: nalgebra::Scalar + From<f64> + Display + SimdRealField + ComplexField + Copy + ClosedSub + From<f64> + ClosedMul + ClosedDiv + ClosedAdd + Signed + PartialOrd + Pow<Self, Output=Self> + Pow<i32, Output=Self> {
pub trait Scalar:
nalgebra::Scalar
+ From<f64>
+ Display
+ SimdRealField
+ ComplexField
+ Copy
+ ClosedSub
+ From<f64>
+ ClosedMul
+ ClosedDiv
+ ClosedAdd
+ Signed
+ PartialOrd
+ Pow<Self, Output = Self>
+ Pow<i32, Output = Self>
{
const EPSILON: Self;
const INFINITY: Self;
}

type IndexType = usize;


impl Scalar for f64 {
const EPSILON: Self = f64::EPSILON;
const INFINITY: Self = f64::INFINITY;
}


pub mod vector;
pub mod matrix;
pub mod linear_solver;
pub mod op;
pub mod matrix;
pub mod nonlinear_solver;
pub mod ode_solver;
pub mod op;
pub mod solver;
pub mod vector;

use std::fmt::Display;

use nalgebra::{ClosedSub, ClosedMul, ClosedDiv, ClosedAdd, SimdRealField, ComplexField};
use num_traits::{Signed, Pow};
use vector::{Vector, VectorView, VectorViewMut, VectorIndex, VectorRef};
use nonlinear_solver::{NonLinearSolver, newton::NewtonNonlinearSolver};
use matrix::{DenseMatrix, MatrixViewMut, Matrix};
use linear_solver::{lu::LU, LinearSolver};
use matrix::{DenseMatrix, Matrix, MatrixViewMut};
use nalgebra::{ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, ComplexField, SimdRealField};
use nonlinear_solver::{newton::NewtonNonlinearSolver, NonLinearSolver};
use num_traits::{Pow, Signed};
pub use ode_solver::{
bdf::Bdf, equations::OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
};
use op::{LinearOp, NonLinearOp};
use solver::SolverProblem;
use linear_solver::{lu::LU, LinearSolver};
pub use ode_solver::{OdeSolverProblem, OdeSolverState, bdf::Bdf, OdeSolverMethod, equations::OdeEquations};

use vector::{Vector, VectorIndex, VectorRef, VectorView, VectorViewMut};

#[cfg(test)]
mod tests {
Expand All @@ -75,19 +90,20 @@ mod tests {
type V = nalgebra::DVector<T>;
let p = V::from_vec(vec![0.04, 1.0e4, 3.0e7]);
let mut problem = OdeSolverProblem::new_ode(
| x: &V, p: &V, _t: T, y: &mut V | {
|x: &V, p: &V, _t: T, y: &mut V| {
y[0] = -p[0] * x[0] + p[1] * x[1] * x[2];
y[1] = p[0] * x[0] - p[1] * x[1] * x[2] - p[2] * x[1] * x[1];
y[2] = p[2] * x[1] * x[1];
},
| x: &V, p: &V, _t: T, v: &V, y: &mut V | {
|x: &V, p: &V, _t: T, v: &V, y: &mut V| {
y[0] = -p[0] * v[0] + p[1] * v[1] * x[2] + p[1] * x[1] * v[2];
y[1] = p[0] * v[0] - p[1] * v[1] * x[2] - p[1] * x[1] * v[2] - 2.0 * p[2] * x[1] * v[1];
y[1] = p[0] * v[0]
- p[1] * v[1] * x[2]
- p[1] * x[1] * v[2]
- 2.0 * p[2] * x[1] * v[1];
y[2] = 2.0 * p[2] * x[1] * v[1];
},
| _p: &V, _t: T | {
V::from_vec(vec![1.0, 0.0, 0.0])
},
|_p: &V, _t: T| V::from_vec(vec![1.0, 0.0, 0.0]),
p,
);
problem.rtol = 1.0e-4;
Expand All @@ -104,15 +120,7 @@ mod tests {
solver.step(&mut state).unwrap();
}
let y2 = solver.interpolate(&state, t);

y2.assert_eq(&y, 1e-6);


y2.assert_eq(&y, 1e-6);
}
}






28 changes: 13 additions & 15 deletions src/linear_solver/gmres.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
use anyhow::Result;

use crate::{LinearSolver, vector::VectorRef, op::LinearOp, SolverProblem};
use crate::{op::LinearOp, vector::VectorRef, LinearSolver, SolverProblem};


pub struct GMRES<C: LinearOp>
where
for <'a> &'a C::V: VectorRef<C::V>,
pub struct GMRES<C: LinearOp>
where
for<'a> &'a C::V: VectorRef<C::V>,
{
_phantom: std::marker::PhantomData<C>,
}

impl<C: LinearOp> GMRES<C>
impl<C: LinearOp> GMRES<C>
where
for <'a> &'a C::V: VectorRef<C::V>,
for<'a> &'a C::V: VectorRef<C::V>,
{
// ...
}

// implement default for gmres
impl<C: LinearOp> Default for GMRES<C>
impl<C: LinearOp> Default for GMRES<C>
where
for <'a> &'a C::V: VectorRef<C::V>,
for<'a> &'a C::V: VectorRef<C::V>,
{
fn default() -> Self {
Self {
Expand All @@ -29,27 +28,26 @@ where
}
}

impl <C: LinearOp> LinearSolver<C> for GMRES<C>
impl<C: LinearOp> LinearSolver<C> for GMRES<C>
where
for <'b> &'b C::V: VectorRef<C::V>,
for<'b> &'b C::V: VectorRef<C::V>,
{

fn problem(&self) -> Option<&SolverProblem<C>> {
todo!()
}
}
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>> {
todo!()
}

fn take_problem(&mut self) -> Option<SolverProblem<C>> {
todo!()
}

fn solve_in_place(&mut self, _state: &mut C::V) -> Result<()> {
todo!()
}

fn set_problem(&mut self, _problem: SolverProblem<C>) {
todo!()
}
}
}
7 changes: 3 additions & 4 deletions src/linear_solver/lu.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use nalgebra::{DVector, Dyn, DMatrix};
use anyhow::Result;
use nalgebra::{DMatrix, DVector, Dyn};

use crate::{op::LinearOp, LinearSolver, Scalar, SolverProblem};

/// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system.
pub struct LU<T, C>
pub struct LU<T, C>
where
T: Scalar,
C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
Expand All @@ -13,7 +13,7 @@ where
problem: Option<SolverProblem<C>>,
}

impl<T, C> Default for LU<T, C>
impl<T, C> Default for LU<T, C>
where
T: Scalar,
C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>,
Expand Down Expand Up @@ -54,4 +54,3 @@ impl<T: Scalar, C: LinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver
self.problem = Some(problem);
}
}

53 changes: 32 additions & 21 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::{op::Op, solver::SolverProblem};
use anyhow::Result;

pub mod lu;
pub mod gmres;
pub mod lu;

/// A solver for the linear problem `Ax = b`.
/// A solver for the linear problem `Ax = b`.
/// The solver is parameterised by the type `C` which is the type of the linear operator `A` (see the [Op] trait for more details).
pub trait LinearSolver<C: Op> {
/// Set the problem to be solved, any previous problem is discarded.
Expand All @@ -13,25 +13,25 @@ pub trait LinearSolver<C: Op> {

/// Get a reference to the current problem, if any.
fn problem(&self) -> Option<&SolverProblem<C>>;

/// Get a mutable reference to the current problem, if any.
fn problem_mut(&mut self) -> Option<&mut SolverProblem<C>>;

/// Take the current problem, if any, and return it.
fn take_problem(&mut self) -> Option<SolverProblem<C>>;
fn reset(&mut self) {
if let Some(problem) = self.take_problem() {
self.set_problem(problem);
}
}

/// Solve the problem `Ax = b` and return the solution `x`.
fn solve(&mut self, b: &C::V) -> Result<C::V> {
let mut b = b.clone();
self.solve_in_place(&mut b)?;
Ok(b)
}

fn solve_in_place(&mut self, b: &mut C::V) -> Result<()>;
}

Expand All @@ -40,46 +40,57 @@ pub struct LinearSolveSolution<V> {
pub b: V,
}

impl <V> LinearSolveSolution<V> {
impl<V> LinearSolveSolution<V> {
pub fn new(b: V, x: V) -> Self {
Self { x, b }
}
}



#[cfg(test)]
pub mod tests {
use std::rc::Rc;

use crate::{op::{linear_closure::LinearClosure, LinearOp}, LinearSolver, vector::VectorRef, DenseMatrix, SolverProblem, Vector, LU};
use crate::{
op::{linear_closure::LinearClosure, LinearOp},
vector::VectorRef,
DenseMatrix, LinearSolver, SolverProblem, Vector, LU,
};
use num_traits::{One, Zero};

use super::LinearSolveSolution;

fn linear_problem<M: DenseMatrix + 'static>() -> (SolverProblem<impl LinearOp<M = M, V = M::V, T = M::T>>, Vec<LinearSolveSolution<M::V>>) {
fn linear_problem<M: DenseMatrix + 'static>() -> (
SolverProblem<impl LinearOp<M = M, V = M::V, T = M::T>>,
Vec<LinearSolveSolution<M::V>>,
) {
let diagonal = M::V::from_vec(vec![2.0.into(), 2.0.into()]);
let jac = M::from_diagonal(&diagonal);
let p = Rc::new(M::V::zeros(0));
let op = Rc::new(LinearClosure::new(
// f = J * x
move | x, _p, _t, y | jac.gemv(M::T::one(), x, M::T::zero(), y),
2, 2, p
move |x, _p, _t, y| jac.gemv(M::T::one(), x, M::T::zero(), y),
2,
2,
p,
));
let t = M::T::zero();
let rtol = M::T::from(1e-6);
let atol = Rc::new(M::V::from_vec(vec![1e-6.into(), 1e-6.into()]));
let problem = SolverProblem::new(op, t, atol, rtol);
let solns = vec![
LinearSolveSolution::new(M::V::from_vec(vec![2.0.into(), 4.0.into()]), M::V::from_vec(vec![1.0.into(), 2.0.into()]))
];
let solns = vec![LinearSolveSolution::new(
M::V::from_vec(vec![2.0.into(), 4.0.into()]),
M::V::from_vec(vec![1.0.into(), 2.0.into()]),
)];
(problem, solns)
}

pub fn test_linear_solver<C>(mut solver: impl LinearSolver<C>, problem: SolverProblem<C>, solns: Vec<LinearSolveSolution<C::V>>)
where
pub fn test_linear_solver<C>(
mut solver: impl LinearSolver<C>,
problem: SolverProblem<C>,
solns: Vec<LinearSolveSolution<C::V>>,
) where
C: LinearOp,
for <'a> &'a C::V: VectorRef<C::V>,
for<'a> &'a C::V: VectorRef<C::V>,
{
solver.set_problem(problem);
for soln in solns {
Expand All @@ -93,11 +104,11 @@ pub mod tests {
}

type MCpu = nalgebra::DMatrix<f64>;

#[test]
fn test_lu() {
let (p, solns) = linear_problem::<MCpu>();
let s = LU::default();
test_linear_solver(s, p, solns);
}
}
}
Loading

0 comments on commit 8d45e9c

Please sign in to comment.