Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Scan enhancement #538 #166

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ class ConstructClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extend
s"Int${width.width * 8}${signToStr(signed)}${fixedEndianToStr(endianOpt.get)}"
case FloatMultiType(width, endianOpt) =>
s"Float${width.width * 8}${fixedEndianToStr(endianOpt.get)}"
case BytesEosType(terminator, include, padRight, process) =>
case BytesEosType(terminator, include, padRight, process, scanEnd) =>
"GreedyBytes"
case blt: BytesLimitType =>
attrBytesLimitType(blt)
case btt: BytesTerminatedType =>
attrBytesTerminatedType(btt, "GreedyBytes")
case StrFromBytesType(bytes, encoding) =>
bytes match {
case BytesEosType(terminator, include, padRight, process) =>
case BytesEosType(terminator, include, padRight, process, scanEnd) =>
s"GreedyString(encoding='$encoding')"
case blt: BytesLimitType =>
attrBytesLimitType(blt, s"GreedyString(encoding='$encoding')")
Expand All @@ -156,7 +156,7 @@ class ConstructClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extend
case utb: UserTypeFromBytes =>
utb.bytes match {
//case BytesEosType(terminator, include, padRight, process) =>
case BytesLimitType(size, terminator, include, padRight, process) =>
case BytesLimitType(size, terminator, include, padRight, process, scanEnd) =>
s"FixedSized(${translator.translate(size)}, LazyBound(lambda: ${type2class(utb.classSpec.get)}))"
//case BytesTerminatedType(terminator, include, consume, eosError, process) =>
case _ => "???"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ object GraphvizClassCompiler extends LanguageCompilerStatic {
dataType match {
case rt: ReadableType => rt.apiCall(None) // FIXME
case ut: UserType => type2display(ut.name)
case FixedBytesType(contents, _) => contents.map(_.formatted("%02X")).mkString(" ")
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case FixedBytesType(contents, _, _) => contents.map(_.formatted("%02X")).mkString(" ")
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
val args = ListBuffer[String]()
if (terminator != 0)
args += s"term=$terminator"
Expand Down
28 changes: 20 additions & 8 deletions shared/src/main/scala/io/kaitai/struct/datatype/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,42 @@ object DataType {
def process: Option[ProcessExpr]
}

abstract class BytesType extends DataType with Processing
trait ScanEnd {
def scanEnd: Option[ScanExpr]
}

abstract class BytesType extends DataType with Processing with ScanEnd
case object CalcBytesType extends BytesType {
override def process = None
override def scanEnd = None
}
case class FixedBytesType(contents: Array[Byte], override val process: Option[ProcessExpr]) extends BytesType
case class FixedBytesType(contents: Array[Byte], override val process: Option[ProcessExpr], override val scanEnd: Option[ScanExpr] = None) extends BytesType
case class BytesEosType(
terminator: Option[Int],
include: Boolean,
padRight: Option[Int],
override val process: Option[ProcessExpr]
override val process: Option[ProcessExpr],
override val scanEnd: Option[ScanExpr] = None
) extends BytesType
case class BytesLimitType(
size: Ast.expr,
terminator: Option[Int],
include: Boolean,
padRight: Option[Int],
override val process: Option[ProcessExpr]
override val process: Option[ProcessExpr],
override val scanEnd: Option[ScanExpr] = None
) extends BytesType
case class BytesTerminatedType(
terminator: Int,
include: Boolean,
consume: Boolean,
eosError: Boolean,
override val process: Option[ProcessExpr]
override val process: Option[ProcessExpr],
override val scanEnd: Option[ScanExpr] = None
) extends BytesType
case class BytesScanEndType(
override val process: Option[ProcessExpr],
override val scanEnd: Option[ScanExpr] = None
) extends BytesType

abstract class StrType extends DataType
Expand Down Expand Up @@ -276,9 +288,9 @@ object DataType {
} else {
(arg.size, arg.sizeEos) match {
case (Some(sizeValue), false) =>
Map(SwitchType.ELSE_CONST -> BytesLimitType(sizeValue, None, false, None, arg.process))
Map(SwitchType.ELSE_CONST -> BytesLimitType(sizeValue, None, false, None, arg.process, arg.scanEnd))
case (None, true) =>
Map(SwitchType.ELSE_CONST -> BytesEosType(None, false, None, arg.process))
Map(SwitchType.ELSE_CONST -> BytesEosType(None, false, None, arg.process, arg.scanEnd))
case (None, false) =>
Map()
case (Some(_), true) =>
Expand All @@ -304,7 +316,7 @@ object DataType {
val r = dto match {
case None =>
arg.contents match {
case Some(c) => FixedBytesType(c, arg.process)
case Some(c) => FixedBytesType(c, arg.process, arg.scanEnd)
case _ => arg.getByteArrayType(path)
}
case Some(dt) => dt match {
Expand Down
13 changes: 9 additions & 4 deletions shared/src/main/scala/io/kaitai/struct/format/AttrSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ case class YamlAttrArgs(
contents: Option[Array[Byte]],
enumRef: Option[String],
parent: Option[Ast.expr],
process: Option[ProcessExpr]
process: Option[ProcessExpr],
scanEnd: Option[ScanExpr]
) {
def getByteArrayType(path: List[String]) = {
(size, sizeEos) match {
Expand All @@ -96,7 +97,8 @@ case class YamlAttrArgs(
case Some(term) =>
BytesTerminatedType(term, include, consume, eosError, process)
case None =>
throw new YAMLParseException("'size', 'size-eos' or 'terminator' must be specified", path)
BytesScanEndType(process, scanEnd)
// throw new YAMLParseException("'size', 'size-eos' or 'terminator' must be specified", path)
}
case (Some(_), true) =>
throw new YAMLParseException("only one of 'size' or 'size-eos' must be specified", path)
Expand All @@ -115,7 +117,8 @@ object AttrSpec {
"consume",
"include",
"eos-error",
"repeat"
"repeat",
"scan-end"
)

val LEGAL_KEYS_BYTES = Set(
Expand Down Expand Up @@ -165,6 +168,8 @@ object AttrSpec {
def fromYaml2(srcMap: Map[String, Any], path: List[String], metaDef: MetaSpec, id: Identifier): AttrSpec = {
val doc = DocSpec.fromYaml(srcMap, path)
val process = ProcessExpr.fromStr(ParseUtils.getOptValueStr(srcMap, "process", path), path)
val scanEnd = ScanExpr.fromStr(ParseUtils.getOptValueStr(srcMap, "scan-end", path), path)

// TODO: add proper path propagation
val contents = srcMap.get("contents").map(parseContentSpec(_, path ++ List("contents")))
val size = ParseUtils.getOptValueExpression(srcMap, "size", path)
Expand All @@ -184,7 +189,7 @@ object AttrSpec {
val yamlAttrArgs = YamlAttrArgs(
size, sizeEos,
encoding, terminator, include, consume, eosError, padRight,
contents, enum, parent, process
contents, enum, parent, process, scanEnd
)

// Unfortunately, this monstrous match can't rewritten in simpler way due to Java type erasure
Expand Down
35 changes: 35 additions & 0 deletions shared/src/main/scala/io/kaitai/struct/format/ScanExpr.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.kaitai.struct.format

import io.kaitai.struct.exprlang.{Ast, Expressions}

sealed trait ScanExpr

case class ScanCustom(name: List[String], args: Seq[Ast.expr]) extends ScanExpr

object ScanExpr {
private val ReCustom = "^([a-z][a-z0-9_.]*)\\(\\s*(.*?)\\s*\\)$".r
private val ReCustomNoArg = "^([a-z][a-z0-9_.]*)$".r

// This method is called in fromYalm() to parse the content.
def fromStr(s: Option[String], path: List[String]): Option[ScanExpr] = {
println(s, path)
s match {
case None =>
None
case Some(op) =>
try {
Some(op match {
case ReCustom(name, args) =>
ScanCustom(name.split('.').toList, Expressions.parseList(args))
case ReCustomNoArg(name) =>
ScanCustom(name.split('.').toList, Seq())
case _ =>
throw YAMLParseException.badProcess(op, path)
})
} catch {
case epe: Expressions.ParseException =>
throw YAMLParseException.expression(epe, path)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class CSharpCompiler(val typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit =
out.puts(s"${privateMemberName(attrName)} = $normalIO.EnsureFixedContents($contents);")

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}

override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -356,7 +360,7 @@ class CSharpCompiler(val typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io.ReadBytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io.ReadBytesFull()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io.ReadBytesTerm($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io.ReadBitsInt(1) != 0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ class CppCompiler(
override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit =
outSrc.puts(s"${privateMemberName(attrName)} = $normalIO->ensure_fixed_contents($contents);")

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}

override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -661,7 +665,7 @@ class CppCompiler(
s"$io->read_bytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io->read_bytes_full()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io->read_bytes_term($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io->read_bits_int(1)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
out.puts("}")
}

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}

override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -276,7 +280,7 @@ class GoCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io.ReadBytes(int(${expression(blt.size)}))"
case _: BytesEosType =>
s"$io.ReadBytesFull()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io.ReadBytesTerm($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io.ReadBitsInt(1)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ class JavaCompiler(val typeProvider: ClassTypeProvider, config: RuntimeConfig)
out.puts(s"${privateMemberName(attrName)} = $normalIO.ensureFixedContents($contents);")
}

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}

override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -429,7 +433,7 @@ class JavaCompiler(val typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io.readBytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io.readBytesFull()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io.readBytesTerm($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io.readBitsInt(1) != 0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class JavaScriptCompiler(val typeProvider: ClassTypeProvider, config: RuntimeCon
override def indent: String = " "
override def outFileName(topClassName: String): String = s"${type2class(topClassName)}.js"

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}

override def outImports(topClass: ClassSpec) = {
val impList = importList.toList
val quotedImpList = impList.map((x) => s"'$x'")
Expand Down Expand Up @@ -370,7 +374,7 @@ class JavaScriptCompiler(val typeProvider: ClassTypeProvider, config: RuntimeCon
s"$io.readBytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io.readBytesFull()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io.readBytesTerm($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io.readBitsInt(1) != 0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
out.dec
}

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}
override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -320,7 +323,7 @@ class LuaCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io:read_bytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io:read_bytes_full()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io:read_bytes_term($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io:read_bits_int(1)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
override def attrFixedContentsParse(attrName: Identifier, contents: String): Unit =
out.puts(s"${privateMemberName(attrName)} = $normalIO->ensureFixedContents($contents);")

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}
override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -318,7 +321,7 @@ class PHPCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io->readBytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io->readBytesFull()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io->readBytesTerm($terminator, $include, $consume, $eosError)"
case BitsType1 =>
s"$io->readBitsInt(1) != 0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ class PerlCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
out.puts(s"${privateMemberName(attrName)} = $normalIO->ensure_fixed_contents($contents);")
}

override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
None
}
override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -287,7 +290,7 @@ class PerlCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io->read_bytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io->read_bytes_full()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io->read_bytes_term($terminator, ${boolLiteral(include)}, ${boolLiteral(consume)}, ${boolLiteral(eosError)})"
case BitsType1 =>
s"$io->read_bits_int(1)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,33 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
out.dec
}

// This function is to generate code for scan-end
override def attrScanCustom(scanEnd: ScanExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)

scanEnd match {
case ScanCustom(name, args) =>
val scanClass = if (name.length == 1) {
val onlyName = name.head
val className = type2class(onlyName)
importList.add(s"from $onlyName import $className")
className
} else {
val pkgName = name.init.mkString(".")
importList.add(s"import $pkgName")
s"$pkgName.${type2class(name.last)}"
}

out.puts(s"pos1 = self._io.pos()")
out.puts(s"_scanner = $scanClass(self._io, ${args.map(expression).mkString(", ")})")
out.puts(s"_scanner.scan()")
out.puts(s"pos2 = self._io.pos()")
out.puts(s"self._io.seek(pos1)")
out.puts(s"$destName = self._io.read_bytes(pos2 - pos1)")
}
}

override def attrProcess(proc: ProcessExpr, varSrc: Identifier, varDest: Identifier): Unit = {
val srcName = privateMemberName(varSrc)
val destName = privateMemberName(varDest)
Expand Down Expand Up @@ -356,8 +383,10 @@ class PythonCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s"$io.read_bytes(${expression(blt.size)})"
case _: BytesEosType =>
s"$io.read_bytes_full()"
case BytesTerminatedType(terminator, include, consume, eosError, _) =>
case BytesTerminatedType(terminator, include, consume, eosError, _, _) =>
s"$io.read_bytes_term($terminator, ${bool2Py(include)}, ${bool2Py(consume)}, ${bool2Py(eosError)})"
case BytesScanEndType(_, scanEnd) =>
s"''"
case BitsType1 =>
s"$io.read_bits_int(1) != 0"
case BitsType(width: Int) =>
Expand Down
Loading