Skip to content

Commit

Permalink
Merge pull request #63 from pola-rs/ewm-time
Browse files Browse the repository at this point in the history
Add Ewma_by_time
  • Loading branch information
MarcoGorelli authored Mar 3, 2024
2 parents 835d739 + 876f9ed commit 0f1a13b
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 7 deletions.
12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "polars_xdt"
version = "0.13.0"
version = "0.14.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -10,14 +10,14 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
pyo3-polars = { version = "0.11.1", features = ["derive"] }
pyo3-polars = { version = "0.12.0", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
chrono = { version = "0.4.31", default-features = false, features = ["std", "unstable-locales"] }
chrono-tz = "0.8.5"
polars = { version = "0.37.0", features = ["strings", "dtype-date"], default-features = false }
polars-time = { version = "0.37.0", features = ["timezones"], default-features = false }
polars-ops = { version = "0.37.0", default-features = false }
polars-arrow = { version = "0.37.0", default-features = false }
polars = { version = "0.38.1", features = ["strings", "dtype-date"], default-features = false }
polars-time = { version = "0.38.1", features = ["timezones"], default-features = false }
polars-ops = { version = "0.38.1", default-features = false }
polars-arrow = { version = "0.38.1", default-features = false }

[target.'cfg(target_os = "linux")'.dependencies]
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
1 change: 1 addition & 0 deletions docs/API.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ API
:toctree: api/

polars_xdt.date_range
polars_xdt.ewma_by_time
polars_xdt.workday_count
polars_xdt.ceil
polars_xdt.day_name
Expand Down
2 changes: 2 additions & 0 deletions polars_xdt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
arg_previous_greater,
ceil,
day_name,
ewma_by_time,
format_localized,
from_local_datetime,
is_workday,
Expand All @@ -31,5 +32,6 @@
"to_local_datetime",
"workday_count",
"arg_previous_greater",
"ewma_by_time",
"__version__",
]
101 changes: 100 additions & 1 deletion polars_xdt/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
import sys
from datetime import date
from datetime import date, timedelta
from typing import TYPE_CHECKING, Literal, Sequence

import polars as pl
Expand Down Expand Up @@ -827,3 +827,102 @@ def arg_previous_greater(expr: IntoExpr) -> pl.Expr:
symbol="arg_previous_greater",
is_elementwise=False,
)


def ewma_by_time(
values: IntoExpr,
*,
times: IntoExpr,
halflife: timedelta,
adjust: bool = True,
) -> pl.Expr:
r"""
Calculate time-based exponentially weighted moving average.
Given observations :math:`x_1, x_2, \ldots, x_n` at times
:math:`t_1, t_2, \ldots, t_n`, the **unadjusted** EWMA is calculated as
.. math::
y_0 &= x_0
\alpha_i &= \exp(-\lambda(t_i - t_{i-1}))
y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0
where :math:`\lambda` equals :math:`\ln(2) / \text{halflife}`.
The **adjusted** version is
.. math::
y_0 &= x_0
\alpha_i &= (\alpha_{i-1} + 1) * \exp(-\lambda(t_i - t_{i-1}))
y_i &= (x_i + \alpha_i y_{i-1}) / (1. + \alpha_i);
Parameters
----------
values
Values to calculate EWMA for. Should be signed numeric.
times
Times corresponding to `values`. Should be ``DateTime`` or ``Date``.
halflife
Unit over which observation decays to half its value.
adjust
Whether to adjust the result to account for the bias towards the
initial value. Defaults to True.
Returns
-------
pl.Expr
Float64
Examples
--------
>>> import polars as pl
>>> import polars_xdt as xdt
>>> from datetime import date, timedelta
>>> df = pl.DataFrame(
... {
... "values": [0, 1, 2, None, 4],
... "times": [
... date(2020, 1, 1),
... date(2020, 1, 3),
... date(2020, 1, 10),
... date(2020, 1, 15),
... date(2020, 1, 17),
... ],
... }
... )
>>> df.with_columns(
... ewma=xdt.ewma_by_time(
... "values", times="times", halflife=timedelta(days=4)
... ),
... )
shape: (5, 3)
┌────────┬────────────┬──────────┐
│ values ┆ times ┆ ewma │
│ --- ┆ --- ┆ --- │
│ i64 ┆ date ┆ f64 │
╞════════╪════════════╪══════════╡
│ 0 ┆ 2020-01-01 ┆ 0.0 │
│ 1 ┆ 2020-01-03 ┆ 0.585786 │
│ 2 ┆ 2020-01-10 ┆ 1.523889 │
│ null ┆ 2020-01-15 ┆ null │
│ 4 ┆ 2020-01-17 ┆ 3.233686 │
└────────┴────────────┴──────────┘
"""
times = parse_into_expr(times)
halflife_us = (
int(halflife.total_seconds()) * 1_000_000 + halflife.microseconds
)
return times.register_plugin(
lib=lib,
symbol="ewma_by_time",
is_elementwise=False,
args=[values],
kwargs={"halflife": halflife_us, "adjust": adjust},
)
84 changes: 84 additions & 0 deletions src/ewma_by_time.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use polars::prelude::*;
use polars_arrow::array::PrimitiveArray;
use pyo3_polars::export::polars_core::export::num::Pow;

