diff --git a/generate_ast/mod.rs b/generate_ast/mod.rs index 33875ac..f2199a1 100644 --- a/generate_ast/mod.rs +++ b/generate_ast/mod.rs @@ -13,19 +13,20 @@ pub fn generate_ast(output_dir: &String) -> io::Result<()> { output_dir, &"Expr".to_string(), &vec![ - "Assign : Token name, Box value".to_string(), - "Binary : Box left, Token operator, Box right".to_string(), - "Call : Box callee, Token paren, Vec arguments".to_string(), - "Grouping : Box expression".to_string(), + "Assign : Token name, Rc value".to_string(), + "Binary : Rc left, Token operator, Rc right".to_string(), + "Call : Rc callee, Token paren, Vec> arguments".to_string(), + "Grouping : Rc expression".to_string(), "Literal : Option value".to_string(), - "Logical : Box left, Token operator, Box right".to_string(), - "Unary : Token operator, Box right".to_string(), + "Logical : Rc left, Token operator, Rc right".to_string(), + "Unary : Token operator, Rc right".to_string(), "Variable : Token name".to_string(), ], &vec![ "crate::tokens::*".to_string(), "crate::errors::*".to_string(), + "std::rc::Rc".to_string(), ], )?; @@ -33,14 +34,14 @@ pub fn generate_ast(output_dir: &String) -> io::Result<()> { output_dir, &"Stmt".to_string(), &vec![ - "Block : Vec statements".to_string(), - "Expression : Expr expression".to_string(), - "Function : Token name, Rc> params, Rc> body".to_string(), - "If : Expr condition, Box then_branch, Option> else_branch".to_string(), - "Print : Expr expression".to_string(), - "Var : Token name, Option initializer".to_string(), - "While : Expr condition, Box body".to_string(), - "Return : Token keyword, Option value".to_string(), + "Block : Rc>> statements".to_string(), + "Expression : Rc expression".to_string(), + "Function : Token name, Rc> params, Rc>> body".to_string(), + "If : Rc condition, Rc then_branch, Option> else_branch".to_string(), + "Print : Rc expression".to_string(), + "Var : Token name, Option> initializer".to_string(), + "While : Rc condition, Rc body".to_string(), + "Return : Token keyword, Option> value".to_string(), ], &vec![ "crate::expr::Expr".to_string(), @@ -84,19 +85,49 @@ fn define_ast( write!(file, "\npub enum {base_name} {{\n")?; for t in &tree_types { - write!(file, " {}({}),\n", t.base_class_name, t.class_name)?; + write!(file, " {}(Rc<{}>),\n", t.base_class_name, t.class_name)?; } write!(file, "}}\n\n")?; + writeln!(file, "impl PartialEq for {} {{", base_name)?; + writeln!(file, " fn eq(&self, other: &Self) -> bool {{")?; + writeln!(file, " match (self, other) {{")?; + for t in &tree_types { + writeln!( + file, + " ({0}::{1}(a), {0}::{1}(b)) => Rc::ptr_eq(a, b),", + base_name, t.base_class_name + )?; + } + writeln!(file, " _ => false,")?; + writeln!(file, " }}")?; + writeln!(file, " }}")?; + writeln!(file, "}}\n\nimpl Eq for {}{{}}\n", base_name)?; + + writeln!(file, "use std::hash::{{Hash, Hasher}};")?; + writeln!(file, "impl Hash for {} {{", base_name)?; + writeln!(file, " fn hash(&self, hasher: &mut H)")?; + writeln!(file, " where H: Hasher,")?; + writeln!(file, " {{ match self {{ ")?; + for t in &tree_types { + writeln!( + file, + " {}::{}(a) => {{ hasher.write_usize(Rc::as_ptr(a) as usize); }}", + base_name, t.base_class_name + )?; + } + writeln!(file, " }}\n }}\n}}\n")?; + write!(file, "impl {} {{\n", base_name)?; - write!(file, " pub fn accept(&self, {}_visitor: &dyn {base_name}Visitor) -> Result {{\n", base_name.to_lowercase())?; + write!(file, " pub fn accept(&self, wrapper: &Rc<{}>, {}_visitor: &dyn {base_name}Visitor) -> Result {{\n", base_name, base_name.to_lowercase())?; write!(file, " match self {{\n")?; for t in &tree_types { write!( file, - " {}::{}(v) => v.accept({}_visitor),\n", + " {0}::{1}(v) => {3}_visitor.visit_{2}_{3}(wrapper, &v),\n", base_name, t.base_class_name, + t.base_class_name.to_lowercase(), base_name.to_lowercase() )?; } @@ -116,29 +147,30 @@ fn define_ast( for t in &tree_types { write!( file, - " fn visit_{}_{}(&self, expr: &{}) -> Result;\n", + " fn visit_{0}_{1}(&self, wrapper: &Rc<{3}>, {1}: &{2}) -> Result;\n", t.base_class_name.to_lowercase(), base_name.to_lowercase(), - t.class_name + t.class_name, + base_name, )?; } write!(file, "}}\n\n")?; - for t in &tree_types { - write!(file, "impl {} {{\n", t.class_name)?; - write!( - file, - " pub fn accept(&self, visitor: &dyn {}Visitor) -> Result {{\n", - base_name - )?; - write!( - file, - " visitor.visit_{}_{}(self)\n", - t.base_class_name.to_lowercase(), - base_name.to_lowercase() - )?; - write!(file, " }}\n")?; - write!(file, "}}\n\n")?; - } + // for t in &tree_types { + // write!(file, "impl {} {{\n", t.class_name)?; + // write!( + // file, + // " pub fn accept(&self, visitor: &dyn {}Visitor) -> Result {{\n", + // base_name + // )?; + // write!( + // file, + // " visitor.visit_{}_{}(self)\n", + // t.base_class_name.to_lowercase(), + // base_name.to_lowercase() + // )?; + // write!(file, " }}\n")?; + // write!(file, "}}\n\n")?; + // } Ok(()) } diff --git a/main.arc b/main.arc index bd46d53..a028575 100644 --- a/main.arc +++ b/main.arc @@ -1,8 +1,8 @@ -var a = 0; -var temp; +fn fib(n) { + if (n <= 1) return n; + return fib(n - 2) + fib(n - 1); +} -for (var b = 1; a < 10000; b = temp + b) { - print a; - temp = a; - a = b; -} \ No newline at end of file +for(var i = 0; i < 100; i = i + 1){ + print fib(i); +} diff --git a/src/errors.rs b/src/errors.rs index 3afaf1c..2482b83 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -38,7 +38,7 @@ impl Error { err } - fn report(&self, loc: &str) { + fn report(&self, _loc: &str) { match self { Error::ParseError { token, message } | Error::RuntimeError { token, message } => { diff --git a/src/expr.rs b/src/expr.rs index 170d4dd..35fc2e9 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -1,51 +1,87 @@ use crate::tokens::*; use crate::errors::*; +use std::rc::Rc; pub enum Expr { - Assign(AssignExpr), - Binary(BinaryExpr), - Call(CallExpr), - Grouping(GroupingExpr), - Literal(LiteralExpr), - Logical(LogicalExpr), - Unary(UnaryExpr), - Variable(VariableExpr), + Assign(Rc), + Binary(Rc), + Call(Rc), + Grouping(Rc), + Literal(Rc), + Logical(Rc), + Unary(Rc), + Variable(Rc), +} + +impl PartialEq for Expr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Expr::Assign(a), Expr::Assign(b)) => Rc::ptr_eq(a, b), + (Expr::Binary(a), Expr::Binary(b)) => Rc::ptr_eq(a, b), + (Expr::Call(a), Expr::Call(b)) => Rc::ptr_eq(a, b), + (Expr::Grouping(a), Expr::Grouping(b)) => Rc::ptr_eq(a, b), + (Expr::Literal(a), Expr::Literal(b)) => Rc::ptr_eq(a, b), + (Expr::Logical(a), Expr::Logical(b)) => Rc::ptr_eq(a, b), + (Expr::Unary(a), Expr::Unary(b)) => Rc::ptr_eq(a, b), + (Expr::Variable(a), Expr::Variable(b)) => Rc::ptr_eq(a, b), + _ => false, + } + } +} + +impl Eq for Expr{} + +use std::hash::{Hash, Hasher}; +impl Hash for Expr { + fn hash(&self, hasher: &mut H) + where H: Hasher, + { match self { + Expr::Assign(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Binary(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Call(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Grouping(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Literal(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Logical(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Unary(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Variable(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + } + } } impl Expr { - pub fn accept(&self, expr_visitor: &dyn ExprVisitor) -> Result { + pub fn accept(&self, wrapper: &Rc, expr_visitor: &dyn ExprVisitor) -> Result { match self { - Expr::Assign(v) => v.accept(expr_visitor), - Expr::Binary(v) => v.accept(expr_visitor), - Expr::Call(v) => v.accept(expr_visitor), - Expr::Grouping(v) => v.accept(expr_visitor), - Expr::Literal(v) => v.accept(expr_visitor), - Expr::Logical(v) => v.accept(expr_visitor), - Expr::Unary(v) => v.accept(expr_visitor), - Expr::Variable(v) => v.accept(expr_visitor), + Expr::Assign(v) => expr_visitor.visit_assign_expr(wrapper, &v), + Expr::Binary(v) => expr_visitor.visit_binary_expr(wrapper, &v), + Expr::Call(v) => expr_visitor.visit_call_expr(wrapper, &v), + Expr::Grouping(v) => expr_visitor.visit_grouping_expr(wrapper, &v), + Expr::Literal(v) => expr_visitor.visit_literal_expr(wrapper, &v), + Expr::Logical(v) => expr_visitor.visit_logical_expr(wrapper, &v), + Expr::Unary(v) => expr_visitor.visit_unary_expr(wrapper, &v), + Expr::Variable(v) => expr_visitor.visit_variable_expr(wrapper, &v), } } } pub struct AssignExpr { pub name: Token, - pub value: Box, + pub value: Rc, } pub struct BinaryExpr { - pub left: Box, + pub left: Rc, pub operator: Token, - pub right: Box, + pub right: Rc, } pub struct CallExpr { - pub callee: Box, + pub callee: Rc, pub paren: Token, - pub arguments: Vec, + pub arguments: Vec>, } pub struct GroupingExpr { - pub expression: Box, + pub expression: Rc, } pub struct LiteralExpr { @@ -53,14 +89,14 @@ pub struct LiteralExpr { } pub struct LogicalExpr { - pub left: Box, + pub left: Rc, pub operator: Token, - pub right: Box, + pub right: Rc, } pub struct UnaryExpr { pub operator: Token, - pub right: Box, + pub right: Rc, } pub struct VariableExpr { @@ -68,61 +104,13 @@ pub struct VariableExpr { } pub trait ExprVisitor { - fn visit_assign_expr(&self, expr: &AssignExpr) -> Result; - fn visit_binary_expr(&self, expr: &BinaryExpr) -> Result; - fn visit_call_expr(&self, expr: &CallExpr) -> Result; - fn visit_grouping_expr(&self, expr: &GroupingExpr) -> Result; - fn visit_literal_expr(&self, expr: &LiteralExpr) -> Result; - fn visit_logical_expr(&self, expr: &LogicalExpr) -> Result; - fn visit_unary_expr(&self, expr: &UnaryExpr) -> Result; - fn visit_variable_expr(&self, expr: &VariableExpr) -> Result; -} - -impl AssignExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_assign_expr(self) - } -} - -impl BinaryExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_binary_expr(self) - } -} - -impl CallExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_call_expr(self) - } -} - -impl GroupingExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_grouping_expr(self) - } -} - -impl LiteralExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_literal_expr(self) - } -} - -impl LogicalExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_logical_expr(self) - } -} - -impl UnaryExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_unary_expr(self) - } -} - -impl VariableExpr { - pub fn accept(&self, visitor: &dyn ExprVisitor) -> Result { - visitor.visit_variable_expr(self) - } + fn visit_assign_expr(&self, wrapper: &Rc, expr: &AssignExpr) -> Result; + fn visit_binary_expr(&self, wrapper: &Rc, expr: &BinaryExpr) -> Result; + fn visit_call_expr(&self, wrapper: &Rc, expr: &CallExpr) -> Result; + fn visit_grouping_expr(&self, wrapper: &Rc, expr: &GroupingExpr) -> Result; + fn visit_literal_expr(&self, wrapper: &Rc, expr: &LiteralExpr) -> Result; + fn visit_logical_expr(&self, wrapper: &Rc, expr: &LogicalExpr) -> Result; + fn visit_unary_expr(&self, wrapper: &Rc, expr: &UnaryExpr) -> Result; + fn visit_variable_expr(&self, wrapper: &Rc, expr: &VariableExpr) -> Result; } diff --git a/src/functions.rs b/src/functions.rs index 946c5ab..3270207 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -10,7 +10,7 @@ use crate::interpreter::Interpreter; pub struct Function { name : Token, params : Rc>, - body : Rc>, + body : Rc>>, closure: Rc>, } diff --git a/src/interpreter.rs b/src/interpreter.rs index c8f0b1c..2596907 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -7,16 +7,18 @@ use crate::native_functions::*; use crate::stmt::*; use crate::tokens::*; -use std::ops::Deref; use std::cell::RefCell; +use std::collections::HashMap; +use std::ops::Deref; use std::rc::Rc; pub struct Interpreter { pub globals: Rc>, environment: RefCell>>, + locals: RefCell, usize>>, } impl StmtVisitor<()> for Interpreter { - fn visit_return_stmt(&self, stmt: &ReturnStmt) -> Result<(), Error> { + fn visit_return_stmt(&self, _: &Rc, stmt: &ReturnStmt) -> Result<(), Error> { if let Some(value) = &stmt.value { let value = self.evaluate(value)?; Err(Error::return_value(value)) @@ -24,7 +26,7 @@ impl StmtVisitor<()> for Interpreter { Err(Error::return_value(Object::Nil)) } } - fn visit_function_stmt(&self, stmt: &FunctionStmt) -> Result<(), Error> { + fn visit_function_stmt(&self, _: &Rc, stmt: &FunctionStmt) -> Result<(), Error> { let function = Function::new(&stmt, self.environment.borrow().deref()); self.environment.borrow().borrow_mut().define( stmt.name.lexeme.clone(), @@ -35,7 +37,7 @@ impl StmtVisitor<()> for Interpreter { Ok(()) } - fn visit_if_stmt(&self, stmt: &IfStmt) -> Result<(), Error> { + fn visit_if_stmt(&self, _: &Rc, stmt: &IfStmt) -> Result<(), Error> { if self.is_truthy(self.evaluate(&stmt.condition)?) { self.execute(&stmt.then_branch)?; } else if let Some(else_branch) = &stmt.else_branch { @@ -43,23 +45,27 @@ impl StmtVisitor<()> for Interpreter { } Ok(()) } - fn visit_block_stmt(&self, stmt: &BlockStmt) -> Result<(), Error> { + fn visit_block_stmt(&self, _: &Rc, stmt: &BlockStmt) -> Result<(), Error> { let environment = Environment::new_with_enclosing(Rc::clone(&self.environment.borrow().clone())); self.execute_block(&stmt.statements, environment) } - fn visit_print_stmt(&self, stmt: &PrintStmt) -> Result<(), Error> { + fn visit_print_stmt(&self, _: &Rc, stmt: &PrintStmt) -> Result<(), Error> { let value = self.evaluate(&stmt.expression)?; println!("{}", value); Ok(()) } - fn visit_expression_stmt(&self, stmt: &ExpressionStmt) -> Result<(), Error> { + fn visit_expression_stmt( + &self, + _: &Rc, + stmt: &ExpressionStmt, + ) -> Result<(), Error> { self.evaluate(&stmt.expression)?; Ok(()) } - fn visit_var_stmt(&self, stmt: &VarStmt) -> Result<(), Error> { + fn visit_var_stmt(&self, _: &Rc, stmt: &VarStmt) -> Result<(), Error> { let value = if let Some(expr) = &stmt.initializer { self.evaluate(expr)? } else { @@ -73,7 +79,7 @@ impl StmtVisitor<()> for Interpreter { Ok(()) } - fn visit_while_stmt(&self, stmt: &WhileStmt) -> Result<(), Error> { + fn visit_while_stmt(&self, _: &Rc, stmt: &WhileStmt) -> Result<(), Error> { while self.is_truthy(self.evaluate(&stmt.condition)?) { self.execute(&stmt.body)?; } @@ -82,7 +88,7 @@ impl StmtVisitor<()> for Interpreter { } impl ExprVisitor for Interpreter { - fn visit_logical_expr(&self, expr: &LogicalExpr) -> Result { + fn visit_logical_expr(&self, _: &Rc, expr: &LogicalExpr) -> Result { let left = self.evaluate(&expr.left)?; if expr.operator.kind == TokenKind::Or { @@ -97,7 +103,7 @@ impl ExprVisitor for Interpreter { self.evaluate(&expr.right) } - fn visit_assign_expr(&self, expr: &AssignExpr) -> Result { + fn visit_assign_expr(&self, _: &Rc, expr: &AssignExpr) -> Result { let value = self.evaluate(&expr.value)?; self.environment .borrow() @@ -106,15 +112,19 @@ impl ExprVisitor for Interpreter { Ok(value) } - fn visit_literal_expr(&self, expr: &LiteralExpr) -> Result { + fn visit_literal_expr(&self, _: &Rc, expr: &LiteralExpr) -> Result { Ok(expr.value.clone().unwrap()) } - fn visit_grouping_expr(&self, expr: &GroupingExpr) -> Result { + fn visit_grouping_expr( + &self, + _: &Rc, + expr: &GroupingExpr, + ) -> Result { Ok(self.evaluate(&expr.expression)?) } - fn visit_unary_expr(&self, expr: &UnaryExpr) -> Result { + fn visit_unary_expr(&self, _: &Rc, expr: &UnaryExpr) -> Result { let right = self.evaluate(&expr.right)?; match expr.operator.kind { @@ -134,7 +144,7 @@ impl ExprVisitor for Interpreter { } } - fn visit_binary_expr(&self, expr: &BinaryExpr) -> Result { + fn visit_binary_expr(&self, _: &Rc, expr: &BinaryExpr) -> Result { let left = self.evaluate(&expr.left)?; let right = self.evaluate(&expr.right)?; let operator = &expr.operator.kind; @@ -208,7 +218,7 @@ impl ExprVisitor for Interpreter { } } - fn visit_call_expr(&self, expr: &CallExpr) -> Result { + fn visit_call_expr(&self, wrapper: &Rc, expr: &CallExpr) -> Result { let callee = self.evaluate(&expr.callee)?; let mut arguments = Vec::new(); @@ -237,8 +247,8 @@ impl ExprVisitor for Interpreter { } } - fn visit_variable_expr(&self, expr: &VariableExpr) -> Result { - self.environment.borrow().borrow().get(&expr.name) + fn visit_variable_expr(&self, wrapper: &Rc, expr: &VariableExpr) -> Result { + self.look_up_variable(&expr.name, wrapper) } } @@ -256,16 +266,17 @@ impl Interpreter { Interpreter { globals: Rc::clone(&global), environment: RefCell::new(Rc::clone(&global)), + locals: RefCell::new(HashMap::new()), } } - fn execute(&self, stmt: &Stmt) -> Result<(), Error> { - stmt.accept(self) + fn execute(&self, stmt: &Rc) -> Result<(), Error> { + stmt.accept(stmt, self) } pub fn execute_block( &self, - statements: &[Stmt], + statements: &Rc>>, environment: Environment, ) -> Result<(), Error> { let previous = self.environment.replace(Rc::new(RefCell::new(environment))); @@ -279,10 +290,10 @@ impl Interpreter { result } - pub fn interpret(&self, statements: &[Stmt]) -> bool { + pub fn interpret(&self, statements: &[Rc]) -> bool { let mut success = true; for statment in statements { - if let Err(e) = self.execute(statment) { + if let Err(_e) = self.execute(statment) { success = false; break; } @@ -290,8 +301,8 @@ impl Interpreter { success } - fn evaluate(&self, expr: &Expr) -> Result { - expr.accept(self) + fn evaluate(&self, expr: &Rc) -> Result { + expr.accept(expr, self) } fn is_truthy(&self, object: Object) -> bool { @@ -305,6 +316,21 @@ impl Interpreter { pub fn print_env(&self) { println!("{:?}", self.environment.borrow()); } + + pub fn resolve(&self, expr: Rc, depth: usize) { + self.locals.borrow_mut().insert(expr, depth); + } + + fn look_up_variable(&self, name: &Token, expr: &Rc) -> Result { + if let Some(distance) = self.locals.borrow().get(expr) { + self.environment + .borrow() + .borrow() + .get_at(*distance, &name.lexeme) + } else { + self.globals.borrow().get(&name) + } + } } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index 6ecac77..4c778fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -30,6 +30,8 @@ mod native_functions; mod functions; +mod resolver; + fn eval(source: &str) -> Result<(), Error> { let lexer = Lexer::new(source.to_string()); let mut tokens: Vec = lexer.collect(); diff --git a/src/parser.rs b/src/parser.rs index 871507d..cd0faf8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -14,8 +14,8 @@ impl Parser { Parser { tokens, current: 0 } } - pub fn parse(&mut self) -> Result, Error> { - let mut statements: Vec = Vec::new(); + pub fn parse(&mut self) -> Result>, Error> { + let mut statements: Vec> = Vec::new(); while !self.is_at_end() { statements.push(self.declaration()?); } @@ -26,7 +26,7 @@ impl Parser { self.assignment() } - fn declaration(&mut self) -> Result { + fn declaration(&mut self) -> Result, Error> { let result = if self.match_token(vec![TokenKind::Fn]) { self.function("function") } else if self.match_token(vec![TokenKind::Var]) { @@ -42,47 +42,48 @@ impl Parser { result } - fn statement(&mut self) -> Result { + fn statement(&mut self) -> Result, Error> { if self.match_token(vec![TokenKind::For]) { return self.for_statement(); } if self.match_token(vec![TokenKind::If]) { - return self.if_statement(); + return Ok(Rc::new(self.if_statement()?)); } if self.match_token(vec![TokenKind::Print]) { - return self.print_statement(); + return Ok(Rc::new(self.print_statement()?)); } if self.match_token(vec![TokenKind::Return]) { - return self.return_statement(); + return Ok(self.return_statement()?); } if self.match_token(vec![TokenKind::While]) { - return self.while_statement(); + return Ok(Rc::new(self.while_statement()?)); } if self.match_token(vec![TokenKind::LeftBrace]) { - return Ok(Stmt::Block(BlockStmt { - statements: self.block()?, - })); + return Ok(Rc::new(Stmt::Block(Rc::new(BlockStmt { + statements: Rc::new(self.block()?), + })))); } - self.expression_statement() + + Ok(self.expression_statement()?) } - fn return_statement(&mut self) -> Result { + fn return_statement(&mut self) -> Result, Error> { let keyword = self.previous(); let value = if self.check(TokenKind::Semicolon) { None } else { - Some(self.expression()?) + Some(Rc::new(self.expression()?)) }; self.consume(TokenKind::Semicolon, "Expect ';' after return value.")?; - Ok(Stmt::Return(ReturnStmt { keyword, value })) + Ok(Rc::new(Stmt::Return(Rc::new(ReturnStmt { keyword, value })))) } - fn for_statement(&mut self) -> Result { + fn for_statement(&mut self) -> Result, Error> { self.consume(TokenKind::LeftParen, "Expect '(' after 'for'.")?; let initializer = if self.match_token(vec![TokenKind::Semicolon]) { @@ -112,26 +113,28 @@ impl Parser { let mut body = self.statement()?; if let Some(incr) = increment { - body = Stmt::Block(BlockStmt { - statements: vec![body, Stmt::Expression(ExpressionStmt { expression: incr })], - }); + body = Rc::new(Stmt::Block(Rc::new(BlockStmt { + statements: Rc::new(vec![body, Rc::new(Stmt::Expression(Rc::new(ExpressionStmt { + expression: Rc::new(incr), + })))]), + }))); } - body = Stmt::While(WhileStmt { + body = Rc::new(Stmt::While(Rc::new(WhileStmt { condition: if let Some(cond) = condition { - cond + Rc::new(cond) } else { - Expr::Literal(LiteralExpr { + Rc::new(Expr::Literal(Rc::new(LiteralExpr { value: Some(Object::Bool(true)), - }) + }))) }, - body: Box::new(body), - }); + body, + }))); if let Some(init) = initializer { - body = Stmt::Block(BlockStmt { - statements: vec![init, body], - }); + body = Rc::new(Stmt::Block(Rc::new(BlockStmt { + statements: Rc::new(vec![init, body]), + }))); } Ok(body) @@ -139,34 +142,34 @@ impl Parser { fn if_statement(&mut self) -> Result { self.consume(TokenKind::LeftParen, "Expect '(' after 'if'.")?; - let condition = self.expression()?; + let condition = Rc::new(self.expression()?); self.consume(TokenKind::RightParen, "Expect ')' after if condition.")?; - let then_branch = Box::new(self.statement()?); + let then_branch = self.statement()?; let else_branch = if self.match_token(vec![TokenKind::Else]) { - Some(Box::new(self.statement()?)) + Some(self.statement()?) } else { None }; - Ok(Stmt::If(IfStmt { + Ok(Stmt::If(Rc::new(IfStmt { condition, then_branch, else_branch, - })) + }))) } fn print_statement(&mut self) -> Result { - let value = self.expression()?; + let value = Rc::new(self.expression()?); self.consume(TokenKind::Semicolon, "Expect ';' after value.")?; - Ok(Stmt::Print(PrintStmt { expression: value })) + Ok(Stmt::Print(Rc::new(PrintStmt { expression: value }))) } - fn var_declaration(&mut self) -> Result { + fn var_declaration(&mut self) -> Result, Error> { let name = self.consume(TokenKind::Identifier, "Expect variable name.")?; let initializer = if self.match_token(vec![TokenKind::Equal]) { - Some(self.expression()?) + Some(Rc::new(self.expression()?)) } else { None }; @@ -176,25 +179,25 @@ impl Parser { "Expect ';' after variable declaration.", )?; - Ok(Stmt::Var(VarStmt { name, initializer })) + Ok(Rc::new(Stmt::Var(Rc::new(VarStmt { name, initializer })))) } fn while_statement(&mut self) -> Result { self.consume(TokenKind::LeftParen, "Expect '(' after 'while'.")?; - let condition = self.expression()?; + let condition = Rc::new(self.expression()?); self.consume(TokenKind::RightParen, "Expect ')' after condition.")?; - let body = Box::new(self.statement()?); + let body = self.statement()?; - Ok(Stmt::While(WhileStmt { condition, body })) + Ok(Stmt::While(Rc::new(WhileStmt { condition, body }))) } - fn expression_statement(&mut self) -> Result { - let value = self.expression()?; + fn expression_statement(&mut self) -> Result, Error> { + let value = Rc::new(self.expression()?); self.consume(TokenKind::Semicolon, "Expect ';' after value.")?; - Ok(Stmt::Expression(ExpressionStmt { expression: value })) + Ok(Rc::new(Stmt::Expression(Rc::new(ExpressionStmt { expression: value })))) } - fn function(&mut self, kind: &str) -> Result { + fn function(&mut self, kind: &str) -> Result, Error> { let name = self.consume(TokenKind::Identifier, &format!("Expect {kind} name"))?; self.consume(TokenKind::LeftParen, &format!("Expect '(' after {kind} name"))?; @@ -216,12 +219,16 @@ impl Parser { self.consume(TokenKind::RightParen, "Expect ')' after parameter")?; self.consume(TokenKind::LeftBrace, &format!("Expect '{{' after {kind} body"))?; - let body = self.block()?; - Ok(Stmt::Function(FunctionStmt { name, params: Rc::new(params), body: Rc::new(body) })) + let body = Rc::new(self.block()?); + Ok(Rc::new(Stmt::Function(Rc::new(FunctionStmt { + name, + params: Rc::new(params), + body, + })))) } - fn block(&mut self) -> Result, Error> { - let mut statements: Vec = Vec::new(); + fn block(&mut self) -> Result>, Error> { + let mut statements: Vec> = Vec::new(); while !self.check(TokenKind::RightBrace) && !self.is_at_end() { statements.push(self.declaration()?); @@ -241,10 +248,10 @@ impl Parser { match expr { Expr::Variable(v) => { - return Ok(Expr::Assign(AssignExpr { - name: v.name, - value: Box::new(value), - })); + return Ok(Expr::Assign(Rc::new(AssignExpr { + name: v.name.clone(), + value: Rc::new(value), + }))); } _ => { return Err(Error::parse_error( @@ -264,11 +271,11 @@ impl Parser { while self.match_token(vec![TokenKind::Or]) { let operator = self.previous(); let right = self.and()?; - expr = Expr::Logical(LogicalExpr { - left: Box::new(expr), + expr = Expr::Logical(Rc::new(LogicalExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -280,11 +287,11 @@ impl Parser { while self.match_token(vec![TokenKind::And]) { let operator = self.previous(); let right = self.equality()?; - expr = Expr::Logical(LogicalExpr { - left: Box::new(expr), + expr = Expr::Logical(Rc::new(LogicalExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -296,11 +303,11 @@ impl Parser { while self.match_token(vec![TokenKind::NotEqual, TokenKind::EqualEqual]) { let operator = self.previous(); let right = self.comparison()?; - expr = Expr::Binary(BinaryExpr { - left: Box::new(expr), + expr = Expr::Binary(Rc::new(BinaryExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -317,11 +324,11 @@ impl Parser { ]) { let operator = self.previous(); let right = self.term()?; - expr = Expr::Binary(BinaryExpr { - left: Box::new(expr), + expr = Expr::Binary(Rc::new(BinaryExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -333,11 +340,11 @@ impl Parser { while self.match_token(vec![TokenKind::Minus, TokenKind::Plus]) { let operator = self.previous(); let right = self.factor()?; - expr = Expr::Binary(BinaryExpr { - left: Box::new(expr), + expr = Expr::Binary(Rc::new(BinaryExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -349,11 +356,11 @@ impl Parser { while self.match_token(vec![TokenKind::Slash, TokenKind::Asterisk]) { let operator = self.previous(); let right = self.unary()?; - expr = Expr::Binary(BinaryExpr { - left: Box::new(expr), + expr = Expr::Binary(Rc::new(BinaryExpr { + left: Rc::new(expr), operator, - right: Box::new(right), - }); + right: Rc::new(right), + })); } Ok(expr) @@ -363,17 +370,17 @@ impl Parser { if self.match_token(vec![TokenKind::Bang, TokenKind::Minus]) { let operator = self.previous(); let right = self.unary()?; - return Ok(Expr::Unary(UnaryExpr { + return Ok(Expr::Unary(Rc::new(UnaryExpr { operator, - right: Box::new(right), - })); + right: Rc::new(right), + }))); } Ok(self.call()?) } fn finish_call(&mut self, callee: Expr) -> Result { - let mut arguments: Vec = Vec::new(); + let mut arguments: Vec> = Vec::new(); if !self.check(TokenKind::RightParen) { loop { @@ -383,7 +390,7 @@ impl Parser { "Can't have more than 255 arguments.", )); } - arguments.push(self.expression()?); + arguments.push(Rc::new(self.expression()?)); if !self.match_token(vec![TokenKind::Comma]) { break; } @@ -392,11 +399,11 @@ impl Parser { let paren = self.consume(TokenKind::RightParen, "Expect ')' after arguments.")?; - Ok(Expr::Call(CallExpr { - callee: Box::new(callee), + Ok(Expr::Call(Rc::new(CallExpr { + callee: Rc::new(callee), paren, arguments, - })) + }))) } fn call(&mut self) -> Result { @@ -414,41 +421,41 @@ impl Parser { fn primary(&mut self) -> Result { if self.match_token(vec![TokenKind::False]) { - return Ok(Expr::Literal(LiteralExpr { + return Ok(Expr::Literal(Rc::new(LiteralExpr { value: Some(Object::Bool(false)), - })); + }))); } if self.match_token(vec![TokenKind::True]) { - return Ok(Expr::Literal(LiteralExpr { + return Ok(Expr::Literal(Rc::new(LiteralExpr { value: Some(Object::Bool(true)), - })); + }))); } if self.match_token(vec![TokenKind::Nil]) { - return Ok(Expr::Literal(LiteralExpr { + return Ok(Expr::Literal(Rc::new(LiteralExpr { value: Some(Object::Nil), - })); + }))); } if self.match_token(vec![TokenKind::Number, TokenKind::String]) { - return Ok(Expr::Literal(LiteralExpr { + return Ok(Expr::Literal(Rc::new(LiteralExpr { value: self.previous().literal, - })); + }))); } if self.match_token(vec![TokenKind::Identifier]) { - return Ok(Expr::Variable(VariableExpr { + return Ok(Expr::Variable(Rc::new(VariableExpr { name: self.previous(), - })); + }))); } if self.match_token(vec![TokenKind::LeftParen]) { let expr = self.expression()?; self.consume(TokenKind::RightParen, "Expect ')' after expression.")?; - return Ok(Expr::Grouping(GroupingExpr { - expression: Box::new(expr), - })); + return Ok(Expr::Grouping(Rc::new(GroupingExpr { + expression: Rc::new(expr), + }))); } Err(Error::parse_error(&self.peek(), "Expect expression.")) diff --git a/src/resolver.rs b/src/resolver.rs new file mode 100644 index 0000000..7c29218 --- /dev/null +++ b/src/resolver.rs @@ -0,0 +1,182 @@ +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::ops::Deref; + +use crate::errors::*; +use crate::expr::*; +use crate::interpreter::*; +use crate::stmt::*; +use crate::tokens::*; + +struct Resolver { + interpreter: Interpreter, + scopes: RefCell>>>, +} + +impl StmtVisitor<()> for Resolver { + fn visit_return_stmt(&self, _: &Rc ,stmt: &ReturnStmt) -> Result<(), Error> { + if let Some(value) = &stmt.value { + self.resolve_expr(value); + } + + Ok(()) + } + fn visit_function_stmt(&self, _: &Rc, stmt: &FunctionStmt) -> Result<(), Error> { + self.declare(&stmt.name); + self.define(&stmt.name); + self.resolve_function(stmt); + Ok(()) + } + fn visit_while_stmt(&self, _: &Rc, stmt: &WhileStmt) -> Result<(), Error> { + self.resolve_expr(&stmt.condition); + self.resolve_stmt(&stmt.body); + Ok(()) + } + fn visit_if_stmt(&self, _: &Rc, stmt: &IfStmt) -> Result<(), Error> { + self.resolve_expr(&stmt.condition); + self.resolve_stmt(&stmt.then_branch); + if let Some(else_branch) = &stmt.else_branch { + self.resolve_stmt(else_branch); + } + Ok(()) + } + + fn visit_block_stmt(&self, _: &Rc, stmt: &BlockStmt) -> Result<(), Error> { + self.begin_scope(); + self.resolve(&stmt.statements); + self.end_scope(); + Ok(()) + } + + fn visit_expression_stmt(&self, _: &Rc, stmt: &ExpressionStmt) -> Result<(), Error> { + self.resolve_expr(&stmt.expression); + Ok(()) + } + fn visit_print_stmt(&self, _: &Rc, stmt: &PrintStmt) -> Result<(), Error> { + self.resolve_expr(&stmt.expression); + Ok(()) + } + + fn visit_var_stmt(&self, _: &Rc, stmt: &VarStmt) -> Result<(), Error> { + self.declare(&stmt.name); + if let Some(init) = &stmt.initializer { + self.resolve_expr(&init); + } + self.define(&stmt.name); + Ok(()) + } +} + +impl ExprVisitor<()> for Resolver { + fn visit_call_expr(&self, _: &Rc, expr: &CallExpr) -> Result<(), Error> { + self.resolve_expr(&expr.callee); + + for arg in expr.arguments.iter() { + self.resolve_expr(arg); + } + Ok(()) + } + fn visit_logical_expr(&self, _: &Rc, expr: &LogicalExpr) -> Result<(), Error> { + self.resolve_expr(&expr.left); + self.resolve_expr(&expr.right); + Ok(()) + } + fn visit_assign_expr(&self, wrapper: &Rc, expr: &AssignExpr) -> Result<(), Error> { + self.resolve_expr(&expr.value); + self.resolve_local(wrapper, &expr.name); + Ok(()) + } + fn visit_literal_expr(&self, _: &Rc, _expr: &LiteralExpr) -> Result<(), Error> { + Ok(()) + } + fn visit_grouping_expr(&self, _: &Rc, expr: &GroupingExpr) -> Result<(), Error> { + self.resolve_expr(&expr.expression); + Ok(()) + } + fn visit_binary_expr(&self, _: &Rc, expr: &BinaryExpr) -> Result<(), Error> { + self.resolve_expr(&expr.left); + self.resolve_expr(&expr.right); + + Ok(()) + } + fn visit_unary_expr(&self, _: &Rc, expr: &UnaryExpr) -> Result<(), Error> { + self.resolve_expr(&expr.right); + Ok(()) + } + fn visit_variable_expr(&self, wrapper: &Rc, expr: &VariableExpr) -> Result<(), Error> { + if !self.scopes.borrow().is_empty() && !self.scopes.borrow().last().unwrap().borrow().get(&expr.name.lexeme).unwrap(){ + return Err(Error::runtime_error( + &expr.name, + "Cannot read local variable in its own initializer.", + )); + } + self.resolve_local(wrapper, &expr.name); + Ok(()) + } +} + +impl Resolver { + fn resolve(&self, statements: &Rc>>){ + for statement in statements.deref() { + self.resolve_stmt(statement); + } + } + + fn resolve_stmt(&self, stmt: &Rc){ + stmt.accept(stmt, self); + } + + fn resolve_expr(&self, expr: &Rc){ + expr.accept(expr, self); + } + + fn begin_scope(&self) { + self.scopes.borrow_mut().push(RefCell::new(HashMap::new())); + } + + fn end_scope(&self) { + self.scopes.borrow_mut().pop(); + } + + fn declare(&self, name: &Token) { + if !self.scopes.borrow().is_empty() { + self.scopes + .borrow() + .last() + .unwrap() + .borrow_mut() + .insert(name.lexeme.clone(), false); + } + } + + fn define(&self, name: &Token) { + if !self.scopes.borrow().is_empty() { + self.scopes + .borrow() + .last() + .unwrap() + .borrow_mut() + .insert(name.lexeme.clone(), true); + } + } + + fn resolve_local(&self, expr: &Rc, name: &Token) { + for (i, scope) in self.scopes.borrow().iter().enumerate().rev() { + if scope.borrow().contains_key(&name.lexeme) { + self.interpreter.resolve(expr.clone(), i); + return; + } + } + } + + fn resolve_function(&self, function: &FunctionStmt){ + self.begin_scope(); + for param in function.params.iter(){ + self.declare(param); + self.define(param); + } + self.resolve(&function.body); + self.end_scope(); + } +} \ No newline at end of file diff --git a/src/stmt.rs b/src/stmt.rs index 922132f..e25eb5f 100644 --- a/src/stmt.rs +++ b/src/stmt.rs @@ -4,126 +4,113 @@ use crate::tokens::*; use std::rc::Rc; pub enum Stmt { - Block(BlockStmt), - Expression(ExpressionStmt), - Function(FunctionStmt), - If(IfStmt), - Print(PrintStmt), - Var(VarStmt), - While(WhileStmt), - Return(ReturnStmt), + Block(Rc), + Expression(Rc), + Function(Rc), + If(Rc), + Print(Rc), + Var(Rc), + While(Rc), + Return(Rc), +} + +impl PartialEq for Stmt { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Stmt::Block(a), Stmt::Block(b)) => Rc::ptr_eq(a, b), + (Stmt::Expression(a), Stmt::Expression(b)) => Rc::ptr_eq(a, b), + (Stmt::Function(a), Stmt::Function(b)) => Rc::ptr_eq(a, b), + (Stmt::If(a), Stmt::If(b)) => Rc::ptr_eq(a, b), + (Stmt::Print(a), Stmt::Print(b)) => Rc::ptr_eq(a, b), + (Stmt::Var(a), Stmt::Var(b)) => Rc::ptr_eq(a, b), + (Stmt::While(a), Stmt::While(b)) => Rc::ptr_eq(a, b), + (Stmt::Return(a), Stmt::Return(b)) => Rc::ptr_eq(a, b), + _ => false, + } + } +} + +impl Eq for Stmt{} + +use std::hash::{Hash, Hasher}; +impl Hash for Stmt { + fn hash(&self, hasher: &mut H) + where H: Hasher, + { match self { + Stmt::Block(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Expression(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Function(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::If(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Print(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Var(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::While(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Return(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + } + } } impl Stmt { - pub fn accept(&self, stmt_visitor: &dyn StmtVisitor) -> Result { + pub fn accept(&self, wrapper: &Rc, stmt_visitor: &dyn StmtVisitor) -> Result { match self { - Stmt::Block(v) => v.accept(stmt_visitor), - Stmt::Expression(v) => v.accept(stmt_visitor), - Stmt::Function(v) => v.accept(stmt_visitor), - Stmt::If(v) => v.accept(stmt_visitor), - Stmt::Print(v) => v.accept(stmt_visitor), - Stmt::Var(v) => v.accept(stmt_visitor), - Stmt::While(v) => v.accept(stmt_visitor), - Stmt::Return(v) => v.accept(stmt_visitor), + Stmt::Block(v) => stmt_visitor.visit_block_stmt(wrapper, &v), + Stmt::Expression(v) => stmt_visitor.visit_expression_stmt(wrapper, &v), + Stmt::Function(v) => stmt_visitor.visit_function_stmt(wrapper, &v), + Stmt::If(v) => stmt_visitor.visit_if_stmt(wrapper, &v), + Stmt::Print(v) => stmt_visitor.visit_print_stmt(wrapper, &v), + Stmt::Var(v) => stmt_visitor.visit_var_stmt(wrapper, &v), + Stmt::While(v) => stmt_visitor.visit_while_stmt(wrapper, &v), + Stmt::Return(v) => stmt_visitor.visit_return_stmt(wrapper, &v), } } } pub struct BlockStmt { - pub statements: Vec, + pub statements: Rc>>, } pub struct ExpressionStmt { - pub expression: Expr, + pub expression: Rc, } pub struct FunctionStmt { pub name: Token, pub params: Rc>, - pub body: Rc>, + pub body: Rc>>, } pub struct IfStmt { - pub condition: Expr, - pub then_branch: Box, - pub else_branch: Option>, + pub condition: Rc, + pub then_branch: Rc, + pub else_branch: Option>, } pub struct PrintStmt { - pub expression: Expr, + pub expression: Rc, } pub struct VarStmt { pub name: Token, - pub initializer: Option, + pub initializer: Option>, } pub struct WhileStmt { - pub condition: Expr, - pub body: Box, + pub condition: Rc, + pub body: Rc, } pub struct ReturnStmt { pub keyword: Token, - pub value: Option, + pub value: Option>, } pub trait StmtVisitor { - fn visit_block_stmt(&self, expr: &BlockStmt) -> Result; - fn visit_expression_stmt(&self, expr: &ExpressionStmt) -> Result; - fn visit_function_stmt(&self, expr: &FunctionStmt) -> Result; - fn visit_if_stmt(&self, expr: &IfStmt) -> Result; - fn visit_print_stmt(&self, expr: &PrintStmt) -> Result; - fn visit_var_stmt(&self, expr: &VarStmt) -> Result; - fn visit_while_stmt(&self, expr: &WhileStmt) -> Result; - fn visit_return_stmt(&self, expr: &ReturnStmt) -> Result; -} - -impl BlockStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_block_stmt(self) - } -} - -impl ExpressionStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_expression_stmt(self) - } -} - -impl FunctionStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_function_stmt(self) - } -} - -impl IfStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_if_stmt(self) - } -} - -impl PrintStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_print_stmt(self) - } -} - -impl VarStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_var_stmt(self) - } -} - -impl WhileStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_while_stmt(self) - } -} - -impl ReturnStmt { - pub fn accept(&self, visitor: &dyn StmtVisitor) -> Result { - visitor.visit_return_stmt(self) - } + fn visit_block_stmt(&self, wrapper: &Rc, stmt: &BlockStmt) -> Result; + fn visit_expression_stmt(&self, wrapper: &Rc, stmt: &ExpressionStmt) -> Result; + fn visit_function_stmt(&self, wrapper: &Rc, stmt: &FunctionStmt) -> Result; + fn visit_if_stmt(&self, wrapper: &Rc, stmt: &IfStmt) -> Result; + fn visit_print_stmt(&self, wrapper: &Rc, stmt: &PrintStmt) -> Result; + fn visit_var_stmt(&self, wrapper: &Rc, stmt: &VarStmt) -> Result; + fn visit_while_stmt(&self, wrapper: &Rc, stmt: &WhileStmt) -> Result; + fn visit_return_stmt(&self, wrapper: &Rc, stmt: &ReturnStmt) -> Result; }