Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge with current state of work, for testing purposes #162

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ class RustClassCompiler(
}

lang.readHeader(defEndian, false)

compileSeq(curClass.seq, defEndian)
lang.classConstructorFooter
lang.readFooter()
}

override def compileInstances(curClass: ClassSpec) = {
Expand Down
112 changes: 64 additions & 48 deletions shared/src/main/scala/io/kaitai/struct/languages/RustCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
outHeader.puts(s"// $headerComment")
outHeader.puts

outHeader.puts("#![allow(unused_parens, unused_imports)]")

importList.add("std::option::Option")
importList.add("std::boxed::Box")
importList.add("std::io::Result")
importList.add("std::io::Cursor")
importList.add("std::vec::Vec")
importList.add("std::rc::Rc")
importList.add("std::rc::Weak")
importList.add("std::default::Default")
importList.add("kaitai_struct::KaitaiStream")
importList.add("kaitai_struct::KaitaiStruct")
Expand Down Expand Up @@ -84,28 +86,22 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
val pRoot = paramName(RootIdentifier)

// Types
val tIo = kstreamName
val tParent = kaitaiType2NativeType(parentType)

out.puts(s"fn new<S: KaitaiStream>(stream: &mut S,")
out.puts(s" _parent: &Option<Box<KaitaiStruct>>,")
out.puts(s" _root: &Option<Box<KaitaiStruct>>)")
out.puts(s" -> Result<Self>")
out.puts(s"fn new(stream: Box<KaitaiStream>,")
out.puts(s" _parent: Option<Weak<KaitaiStruct>>,")
out.puts(s" _root: Option<Weak<KaitaiStruct>>)")
out.puts(s" -> Result<Rc<Self>>")
out.inc
out.puts(s"where Self: Sized {")

out.puts(s"let mut s: Self = Default::default();")
out.puts(s"s.stream = Some(stream);")
out.puts(s"s._parent = _parent;")
out.puts(s"s._root = _root;")
out.puts

out.puts(s"s.stream = stream;")

out.puts(s"s.read(stream, _parent, _root)?;")
out.puts

out.puts("Ok(s)")
out.dec
out.puts("}")
out.puts
out.puts(s"let mut rc_s = Rc::new(s);")
}

override def runRead(): Unit = {
Expand All @@ -117,29 +113,21 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def readHeader(endian: Option[FixedEndian], isEmpty: Boolean) = {
out.puts
out.puts(s"fn read<S: KaitaiStream>(&mut self,")
out.puts(s" stream: &mut S,")
out.puts(s" _parent: &Option<Box<KaitaiStruct>>,")
out.puts(s" _root: &Option<Box<KaitaiStruct>>)")
out.puts(s" -> Result<()>")
out.inc
out.puts(s"where Self: Sized {")

}

override def readFooter(): Unit = {
out.puts
out.puts("Ok(())")
out.puts("Ok(rc_s)")
out.dec
out.puts("}")
}

override def attributeDeclaration(attrName: Identifier, attrType: DataType, isNullable: Boolean): Unit = {
attrName match {
case ParentIdentifier | RootIdentifier | IoIdentifier =>
// just ignore it for now
case IoIdentifier =>
out.puts(s" stream: ${kaitaiType2NativeType(attrType)},")
case IoIdentifier => out.puts(s" stream: Option<Box<KaitaiStream>>,")
case ParentIdentifier => out.puts(s" _parent: Option<Weak<KaitaiStruct>>,")
case RootIdentifier => out.puts(s" _root: Option<Weak<KaitaiStruct>>,")
case _ =>
out.puts(s" pub ${idToStr(attrName)}: ${kaitaiType2NativeType(attrType)},")
}
Expand Down Expand Up @@ -180,19 +168,25 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
proc match {
case ProcessXor(xorValue) =>
val procName = translator.detectType(xorValue) match {
case _: IntType => "processXorOne"
case _: BytesType => "processXorMany"
case _: IntType => "process_xor_one"
case _: BytesType => "process_xor_many"
}
out.puts(s"$destName = $kstreamName::$procName($srcName, ${expression(xorValue)});")
out.puts(s"$destName = self.stream.expect(" +
"\"This should never be None\")" +
s".$procName($srcName, ${expression(xorValue)});")
case ProcessZlib =>
out.puts(s"$destName = $kstreamName::processZlib($srcName);")
out.puts(s"$destName = self.stream.expect(" +
"\"This should never be None\")" +
s".process_zlib($srcName)?;")
case ProcessRotate(isLeft, rotValue) =>
val expr = if (isLeft) {
expression(rotValue)
} else {
s"8 - (${expression(rotValue)})"
}
out.puts(s"$destName = $kstreamName::processRotateLeft($srcName, $expr, 1);")
out.puts(s"$destName = self.stream.expect(" +
"\"This should never be None\")" +
s".process_rotate_left($srcName, $expr, 1);")
case ProcessCustom(name, args) =>
val procClass = if (name.length == 1) {
val onlyName = name.head
Expand Down Expand Up @@ -220,6 +214,8 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
case NoRepeat => memberName
}

importList.add("std::io::Cursor")

out.puts(s"let mut io = Cursor::new($args);")
"io"
}
Expand Down Expand Up @@ -263,6 +259,8 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def condRepeatExprHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, repeatExpr: Ast.expr): Unit = {
importList.add("std::vec::Vec")

if (needRaw)
out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();")
out.puts(s"${privateMemberName(id)} = vec!();")
Expand All @@ -275,6 +273,8 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def condRepeatUntilHeader(id: Identifier, io: String, dataType: DataType, needRaw: Boolean, untilExpr: Ast.expr): Unit = {
importList.add("std::vec::Vec")

if (needRaw)
out.puts(s"${privateMemberName(RawIdentifier(id))} = vec!();")
out.puts(s"${privateMemberName(id)} = vec!();")
Expand All @@ -300,7 +300,12 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def handleAssignmentSimple(id: Identifier, expr: String): Unit = {
out.puts(s"${privateMemberName(id)} = $expr;")
id match {
case _: InstanceIdentifier =>
out.puts(s"${privateMemberName(id)} = Some($expr);")
case _ =>
out.puts(s"${privateMemberName(id)} = $expr;")
}
}

override def parseExpr(dataType: DataType, assignType: DataType, io: String, defEndian: Option[FixedEndian]): String = {
Expand Down Expand Up @@ -334,17 +339,21 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
s", $parent, ${privateMemberName(RootIdentifier)}$addEndian"
}

s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(self.stream, self, _root)?)"
s"Box::new(${translator.types2classAbs(t.classSpec.get.name)}::new(stream, Some(Rc::downgrade(&rc_s as Rc<KaitaiStruct>)), _root)?)"
}
}

