diff --git a/Cargo.lock b/Cargo.lock index 81ea7d6c32..a12f72cc38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1968,6 +1968,7 @@ dependencies = [ "common-treenode", "daft-dsl", "daft-schema", + "indexmap 2.7.0", "rstest", ] diff --git a/src/common/treenode/src/lib.rs b/src/common/treenode/src/lib.rs index f749040895..4bd40310fc 100644 --- a/src/common/treenode/src/lib.rs +++ b/src/common/treenode/src/lib.rs @@ -736,6 +736,20 @@ pub trait TreeNodeIterator: Iterator { self, f: F, ) -> Result>>; + + /// Apples `f` to each item in this iterator + /// + /// # Returns + /// Error if `f` returns an error + /// + /// Ok(Transformed) such that: + /// 1. `transformed` is true if any return from `f` had transformed true + /// 2. `data` from the last invocation of `f` + /// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty + fn map_and_collect Result>>( + self, + f: F, + ) -> Result>>; } impl TreeNodeIterator for I { @@ -771,6 +785,23 @@ impl TreeNodeIterator for I { .collect::>>() .map(|data| Transformed::new(data, transformed, tnr)) } + + fn map_and_collect Result>>( + self, + mut f: F, + ) -> Result>> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.map(|item| { + f(item).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } } /// Transformation helper to process a heterogeneous sequence of tree node containing diff --git a/src/daft-algebra/Cargo.toml b/src/daft-algebra/Cargo.toml index 89e942c700..680929ea3e 100644 --- a/src/daft-algebra/Cargo.toml +++ b/src/daft-algebra/Cargo.toml @@ -3,6 +3,7 @@ common-error = {path = "../common/error", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-schema = {path = "../daft-schema", default-features = false} +indexmap = {workspace = true} [dev-dependencies] rstest = {workspace = true} diff --git a/src/daft-algebra/src/simplify.rs b/src/daft-algebra/src/simplify.rs deleted file mode 100644 index e8547b25a3..0000000000 --- a/src/daft-algebra/src/simplify.rs +++ /dev/null @@ -1,473 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use common_treenode::Transformed; -use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; -use daft_schema::{dtype::DataType, schema::SchemaRef}; - -pub fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { - Ok(match expr { - // ---------------- - // Eq - // ---------------- - // true = A --> A - // false = A --> !A - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - // A = true --> A - // A = false --> !A - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right, - Some(false) => right.not(), - None => unreachable!(), - }) - } - - // null = A --> null - // A = null --> null - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // Neq - // ---------------- - // true != A --> !A - // false != A --> A - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - // A != true --> !A - // A != false --> A - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right.not(), - Some(false) => right, - None => unreachable!(), - }) - } - - // null != A --> null - // A != null --> null - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // OR - // ---------------- - - // true OR A --> true - Expr::BinaryOp { - op: Operator::Or, - left, - right: _, - } if is_true(&left) => Transformed::yes(left), - // false OR A --> A - Expr::BinaryOp { - op: Operator::Or, - left, - right, - } if is_false(&left) => Transformed::yes(right), - // A OR true --> true - Expr::BinaryOp { - op: Operator::Or, - left: _, - right, - } if is_true(&right) => Transformed::yes(right), - // A OR false --> A - Expr::BinaryOp { - left, - op: Operator::Or, - right, - } if is_false(&right) => Transformed::yes(left), - - // ---------------- - // AND (TODO) - // ---------------- - - // ---------------- - // Multiplication - // ---------------- - - // A * 1 --> A - // 1 * A --> A - Expr::BinaryOp { - op: Operator::Multiply, - left, - right, - }| Expr::BinaryOp { - op: Operator::Multiply, - left: right, - right: left, - } if is_one(&right) => Transformed::yes(left), - - // A * null --> null - Expr::BinaryOp { - op: Operator::Multiply, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null * A --> null - Expr::BinaryOp { - op: Operator::Multiply, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) - // A * 0 --> 0 (if A is not null and not floating/decimal) - // 0 * A --> 0 (if A is not null and not floating/decimal) - - // ---------------- - // Division - // ---------------- - // A / 1 --> A - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right, - } if is_one(&right) => Transformed::yes(left), - // null / A --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - // A / null --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // ---------------- - // Addition - // ---------------- - // A + 0 --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // 0 + A --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&left) => Transformed::yes(right), - - // ---------------- - // Subtraction - // ---------------- - - // A - 0 --> A - Expr::BinaryOp { - op: Operator::Minus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // A - null --> null - Expr::BinaryOp { - op: Operator::Minus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null - A --> null - Expr::BinaryOp { - op: Operator::Minus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // ---------------- - // Modulus - // ---------------- - - // A % null --> null - Expr::BinaryOp { - op: Operator::Modulus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // null % A --> null - Expr::BinaryOp { - op: Operator::Modulus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // A BETWEEN low AND high --> A >= low AND A <= high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) - } - Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { - // NOT (BETWEEN A AND B) --> A < low OR A > high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } - // expr NOT IN () --> true - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), - - expr => { - let expr = simplify_expr(expr, schema)?; - if expr.transformed { - Transformed::yes(expr.data.not()) - } else { - Transformed::no(expr.data.not()) - } - } - }, - // expr IN () --> false - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), - - other => Transformed::no(Arc::new(other)), - }) -} - -fn is_zero(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int8(0)) - | Expr::Literal(LiteralValue::UInt8(0)) - | Expr::Literal(LiteralValue::Int16(0)) - | Expr::Literal(LiteralValue::UInt16(0)) - | Expr::Literal(LiteralValue::Int32(0)) - | Expr::Literal(LiteralValue::UInt32(0)) - | Expr::Literal(LiteralValue::Int64(0)) - | Expr::Literal(LiteralValue::UInt64(0)) - | Expr::Literal(LiteralValue::Float64(0.)) => true, - Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, - _ => false, - } -} - -fn is_one(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int8(1)) - | Expr::Literal(LiteralValue::UInt8(1)) - | Expr::Literal(LiteralValue::Int16(1)) - | Expr::Literal(LiteralValue::UInt16(1)) - | Expr::Literal(LiteralValue::Int32(1)) - | Expr::Literal(LiteralValue::UInt32(1)) - | Expr::Literal(LiteralValue::Int64(1)) - | Expr::Literal(LiteralValue::UInt64(1)) - | Expr::Literal(LiteralValue::Float64(1.)) => true, - - Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { - *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) - } - _ => false, - } -} - -fn is_true(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => *v, - _ => false, - } -} -fn is_false(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => !*v, - _ => false, - } -} - -/// returns true if expr is a -/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise -fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) -} - -fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { - matches!(expr.get_type(schema), Ok(DataType::Boolean)) -} - -fn as_bool_lit(expr: &Expr) -> Option { - expr.as_literal().and_then(|l| l.as_bool()) -} - -fn is_null(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Null)) -} - -static POWS_OF_TEN: [i128; 38] = [ - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000, - 10000000000000000000, - 100000000000000000000, - 1000000000000000000000, - 10000000000000000000000, - 100000000000000000000000, - 1000000000000000000000000, - 10000000000000000000000000, - 100000000000000000000000000, - 1000000000000000000000000000, - 10000000000000000000000000000, - 100000000000000000000000000000, - 1000000000000000000000000000000, - 10000000000000000000000000000000, - 100000000000000000000000000000000, - 1000000000000000000000000000000000, - 10000000000000000000000000000000000, - 100000000000000000000000000000000000, - 1000000000000000000000000000000000000, - 10000000000000000000000000000000000000, -]; - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use common_error::DaftResult; - use daft_dsl::{col, lit, null_lit, ExprRef}; - use daft_schema::{ - dtype::DataType, - field::Field, - schema::{Schema, SchemaRef}, - }; - use rstest::{fixture, rstest}; - - use crate::simplify_expr; - - #[fixture] - fn schema() -> SchemaRef { - Arc::new( - Schema::new(vec![ - Field::new("bool", DataType::Boolean), - Field::new("int", DataType::Int32), - ]) - .unwrap(), - ) - } - - #[rstest] - // true = A --> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // false = A --> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // A == true ---> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // null = A --> null - #[case(null_lit().eq(col("bool")), null_lit())] - // A == false ---> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // true != A --> !A - #[case(lit(true).not_eq(col("bool")), col("bool").not())] - // false != A --> A - #[case(lit(false).not_eq(col("bool")), col("bool"))] - // true OR A --> true - #[case(lit(true).or(col("bool")), lit(true))] - // false OR A --> A - #[case(lit(false).or(col("bool")), col("bool"))] - // A OR true --> true - #[case(col("bool").or(lit(true)), lit(true))] - // A OR false --> A - #[case(col("bool").or(lit(false)), col("bool"))] - fn test_simplify_bool_exprs( - #[case] input: ExprRef, - #[case] expected: ExprRef, - schema: SchemaRef, - ) -> DaftResult<()> { - let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; - - assert!(optimized.transformed); - assert_eq!(optimized.data, expected); - Ok(()) - } - - #[rstest] - // A * 1 --> A - #[case(col("int").mul(lit(1)), col("int"))] - // 1 * A --> A - #[case(lit(1).mul(col("int")), col("int"))] - // A / 1 --> A - #[case(col("int").div(lit(1)), col("int"))] - // A + 0 --> A - #[case(col("int").add(lit(0)), col("int"))] - // A - 0 --> A - #[case(col("int").sub(lit(0)), col("int"))] - fn test_math_exprs( - #[case] input: ExprRef, - #[case] expected: ExprRef, - schema: SchemaRef, - ) -> DaftResult<()> { - let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; - - assert!(optimized.transformed); - assert_eq!(optimized.data, expected); - Ok(()) - } - - #[rstest] - fn test_not_between(schema: SchemaRef) -> DaftResult<()> { - let input = col("int").between(lit(1), lit(10)).not(); - let expected = col("int").lt(lit(1)).or(col("int").gt(lit(10))); - - let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; - - assert!(optimized.transformed); - assert_eq!(optimized.data, expected); - Ok(()) - } - - #[rstest] - fn test_between(schema: SchemaRef) -> DaftResult<()> { - let input = col("int").between(lit(1), lit(10)); - let expected = col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))); - - let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; - - assert!(optimized.transformed); - assert_eq!(optimized.data, expected); - Ok(()) - } -} diff --git a/src/daft-algebra/src/simplify/boolean.rs b/src/daft-algebra/src/simplify/boolean.rs new file mode 100644 index 0000000000..cdaa19db61 --- /dev/null +++ b/src/daft-algebra/src/simplify/boolean.rs @@ -0,0 +1,290 @@ +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::{dtype::DataType, schema::SchemaRef}; +use indexmap::IndexSet; + +use crate::boolean::{ + combine_conjunction, combine_disjunction, split_conjunction, split_disjunction, +}; + +fn is_true(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(true))) +} + +fn is_false(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(false))) +} + +fn is_bool(expr: &Expr, schema: &SchemaRef) -> DaftResult { + Ok(matches!(expr.get_type(schema)?, DataType::Boolean)) +} + +/// Simplify boolean expressions (AND, OR, NOT, etc.) +pub(crate) fn simplify_boolean_expr( + expr: ExprRef, + _schema: &SchemaRef, +) -> DaftResult> { + Ok(match expr.as_ref() { + Expr::BinaryOp { op, left, right } => match op { + // true AND e -> e + Operator::And if is_true(left) => Transformed::yes(right.clone()), + Operator::And if is_true(right) => Transformed::yes(left.clone()), + // false AND e -> false + Operator::And if is_false(left) => Transformed::yes(lit(false)), + Operator::And if is_false(right) => Transformed::yes(lit(false)), + // false OR e -> e + Operator::Or if is_false(left) => Transformed::yes(right.clone()), + Operator::Or if is_false(right) => Transformed::yes(left.clone()), + // true OR e -> true + Operator::Or if is_true(left) => Transformed::yes(lit(true)), + Operator::Or if is_true(right) => Transformed::yes(lit(true)), + + Operator::And => simplify_and(expr.clone(), left, right), + Operator::Or => simplify_or(expr.clone(), left, right), + + _ => Transformed::no(expr), + }, + Expr::Not(not) => { + match not.as_ref() { + // !!e -> e + Expr::Not(not_not) => Transformed::yes(not_not.clone()), + // !true -> false, !false -> true + Expr::Literal(LiteralValue::Boolean(val)) => Transformed::yes(lit(!val)), + _ => Transformed::no(expr), + } + } + _ => Transformed::no(expr), + }) +} + +/// Combines common sub-expressions in AND expressions. +/// +/// 1. Matches up each expression in the left-side conjunction with each expression in the right-side conjunction. +/// +/// (We only match expressions between left and right sides and not within each side because this rule is run bottom-up, +/// so the sides should have already been individually simplified.) +/// +/// 2. For each pair, check if there are any common expressions in their disjunctions. +/// +/// Ex: +/// - left: (a OR b OR c OR ...) +/// - right: (a OR b OR d OR ...) +/// - common: a, b +/// +/// 3. If there are common expressions, extract them out into ORs. +/// +/// Ex (using above example): +/// - a OR b OR ((c OR ...) AND (d OR ...)) +/// +/// 4. Finally, combine all new and remaining expressions back into conjunction +fn simplify_and(expr: ExprRef, left: &ExprRef, right: &ExprRef) -> Transformed { + let left_and_exprs = split_conjunction(left); + let right_and_exprs = split_conjunction(right); + + let left_and_of_or_exprs = left_and_exprs + .iter() + .map(|e| split_disjunction(e).into_iter().collect()) + .collect::>>(); + let right_and_of_or_exprs = right_and_exprs + .iter() + .map(|e| split_disjunction(e).into_iter().collect()) + .collect::>>(); + + let mut left_eliminated = vec![false; left_and_exprs.len()]; + let mut right_eliminated = vec![false; right_and_exprs.len()]; + let mut simplified_exprs = Vec::new(); + + for (i, left_exprs) in left_and_of_or_exprs.iter().enumerate() { + for (j, right_exprs) in right_and_of_or_exprs.iter().enumerate() { + if right_eliminated[j] || left_exprs.is_disjoint(right_exprs) { + continue; + } + + let common_exprs: IndexSet<_> = left_exprs.intersection(right_exprs).cloned().collect(); + + let simplified = if common_exprs == *left_exprs || common_exprs == *right_exprs { + // (a OR b OR c OR ...) AND (a OR b) -> (a OR b) + combine_disjunction(common_exprs).unwrap() + } else { + // (a OR b OR c OR ...) AND (a OR b OR d OR ...) -> a OR b OR ((c OR ...) AND (d OR ...)) + let left_remaining = + combine_disjunction(left_exprs.difference(&common_exprs).cloned()).unwrap(); + let right_remaining = + combine_disjunction(right_exprs.difference(&common_exprs).cloned()).unwrap(); + + let common_disjunction = combine_disjunction(common_exprs).unwrap(); + + common_disjunction.or(left_remaining.and(right_remaining)) + }; + + simplified_exprs.push(simplified); + left_eliminated[i] = true; + right_eliminated[j] = true; + break; + } + } + + if simplified_exprs.is_empty() { + Transformed::no(expr) + } else { + let left_and_remaining = left_and_exprs + .into_iter() + .zip(left_eliminated) + .filter(|(_, eliminated)| !*eliminated) + .map(|(e, _)| e); + + let right_and_remaining = right_and_exprs + .into_iter() + .zip(right_eliminated) + .filter(|(_, eliminated)| !*eliminated) + .map(|(e, _)| e); + + Transformed::yes( + combine_conjunction( + simplified_exprs + .into_iter() + .chain(left_and_remaining) + .chain(right_and_remaining), + ) + .unwrap(), + ) + } +} + +/// Combines common sub-expressions in OR expressions. +/// +/// 1. Matches up each expression in the left-side disjunction with each expression in the right-side disjunction. +/// +/// (We only match expressions between left and right sides and not within each side because this rule is run bottom-up, +/// so the sides should have already been individually simplified.) +/// +/// 2. For each pair, check if there are any common expressions in their conjunctions. +/// +/// Ex: +/// - left: (a AND b AND c AND ...) +/// - right: (a AND b AND d AND ...) +/// - common: a, b +/// +/// 3. If there are common expressions, extract them out into ANDs. +/// +/// Ex (using above example): +/// - a AND b AND ((c AND ...) OR (d AND ...)) +/// +/// 4. Finally, combine all new and remaining expressions back into disjunction +fn simplify_or(expr: ExprRef, left: &ExprRef, right: &ExprRef) -> Transformed { + let left_or_exprs = split_disjunction(left); + let right_or_exprs = split_disjunction(right); + + let left_or_of_and_exprs = left_or_exprs + .iter() + .map(|e| split_conjunction(e).into_iter().collect()) + .collect::>>(); + let right_or_of_and_exprs = right_or_exprs + .iter() + .map(|e| split_conjunction(e).into_iter().collect()) + .collect::>>(); + + let mut left_eliminated = vec![false; left_or_exprs.len()]; + let mut right_eliminated = vec![false; right_or_exprs.len()]; + let mut simplified_exprs = Vec::new(); + + for (i, left_exprs) in left_or_of_and_exprs.iter().enumerate() { + for (j, right_exprs) in right_or_of_and_exprs.iter().enumerate() { + if right_eliminated[j] || left_exprs.is_disjoint(right_exprs) { + continue; + } + + let common_exprs: IndexSet<_> = left_exprs.intersection(right_exprs).cloned().collect(); + + let simplified = if common_exprs == *left_exprs || common_exprs == *right_exprs { + // (a AND b AND c AND ...) OR (a AND b) -> (a AND b) + combine_conjunction(common_exprs).unwrap() + } else { + // (a AND b AND c AND ...) OR (a AND b AND d AND ...) -> a AND b AND ((c AND ...) OR (d AND ...)) + let left_remaining = + combine_conjunction(left_exprs.difference(&common_exprs).cloned()).unwrap(); + let right_remaining = + combine_conjunction(right_exprs.difference(&common_exprs).cloned()).unwrap(); + + let common_conjunction = combine_conjunction(common_exprs).unwrap(); + + common_conjunction.and(left_remaining.or(right_remaining)) + }; + + simplified_exprs.push(simplified); + left_eliminated[i] = true; + right_eliminated[j] = true; + break; + } + } + + if simplified_exprs.is_empty() { + Transformed::no(expr) + } else { + let left_or_remaining = left_or_exprs + .into_iter() + .zip(left_eliminated) + .filter(|(_, eliminated)| !*eliminated) + .map(|(e, _)| e); + + let right_or_remaining = right_or_exprs + .into_iter() + .zip(right_eliminated) + .filter(|(_, eliminated)| !*eliminated) + .map(|(e, _)| e); + + Transformed::yes( + combine_disjunction( + simplified_exprs + .into_iter() + .chain(left_or_remaining) + .chain(right_or_remaining), + ) + .unwrap(), + ) + } +} + +/// Simplify binary comparison ops (==, !=, >, <, >=, <=) +pub(crate) fn simplify_binary_compare( + expr: ExprRef, + schema: &SchemaRef, +) -> DaftResult> { + Ok(match expr.as_ref() { + Expr::BinaryOp { op, left, right } => match op { + // true = e -> e + Operator::Eq if is_true(left) && is_bool(right, schema)? => { + Transformed::yes(right.clone()) + } + Operator::Eq if is_true(right) && is_bool(left, schema)? => { + Transformed::yes(left.clone()) + } + // false = e -> !e + Operator::Eq if is_false(left) && is_bool(right, schema)? => { + Transformed::yes(right.clone().not()) + } + Operator::Eq if is_false(right) && is_bool(left, schema)? => { + Transformed::yes(left.clone().not()) + } + + // true != e -> !e + Operator::NotEq if is_true(left) && is_bool(right, schema)? => { + Transformed::yes(right.clone().not()) + } + Operator::NotEq if is_true(right) && is_bool(left, schema)? => { + Transformed::yes(left.clone()) + } + // false != e -> e + Operator::NotEq if is_false(left) && is_bool(right, schema)? => { + Transformed::yes(right.clone()) + } + Operator::NotEq if is_false(right) && is_bool(left, schema)? => { + Transformed::yes(left.clone()) + } + + _ => Transformed::no(expr), + }, + _ => Transformed::no(expr), + }) +} diff --git a/src/daft-algebra/src/simplify/mod.rs b/src/daft-algebra/src/simplify/mod.rs new file mode 100644 index 0000000000..80abd1b156 --- /dev/null +++ b/src/daft-algebra/src/simplify/mod.rs @@ -0,0 +1,165 @@ +mod boolean; +mod null; +mod numeric; + +use boolean::{simplify_binary_compare, simplify_boolean_expr}; +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; +use daft_dsl::{lit, Expr, ExprRef}; +use daft_schema::schema::SchemaRef; +use null::simplify_expr_with_null; +use numeric::simplify_numeric_expr; + +/// Recursively simplify expression. +pub fn simplify_expr(expr: ExprRef, schema: &SchemaRef) -> DaftResult> { + let simplify_fns = [ + simplify_boolean_expr, + simplify_binary_compare, + simplify_expr_with_null, + simplify_numeric_expr, + simplify_misc_expr, + ]; + + // Our simplify rules currently require bottom-up traversal to work + // If we introduce top-down rules in the future, please add a separate pass + // on the expression instead of changing this. + expr.transform_up(|node| { + let dtype = node.to_field(schema)?.dtype; + + let transformed = simplify_fns + .into_iter() + .try_fold(Transformed::no(node), |transformed, f| { + transformed.transform_data(|e| f(e, schema)) + })?; + + // cast back to original dtype if necessary + transformed.map_data(|new_node| { + Ok(if new_node.to_field(schema)?.dtype == dtype { + new_node + } else { + new_node.cast(&dtype) + }) + }) + }) +} + +fn simplify_misc_expr(expr: ExprRef, schema: &SchemaRef) -> DaftResult> { + Ok(match expr.as_ref() { + // a BETWEEN low AND high --> a >= low AND e <= high + Expr::Between(between, low, high) => Transformed::yes( + between + .clone() + .lt_eq(high.clone()) + .and(between.clone().gt_eq(low.clone())), + ), + // e IN () --> false + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), + // CAST(e AS dtype) -> e if e.dtype == dtype + Expr::Cast(e, dtype) if e.get_type(schema)? == *dtype => Transformed::yes(e.clone()), + _ => Transformed::no(expr), + }) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_schema::{ + dtype::DataType, + field::Field, + schema::{Schema, SchemaRef}, + }; + use rstest::{fixture, rstest}; + + use crate::simplify_expr; + + #[fixture] + fn schema() -> SchemaRef { + Arc::new( + Schema::new(vec![ + Field::new("bool", DataType::Boolean), + Field::new("int", DataType::Int32), + Field::new("a", DataType::Boolean), + Field::new("b", DataType::Boolean), + Field::new("c", DataType::Boolean), + ]) + .unwrap(), + ) + } + + #[rstest] + // true = A --> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // false = A --> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // A == true ---> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // null = A --> null + #[case(null_lit().eq(col("bool")), null_lit().cast(&DataType::Boolean))] + // A == false ---> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // true != A --> !A + #[case(lit(true).not_eq(col("bool")), col("bool").not())] + // false != A --> A + #[case(lit(false).not_eq(col("bool")), col("bool"))] + // true OR A --> true + #[case(lit(true).or(col("bool")), lit(true))] + // false OR A --> A + #[case(lit(false).or(col("bool")), col("bool"))] + // A OR true --> true + #[case(col("bool").or(lit(true)), lit(true))] + // A OR false --> A + #[case(col("bool").or(lit(false)), col("bool"))] + // (A OR B) AND (A OR C) -> A OR (B AND C) + #[case((col("a").or(col("b"))).and(col("a").or(col("c"))), col("a").or(col("b").and(col("c"))))] + // (A AND B) OR (A AND C) -> A AND (B OR C) + #[case((col("a").and(col("b"))).or(col("a").and(col("c"))), col("a").and(col("b").or(col("c"))))] + fn test_simplify_bool_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(input, &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + // A * 1 --> A + #[case(col("int").mul(lit(1)), col("int"))] + // 1 * A --> A + #[case(lit(1).mul(col("int")), col("int"))] + // A / 1 --> A + #[case(col("int").div(lit(1)), col("int").cast(&DataType::Float64))] + // A + 0 --> A + #[case(col("int").add(lit(0)), col("int"))] + // A - 0 --> A + #[case(col("int").sub(lit(0)), col("int"))] + fn test_math_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(input, &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)); + let expected = col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))); + + let optimized = simplify_expr(input, &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } +} diff --git a/src/daft-algebra/src/simplify/null.rs b/src/daft-algebra/src/simplify/null.rs new file mode 100644 index 0000000000..276dcf15bd --- /dev/null +++ b/src/daft-algebra/src/simplify/null.rs @@ -0,0 +1,71 @@ +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::schema::SchemaRef; + +fn is_null(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Null)) +} + +/// Simplify expressions with null +pub(crate) fn simplify_expr_with_null( + expr: ExprRef, + _schema: &SchemaRef, +) -> DaftResult> { + Ok(match expr.as_ref() { + Expr::BinaryOp { op, left, right } if is_null(left) || is_null(right) => match op { + // ops where one null input always yields null + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::TrueDivide + | Operator::FloorDivide + | Operator::Modulus + | Operator::Xor + | Operator::ShiftLeft + | Operator::ShiftRight => Transformed::yes(null_lit()), + + // operators where null input may not result in null + Operator::EqNullSafe | Operator::And | Operator::Or => Transformed::no(expr), + }, + // !NULL -> NULL + Expr::Not(e) if is_null(e) => Transformed::yes(null_lit()), + Expr::IsNull(e) => { + match e.as_ref() { + Expr::Literal(l) => { + match l { + // is_null(NULL) -> true + LiteralValue::Null => Transformed::yes(lit(true)), + // is_null(lit(x)) -> false + _ => Transformed::yes(lit(false)), + } + } + _ => Transformed::no(expr), + } + } + Expr::NotNull(e) => { + match e.as_ref() { + Expr::Literal(l) => { + match l { + // not_null(NULL) -> false + LiteralValue::Null => Transformed::yes(lit(false)), + // not_null(lit(x)) -> true + _ => Transformed::yes(lit(true)), + } + } + _ => Transformed::no(expr), + } + } + // if NULL { a } else { b } -> NULL + Expr::IfElse { predicate, .. } if is_null(predicate) => Transformed::yes(null_lit()), + // NULL IN (...) -> NULL + Expr::IsIn(e, _) if is_null(e) => Transformed::yes(null_lit()), + _ => Transformed::no(expr), + }) +} diff --git a/src/daft-algebra/src/simplify/numeric.rs b/src/daft-algebra/src/simplify/numeric.rs new file mode 100644 index 0000000000..265d0caa80 --- /dev/null +++ b/src/daft-algebra/src/simplify/numeric.rs @@ -0,0 +1,112 @@ +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::schema::SchemaRef; + +static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int8(1)) + | Expr::Literal(LiteralValue::UInt8(1)) + | Expr::Literal(LiteralValue::Int16(1)) + | Expr::Literal(LiteralValue::UInt16(1)) + | Expr::Literal(LiteralValue::Int32(1)) + | Expr::Literal(LiteralValue::UInt32(1)) + | Expr::Literal(LiteralValue::Int64(1)) + | Expr::Literal(LiteralValue::UInt64(1)) + | Expr::Literal(LiteralValue::Float64(1.)) => true, + + Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { + *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) + } + _ => false, + } +} + +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int8(0)) + | Expr::Literal(LiteralValue::UInt8(0)) + | Expr::Literal(LiteralValue::Int16(0)) + | Expr::Literal(LiteralValue::UInt16(0)) + | Expr::Literal(LiteralValue::Int32(0)) + | Expr::Literal(LiteralValue::UInt32(0)) + | Expr::Literal(LiteralValue::Int64(0)) + | Expr::Literal(LiteralValue::UInt64(0)) + | Expr::Literal(LiteralValue::Float64(0.)) => true, + Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, + _ => false, + } +} + +// simplify expressions with numeric operators +pub(crate) fn simplify_numeric_expr( + expr: ExprRef, + _schema: &SchemaRef, +) -> DaftResult> { + Ok(match expr.as_ref() { + Expr::BinaryOp { op, left, right } => { + match op { + // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) + // A * 0 --> 0 (if A is not null and not floating/decimal) + // 0 * A --> 0 (if A is not null and not floating/decimal) + + // A * 1 -> A + Operator::Multiply if is_one(left) => Transformed::yes(right.clone()), + Operator::Multiply if is_one(right) => Transformed::yes(left.clone()), + // A / 1 -> A + Operator::TrueDivide if is_one(left) => Transformed::yes(right.clone()), + Operator::TrueDivide if is_one(right) => Transformed::yes(left.clone()), + // A + 0 -> A + Operator::Plus if is_zero(left) => Transformed::yes(right.clone()), + Operator::Plus if is_zero(right) => Transformed::yes(left.clone()), + // A - 0 -> A + Operator::Minus if is_zero(left) => Transformed::yes(right.clone()), + Operator::Minus if is_zero(right) => Transformed::yes(left.clone()), + + _ => Transformed::no(expr), + } + } + _ => Transformed::no(expr), + }) +} diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index d1fea88019..f21a8ef962 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -104,7 +104,7 @@ impl Default for OptimizerBuilder { // we want to simplify expressions first to make the rest of the rules easier RuleBatch::new( vec![Box::new(SimplifyExpressionsRule::new())], - RuleExecutionStrategy::FixedPoint(Some(3)), + RuleExecutionStrategy::FixedPoint(None), ), // --- Filter out null join keys --- // This rule should be run once, before any filter pushdown rules. @@ -134,6 +134,11 @@ impl Default for OptimizerBuilder { vec![Box::new(PushDownLimit::new())], RuleExecutionStrategy::FixedPoint(Some(3)), ), + // --- Simplify expressions before scans are materialized --- + RuleBatch::new( + vec![Box::new(SimplifyExpressionsRule::new())], + RuleExecutionStrategy::FixedPoint(None), + ), // --- Materialize scan nodes --- RuleBatch::new( vec![Box::new(MaterializeScans::new())], @@ -166,7 +171,7 @@ impl OptimizerBuilder { // Try to simplify expressions again as other rules could introduce new exprs. self.rule_batches.push(RuleBatch::new( vec![Box::new(SimplifyExpressionsRule::new())], - RuleExecutionStrategy::FixedPoint(Some(3)), + RuleExecutionStrategy::FixedPoint(None), )); self } diff --git a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs index bb395ae428..286402adb0 100644 --- a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs +++ b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs @@ -35,12 +35,7 @@ impl OptimizerRule for SimplifyExpressionsRule { return Ok(Transformed::no(plan)); } - let schema = plan.schema(); - plan.transform(|plan| { - Ok(Arc::unwrap_or_clone(plan) - .map_expressions(|expr| simplify_expr(Arc::unwrap_or_clone(expr), &schema))? - .update_data(Arc::new)) - }) + plan.transform(|plan| plan.map_expressions(simplify_expr)) } } diff --git a/src/daft-logical-plan/src/treenode.rs b/src/daft-logical-plan/src/treenode.rs index a54ff46e29..d468d8b1c9 100644 --- a/src/daft-logical-plan/src/treenode.rs +++ b/src/daft-logical-plan/src/treenode.rs @@ -1,14 +1,15 @@ use std::sync::Arc; use common_error::DaftResult; -use common_treenode::{ - map_until_stop_and_collect, DynTreeNode, Transformed, TreeNodeIterator, TreeNodeRecursion, -}; +use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanState}; +use common_treenode::{DynTreeNode, Transformed, TreeNodeIterator}; +use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use crate::{ + ops::Source, partitioning::{HashRepartitionConfig, RepartitionSpec}, - LogicalPlan, + LogicalPlan, SourceInfo, }; impl DynTreeNode for LogicalPlan { @@ -37,39 +38,42 @@ impl DynTreeNode for LogicalPlan { } impl LogicalPlan { - pub fn map_expressions DaftResult>>( - self, + pub fn map_expressions DaftResult>>( + self: Arc, mut f: F, - ) -> DaftResult> { + ) -> DaftResult>> { use crate::ops::{ActorPoolProject, Explode, Filter, Join, Project, Repartition, Sort}; - Ok(match self { + Ok(match self.as_ref() { Self::Project(Project { input, projection, projected_schema, stats_state, }) => projection - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &input.schema()))? + .update_data(|new_projection| { Self::Project(Project { - input, - projection: expr, - projected_schema, - stats_state, + input: input.clone(), + projection: new_projection, + projected_schema: projected_schema.clone(), + stats_state: stats_state.clone(), }) + .into() }), Self::Filter(Filter { input, predicate, stats_state, - }) => f(predicate)?.update_data(|expr| { + }) => f(predicate.clone(), &input.schema())?.update_data(|expr| { Self::Filter(Filter { - input, + input: input.clone(), predicate: expr, - stats_state, + stats_state: stats_state.clone(), }) + .into() }), Self::Repartition(Repartition { input, @@ -77,38 +81,39 @@ impl LogicalPlan { stats_state, }) => match repartition_spec { RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => by - .into_iter() - .map_until_stop_and_collect(f)? + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &input.schema()))? .update_data(|expr| { - RepartitionSpec::Hash(HashRepartitionConfig { - num_partitions, - by: expr, + Self::Repartition(Repartition { + input: input.clone(), + repartition_spec: RepartitionSpec::Hash(HashRepartitionConfig { + num_partitions: *num_partitions, + by: expr, + }), + stats_state: stats_state.clone(), }) + .into() }), - repartition_spec => Transformed::no(repartition_spec), - } - .update_data(|repartition_spec| { - Self::Repartition(Repartition { - input, - repartition_spec, - stats_state, - }) - }), + _ => Transformed::no(self.clone()), + }, Self::ActorPoolProject(ActorPoolProject { input, projection, projected_schema, stats_state, }) => projection - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &input.schema()))? + .update_data(|new_projection| { Self::ActorPoolProject(ActorPoolProject { - input, - projection: expr, - projected_schema, - stats_state, + input: input.clone(), + projection: new_projection, + projected_schema: projected_schema.clone(), + stats_state: stats_state.clone(), }) + .into() }), Self::Sort(Sort { input, @@ -117,16 +122,18 @@ impl LogicalPlan { nulls_first, stats_state, }) => sort_by - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &input.schema()))? + .update_data(|new_sort_by| { Self::Sort(Sort { - input, - sort_by: expr, - descending, - nulls_first, - stats_state, + input: input.clone(), + sort_by: new_sort_by, + descending: descending.clone(), + nulls_first: nulls_first.clone(), + stats_state: stats_state.clone(), }) + .into() }), Self::Explode(Explode { input, @@ -134,15 +141,17 @@ impl LogicalPlan { exploded_schema, stats_state, }) => to_explode - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &input.schema()))? + .update_data(|new_to_explode| { Self::Explode(Explode { - input, - to_explode: expr, - exploded_schema, - stats_state, + input: input.clone(), + to_explode: new_to_explode, + exploded_schema: exploded_schema.clone(), + stats_state: stats_state.clone(), }) + .into() }), Self::Join(Join { left, @@ -152,39 +161,75 @@ impl LogicalPlan { null_equals_nulls, join_type, join_strategy, - .. + output_schema, + stats_state, }) => { - let o = left_on - .into_iter() - .zip(right_on) - .map_until_stop_and_collect(|(l, r)| { - map_until_stop_and_collect!(f(l), r, f(r)) - })?; - let (left_on, right_on) = o.data.into_iter().unzip(); + let new_left_on = left_on + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &left.schema()))?; + let new_right_on = right_on + .iter() + .cloned() + .map_and_collect(|expr| f(expr, &right.schema()))?; - if o.transformed { - Transformed::yes(Self::Join(Join::try_new( - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - )?)) + if new_left_on.transformed && new_right_on.transformed { + Transformed::yes( + Self::Join(Join { + left: left.clone(), + right: right.clone(), + left_on: new_left_on.data, + right_on: new_right_on.data, + null_equals_nulls: null_equals_nulls.clone(), + join_type: *join_type, + join_strategy: *join_strategy, + output_schema: output_schema.clone(), + stats_state: stats_state.clone(), + }) + .into(), + ) } else { - Transformed::no(Self::Join(Join::try_new( - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - )?)) + Transformed::no(self) } } - lp => Transformed::no(lp), + Self::Source(Source { + output_schema, + source_info, + stats_state, + }) => match source_info.as_ref() { + SourceInfo::Physical( + physical_scan_info @ PhysicalScanInfo { + pushdowns: + pushdowns @ Pushdowns { + filters: Some(filter), + .. + }, + scan_state: ScanState::Operator(scan_operator), + source_schema, + .. + }, + ) => { + let schema = if let Some(fields) = scan_operator.0.generated_fields() { + &Arc::new(source_schema.non_distinct_union(&fields)) + } else { + source_schema + }; + + f(filter.clone(), schema)?.update_data(|new_filter| { + Self::Source(Source { + output_schema: output_schema.clone(), + source_info: Arc::new(SourceInfo::Physical( + physical_scan_info + .with_pushdowns(pushdowns.with_filters(Some(new_filter))), + )), + stats_state: stats_state.clone(), + }) + .into() + }) + } + _ => Transformed::no(self), + }, + _ => Transformed::no(self), }) } }