Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler: Construct-Dataclasses #256

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
package io.kaitai.struct

import io.kaitai.struct.ConstructDataclassCompiler.convertTypeToClass
import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.datatype._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.format._
import io.kaitai.struct.languages.components.{LanguageCompiler, LanguageCompilerStatic, PythonOps}
import io.kaitai.struct.translators.ConstructTranslator

import scala.collection.mutable.ListBuffer

class ConstructDataclassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec)
extends AbstractCompiler {

val out = new StringLanguageOutputWriter(indent)
val importList = new ImportList
val provider = new ClassTypeProvider(classSpecs, topClass)
val translator = new ConstructTranslator(provider, importList)

private val classList = new ListBuffer[String]()

override def compile: CompileLog.SpecSuccess = {
importList.add("# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild")
importList.add("import typing as t")
importList.add("import dataclasses as dc")

// compile class
this.compileClass(topClass)
out.puts(s"_schema = ${convertTypeToClass(topClass.name)}")

// Support for wildcard import of generated python files. By adding __all__ we
// prevent importing all symbols of construct and construct_dataclasses.
out.puts(s"__all__ = [")
out.inc
out.putsLines("", classList.mkString(",\n"))
out.dec
out.puts("]")

// From ... imports should be placed at the end
importList.add("\nfrom construct import *")
importList.add("from construct.lib import *")
importList.add("from construct_dataclasses import *")
CompileLog.SpecSuccess(
"",
List(CompileLog.FileSuccess(
outFileName(topClass.nameAsStr),
importList.toList.mkString("\n") + "\n\n" + out.result + "\n"
))
)
}

private def compileClass(cs: ClassSpec): Unit = {
// 1. Step: compile all present enum definitions
cs.enums.foreach({
case (_, spec) => {
this.compileEnum(spec)
}
})


// 2. Step: compile all other structs
cs.types.foreach({
case (_, spec) => this.compileClass(spec)
})

// See https://github.com/MatrixEditor/construct-dataclasses/pull/3
out.puts("@container")

// REVISIT: what if we have to process bitwise? Where do we specify to process bitwise?
out.puts("@dc.dataclass")
out.puts(s"class ${convertTypeToClass(cs.name)}_t:")
out.inc
val docStr = PythonOps.compileUniversalDocs(cs.doc)
if (docStr.nonEmpty) {
out.putsLines("", "\"\"\"" + docStr + "\"\"\"")
}

classList.append(s"'${convertTypeToClass(cs.name)}'") // parser
classList.append(s"'${convertTypeToClass(cs.name)}_t'") // dataclass
provider.nowClass = cs
cs.seq.foreach({
seqAttribute => compileAttribute(seqAttribute)
})
cs.instances.foreach({
case (identifier, spec) => spec match {
case vis: ValueInstanceSpec => compileComputedField(identifier, vis)
case pis: ParseInstanceSpec => compilePointerField(pis)
}
})

out.dec
out.puts("\n")
out.puts(s"${convertTypeToClass(cs.name)} = DataclassStruct(${convertTypeToClass(cs.name)}_t)\n")
}

private def compileEnum(enumSpec: EnumSpec): Unit = {
importList.add("import enum")
out.puts(s"class ${convertTypeToClass(enumSpec.name)}(enum.IntEnum):")
out.inc
classList.append(s"'${convertTypeToClass(enumSpec.name)}'")

enumSpec.sortedSeq.foreach({
case (id, spec) => out.puts(s"${spec.name} = ${translator.doIntLiteral(id)}")
})
out.dec
out.puts
}

private def compileAttribute(attributeSpec: AttrLikeSpec): Unit = {
val name: String = getName(attributeSpec.id)
val type_hint: String = getAttributeTypeHint(attributeSpec)
val fieldType: String = getDataclassFieldType(attributeSpec.dataType)
val body: String = compileAttributeBody(attributeSpec)
if (type_hint != null) {
out.puts(s"$name: ${correctListTypeHint(type_hint, attributeSpec)} = $fieldType$body)")
} else {
out.puts(s"$name = $fieldType$body)")
}

val docStr = PythonOps.compileUniversalDocs(attributeSpec.doc)
if (docStr.nonEmpty) {
out.putsLines("", "\"\"\"" + docStr + "\"\"\"")
}
}

private def indent: String = " ".repeat(4)

private def outFileName(topClassName: String): String = s"${topClassName}.py"

private def getName(identifier: Identifier): String = {
identifier match {
case SpecialIdentifier(name) => name
case InstanceIdentifier(name) => name
case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx"
case NamedIdentifier(name) => name
}
}

private def withConditional(repeat: String, ifExpr: Option[Ast.expr]): String = {
ifExpr match {
case Some(value) => s"If(${translator.translate(value)}, $repeat)"
case None => repeat
}
}

private def compileAttributeBody(spec: AttrLikeSpec): String = {
val defaultTypeRepr: String = toString(spec.dataType)
val repeat: String = withRepeat(defaultTypeRepr, spec)
withConditional(repeat, spec.cond.ifExpr)
}

private def withRepeat(repr: String, spec: AttrLikeSpec): String = {
spec.cond.repeat match {
case RepeatExpr(expr) => s"Array(${translator.translate(expr)}, $repr)"
case RepeatUntil(expr) =>
provider._currentIteratorType = Some(spec.dataType)
s"RepeatUntil(lambda obj_, list_, this: ${translator.translate(expr)}, $repr)"
case RepeatEos => s"GreedyRange($repr)"
case NoRepeat => repr
}
}

private def toString(dataType: DataType): String = dataType match {
case Int1Type(signed) => s"Int8${correctSignedType(signed)}b"
case IntMultiType(signed, width, endian) =>
s"Int${width.width * 8}${correctSignedType(signed)}${endianToSting(endian.get)}"
case FloatMultiType(width, endian) =>
s"Float${width.width * 8}${endianToSting(endian.get)}"
case BytesEosType(_, _, _, _) => "GreedyBytes"

case StrFromBytesType(bytes, encoding) =>
bytes match {
case BytesEosType(_, _, _, _) => s"GreedyString(encoding='$encoding')"
case blt: BytesLimitType =>
if (blt.terminator.isEmpty && blt.padRight.isEmpty) {
return s"PaddedString(${translator.translate(blt.size)}, encoding='$encoding')"
}
createBytesLimitTypeConstruct(blt, s"GreedyString(encoding='$encoding')")
case btt: BytesTerminatedType =>
createBytesTerminatedTypeConstruct(btt, s"GreedyString(encoding='$encoding')")
}
case blt: BytesLimitType =>
if (blt.terminator.isEmpty && blt.padRight.isEmpty) {
return s"Bytes(${translator.translate(blt.size)})"
}
createBytesLimitTypeConstruct(blt)
case btt: BytesTerminatedType =>
createBytesTerminatedTypeConstruct(btt, "GreedyBytes")
case ut: UserTypeInstream =>
s"LazyBound(lambda: ${convertTypeToClass(ut.classSpec.get.name)})"
case utb: UserTypeFromBytes =>
utb.bytes match {
case BytesLimitType(size, _, _, _, _) =>
s"FixedSized(${translator.translate(size)}, LazyBound(lambda: ${convertTypeToClass(utb.classSpec.get.name)}))"
case _ => "???"
}
case BitsType1(endianness) =>
val swapped = endianness match {
case LittleBitEndian => "True"
case BigBitEndian => "False"
}
s"Bitwise(BitsInteger(1, swapped=$swapped))"
case BitsType(width, bitEndian) =>
val swapped = bitEndian match {
case LittleBitEndian => "True"
case BigBitEndian => "False"
}
s"Bitwise(BitsInteger($width, swapped=$swapped))"
case st: SwitchType => createSwitchConstruct(st)
case enumType: EnumType =>
s"Enum(${toString(enumType.basedOn)}, ${convertTypeToClass(enumType.enumSpec.get.name)})"
case any => any.toString
}

private def correctSignedType(signed: Boolean): String = if (signed) "s" else "u"

private def endianToSting(endian: FixedEndian): String = {
endian match {
case LittleEndian => "l"
case BigEndian => "b"
case _ => "n"
}
}

private def getDataclassFieldType(dataType: DataType): String = {
dataType match {
// NOTE: Even though, there is an additional function since version 1.1.9 that
// introduces a shortcut for this case, using 'tfield' here is more consistent.
// The 'csenum' function takes the enum-type and subcon instance as input
// parameters. As we only return the start of a dataclass field, we can't control
// the structure of the converted enum string.
case et: EnumType => s"tfield(${convertTypeToClass(et.enumSpec.get.name)}, "
case _ => "csfield("
}
}

private def getAttributeTypeHint(spec: AttrLikeSpec): String = {
spec.dataType match {
case utb: UserTypeFromBytes => convertTypeToClass(utb.classSpec.get.name)
case enumType: EnumType => convertTypeToClass(enumType.enumSpec.get.name)
case _: Int1Type | _: IntMultiType => "int"
case _: FloatMultiType | _: FloatType => "float"
case _: BitsType1 | _: BitsType => "int"
case _: BytesEosType | _: BytesLimitType | _: BytesTerminatedType => "bytes"
case ut: UserTypeInstream => s"${convertTypeToClass(ut.classSpec.get.name)}_t"
case utb: UserTypeFromBytes => s"${convertTypeToClass(utb.classSpec.get.name)}_t"
case StrFromBytesType(_, _) => "str"
case _ => "t.Any"
}
}

private def correctListTypeHint(typeHint: String, spec: AttrLikeSpec): String = {
spec.cond.repeat match {
case NoRepeat => typeHint
case _ => s"t.List[$typeHint]"
}
}

private def createBytesTerminatedTypeConstruct(btt: BytesTerminatedType, subCon: String): String = {
val terminator = "\\x%02X".format(btt.terminator & 0xFF)
s"NullTerminated($subCon, " +
s"term=b'$terminator', " +
s"include=${translator.doBoolLiteral(btt.include)}, " +
s"consume=${translator.doBoolLiteral(btt.consume)})"
}

private def createBytesLimitTypeConstruct(blt: BytesLimitType, subCon: String = "GreedyBytes"): String = {
val subCon2 = blt.terminator match {
case None => subCon
case Some(value) =>
val term = "\\x%02X".format(value & 0xFF)
return s"NullTerminated($subCon, term=b'$term', include=${translator.doBoolLiteral(blt.include)})"
}
val subCon3 = blt.padRight match {
case None => subCon2
case Some(value) =>
val padding = "\\x%02X".format(value & 0xFF)
return s"NullStripped($subCon2, pad=b'$padding')"
}

s"FixedSized(${translator.translate(blt.size)}, $subCon3)"
}

private def createSwitchConstruct(st: SwitchType): String = {
val cases = st.cases.filter({
case (expr, _) => expr != SwitchType.ELSE_CONST
}).map({
case (expr, dataType) => s"${translator.translate(expr)}: ${toString(dataType)}, "
})

val defaultSuffix = st.cases.get(SwitchType.ELSE_CONST).map({
t => s", default=${toString(t)}"
}).getOrElse("")

s"Switch(${translator.translate(st.on)}, {${cases.mkString}}$defaultSuffix)"
}

private def compileComputedField(id: Identifier, vis: ValueInstanceSpec): Unit = {
val structRepr = s"Computed(lambda this: ${translator.translate(vis.value)})"
val conditionalRepr = withConditional(structRepr, vis.ifExpr)
out.puts(s"${getName(id)}: t.Any = csfield($conditionalRepr)")
val docStr = PythonOps.compileUniversalDocs(vis.doc)
if (docStr.nonEmpty) {
out.putsLines("", "\"\"\"" + docStr + "\"\"\"")
}
}

private def compilePointerField(pis: ParseInstanceSpec): Unit = {
val typeHint = getAttributeTypeHint(pis)
val fieldType = getDataclassFieldType(pis.dataType)
pis.pos match {
case None => compileAttribute(pis)
case Some(value) =>
// TODO: subcsfield and tfield both need an extra argument
out.puts(s"${getName(pis.id)}: $typeHint = ${fieldType}Pointer(${translator.translate(value)}, ${compileAttributeBody(pis)}))")
val docStr = PythonOps.compileUniversalDocs(pis.doc)
if (docStr.nonEmpty) {
out.putsLines("", "\"\"\"" + docStr + "\"\"\"")
}
}
}

}

