Skip to content

Commit

Permalink
add arg_previous_greater
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Feb 25, 2024
1 parent 5f9c851 commit 6e11ff5
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 34 deletions.
4 changes: 2 additions & 2 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import polars_xdt.namespace # noqa: F401
from polars_xdt.functions import (
arg_prev_greater_value,
arg_previous_greater,
ceil,
day_name,
format_localized,
Expand Down Expand Up @@ -30,6 +30,6 @@
"to_julian_date",
"to_local_datetime",
"workday_count",
"arg_prev_greater_value",
"arg_previous_greater",
"__version__",
]
76 changes: 65 additions & 11 deletions polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,44 +732,98 @@ def workday_count(
},
)

def arg_prev_greater_value(
expr: IntoExpr
) -> pl.Expr:

def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
"""
Find the row count of the previous value greater than the current one.
Parameters
----------
expr
Expression.
Returns
-------
Expr
UInt64 or UInt32 type, depending on the platform.
Examples
--------
>>> import polars as pl
>>> import polars_xdt as xdt
>>> df = pl.DataFrame({'value': [1, 9, 6, 7, 3]})
>>> df.with_columns(result=xdt.arg_prev_greater_value('value'))
>>> df = pl.DataFrame({"value": [1, 9, 6, 7, 3]})
>>> df.with_columns(result=xdt.arg_previous_greater("value"))
shape: (5, 2)
┌───────┬────────┐
│ value ┆ result │
│ --- ┆ --- │
│ i64 ┆ u32 │
╞═══════╪════════╡
│ 1 ┆ null │
│ 9 ┆ 0
│ 9 ┆ 1
│ 6 ┆ 1 │
│ 7 ┆ 2
│ 3 ┆ 1
│ 7 ┆ 1
│ 3 ┆ 3
└───────┴────────┘
This can be useful when working with time series. For example,
if you a dataset like this:
>>> df = pl.DataFrame(
... {
... "date": [
... "2024-02-01",
... "2024-02-02",
... "2024-02-03",
... "2024-02-04",
... "2024-02-05",
... "2024-02-06",
... "2024-02-07",
... "2024-02-08",
... "2024-02-09",
... "2024-02-10",
... ],
... "group": ["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"],
... "value": [1, 9, None, 7, 3, 2, 4, 5, 1, 9],
... }
... )
>>> df = df.with_columns(pl.col("date").str.to_date())
and want find out, for each day and each item, how many days it's
been since `'value'` was higher than it currently is, you could do
>>> df.with_columns(
... result=(
... (
... pl.col("date")
... - pl.col("date")
... .gather(xdt.arg_previous_greater("value"))
... .over("group")
... ).dt.total_days()
... ),
... )
shape: (10, 4)
┌────────────┬───────┬───────┬────────┐
│ date ┆ group ┆ value ┆ result │
│ --- ┆ --- ┆ --- ┆ --- │
│ date ┆ str ┆ i64 ┆ i64 │
╞════════════╪═══════╪═══════╪════════╡
│ 2024-02-01 ┆ A ┆ 1 ┆ null │
│ 2024-02-02 ┆ A ┆ 9 ┆ 0 │
│ 2024-02-03 ┆ A ┆ null ┆ null │
│ 2024-02-04 ┆ A ┆ 7 ┆ 2 │
│ 2024-02-05 ┆ A ┆ 3 ┆ 1 │
│ 2024-02-06 ┆ B ┆ 2 ┆ null │
│ 2024-02-07 ┆ B ┆ 4 ┆ 0 │
│ 2024-02-08 ┆ B ┆ 5 ┆ 0 │
│ 2024-02-09 ┆ B ┆ 1 ┆ 1 │
│ 2024-02-10 ┆ B ┆ 9 ┆ 0 │
└────────────┴───────┴───────┴────────┘
"""
expr = parse_into_expr(expr)
return expr.register_plugin(
lib=lib,
symbol="arg_prev_greater_value",
symbol="arg_previous_greater",
is_elementwise=False,
)
16 changes: 9 additions & 7 deletions src/arg_prev_greater_value.rs → src/arg_previous_greater.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
use polars::prelude::*;