override def bytesPadTermExpr(expr0: String, padRight: Option[Int], terminator: Option[Int], include: Boolean): String = {
val expr1 = padRight match {
case Some(padByte) => s"$kstreamName::bytesStripRight($expr0, $padByte)"
case Some(padByte) => s"self.stream.expect(" +
"\"This should never be None\")" +
s".bytes_strip_right($expr0, $padByte)"
case None => expr0
}
val expr2 = terminator match {
case Some(term) => s"$kstreamName::bytesTerminate($expr1, $term, $include)"
case Some(term) => s"self.stream.expect(" +
"\"This should never be None\")" +
s".bytes_terminate($expr1, $term, $include)"
case None => expr1
}
expr2
Expand Down Expand Up @@ -450,7 +459,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
}

override def instanceReturn(instName: InstanceIdentifier, attrType: DataType): Unit = {
out.puts(s"return ${privateMemberName(instName)};")
out.puts(s"return ${privateMemberName(instName)}.expect(" + "\"Something went very wrong\");")
}

override def enumDeclaration(curClass: List[String], enumName: String, enumColl: Seq[(Long, EnumValueSpec)]): Unit = {
Expand All @@ -466,6 +475,13 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)

out.dec
out.puts("}")
out.puts

out.puts(s"impl Default for $enumClass {")
out.inc
out.puts(s"fn default() -> Self { $enumClass::${value2Const(enumColl.head._2.name)} }")
out.dec
out.puts("}")
}

def value2Const(label: String) = label.toUpperCase
Expand All @@ -475,16 +491,16 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)
case SpecialIdentifier(name) => name
case NamedIdentifier(name) => Utils.lowerCamelCase(name)
case NumberedIdentifier(idx) => s"_${NumberedIdentifier.TEMPLATE}$idx"
case InstanceIdentifier(name) => Utils.lowerCamelCase(name)
case InstanceIdentifier(name) => s"${Utils.lowerCamelCase(name)}"
case RawIdentifier(innerId) => "_raw_" + idToStr(innerId)
}
}

override def privateMemberName(id: Identifier): String = {
id match {
case IoIdentifier => s"self.stream"
case RootIdentifier => s"_root"
case ParentIdentifier => s"_parent"
case IoIdentifier => s"self.stream.expect(" + "\"This should never be None\")"
case RootIdentifier => s"self._root"
case ParentIdentifier => s"self._parent"
case _ => s"self.${idToStr(id)}"
}
}
Expand Down Expand Up @@ -531,7 +547,7 @@ class RustCompiler(typeProvider: ClassTypeProvider, config: RuntimeConfig)

case ArrayType(inType) => s"Vec<${kaitaiType2NativeType(inType)}>"

case KaitaiStreamType => s"Option<Box<KaitaiStream>>"
case KaitaiStreamType => s"Box<KaitaiStream>"
case KaitaiStructType | CalcKaitaiStructType => s"Option<Box<KaitaiStruct>>"

case st: SwitchType => kaitaiType2NativeType(st.combinedType)
Expand Down Expand Up @@ -590,16 +606,16 @@ object RustCompiler extends LanguageCompilerStatic
): LanguageCompiler = new RustCompiler(tp, config)

override def kstructName = "&Option<Box<KaitaiStruct>>"
override def kstreamName = "&mut S"
override def kstreamName = "stream"

def types2class(typeName: Ast.typeId) = {
typeName.names.map(type2class).mkString(
if (typeName.absolute) "__" else "",
"__",
"",
"",
""
)
}

def types2classRel(names: List[String]) =
names.map(type2class).mkString("__")
names.map(type2class).mkString("")
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.kaitai.struct.datatype.DataType._
import io.kaitai.struct.exprlang.Ast
import io.kaitai.struct.exprlang.Ast.expr
import io.kaitai.struct.format.Identifier
import io.kaitai.struct.format.NamedIdentifier
import io.kaitai.struct.languages.RustCompiler
import io.kaitai.struct.{RuntimeConfig, Utils}

Expand Down Expand Up @@ -46,7 +47,7 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base
}
}

override def doName(s: String) = s
override def doName(s: String) = s"${Utils.lowerCamelCase(s)}"

override def doEnumByLabel(enumTypeAbs: List[String], label: String): String = {
val enumClass = types2classAbs(enumTypeAbs)
Expand All @@ -61,7 +62,7 @@ class RustTranslator(provider: TypeProvider, config: RuntimeConfig) extends Base
override def doIfExp(condition: expr, ifTrue: expr, ifFalse: expr): String =
"if " + translate(condition) +
" { " + translate(ifTrue) + " } else { " +
translate(ifFalse) + "}"
translate(ifFalse) + " }"

// Predefined methods of various types
override def strConcat(left: Ast.expr, right: Ast.expr): String =
Expand Down