pub(crate) fn impl_ewma_by_time_float(
times: &Int64Chunked,
values: &Float64Chunked,
halflife: i64,
adjust: bool,
time_unit: TimeUnit,
) -> Float64Chunked {
let mut out = Vec::with_capacity(times.len());
if values.is_empty() {
return Float64Chunked::full_null("", times.len());
}

let halflife = match time_unit {
TimeUnit::Milliseconds => halflife / 1_000,
TimeUnit::Microseconds => halflife,
TimeUnit::Nanoseconds => halflife * 1_000,
};

let mut prev_time: i64 = times.get(0).unwrap();
let mut prev_result = values.get(0).unwrap();
let mut prev_alpha = 0.0;
out.push(Some(prev_result));
let _ = values
.iter()
.zip(times.iter())
.skip(1)
.map(|(value, time)| {
match (time, value) {
(Some(time), Some(value)) => {
let delta_time = time - prev_time;
let result: f64;
if adjust {
let alpha =
(prev_alpha + 1.) * Pow::pow(0.5, delta_time as f64 / halflife as f64);
result = (value + alpha * prev_result) / (1. + alpha);
prev_alpha = alpha;
} else {
// equivalent to:
// alpha = exp(-delta_time*ln(2) / halflife)
prev_alpha = (0.5_f64).powf(delta_time as f64 / halflife as f64);
result = (1. - prev_alpha) * value + prev_alpha * prev_result;
}
prev_time = time;
prev_result = result;
out.push(Some(result));
}
_ => out.push(None),
}
})
.collect::<Vec<_>>();
let arr = PrimitiveArray::<f64>::from(out);
Float64Chunked::from(arr)
}

pub(crate) fn impl_ewma_by_time(
times: &Int64Chunked,
values: &Series,
halflife: i64,
adjust: bool,
time_unit: TimeUnit,
) -> Series {
match values.dtype() {
DataType::Float64 => {
let values = values.f64().unwrap();
impl_ewma_by_time_float(times, values, halflife, adjust, time_unit).into_series()
}
DataType::Int64 | DataType::Int32 => {
let values = values.cast(&DataType::Float64).unwrap();
let values = values.f64().unwrap();
impl_ewma_by_time_float(times, values, halflife, adjust, time_unit).into_series()
}
DataType::Float32 => {
// todo: preserve Float32 in this case
let values = values.cast(&DataType::Float64).unwrap();
let values = values.f64().unwrap();
impl_ewma_by_time_float(times, values, halflife, adjust, time_unit).into_series()
}
dt => panic!("Expected values to be signed numeric, got {:?}", dt),
}
}
38 changes: 38 additions & 0 deletions src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::unit_arg, clippy::unused_unit)]
use crate::arg_previous_greater::*;
use crate::business_days::*;
use crate::ewma_by_time::*;
use crate::format_localized::*;
use crate::is_workday::*;
use crate::sub::*;
Expand Down Expand Up @@ -166,3 +167,40 @@ fn arg_previous_greater(inputs: &[Series]) -> PolarsResult<Series> {
dt => polars_bail!(ComputeError:"Expected numeric data type, got: {}", dt),
}
}

#[derive(Deserialize)]
struct EwmTimeKwargs {
halflife: i64,
adjust: bool,
}

#[polars_expr(output_type=Float64)]
fn ewma_by_time(inputs: &[Series], kwargs: EwmTimeKwargs) -> PolarsResult<Series> {
let values = &inputs[1];
match &inputs[0].dtype() {
DataType::Datetime(_, _) => {
let time = &inputs[0].datetime().unwrap();
Ok(impl_ewma_by_time(
&time.0,
values,
kwargs.halflife,
kwargs.adjust,
time.time_unit(),
)
.into_series())
}
DataType::Date => {
let binding = &inputs[0].cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?;
let time = binding.datetime().unwrap();
Ok(impl_ewma_by_time(
&time.0,
values,
kwargs.halflife,
kwargs.adjust,
time.time_unit(),
)
.into_series())
}
_ => polars_bail!(InvalidOperation: "First argument should be a date or datetime type."),
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod arg_previous_greater;
mod business_days;
mod ewma_by_time;
mod expressions;
mod format_localized;
mod is_workday;
Expand Down
36 changes: 36 additions & 0 deletions tests/test_ewma_by_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import polars as pl
from polars.testing import assert_frame_equal
import polars_xdt as xdt
from datetime import datetime, timedelta


def test_ewma_by_time():
df = pl.DataFrame(
{
"values": [0.0, 1, 2, None, 4],
"times": [
datetime(2020, 1, 1),
datetime(2020, 1, 3),
datetime(2020, 1, 10),
datetime(2020, 1, 15),
datetime(2020, 1, 17),
],
}
)
result = df.select(
ewma=xdt.ewma_by_time(
"values", times="times", halflife=timedelta(days=4)
),
)
expected = pl.DataFrame(
{
"ewma": [
0.0,
0.585786437626905,
1.52388878049859,
None,
3.2336858398518338,
]
}
)
assert_frame_equal(result, expected)

0 comments on commit 0f1a13b

Please sign in to comment.