object ConstructDataclassCompiler extends LanguageCompilerStatic {
override def getCompiler(tp: ClassTypeProvider, config: RuntimeConfig): LanguageCompiler = ???

def convertTypeToClass(name: List[String]): String = {
// REVISIT:
name.mkString("__")
}

}
2 changes: 2 additions & 0 deletions shared/src/main/scala/io/kaitai/struct/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ object Main {
new RustClassCompiler(specs, spec, config)
case ConstructClassCompiler =>
new ConstructClassCompiler(specs, spec)
case ConstructDataclassCompiler =>
new ConstructDataclassCompiler(specs, spec)
case NimCompiler =>
new NimClassCompiler(specs, spec, config)
case HtmlClassCompiler =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ object LanguageCompilerStatic {
"php" -> PHPCompiler,
"python" -> PythonCompiler,
"ruby" -> RubyCompiler,
"rust" -> RustCompiler
"rust" -> RustCompiler,
"cs_dataclass" -> ConstructDataclassCompiler
)

val CLASS_TO_NAME: Map[LanguageCompilerStatic, String] = NAME_TO_CLASS.map(_.swap)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.kaitai.struct.languages.components

import io.kaitai.struct.StringLanguageOutputWriter
import io.kaitai.struct.format.{DocSpec, TextRef, UrlRef}

# Taken from pull request #243
object PythonOps {
def compileUniversalDocs(doc: DocSpec): String = {
val docStr = doc.summary match {
case Some(summary) =>
val lastChar = summary.last
if (lastChar == '.' || lastChar == '\n') {
summary
} else {
summary + "."
}
case None =>
""
}

val extraNewline = if (docStr.isEmpty || docStr.last == '\n') "" else "\n"
val refStr = doc.ref.map {
case TextRef(text) =>
val seeAlso = new StringLanguageOutputWriter("")
seeAlso.putsLines(" ", text)
s"$extraNewline\n.. seealso::\n${seeAlso.result}"
case ref: UrlRef =>
val seeAlso = new StringLanguageOutputWriter("")
seeAlso.putsLines(" ", s"${ref.text} - ${ref.url}")
s"$extraNewline\n.. seealso::\n${seeAlso.result}"
}.mkString("\n")

docStr + refStr
}
}
Loading