diff --git a/Cargo.toml b/Cargo.toml index 134c902..cb53990 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "polars_xdt" -version = "0.12.11" +version = "0.13.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/polars_xdt/__init__.py b/polars_xdt/__init__.py index 9099b37..abe34c1 100644 --- a/polars_xdt/__init__.py +++ b/polars_xdt/__init__.py @@ -2,6 +2,7 @@ import polars_xdt.namespace # noqa: F401 from polars_xdt.functions import ( + arg_previous_greater, ceil, day_name, format_localized, @@ -29,5 +30,6 @@ "to_julian_date", "to_local_datetime", "workday_count", + "arg_previous_greater", "__version__", ] diff --git a/polars_xdt/functions.py b/polars_xdt/functions.py index 366d7a8..38307bd 100644 --- a/polars_xdt/functions.py +++ b/polars_xdt/functions.py @@ -731,3 +731,99 @@ def workday_count( "holidays": holidays_int, }, ) + + +def arg_previous_greater(expr: IntoExpr) -> pl.Expr: + """ + Find the row count of the previous value greater than the current one. + + Parameters + ---------- + expr + Expression. + + Returns + ------- + Expr + UInt64 or UInt32 type, depending on the platform. + + Examples + -------- + >>> import polars as pl + >>> import polars_xdt as xdt + >>> df = pl.DataFrame({"value": [1, 9, 6, 7, 3]}) + >>> df.with_columns(result=xdt.arg_previous_greater("value")) + shape: (5, 2) + ┌───────┬────────┐ + │ value ┆ result │ + │ --- ┆ --- │ + │ i64 ┆ u32 │ + ╞═══════╪════════╡ + │ 1 ┆ null │ + │ 9 ┆ 1 │ + │ 6 ┆ 1 │ + │ 7 ┆ 1 │ + │ 3 ┆ 3 │ + └───────┴────────┘ + + This can be useful when working with time series. For example, + if you a dataset like this: + + >>> df = pl.DataFrame( + ... { + ... "date": [ + ... "2024-02-01", + ... "2024-02-02", + ... "2024-02-03", + ... "2024-02-04", + ... "2024-02-05", + ... "2024-02-06", + ... "2024-02-07", + ... "2024-02-08", + ... "2024-02-09", + ... "2024-02-10", + ... ], + ... "group": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"], + ... "value": [1, 9, None, 7, 3, 2, 4, 5, 1, 9], + ... } + ... ) + >>> df = df.with_columns(pl.col("date").str.to_date()) + + and want find out, for each day and each item, how many days it's + been since `'value'` was higher than it currently is, you could do + + >>> df.with_columns( + ... result=( + ... ( + ... pl.col("date") + ... - pl.col("date") + ... .gather(xdt.arg_previous_greater("value")) + ... .over("group") + ... ).dt.total_days() + ... ), + ... ) + shape: (10, 4) + ┌────────────┬───────┬───────┬────────┐ + │ date ┆ group ┆ value ┆ result │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ date ┆ str ┆ i64 ┆ i64 │ + ╞════════════╪═══════╪═══════╪════════╡ + │ 2024-02-01 ┆ A ┆ 1 ┆ null │ + │ 2024-02-02 ┆ A ┆ 9 ┆ 0 │ + │ 2024-02-03 ┆ A ┆ null ┆ null │ + │ 2024-02-04 ┆ A ┆ 7 ┆ 2 │ + │ 2024-02-05 ┆ A ┆ 3 ┆ 1 │ + │ 2024-02-06 ┆ B ┆ 2 ┆ null │ + │ 2024-02-07 ┆ B ┆ 4 ┆ 0 │ + │ 2024-02-08 ┆ B ┆ 5 ┆ 0 │ + │ 2024-02-09 ┆ B ┆ 1 ┆ 1 │ + │ 2024-02-10 ┆ B ┆ 9 ┆ 0 │ + └────────────┴───────┴───────┴────────┘ + + """ + expr = parse_into_expr(expr) + return expr.register_plugin( + lib=lib, + symbol="arg_previous_greater", + is_elementwise=False, + ) diff --git a/src/arg_previous_greater.rs b/src/arg_previous_greater.rs new file mode 100644 index 0000000..b3e4907 --- /dev/null +++ b/src/arg_previous_greater.rs @@ -0,0 +1,38 @@ +use polars::prelude::*; + +pub(crate) fn impl_arg_previous_greater(ca: &ChunkedArray) -> IdxCa +where + T: PolarsNumericType, +{ + let mut idx: Vec> = Vec::with_capacity(ca.len()); + let out: IdxCa = ca + .into_iter() + .enumerate() + .map(|(i, opt_val)| { + if opt_val.is_none() { + idx.push(None); + return None; + } + let i_curr = i; + let mut i = Some((i as i32) - 1); // look at previous element + while i >= Some(0) && ca.get(i.unwrap() as usize).is_none() { + // find previous non-null value + i = Some(i.unwrap() - 1) + } + if i < Some(0) { + idx.push(None); + return None; + } + while i.is_some() && opt_val >= ca.get(i.unwrap() as usize) { + i = idx[i.unwrap() as usize]; + } + if i.is_none() { + idx.push(None); + return Some(i_curr as IdxSize); + } + idx.push(i); + i.map(|x| x as IdxSize) + }) + .collect(); + out +} diff --git a/src/expressions.rs b/src/expressions.rs index 9f5d8a2..8331725 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -1,4 +1,5 @@ #![allow(clippy::unit_arg, clippy::unused_unit)] +use crate::arg_previous_greater::*; use crate::business_days::*; use crate::format_localized::*; use crate::is_workday::*; @@ -146,3 +147,22 @@ fn dst_offset(inputs: &[Series]) -> PolarsResult { _ => polars_bail!(InvalidOperation: "base_utc_offset only works on Datetime type."), } } + +fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult { + let field = Field::new(input_fields[0].name(), DataType::List(Box::new(IDX_DTYPE))); + Ok(field.clone()) +} + +#[polars_expr(output_type_func=list_idx_dtype)] +fn arg_previous_greater(inputs: &[Series]) -> PolarsResult { + let ser = &inputs[0]; + match ser.dtype() { + DataType::Int64 => Ok(impl_arg_previous_greater(ser.i64().unwrap()).into_series()), + DataType::Int32 => Ok(impl_arg_previous_greater(ser.i32().unwrap()).into_series()), + DataType::UInt64 => Ok(impl_arg_previous_greater(ser.u64().unwrap()).into_series()), + DataType::UInt32 => Ok(impl_arg_previous_greater(ser.u32().unwrap()).into_series()), + DataType::Float64 => Ok(impl_arg_previous_greater(ser.f64().unwrap()).into_series()), + DataType::Float32 => Ok(impl_arg_previous_greater(ser.f32().unwrap()).into_series()), + dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt), + } +} diff --git a/src/lib.rs b/src/lib.rs index 5698d4f..dbc8d40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod arg_previous_greater; mod business_days; mod expressions; mod format_localized;