Skip to content

Commit

Permalink
Create new AST node Group in the expression language, that represents…
Browse files Browse the repository at this point in the history
… value in parenthesis

This change fixes https://github.com/kaitai-io/kaitai_struct_tests/blob/master/formats/expr_ops_parens.ksy test
for Java and, I'm sure, for all other targets
  • Loading branch information
Mingun committed Mar 9, 2024
1 parent d071323 commit 7193b62
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 42 deletions.
8 changes: 8 additions & 0 deletions jvm/src/test/scala/io/kaitai/struct/exprlang/Ast$Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("42 - 2").evaluateIntConst should be(Some(40))
}

it ("considers `(42 - 2)` constant") {
Expressions.parse("(42 - 2)").evaluateIntConst should be(Some(40))
}

it ("considers `(-3 + 7) * 8 / 2` constant") {
Expressions.parse("(-3 + 7) * 8 / 2").evaluateIntConst should be(Some(16))
}
Expand All @@ -33,6 +37,10 @@ class Ast$Test extends AnyFunSpec {
Expressions.parse("4 > 2 ? 1 : 5").evaluateIntConst should be(Some(1))
}

it ("considers `((4) > 2) ? ((1)) : 5` constant") {
Expressions.parse("((4) > 2) ? ((1)) : 5").evaluateIntConst should be(Some(1))
}

it ("considers `x` variable") {
Expressions.parse("x").evaluateIntConst should be(None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ class ExpressionsSpec extends AnyFunSpec {
)
}

it("parses group") {
Expressions.parse("(123)") should be (Group(IntNum(123)))
Expressions.parse("(foo)") should be (Group(Name(identifier("foo"))))
}

it("parses (1 + 2) / (7 * 8)") {
Expressions.parse("(1 + 2) / (7 * 8)") should be (
BinOp(
BinOp(IntNum(1), Add, IntNum(2)),
Group(BinOp(IntNum(1), Add, IntNum(2))),
Div,
BinOp(IntNum(7), Mult, IntNum(8))
Group(BinOp(IntNum(7), Mult, IntNum(8)))
)
)
}
Expand Down Expand Up @@ -128,7 +133,7 @@ class ExpressionsSpec extends AnyFunSpec {
}

it("parses ~(7+3)") {
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, BinOp(IntNum(7), Add, IntNum(3))))
Expressions.parse("~(7+3)") should be (UnaryOp(Invert, Group(BinOp(IntNum(7), Add, IntNum(3)))))
}

// Enums
Expand Down Expand Up @@ -285,7 +290,7 @@ class ExpressionsSpec extends AnyFunSpec {

it("parses (123).as<u4>") {
Expressions.parse("(123).as<u4>") should be (
CastToType(IntNum(123), typeId(false, Seq("u4")))
CastToType(Group(IntNum(123)), typeId(false, Seq("u4")))
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AttrSpec$Test extends AnyFunSpec {
spec.id should be(NamedIdentifier("foo"))
val dataType = spec.dataType.asInstanceOf[UserType]
dataType.name should be(List("bar"))
dataType.args should be(Seq(Ast.expr.IntNum(5)))
dataType.args should be(Seq(Ast.expr.Group(Ast.expr.IntNum(5))))
}
}
}
5 changes: 4 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object Ast {
}
}

