Skip to content

Commit

Permalink
Create new AST node Group in the expression language, that represents…
Browse files Browse the repository at this point in the history
… value in parenthesis

This change fixes https://github.com/kaitai-io/kaitai_struct_tests/blob/master/formats/expr_ops_parens.ksy test
for Java and, I'm sure, for all other targets
  • Loading branch information
Mingun committed Mar 9, 2024
1 parent 6a6083d commit 5ad8e61
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 7 deletions.
8 changes: 8 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("42 - 2").evaluateIntConst should be(Some(40))
}

it ("considers `(42 - 2)` constant") {
Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40))
}

it ("considers `(-3 + 7) * 8 / 2` constant") {
Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16))
}
Expand All @@ -33,6 +37,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1))
}

it ("considers `((4) > 2) ? ((1)) : 5` constant") {
Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1))
}

it ("considers `x` variable") {
Expressions.parse("x").evaluateIntConst should be(None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ class ExpressionsSpec extends AnyFunSpec {
)
}

it("parses group") {
Expressions.parse("(123)") should be (Group(IntNum(123)))
Expressions.parse("(foo)") should be (Group(Name(identifier("foo"))))
}

it("parses (1 + 2) / (7 * 8)") {
Expressions.parse("(1 + 2) / (7 * 8)") should be (
BinOp(
BinOp(IntNum(1), Add, IntNum(2)),
Group(BinOp(IntNum(1), Add, IntNum(2))),
Div,
BinOp(IntNum(7), Mult, IntNum(8))
Group(BinOp(IntNum(7), Mult, IntNum(8)))
)
)
}
Expand Down Expand Up @@ -128,7 +133,7 @@ class ExpressionsSpec extends AnyFunSpec {
}

it("parses ~(7+3)") {
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3))))
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3)))))
}

// Enums
Expand Down Expand Up @@ -285,7 +290,7 @@ class ExpressionsSpec extends AnyFunSpec {

it("parses (123).as<u4>") {
Expressions.parse("(123).as<u4>") should be (
CastToType(IntNum(123), typeId(false, Seq("u4")))
CastToType(Group(IntNum(123)), typeId(false, Seq("u4")))
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AttrSpec$Test extends AnyFunSpec {
spec.id should be(NamedIdentifier("foo"))
val dataType = spec.dataType.asInstanceOf[UserType]
dataType.name should be(List("bar"))
dataType.args should be(Seq(Ast.expr.IntNum(5)))
dataType.args should be(Seq(Ast.expr.Group(Ast.expr.IntNum(5))))
}
}
}
5 changes: 4 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object Ast {
}
}

object expr{
object expr {
case class BoolOp(op: boolop, values: Seq[expr]) extends expr
case class BinOp(left: expr, op: operator, right: expr) extends expr
case class UnaryOp(op: unaryop, operand: expr) extends expr
Expand Down Expand Up @@ -91,6 +91,9 @@ object Ast {
case class List(elts: Seq[expr]) extends expr
case class InterpolatedStr(elts: Seq[expr]) extends expr

/** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */
case class Group(expr: expr) extends expr

/**
* Implicit declaration of ordering, so expressions can be used for ordering operations, e.g.
* for `SortedMap.from(...)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ object ConstEvaluator {
case _ => value.NonConst
}

case expr.Group(expr) => evaluate(expr)
case _ => value.NonConst
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ object Expressions {
case (lhs, trailers) =>
trailers.foldLeft(lhs)((l, t) => t(l))
}
def group[$: P]: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr))
def empty_list[$: P] = P("[" ~ "]").map(_ => Ast.expr.List(Nil))
// def empty_dict[_: P] = P("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil))
def atom[$: P]: P[Ast.expr] = P(
empty_list |
// empty_dict |
"(" ~ test ~ ")" |
group |
"[" ~ list ~ "]" |
// "{" ~ dictorsetmaker ~ "}" |
enumByName |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ abstract class BaseTranslator(val provider: TypeProvider)
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class ExpressionValidator(val provider: TypeProvider)
CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName))
case Ast.expr.InterpolatedStr(elts: Seq[Ast.expr]) =>
elts.foreach(validate)
case Ast.expr.Group(nested) =>
validate(nested)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
ResultString(nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class TypeDetector(provider: TypeProvider) {
detectCastType(typeName)
case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) =>
CalcIntType
case Ast.expr.Group(expr) =>
detectType(expr)
}
}

Expand Down Expand Up @@ -244,6 +246,12 @@ class TypeDetector(provider: TypeProvider) {
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Group(nested) => detectCallTypeImpl(nested)
case func => detectCallTypeImpl(func)
}
}
def detectCallTypeImpl(func: Ast.expr): DataType = {
func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
Expand Down

0 comments on commit 5ad8e61

Please sign in to comment.