diff --git a/ordo/src/emit.rs b/ordo/src/emit.rs index 263289f..691b1cd 100644 --- a/ordo/src/emit.rs +++ b/ordo/src/emit.rs @@ -18,7 +18,7 @@ pub fn emit(expr: ExprTypedAt) -> Result { #[derive(Clone)] enum WasmType { - // I32, + I32, I64, // F32, // F64, @@ -28,6 +28,8 @@ impl From<&Type> for WasmType { fn from(value: &Type) -> Self { if value == &Type::int() { WasmType::I64 + } else if value == &Type::bool() { + WasmType::I32 } else { todo!("{:?}", value) } @@ -43,7 +45,7 @@ impl From for WasmType { impl std::fmt::Display for WasmType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let t = match self { - // WasmType::I32 => "i32", + WasmType::I32 => "i32", WasmType::I64 => "i64", // WasmType::F32 => "f32", // WasmType::F64 => "f64", @@ -132,7 +134,7 @@ impl std::fmt::Display for WasmLocal { enum WasmCode { LocalGet(String, WasmType), - Add(WasmType), + Instruction(&'static str, WasmType, WasmType), Const(WasmValue), Call(String, WasmType), LocalSet(String, WasmType), @@ -142,7 +144,7 @@ impl std::fmt::Display for WasmCode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { WasmCode::LocalGet(name, _) => write!(f, "local.get ${}", name), - WasmCode::Add(ty) => write!(f, "{}.add", ty), + WasmCode::Instruction(instruction, in_ty, _) => write!(f, "{}.{}", in_ty, instruction), WasmCode::Const(v) => write!(f, "{}.const {}", v.ty(), v), WasmCode::Call(name, _) => write!(f, "call ${}", name), WasmCode::LocalSet(name, _) => { @@ -156,7 +158,7 @@ impl WasmCode { fn ty(&self) -> WasmType { match self { WasmCode::LocalGet(_, ty) => ty.clone(), - WasmCode::Add(ty) => ty.clone(), + WasmCode::Instruction(_, _, out_ty) => out_ty.clone(), WasmCode::Const(value) => value.ty(), WasmCode::Call(_, ty) => ty.clone(), WasmCode::LocalSet(_, ty) => ty.clone(), @@ -269,23 +271,25 @@ impl Module { self.function.code.push(WasmCode::Const(i.into())); Ok(()) } - Expr::IntBinOp(op, a, b) => match op { - IntBinOp::Plus => { - self.emit(a)?; - self.emit(b)?; - self.function - .code - .push(WasmCode::Add(op.output_ty().into())); - Ok(()) - } - IntBinOp::Minus => todo!(), - IntBinOp::Multiply => todo!(), - IntBinOp::Divide => todo!(), - IntBinOp::LessThan => todo!(), - IntBinOp::LessThanOrEqual => todo!(), - IntBinOp::GreaterThan => todo!(), - IntBinOp::GreaterThanOrEqual => todo!(), - }, + Expr::IntBinOp(op, a, b) => { + let out = op.output_ty().into(); + let i = match op { + IntBinOp::Plus => "add", + IntBinOp::Minus => "sub", + IntBinOp::Multiply => "mul", + IntBinOp::Divide => "div_s", + IntBinOp::LessThan => "lt_s", + IntBinOp::LessThanOrEqual => "le_s", + IntBinOp::GreaterThan => "gt_s", + IntBinOp::GreaterThanOrEqual => "ge_s", + }; + self.emit(a)?; + self.emit(b)?; + self.function + .code + .push(WasmCode::Instruction(i, WasmType::I64, out)); + Ok(()) + } Expr::Var(name) => { self.function.code.push(WasmCode::LocalGet(name, ty)); Ok(()) @@ -300,6 +304,15 @@ impl Module { } _ => todo!(), }, + Expr::EqualEqual(a, b) => { + let in_ty = WasmType::from(&a.context.ty.ty); + self.emit(a)?; + self.emit(b)?; + self.function + .code + .push(WasmCode::Instruction("eq", in_ty, WasmType::I32)); + Ok(()) + } _ => todo!(), } } diff --git a/ordo/src/tests/emit_tests.rs b/ordo/src/tests/emit_tests.rs index d498c72..8818b3c 100644 --- a/ordo/src/tests/emit_tests.rs +++ b/ordo/src/tests/emit_tests.rs @@ -6,7 +6,7 @@ use crate::{ use wasmi::*; #[track_caller] -fn pass(source: &str, expected: i64) { +fn pass(source: &str, expected: T) { let mut env = infer::Env::default(); let expr = Parser::expr(source).unwrap(); let typed_expr = env.infer(expr.clone()).unwrap(); @@ -21,15 +21,16 @@ fn pass(source: &str, expected: i64) { .unwrap() .start(&mut store) .unwrap(); - let load = instance - .get_typed_func::<(), i64>(&store, LOAD_NAME) - .unwrap(); + let load = instance.get_typed_func::<(), T>(&store, LOAD_NAME).unwrap(); let actual = load.call(&mut store, ()).unwrap(); assert_eq!(expected, actual); } #[test] fn test() { - pass("let add(a, b) = a + b in add(1, 2)", 3); - pass("let a = 1 in a", 1); + pass("let add(a, b) = a + b in add(1, 2)", 3_u64); + pass("let a = 1 in a", 1_u64); + pass("4 * 5 / 2 - 1 + 4", 13_u64); + pass("2 == 2", 1_i32); + pass("3 > 4", 0_i32); }