pub(crate) fn impl_arg_prev_greater_value<T>(ca: &ChunkedArray<T>) -> IdxCa
where T: PolarsNumericType {
pub(crate) fn impl_arg_previous_greater<T>(ca: &ChunkedArray<T>) -> IdxCa
where
T: PolarsNumericType,
{
let mut idx: Vec<Option<i32>> = Vec::with_capacity(ca.len());
let out: IdxCa = ca
.into_iter()
.enumerate()
.map(|(i, opt_val)| {
if opt_val.is_none() {
idx.push(None);
return None
return None;
}
let i_curr = i;
let mut i = Some((i as i32) - 1); // look at previous element
let mut i = Some((i as i32) - 1); // look at previous element
while i >= Some(0) && ca.get(i.unwrap() as usize).is_none() {
// find previous non-null value
i = Some(i.unwrap()-1)
i = Some(i.unwrap() - 1)
}
if i < Some(0) {
idx.push(None);
return None
return None;
}
while i.is_some() && opt_val >= ca.get(i.unwrap() as usize) {
i = idx[i.unwrap() as usize];
}
if i.is_none() {
idx.push(None);
return Some(i_curr as IdxSize)
return Some(i_curr as IdxSize);
}
idx.push(i);
i.map(|x| x as IdxSize)
Expand Down
22 changes: 9 additions & 13 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![allow(clippy::unit_arg, clippy::unused_unit)]
use crate::arg_previous_greater::*;
use crate::business_days::*;
use crate::format_localized::*;
use crate::is_workday::*;
use crate::arg_prev_greater_value::*;
use crate::sub::*;
use crate::timezone::*;
use crate::to_julian::*;
Expand Down Expand Up @@ -154,19 +154,15 @@ fn list_idx_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
}

#[polars_expr(output_type_func=list_idx_dtype)]
fn arg_prev_greater_value(inputs: &[Series]) -> PolarsResult<Series> {
fn arg_previous_greater(inputs: &[Series]) -> PolarsResult<Series> {
let ser = &inputs[0];
// steps:
// 1. make generic on inputs[0]
// 2. optionally accept second argument?
// or at least, take-based solution
match ser.dtype() {
DataType::Int64 => Ok(impl_arg_prev_greater_value(ser.i64().unwrap()).into_series()),
DataType::Int32 => Ok(impl_arg_prev_greater_value(ser.i32().unwrap()).into_series()),
DataType::UInt64 => Ok(impl_arg_prev_greater_value(ser.u64().unwrap()).into_series()),
DataType::UInt32 => Ok(impl_arg_prev_greater_value(ser.u32().unwrap()).into_series()),
DataType::Float64 => Ok(impl_arg_prev_greater_value(ser.f64().unwrap()).into_series()),
DataType::Float32 => Ok(impl_arg_prev_greater_value(ser.f32().unwrap()).into_series()),
dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt)
DataType::Int64 => Ok(impl_arg_previous_greater(ser.i64().unwrap()).into_series()),
DataType::Int32 => Ok(impl_arg_previous_greater(ser.i32().unwrap()).into_series()),
DataType::UInt64 => Ok(impl_arg_previous_greater(ser.u64().unwrap()).into_series()),
DataType::UInt32 => Ok(impl_arg_previous_greater(ser.u32().unwrap()).into_series()),
DataType::Float64 => Ok(impl_arg_previous_greater(ser.f64().unwrap()).into_series()),
DataType::Float32 => Ok(impl_arg_previous_greater(ser.f32().unwrap()).into_series()),
dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt),
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod arg_previous_greater;
mod business_days;
mod expressions;
mod format_localized;
Expand All @@ -6,7 +7,6 @@ mod sub;
mod timezone;
mod to_julian;
mod utc_offsets;
mod arg_prev_greater_value;

use pyo3::types::PyModule;
use pyo3::{pymodule, PyResult, Python};
Expand Down

0 comments on commit 6e11ff5

Please sign in to comment.