object expr{
object expr {
case class BoolOp(op: boolop, values: Seq[expr]) extends expr
case class BinOp(left: expr, op: operator, right: expr) extends expr
case class UnaryOp(op: unaryop, operand: expr) extends expr
Expand Down Expand Up @@ -91,6 +91,9 @@ object Ast {
case class List(elts: Seq[expr]) extends expr
case class InterpolatedStr(elts: Seq[expr]) extends expr

/** Wraps `expr` into parentheses. Multiple nested groups merge into the one group */
case class Group(expr: expr) extends expr

/**
* Implicit declaration of ordering, so expressions can be used for ordering operations, e.g.
* for `SortedMap.from(...)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ object ConstEvaluator {
case _ => value.NonConst
}

case expr.Group(expr) => evaluate(expr)
case _ => value.NonConst
}

Expand Down
69 changes: 35 additions & 34 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ object Expressions {
def formatExpr[$: P]: P[Ast.expr] = P("{" ~/ test ~ "}")

def test[$: P]: P[Ast.expr] = P( or_test ~ ("?" ~ test ~ ":" ~ test).? ).map {
case (x, None) => x
case (condition, Some((ifTrue, ifFalse))) => Ast.expr.IfExp(condition, ifTrue, ifFalse)
}
case (x, None) => x
case (condition, Some((ifTrue, ifFalse))) => Ast.expr.IfExp(condition, ifTrue, ifFalse)
}
def or_test[$: P] = P( and_test.rep(1, kw("or")) ).map {
case Seq(x) => x
case xs => Ast.expr.BoolOp(Ast.boolop.Or, xs)
Expand Down Expand Up @@ -106,46 +106,47 @@ object Expressions {
power
)
// def power[_: P]: P[Ast.expr] = P( atom ~ trailer.rep ~ (Pow ~ factor).? ).map {
// case (lhs, trailers, rhs) =>
// val left = trailers.foldLeft(lhs)((l, t) => t(l))
// rhs match{
// case None => left
// case Some((op, right)) => Ast.expr.BinOp(left, op, right)
// }
// }
// case (lhs, trailers, rhs) =>
// val left = trailers.foldLeft(lhs)((l, t) => t(l))
// rhs match{
// case None => left
// case Some((op, right)) => Ast.expr.BinOp(left, op, right)
// }
// }
def power[$: P]: P[Ast.expr] = P( atom ~ trailer.rep ).map {
case (lhs, trailers) =>
trailers.foldLeft(lhs)((l, t) => t(l))
}
def group[$: P]: P[Ast.expr] = ("(" ~ test ~ ")").map(expr => Ast.expr.Group(expr))
def empty_list[$: P] = P("[" ~ "]").map(_ => Ast.expr.List(Nil))
// def empty_dict[_: P] = P("{" ~ "}").map(_ => Ast.expr.Dict(Nil, Nil))
def atom[$: P]: P[Ast.expr] = P(
empty_list |
// empty_dict |
"(" ~ test ~ ")" |
"[" ~ list ~ "]" |
// "{" ~ dictorsetmaker ~ "}" |
enumByName |
byteSizeOfType |
bitSizeOfType |
empty_list |
// empty_dict |
group |
"[" ~ list ~ "]" |
// "{" ~ dictorsetmaker ~ "}" |
enumByName |
byteSizeOfType |
bitSizeOfType |
fstring |
STRING.rep(1).map(_.mkString).map(Ast.expr.Str) |
NAME.map((x) => x.name match {
case "true" => Ast.expr.Bool(true)
case "false" => Ast.expr.Bool(false)
case _ => Ast.expr.Name(x)
}) |
FLOAT_NUMBER.map(Ast.expr.FloatNum) |
INT_NUMBER.map(Ast.expr.IntNum)
)
STRING.rep(1).map(_.mkString).map(Ast.expr.Str) |
NAME.map((x) => x.name match {
case "true" => Ast.expr.Bool(true)
case "false" => Ast.expr.Bool(false)
case _ => Ast.expr.Name(x)
}) |
FLOAT_NUMBER.map(Ast.expr.FloatNum) |
INT_NUMBER.map(Ast.expr.IntNum)
)
def list_contents[$: P] = P( test.rep(1, ",") ~ ",".? )
def list[$: P] = P( list_contents ).map(Ast.expr.List(_))

def call[$: P] = P("(" ~ arglist ~ ")").map { case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)}
def slice[$: P] = P("[" ~ test ~ "]").map { case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)}
def cast[$: P] = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map(
typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName)
)
typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName)
)
def attr[$: P] = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id))
def trailer[$: P]: P[Ast.expr => Ast.expr] = P( call | slice | cast | attr )

Expand All @@ -154,11 +155,11 @@ object Expressions {

// def dict_item[_: P] = P( test ~ ":" ~ test )
// def dict[_: P]: P[Ast.expr.Dict] = P(
// (dict_item.rep(1, ",") ~ ",".?).map { x =>
// val (keys, values) = x.unzip
// Ast.expr.Dict(keys, values)
// }
// )
// (dict_item.rep(1, ",") ~ ",".?).map{x =>
// val (keys, values) = x.unzip
// Ast.expr.Dict(keys, values)
// }
// )
// def dictorsetmaker[_: P]: P[Ast.expr] = P( /*dict_comp |*/ dict /*| set_comp | set*/)

def arglist[$: P]: P[Seq[Ast.expr]] = P( (test).rep(0, ",") )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ abstract class BaseTranslator(val provider: TypeProvider)
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class ExpressionValidator(val provider: TypeProvider)
case Ast.expr.Subscript(container: Ast.expr, idx: Ast.expr) =>
detectType(container) match {
case _: ArrayType | _: BytesType =>
validate(container)
validate(container)
detectType(idx) match {
case _: IntType =>
validate(idx)
validate(idx)
case indexType =>
throw new TypeMismatchError(s"subscript operation on arrays require index to be integer, but found $indexType")
}
Expand All @@ -87,6 +87,8 @@ class ExpressionValidator(val provider: TypeProvider)
CommonSizeOf.getBitsSizeOfType(typeName.nameAsStr, detectCastType(typeName))
case Ast.expr.InterpolatedStr(elts: Seq[Ast.expr]) =>
elts.foreach(validate)
case Ast.expr.Group(nested) =>
validate(nested)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class GoTranslator(out: StringLanguageOutputWriter, provider: TypeProvider, impo
doByteSizeOfType(typeName)
case Ast.expr.BitSizeOfType(typeName) =>
doBitSizeOfType(typeName)
case Ast.expr.Group(nested) =>
ResultString(nested match {
// Unpack nested groups
case Ast.expr.Group(e) => translate(e)
case e => s"(${translate(e)})"
})
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class TypeDetector(provider: TypeProvider) {
detectCastType(typeName)
case Ast.expr.ByteSizeOfType(_) | Ast.expr.BitSizeOfType(_) =>
CalcIntType
case Ast.expr.Group(expr) =>
detectType(expr)
}
}

Expand Down Expand Up @@ -244,6 +246,12 @@ class TypeDetector(provider: TypeProvider) {
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Group(nested) => detectCallTypeImpl(nested)
case func => detectCallTypeImpl(func)
}
}
def detectCallTypeImpl(func: Ast.expr): DataType = {
func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
Expand Down

0 comments on commit 7193b62

Please sign in to comment.