Skip to content

Commit

Permalink
fix: Check values in strict cast Int to Time (#18854)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Sep 23, 2024
1 parent 33f3fa0 commit 66960ff
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 19 deletions.
4 changes: 4 additions & 0 deletions crates/polars-arrow/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ impl<T: NativeType> PrimitiveArray<T> {
values: Buffer<T>,
validity: Option<Bitmap>,
) -> Self {
if cfg!(debug_assertions) {
check(&dtype, &values, validity.as_ref().map(|v| v.len())).unwrap();
}

Self {
dtype,
values,
Expand Down
14 changes: 4 additions & 10 deletions crates/polars-arrow/src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,22 +689,16 @@ pub fn cast(

// temporal casts
(Int32, Date32) => primitive_to_same_primitive_dyn::<i32>(array, to_type),
(Int32, Time32(TimeUnit::Second)) => primitive_to_same_primitive_dyn::<i32>(array, to_type),
(Int32, Time32(TimeUnit::Millisecond)) => {
primitive_to_same_primitive_dyn::<i32>(array, to_type)
},
(Int32, Time32(TimeUnit::Second)) => primitive_dyn!(array, int32_to_time32s),
(Int32, Time32(TimeUnit::Millisecond)) => primitive_dyn!(array, int32_to_time32ms),
// No support for microsecond/nanosecond with i32
(Date32, Int32) => primitive_to_same_primitive_dyn::<i32>(array, to_type),
(Date32, Int64) => primitive_to_primitive_dyn::<i32, i64>(array, to_type, options),
(Time32(_), Int32) => primitive_to_same_primitive_dyn::<i32>(array, to_type),
(Int64, Date64) => primitive_to_same_primitive_dyn::<i64>(array, to_type),
// No support for second/milliseconds with i64
(Int64, Time64(TimeUnit::Microsecond)) => {
primitive_to_same_primitive_dyn::<i64>(array, to_type)
},
(Int64, Time64(TimeUnit::Nanosecond)) => {
primitive_to_same_primitive_dyn::<i64>(array, to_type)
},
(Int64, Time64(TimeUnit::Microsecond)) => primitive_dyn!(array, int64_to_time64us),
(Int64, Time64(TimeUnit::Nanosecond)) => primitive_dyn!(array, int64_to_time64ns),

(Date64, Int32) => primitive_to_primitive_dyn::<i64, i32>(array, to_type, options),
(Date64, Int64) => primitive_to_same_primitive_dyn::<i64>(array, to_type),
Expand Down
77 changes: 77 additions & 0 deletions crates/polars-arrow/src/compute/cast/primitive_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,83 @@ pub fn primitive_to_dictionary<T: NativeType + Eq + Hash, K: DictionaryKey>(
Ok(array.into())
}

/// # Safety
///
/// `dtype` should be valid for primitive.
pub unsafe fn primitive_map_is_valid<T: NativeType>(
from: &PrimitiveArray<T>,
f: impl Fn(T) -> bool,
dtype: ArrowDataType,
) -> PrimitiveArray<T> {
let values = from.values().clone();

let validity: Bitmap = values.iter().map(|&v| f(v)).collect();

let validity = if validity.unset_bits() > 0 {
let new_validity = match from.validity() {
None => validity,
Some(v) => v & &validity,
};

Some(new_validity)
} else {
from.validity().cloned()
};

// SAFETY:
// - Validity did not change length
// - dtype should be valid
unsafe { PrimitiveArray::new_unchecked(dtype, values, validity) }
}

/// Conversion of `Int32` to `Time32(TimeUnit::Second)`
pub fn int32_to_time32s(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
// SAFETY: Time32(TimeUnit::Second) is valid for Int32
unsafe {
primitive_map_is_valid(
from,
|v| (0..SECONDS_IN_DAY as i32).contains(&v),
ArrowDataType::Time32(TimeUnit::Second),
)
}
}

/// Conversion of `Int32` to `Time32(TimeUnit::Millisecond)`
pub fn int32_to_time32ms(from: &PrimitiveArray<i32>) -> PrimitiveArray<i32> {
// SAFETY: Time32(TimeUnit::Millisecond) is valid for Int32
unsafe {
primitive_map_is_valid(
from,
|v| (0..MILLISECONDS_IN_DAY as i32).contains(&v),
ArrowDataType::Time32(TimeUnit::Millisecond),
)
}
}

/// Conversion of `Int64` to `Time32(TimeUnit::Microsecond)`
pub fn int64_to_time64us(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
// SAFETY: Time64(TimeUnit::Microsecond) is valid for Int64
unsafe {
primitive_map_is_valid(
from,
|v| (0..MICROSECONDS_IN_DAY).contains(&v),
ArrowDataType::Time32(TimeUnit::Microsecond),
)
}
}

/// Conversion of `Int64` to `Time32(TimeUnit::Nanosecond)`
pub fn int64_to_time64ns(from: &PrimitiveArray<i64>) -> PrimitiveArray<i64> {
// SAFETY: Time64(TimeUnit::Nanosecond) is valid for Int64
unsafe {
primitive_map_is_valid(
from,
|v| (0..NANOSECONDS_IN_DAY).contains(&v),
ArrowDataType::Time64(TimeUnit::Nanosecond),
)
}
}

/// Conversion of dates
pub fn date32_to_date64(from: &PrimitiveArray<i32>) -> PrimitiveArray<i64> {
unary(
Expand Down
43 changes: 41 additions & 2 deletions crates/polars-core/src/chunked_array/logical/time.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use arrow::compute::cast::CastOptionsImpl;

use super::*;
use crate::prelude::*;

Expand All @@ -10,8 +12,45 @@ impl From<Int64Chunked> for TimeChunked {
}

impl Int64Chunked {
pub fn into_time(self) -> TimeChunked {
TimeChunked::new_logical(self)
pub fn into_time(mut self) -> TimeChunked {
let mut null_count = 0;

// Invalid time values are replaced with `null` during the arrow cast. We utilize the
// validity coming from there to create the new TimeChunked.
let chunks = std::mem::take(&mut self.chunks)
.into_iter()
.map(|chunk| {
// We need to retain the PhysicalType underneath, but we should properly update the
// validity as that might change because Time is not valid for all values of Int64.
let casted = arrow::compute::cast::cast(
chunk.as_ref(),
&ArrowDataType::Time64(ArrowTimeUnit::Nanosecond),
CastOptionsImpl::default(),
)
.unwrap();
let validity = casted.validity();

match validity {
None => chunk,
Some(validity) => {
null_count += validity.unset_bits();
chunk.with_validity(Some(validity.clone()))
},
}
})
.collect::<Vec<Box<dyn Array>>>();

let null_count = null_count as IdxSize;

debug_assert!(null_count >= self.null_count);

// @TODO: We throw away metadata here. That is mostly not needed.
// SAFETY: We calculated the null_count again. And we are taking the rest from the previous
// Int64Chunked.
let int64chunked =
unsafe { Self::new_with_dims(self.field.clone(), chunks, self.length, null_count) };

TimeChunked::new_logical(int64chunked)
}
}

Expand Down
4 changes: 0 additions & 4 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,10 +713,6 @@ impl Series {

#[cfg(feature = "dtype-time")]
pub(crate) fn into_time(self) -> Series {
#[cfg(not(feature = "dtype-time"))]
{
panic!("activate feature dtype-time")
}
match self.dtype() {
DataType::Int64 => self.i64().unwrap().clone().into_time().into_series(),
DataType::Time => self
Expand Down
16 changes: 13 additions & 3 deletions crates/polars-expr/src/expressions/literal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::ops::Deref;

use arrow::temporal_conversions::NANOSECONDS_IN_DAY;
use polars_core::prelude::*;
use polars_core::utils::NoNull;
use polars_plan::constants::get_literal_name;
Expand Down Expand Up @@ -91,9 +92,18 @@ impl PhysicalExpr for LiteralExpr {
.into_date()
.into_series(),
#[cfg(feature = "dtype-time")]
Time(v) => Int64Chunked::full(get_literal_name().clone(), *v, 1)
.into_time()
.into_series(),
Time(v) => {
if !(0..NANOSECONDS_IN_DAY).contains(v) {
polars_bail!(
InvalidOperation: "value `{v}` is out-of-range for `time` which can be 0 - {}",
NANOSECONDS_IN_DAY - 1
);
}

Int64Chunked::full(get_literal_name().clone(), *v, 1)
.into_time()
.into_series()
},
Series(series) => series.deref().clone(),
OtherScalar(s) => s.clone().into_series(get_literal_name().clone()),
lv @ (Int(_) | Float(_) | StrCat(_)) => polars_core::prelude::Series::from_any_values(
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/unit/datatypes/test_time.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import time

import pytest

import polars as pl


Expand All @@ -15,3 +17,17 @@ def test_time_microseconds_3843() -> None:
in_val = [time(0, 9, 11, 558332)]
s = pl.Series(in_val)
assert s.to_list() == in_val


def test_invalid_casts() -> None:
with pytest.raises(pl.exceptions.InvalidOperationError):
pl.DataFrame({"a": []}).with_columns(a=pl.lit(-1).cast(pl.Time))

with pytest.raises(pl.exceptions.InvalidOperationError):
pl.Series([-1]).cast(pl.Time)

with pytest.raises(pl.exceptions.InvalidOperationError):
pl.Series([24 * 60 * 60 * 1_000_000_000]).cast(pl.Time)

largest_value = pl.Series([24 * 60 * 60 * 1_000_000_000 - 1]).cast(pl.Time)
assert "23:59:59.999999999" in str(largest_value)

0 comments on commit 66960ff

Please sign in to comment.