Skip to content

Commit

Permalink
Merge pull request #61 from pola-rs/prev-greater
Browse files Browse the repository at this point in the history
Add arg_prev_greater
  • Loading branch information
MarcoGorelli authored Feb 25, 2024
2 parents d69b299 + 0f51e60 commit 835d739
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_xdt"
version = "0.12.11"
version = "0.13.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 2 additions & 0 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import polars_xdt.namespace # noqa: F401
from polars_xdt.functions import (
arg_previous_greater,
ceil,
day_name,
format_localized,
Expand Down Expand Up @@ -29,5 +30,6 @@
"to_julian_date",
"to_local_datetime",
"workday_count",
"arg_previous_greater",
"__version__",
]
96 changes: 96 additions & 0 deletions polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,3 +731,99 @@ def workday_count(
"holidays": holidays_int,
},
)


def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
"""
Find the row count of the previous value greater than the current one.
Parameters
----------
expr
Expression.
Returns
-------
Expr
UInt64 or UInt32 type, depending on the platform.
Examples
--------
>>> import polars as pl
>>> import polars_xdt as xdt
>>> df = pl.DataFrame({"value": [1, 9, 6, 7, 3]})
>>> df.with_columns(result=xdt.arg_previous_greater("value"))
shape: (5, 2)
┌───────┬────────┐
│ value ┆ result │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞═══════╪════════╡
│ 1 ┆ null │
│ 9 ┆ 1 │
│ 6 ┆ 1 │
│ 7 ┆ 1 │
│ 3 ┆ 3 │
└───────┴────────┘
This can be useful when working with time series. For example,
if you a dataset like this:
>>> df = pl.DataFrame(
... {
... "date": [
... "2024-02-01",
... "2024-02-02",
... "2024-02-03",
... "2024-02-04",
... "2024-02-05",
... "2024-02-06",
... "2024-02-07",
... "2024-02-08",
... "2024-02-09",
... "2024-02-10",
... ],
... "group": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"],
... "value": [1, 9, None, 7, 3, 2, 4, 5, 1, 9],
... }
... )
>>> df = df.with_columns(pl.col("date").str.to_date())
and want find out, for each day and each item, how many days it's
been since `'value'` was higher than it currently is, you could do
>>> df.with_columns(
... result=(
... (
... pl.col("date")
... - pl.col("date")
... .gather(xdt.arg_previous_greater("value"))
... .over("group")
... ).dt.total_days()
... ),
... )
shape: (10, 4)
┌────────────┬───────┬───────┬────────┐
│ date ┆ group ┆ value ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ str ┆ i64 ┆ i64 │
╞════════════╪═══════╪═══════╪════════╡
│ 2024-02-01 ┆ A ┆ 1 ┆ null │
│ 2024-02-02 ┆ A ┆ 9 ┆ 0 │
│ 2024-02-03 ┆ A ┆ null ┆ null │
│ 2024-02-04 ┆ A ┆ 7 ┆ 2 │
│ 2024-02-05 ┆ A ┆ 3 ┆ 1 │
│ 2024-02-06 ┆ B ┆ 2 ┆ null │
│ 2024-02-07 ┆ B ┆ 4 ┆ 0 │
│ 2024-02-08 ┆ B ┆ 5 ┆ 0 │
│ 2024-02-09 ┆ B ┆ 1 ┆ 1 │
│ 2024-02-10 ┆ B ┆ 9 ┆ 0 │
└────────────┴───────┴───────┴────────┘
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="arg_previous_greater",
is_elementwise=False,
)
38 changes: 38 additions & 0 deletions src/arg_previous_greater.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use polars::prelude::*;

pub(crate) fn impl_arg_previous_greater<T>(ca: &ChunkedArray<T>) -> IdxCa
where
T: PolarsNumericType,
{
let mut idx: Vec<Option<i32>> = Vec::with_capacity(ca.len());
let out: IdxCa = ca
.into_iter()
.enumerate()
.map(|(i, opt_val)| {
if opt_val.is_none() {
idx.push(None);
return None;
}
let i_curr = i;
let mut i = Some((i as i32) - 1); // look at previous element
while i >= Some(0) && ca.get(i.unwrap() as usize).is_none() {
// find previous non-null value
i = Some(i.unwrap() - 1)
}
if i < Some(0) {
idx.push(None);
return None;
}
while i.is_some() && opt_val >= ca.get(i.unwrap() as usize) {
i = idx[i.unwrap() as usize];
}
if i.is_none() {
idx.push(None);
return Some(i_curr as IdxSize);
}
idx.push(i);
i.map(|x| x as IdxSize)
})
.collect();
out
}
20 changes: 20 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::unit_arg, clippy::unused_unit)]
use crate::arg_previous_greater::*;
use crate::business_days::*;
use crate::format_localized::*;
use crate::is_workday::*;
Expand Down Expand Up @@ -146,3 +147,22 @@ fn dst_offset(inputs: &[Series]) -> PolarsResult<Series> {
_ => polars_bail!(InvalidOperation: "base_utc_offset only works on Datetime type."),
}
}

fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
let field = Field::new(input_fields[0].name(), DataType::List(Box::new(IDX_DTYPE)));
Ok(field.clone())
}

#[polars_expr(output_type_func=list_idx_dtype)]
fn arg_previous_greater(inputs: &[Series]) -> PolarsResult<Series> {
let ser = &inputs[0];
match ser.dtype() {
DataType::Int64 => Ok(impl_arg_previous_greater(ser.i64().unwrap()).into_series()),
DataType::Int32 => Ok(impl_arg_previous_greater(ser.i32().unwrap()).into_series()),
DataType::UInt64 => Ok(impl_arg_previous_greater(ser.u64().unwrap()).into_series()),
DataType::UInt32 => Ok(impl_arg_previous_greater(ser.u32().unwrap()).into_series()),
DataType::Float64 => Ok(impl_arg_previous_greater(ser.f64().unwrap()).into_series()),
DataType::Float32 => Ok(impl_arg_previous_greater(ser.f32().unwrap()).into_series()),
dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt),
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod arg_previous_greater;
mod business_days;
mod expressions;
mod format_localized;
Expand Down

0 comments on commit 835d739

Please sign in to comment.