diff --git a/src/model/expressions.rs b/src/model/expressions.rs index 8b2ad50..8430147 100644 --- a/src/model/expressions.rs +++ b/src/model/expressions.rs @@ -209,37 +209,273 @@ impl OwningNode { let left = Self::validate(*left, ctx)?; let right = Self::validate(*right, ctx)?; + macro_rules! fold_constants { + ( + $l1:ident, $r1:ident, $r2:ident; + $l:ident, $r:ident; + $op1:ident, $op2:ident; + $res1:ident($expr1:expr); + $res2:ident($expr2:expr); + ) => { + match (&*$l1, &*$r1, &$r2) { + (_, Int($l), Int($r) ) => Binary { op: $res1, left: Box::new(Int ($expr1)), right: $l1 }, + (_, Float($l), Float($r)) => Binary { op: $res1, left: Box::new(Float($expr1)), right: $l1 }, + + (Int($l), _, Int($r) ) => Binary { op: $res2, left: Box::new(Int ($expr2)), right: $r1 }, + (Float($l), _, Float($r)) => Binary { op: $res2, left: Box::new(Float($expr2)), right: $r1 }, + + _ => Binary { + op: $op1, + left: Box::new(Binary { op: $op2, left: $l1, right: $r1 }), + right: Box::new($r2), + }, + } + }; + } + match (op, left, right) { + (Add, Str(l), Str(r)) => Str(l + &r), //TODO: Check types before simplification (Add, Str(l), r) if l.is_empty() => r, (Add, l, Str(r)) if r.is_empty() => l, + (Add, Int(l), Int(r)) => Int(l + r), + (Add, Int(l), Float(r)) => Float(BigDecimal::from(l) + r), + //TODO: Check types before simplification (Add, Int(l), r) if l.is_zero() => r, (Add, l, Int(r)) if r.is_zero() => l, - (Sub, Int(l), r) if l.is_zero() => Unary { op: UnaryOp::Neg, expr: Box::new(r) }, - (Sub, l, Int(r)) if r.is_zero() => l, + (Add, Float(l), Int(r)) => Float(l + BigDecimal::from(r)), + (Add, Float(l), Float(r)) => Float(l + r), + //TODO: Check types before simplification (Add, Float(l), r) if l.is_zero() => r, (Add, l, Float(r)) if r.is_zero() => l, + + //---------------------------------------------------------------------------------------- + (Sub, Int(l), Int(r)) => Int(l - r), + (Sub, Int(l), Float(r)) => Float(BigDecimal::from(l) - r), + //TODO: Check types before simplification + (Sub, Int(l), r) if l.is_zero() => Unary { op: UnaryOp::Neg, expr: Box::new(r) }, + (Sub, l, Int(r)) if r.is_zero() => l, + + (Sub, Float(l), Int(r)) => Float(l - BigDecimal::from(r)), + (Sub, Float(l), Float(r)) => Float(l - r), + //TODO: Check types before simplification (Sub, Float(l), r) if l.is_zero() => Unary { op: UnaryOp::Neg, expr: Box::new(r) }, (Sub, l, Float(r)) if r.is_zero() => l, //---------------------------------------------------------------------------------------- + (Mul, Int(l), Int(r)) => Int(l * r), + (Mul, Int(l), Float(r)) => Float(l * r), + //TODO: Check types before simplification (Mul, Int(l), _) if l.is_zero() => Int(0.into()),// 0 * x = 0 - (Mul, _, Int(r)) if r.is_zero() => Int(0.into()),// x * 0 = 0 - - (Mul, Float(l), _) if l.is_zero() => Int(0.into()),// 0 * x = 0 - (Mul, _, Float(r)) if r.is_zero() => Int(0.into()),// x * 0 = 0 - (Mul, Int(l), r) if l.is_one() => r,// 1 * x = x + (Mul, _, Int(r)) if r.is_zero() => Int(0.into()),// x * 0 = 0 (Mul, l, Int(r)) if r.is_one() => l,// x * 1 = x - (Div, l, Int(r)) if r.is_one() => l,// x / 1 = x + (Mul, Float(l), Int(r)) => Float(l * r), + (Mul, Float(l), Float(r)) => Float(l * r), + //TODO: Check types before simplification + (Mul, Float(l), _) if l.is_zero() => Float(0.into()),// 0 * x = 0 (Mul, Float(l), r) if l.is_one() => r,// 1 * x = x + (Mul, _, Float(r)) if r.is_zero() => Float(0.into()),// x * 0 = 0 (Mul, l, Float(r)) if r.is_one() => l,// x * 1 = x + + //---------------------------------------------------------------------------------------- + (Div, Float(l), Int(r)) => Float(l / BigDecimal::from(r)), + (Div, Float(l), Float(r)) => Float(l / r), + (Div, l, Int(r)) if r.is_one() => l,// x / 1 = x (Div, l, Float(r)) if r.is_one() => l,// x / 1 = x //======================================================================================== + // _ + L + R L + _ + R => SUM + _ + // + + + + // / \ / \ / \ + // + R OR + R SUM _ + // / \ / \ (L+R) + // _ L L _ + (Add, Binary { op: Add, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) { + (_, Str(l), Str(r) ) => Binary { op: Add, left: l1, right: Box::new(Str(l.to_owned() + r)) }, + + (_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (l + r)), right: l1 }, + (_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(l + r)), right: l1 }, + + (Int(l), _, Int(r) ) => Binary { op: Add, left: Box::new(Int (l + r)), right: r1 }, + (Float(l), _, Float(r)) => Binary { op: Add, left: Box::new(Float(l + r)), right: r1 }, + + _ => Binary { + op: Add, + left: Box::new(Binary { op: Add, left: l1, right: r1 }), + right: Box::new(r2), + }, + }, + (Mul, Binary { op: Mul, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Mul, Mul; + Mul(l * r); + Mul(l * r); + ), + //---------------------------------------------------------------------------------------- + // _ - L + R => SUB + _ L - _ + R => SUM - _ + // + + + - + // / \ / \ / \ / \ + // - R SUB _ OR - R SUM _ + // / \ (R-L) / \ (L+R) + // _ L L _ + (Add, Binary { op: Sub, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Add, Sub; + Add(r - l); + Sub(r + l); + ), + (Mul, Binary { op: Div, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Mul, Div; + Mul(r / l); + Div(r * l); + ), + /*(Add, Binary { op: Sub, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) { + (_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (r - l)), right: l1 }, + (_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(r - l)), right: l1 }, + + (Int(l), _, Int(r) ) => Binary { op: Sub, left: Box::new(Int (l + r)), right: r1 }, + (Float(l), _, Float(r)) => Binary { op: Sub, left: Box::new(Float(l + r)), right: r1 }, + + _ => Binary { + op: Add, + left: Box::new(Binary { op: Sub, left: l1, right: r1 }), + right: Box::new(r2), + }, + },*/ + //---------------------------------------------------------------------------------------- + // _ - L - R => SUM + _ L - _ - R => SUB - _ + // - + - - + // / \ / \ / \ / \ + // - R SUM _ OR - R SUB _ + // / \ (-L-R) / \ (L-R) + // _ L L _ + (Sub, Binary { op: Sub, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Sub, Sub; + Add(-l - r); + Sub( l - r); + ), + (Div, Binary { op: Div, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Div, Div; + Mul(1/(l * r)); + Div(l / r); + ), + /*(Sub, Binary { op: Sub, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) { + (_, Int(l), Int(r) ) => Binary { op: Sub, left: l1, right: Box::new(Int (l + r)) }, + (_, Float(l), Float(r)) => Binary { op: Sub, left: l1, right: Box::new(Float(l + r)) }, + + (Int(l), _, Int(r) ) => Binary { op: Sub, left: Box::new(Int (l - r)), right: r1 }, + (Float(l), _, Float(r)) => Binary { op: Sub, left: Box::new(Float(l - r)), right: r1 }, + + _ => Binary { + op: Sub, + left: Box::new(Binary { op: Sub, left: l1, right: r1 }), + right: Box::new(r2), + }, + },*/ + //---------------------------------------------------------------------------------------- + // _ + L - R L + _ - R => SUB + _ + // - - + + // / \ / \ / \ + // + R OR + R SUB _ + // / \ / \ (L-R) + // _ L L _ + (Sub, Binary { op: Add, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Sub, Add; + Add(l - r); + Add(l - r); + ), + (Div, Binary { op: Mul, left: l1, right: r1 }, r2) => fold_constants!( + l1, r1, r2; + l, r; + Div, Mul; + Mul(l / r); + Mul(l / r); + ), + /*(Sub, Binary { op: Add, left: l1, right: r1 }, r2) => match (&*l1, &*r1, &r2) { + (_, Int(l), Int(r) ) => Binary { op: Add, left: Box::new(Int (l - r)), right: l1 }, + (_, Float(l), Float(r)) => Binary { op: Add, left: Box::new(Float(l - r)), right: l1 }, + + (Int(l), _, Int(r) ) => Binary { op: Add, left: Box::new(Int (l - r)), right: r1 }, + (Float(l), _, Float(r)) => Binary { op: Add, left: Box::new(Float(l - r)), right: r1 }, + + _ => Binary { + op: Sub, + left: Box::new(Binary { op: Add, left: l1, right: r1 }), + right: Box::new(r2), + }, + },*/ + //======================================================================================== + (Le, Str(l), Str(r)) => Bool(l <= r), + (Ge, Str(l), Str(r)) => Bool(l >= r), + (Lt, Str(l), Str(r)) => Bool(l < r), + (Gt, Str(l), Str(r)) => Bool(l > r), + + //---------------------------------------------------------------------------------------- + (Div, Int(l), Int(r)) => Int(l / r), + // Rust `%` uses modulo operation (negative result for negative `l`), but Kaitai Struct + // uses remainder operation (always positive result): https://doc.kaitai.io/user_guide.html#_operators + // (Rem, Int(l), Int(r)) => Int(l.rem_euclid(r)), + + // (Shl, Int(l), Int(r)) => Int(l << r), + // (Shr, Int(l), Int(r)) => Int(l >> r), + + (Le, Int(l), Int(r)) => Bool(l <= r), + (Ge, Int(l), Int(r)) => Bool(l >= r), + (Lt, Int(l), Int(r)) => Bool(l < r), + (Gt, Int(l), Int(r)) => Bool(l > r), + + (BitAnd, Int(l), Int(r)) => Int(l & r), + (BitOr, Int(l), Int(r)) => Int(l | r), + (BitXor, Int(l), Int(r)) => Int(l ^ r), + //---------------------------------------------------------------------------------------- + + (Le, Float(l), Float(r)) => Bool(l <= r), + (Ge, Float(l), Float(r)) => Bool(l >= r), + (Lt, Float(l), Float(r)) => Bool(l < r), + (Gt, Float(l), Float(r)) => Bool(l > r), + //---------------------------------------------------------------------------------------- + + (And, Bool(l), Bool(r)) => Bool(l && r), + (Or, Bool(l), Bool(r)) => Bool(l || r), + + //---------------------------------------------------------------------------------------- + // If two primitives are not equal after normalization, + // then the result is known at compile-time + (Eq, Str(l), Str(r)) => Bool(l == r), + (Eq, Int(l), Int(r)) => Bool(l == r), + (Eq, Int(l), Float(r)) => Bool(r == l.into()), + (Eq, Float(l), Int(r)) => Bool(l == r.into()), + (Eq, Float(l), Float(r)) => Bool(l == r), + (Eq, Bool(l), Bool(r)) => Bool(l == r), + // Symbolic calculations: if two subtrees are equal after normalization, + // then the result is known at compile-time + (Eq, l, r) if l == r => Bool(true), + + //---------------------------------------------------------------------------------------- + // If two primitives are not equal after normalization, + // then the result is known at compile-time + (Ne, Str(l), Str(r)) => Bool(l != r), + (Ne, Int(l), Int(r)) => Bool(l != r), + (Ne, Int(l), Float(r)) => Bool(r != l.into()), + (Ne, Float(l), Int(r)) => Bool(l != r.into()), + (Ne, Float(l), Float(r)) => Bool(l != r), + (Ne, Bool(l), Bool(r)) => Bool(l != r), + + //---------------------------------------------------------------------------------------- (_, l, r) => Binary { op, left: Box::new(l), @@ -1389,6 +1625,99 @@ mod evaluation { } } } + + /// Checks that folding constants in triplets behaves correctly + mod triplets { + use super::*; + use pretty_assertions::assert_eq; + use BinaryOp::Add; + + /// Tests folding of the integral numbers' constants + #[test] + fn int() { + assert_eq!(parse("x + 1 + 2").unwrap(), Binary { + op: Add, + left: Box::new(Int(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + assert_eq!(parse("1 + x + 2").unwrap(), Binary { + op: Add, + left: Box::new(Int(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + assert_eq!(parse("1 + 2 + x").unwrap(), Binary { + op: Add, + left: Box::new(Int(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + + assert_eq!(parse("1 + 2 + 3 + x + 4 + 5 + 6").unwrap(), Binary { + op: Add, + left: Box::new(Int(21.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + } + + /// Tests folding of the floating-points' constants + #[test] + fn float() { + assert_eq!(parse("x + 1.0 + 2.0").unwrap(), Binary { + op: Add, + left: Box::new(Float(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + assert_eq!(parse("1.0 + x + 2.0").unwrap(), Binary { + op: Add, + left: Box::new(Float(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + assert_eq!(parse("1.0 + 2.0 + x").unwrap(), Binary { + op: Add, + left: Box::new(Float(3.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + + assert_eq!(parse("1.0 + 2.0 + 3.0 + x + 4.0 + 5.0 + 6.0").unwrap(), Binary { + op: Add, + left: Box::new(Float(21.into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + } + + /// Tests folding of the string constants + #[test] + fn str() { + assert_eq!(parse("x + 'a' + 'b'").unwrap(), Binary { + op: Add, + left: Box::new(Attr(FieldName::valid("x").into())), + right: Box::new(Str("ab".into())), + }); + assert_eq!(parse("'a' + x + 'b'").unwrap(), Binary { + op: Add, + left: Box::new(Binary { + op: Add, + left: Box::new(Str("a".into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }), + right: Box::new(Str("b".into())), + }); + assert_eq!(parse("'a' + 'b' + x").unwrap(), Binary { + op: Add, + left: Box::new(Str("ab".into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }); + + assert_eq!(parse("'a' + 'b' + 'c' + x + 'd' + 'e' + 'f'").unwrap(), Binary { + op: Add, + left: Box::new(Binary { + op: Add, + left: Box::new(Str("abc".into())), + right: Box::new(Attr(FieldName::valid("x").into())), + }), + right: Box::new(Str("def".into())), + }); + } + } } #[test] @@ -1401,4 +1730,22 @@ mod evaluation { if_false: Box::new(Attr(FieldName::valid("b").into())), })); } + + #[test] + #[ignore] + fn index() {//TODO: Index validation + assert_eq!(parse("[3, 1, 4][-1]"), Ok(Int(4.into()))); + assert_eq!(parse("[3, 1, 4][ 2]"), Ok(Int(4.into()))); + assert_eq!(parse("[3, 1, x][ 2]"), Ok(Attr(FieldName::valid("x").into()))); + assert_eq!(parse("[3, 1, 4][ x]"), Ok(Index { + expr: Box::new(List(vec![ + Int(3.into()), + Int(1.into()), + Int(4.into()), + ])), + index: Box::new(Attr(FieldName::valid("x").into())), + })); + + assert_eq!(parse("[3, 1, 4][ 3]"), Err(Validation("".into()))); + } } diff --git a/src/parser/expressions.rs b/src/parser/expressions.rs index ec90c1f..cf94c75 100644 --- a/src/parser/expressions.rs +++ b/src/parser/expressions.rs @@ -394,7 +394,14 @@ pub enum BinaryOp { Mul, /// `/`: Division of two numeric arguments. Div, - /// `%`: Remainder of division of two numeric arguments. + /// `%`: Nonnegative remainder of division of two numeric arguments. + /// + /// This operation is different from Rust [`%`] operator: [`-5 % 3` is `1`, not `-2`][operators]. + /// Analogous in Rust for this operation is [`rem_euclid`]. + /// + /// [`%`]: std::ops::Rem + /// [`rem_euclid`]: i64::rem_euclid + /// [operators]: https://doc.kaitai.io/user_guide.html#_operators Rem, /// `<<`: The left shift operator.