Skip to content

Commit

Permalink
fix: Mix of columnd and field expansion (#16502)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 26, 2024
1 parent 3e8e6a5 commit af3a42f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
59 changes: 49 additions & 10 deletions crates/polars-plan/src/logical_plan/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,22 @@ struct ExpansionFlags {
has_selector: bool,
has_exclude: bool,
#[cfg(feature = "dtype-struct")]
expands_fields: bool,
#[cfg(feature = "dtype-struct")]
has_struct_field_by_index: bool,
}

impl ExpansionFlags {
fn expands(&self) -> bool {
#[cfg(feature = "dtype-struct")]
let expands_fields = self.expands_fields;
#[cfg(not(feature = "dtype-struct"))]
let expands_fields = false;

self.multiple_columns || expands_fields
}
}

fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
let mut multiple_columns = false;
let mut has_nth = false;
Expand All @@ -570,6 +583,8 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
let mut has_exclude = false;
#[cfg(feature = "dtype-struct")]
let mut has_struct_field_by_index = false;
#[cfg(feature = "dtype-struct")]
let mut expands_fields = false;

// Do a single pass and collect all flags at once.
// Supertypes/modification that can be done in place are also done in that pass
Expand All @@ -592,7 +607,7 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
function: FunctionExpr::StructExpr(StructFunction::MultipleFields(_)),
..
} => {
multiple_columns = true;
expands_fields = true;
},
Expr::Exclude(_, _) => has_exclude = true,
#[cfg(feature = "dtype-struct")]
Expand All @@ -610,6 +625,8 @@ fn find_flags(expr: &Expr) -> PolarsResult<ExpansionFlags> {
has_exclude,
#[cfg(feature = "dtype-struct")]
has_struct_field_by_index,
#[cfg(feature = "dtype-struct")]
expands_fields,
})
}

Expand Down Expand Up @@ -661,14 +678,14 @@ fn replace_and_add_to_results(

// has multiple column names
// the expanded columns are added to the result
if flags.multiple_columns {
if flags.expands() {
if let Some(e) = expr.into_iter().find(|e| match e {
Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_) => true,
#[cfg(feature = "dtype-struct")]
Expr::Function {
function: FunctionExpr::StructExpr(StructFunction::MultipleFields(_)),
..
} => true,
} => flags.expands_fields,
_ => false,
}) {
match &e {
Expand All @@ -686,18 +703,40 @@ fn replace_and_add_to_results(
expand_indices(&expr, result, schema, indices, &exclude)?
},
#[cfg(feature = "dtype-struct")]
Expr::Function { function, .. }
if matches!(
function,
FunctionExpr::StructExpr(StructFunction::MultipleFields(_))
) =>
{
Expr::Function { function, .. } => {
let FunctionExpr::StructExpr(StructFunction::MultipleFields(names)) = function
else {
unreachable!()
};
let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?;
expand_struct_fields(e, &expr, result, schema, names, &exclude)?

// has both column and field expansion
// col('a', 'b').struct.field('*')
if flags.multiple_columns {
// First expand col('a', 'b') into an intermediate result.
let mut intermediate = vec![];
let mut flags = flags;
flags.expands_fields = false;
replace_and_add_to_results(
expr.clone(),
flags,
&mut intermediate,
schema,
keys,
)?;

// Then expand the fields and add to the final result vec.
flags.expands_fields = true;
flags.multiple_columns = false;
for e in intermediate {
replace_and_add_to_results(e, flags, result, schema, keys)?;
}
}
// has only field expansion
// col('a').struct.field('*')
else {
expand_struct_fields(e, &expr, result, schema, names, &exclude)?
}
},
_ => {},
}
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/test_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,14 @@ def test_struct_field_expansion_16410() -> None:
"x": [2.0],
"y": [4],
}


def test_field_and_column_expansion() -> None:
df = pl.DataFrame({"a": [{"x": 1, "y": 2}], "b": [{"i": 3, "j": 4}]})

assert df.select(pl.col("a", "b").struct.field("*")).to_dict(as_series=False) == {
"x": [1],
"y": [2],
"i": [3],
"j": [4],
}

0 comments on commit af3a42f

Please sign in to comment.