From 68fce9a05caa51071128d0dec13b816b9bf04248 Mon Sep 17 00:00:00 2001 From: aym-n Date: Wed, 10 Jan 2024 11:19:16 +0530 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20added=20classes=20and=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- generate_ast/mod.rs | 3 ++ main.arc | 23 +++---------- src/expr.rs | 21 ++++++++++++ src/instance.rs | 38 ++++++++++++++++++++++ src/interpreter.rs | 79 +++++++++++++++++++++++++++++++++++++++------ src/main.rs | 2 ++ src/parser.rs | 36 +++++++++++++++++++-- src/resolver.rs | 29 +++++++++++++++++ src/stmt.rs | 10 ++++++ src/tokens.rs | 47 +++++++++++++++++++++++++++ 10 files changed, 258 insertions(+), 30 deletions(-) create mode 100644 src/instance.rs diff --git a/generate_ast/mod.rs b/generate_ast/mod.rs index 50b0c9b..c3d80fd 100644 --- a/generate_ast/mod.rs +++ b/generate_ast/mod.rs @@ -17,9 +17,11 @@ pub fn generate_ast(output_dir: &str) -> io::Result<()> { "Assign : Token name, Rc value", "Binary : Rc left, Token operator, Rc right", "Call : Rc callee, Token paren, Vec> arguments", + "Get : Rc object, Token name", "Grouping : Rc expression", "Literal : Option value", "Logical : Rc left, Token operator, Rc right", + "Set : Rc object, Token name, Rc value", "Unary : Token operator, Rc right", "Variable : Token name", ], @@ -30,6 +32,7 @@ pub fn generate_ast(output_dir: &str) -> io::Result<()> { &["errors", "expr", "tokens", "rc"], &[ "Block : Rc>> statements", + "Class : Token name, Rc>> methods", "Expression : Rc expression", "Function : Token name, Rc> params, Rc>> body", "If : Rc condition, Rc then_branch, Option> else_branch", diff --git a/main.arc b/main.arc index eb77f8d..2d7df01 100644 --- a/main.arc +++ b/main.arc @@ -1,20 +1,7 @@ - -var a = "global a"; -var b = "global b"; -var c = "global c"; -{ - var a = "outer a"; - var b = "outer b"; - { - var a = "inner a"; - print a; - print b; - print c; +class Bacon { + eat() { + print "Crunch crunch crunch!"; } - print a; - print b; - print c; } -print a; -print b; -print c; \ No newline at end of file + +Bacon().eat(); \ No newline at end of file diff --git a/src/expr.rs b/src/expr.rs index aacdc8d..7baeee5 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -6,9 +6,11 @@ pub enum Expr { Assign(Rc), Binary(Rc), Call(Rc), + Get(Rc), Grouping(Rc), Literal(Rc), Logical(Rc), + Set(Rc), Unary(Rc), Variable(Rc), } @@ -19,9 +21,11 @@ impl PartialEq for Expr { (Expr::Assign(a), Expr::Assign(b)) => Rc::ptr_eq(a, b), (Expr::Binary(a), Expr::Binary(b)) => Rc::ptr_eq(a, b), (Expr::Call(a), Expr::Call(b)) => Rc::ptr_eq(a, b), + (Expr::Get(a), Expr::Get(b)) => Rc::ptr_eq(a, b), (Expr::Grouping(a), Expr::Grouping(b)) => Rc::ptr_eq(a, b), (Expr::Literal(a), Expr::Literal(b)) => Rc::ptr_eq(a, b), (Expr::Logical(a), Expr::Logical(b)) => Rc::ptr_eq(a, b), + (Expr::Set(a), Expr::Set(b)) => Rc::ptr_eq(a, b), (Expr::Unary(a), Expr::Unary(b)) => Rc::ptr_eq(a, b), (Expr::Variable(a), Expr::Variable(b)) => Rc::ptr_eq(a, b), _ => false, @@ -39,9 +43,11 @@ impl Hash for Expr { Expr::Assign(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Binary(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Call(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Get(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Grouping(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Literal(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Logical(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Expr::Set(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Unary(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Expr::Variable(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } } @@ -54,9 +60,11 @@ impl Expr { Expr::Assign(v) => expr_visitor.visit_assign_expr(wrapper, v), Expr::Binary(v) => expr_visitor.visit_binary_expr(wrapper, v), Expr::Call(v) => expr_visitor.visit_call_expr(wrapper, v), + Expr::Get(v) => expr_visitor.visit_get_expr(wrapper, v), Expr::Grouping(v) => expr_visitor.visit_grouping_expr(wrapper, v), Expr::Literal(v) => expr_visitor.visit_literal_expr(wrapper, v), Expr::Logical(v) => expr_visitor.visit_logical_expr(wrapper, v), + Expr::Set(v) => expr_visitor.visit_set_expr(wrapper, v), Expr::Unary(v) => expr_visitor.visit_unary_expr(wrapper, v), Expr::Variable(v) => expr_visitor.visit_variable_expr(wrapper, v), } @@ -80,6 +88,11 @@ pub struct CallExpr { pub arguments: Vec>, } +pub struct GetExpr { + pub object: Rc, + pub name: Token, +} + pub struct GroupingExpr { pub expression: Rc, } @@ -94,6 +107,12 @@ pub struct LogicalExpr { pub right: Rc, } +pub struct SetExpr { + pub object: Rc, + pub name: Token, + pub value: Rc, +} + pub struct UnaryExpr { pub operator: Token, pub right: Rc, @@ -107,9 +126,11 @@ pub trait ExprVisitor { fn visit_assign_expr(&self, wrapper: Rc, expr: &AssignExpr) -> Result; fn visit_binary_expr(&self, wrapper: Rc, expr: &BinaryExpr) -> Result; fn visit_call_expr(&self, wrapper: Rc, expr: &CallExpr) -> Result; + fn visit_get_expr(&self, wrapper: Rc, expr: &GetExpr) -> Result; fn visit_grouping_expr(&self, wrapper: Rc, expr: &GroupingExpr) -> Result; fn visit_literal_expr(&self, wrapper: Rc, expr: &LiteralExpr) -> Result; fn visit_logical_expr(&self, wrapper: Rc, expr: &LogicalExpr) -> Result; + fn visit_set_expr(&self, wrapper: Rc, expr: &SetExpr) -> Result; fn visit_unary_expr(&self, wrapper: Rc, expr: &UnaryExpr) -> Result; fn visit_variable_expr(&self, wrapper: Rc, expr: &VariableExpr) -> Result; } diff --git a/src/instance.rs b/src/instance.rs new file mode 100644 index 0000000..2a43977 --- /dev/null +++ b/src/instance.rs @@ -0,0 +1,38 @@ +use crate::errors::*; +use crate::tokens::*; +use std::cell::RefCell; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(Debug, PartialEq, Clone)] +pub struct InstanceStruct { + pub class: Rc, + fields: RefCell>, +} + +impl InstanceStruct { + pub fn new(class: Rc) -> Self { + InstanceStruct { + class: Rc::clone(&class), + fields: RefCell::new(HashMap::new()), + } + } + + pub fn get(&self, name: &Token) -> Result { + if let Entry::Occupied(entry) = self.fields.borrow_mut().entry(name.lexeme.clone()) { + Ok(entry.get().clone()) + } else if let Some(method) = self.class.find_method(name.lexeme.clone()) { + Ok(method) + } else { + Err(Error::runtime_error( + name, + &format!("Undefined property '{}'.", name.lexeme), + )) + } + } + + pub fn set(&self, name: &Token, value: Object) { + self.fields.borrow_mut().insert(name.lexeme.clone(), value); + } +} diff --git a/src/interpreter.rs b/src/interpreter.rs index 04d9316..7e44f94 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -18,6 +18,36 @@ pub struct Interpreter { } impl StmtVisitor<()> for Interpreter { + fn visit_class_stmt(&self, _: Rc, stmt: &ClassStmt) -> Result<(), Error> { + self.environment + .borrow() + .borrow_mut() + .define(stmt.name.lexeme.clone(), Object::Nil); + + let mut methods = HashMap::new(); + for method in stmt.methods.deref() { + if let Stmt::Function(func) = method.deref() { + let function = Object::Function(Callable { + func: Rc::new(Function::new(func, self.environment.borrow().deref())), + }); + methods.insert(func.name.lexeme.clone(), function); + } else { + return Err(Error::runtime_error( + &stmt.name, + "Class method did not resolve to a function.", + )); + }; + } + + let cls = Object::Class(Rc::new(ClassStruct::new(stmt.name.lexeme.clone(), methods))); + + self.environment + .borrow() + .borrow_mut() + .assign(&stmt.name, cls.clone())?; + Ok(()) + } + fn visit_return_stmt(&self, _: Rc, stmt: &ReturnStmt) -> Result<(), Error> { if let Some(value) = stmt.value.clone() { let value = self.evaluate(value)?; @@ -56,11 +86,7 @@ impl StmtVisitor<()> for Interpreter { Ok(()) } - fn visit_expression_stmt( - &self, - _: Rc, - stmt: &ExpressionStmt, - ) -> Result<(), Error> { + fn visit_expression_stmt(&self, _: Rc, stmt: &ExpressionStmt) -> Result<(), Error> { self.evaluate(stmt.expression.clone())?; Ok(()) } @@ -88,6 +114,31 @@ impl StmtVisitor<()> for Interpreter { } impl ExprVisitor for Interpreter { + fn visit_set_expr(&self, wrapper: Rc, expr: &SetExpr) -> Result { + let object = self.evaluate(expr.object.clone())?; + + if let Object::Instance(inst) = object { + let value = self.evaluate(expr.value.clone())?; + inst.set(&expr.name, value.clone()); + return Ok(value); + } + + Err(Error::runtime_error( + &expr.name, + "Only instances have fields.", + )) + } + + fn visit_get_expr(&self, wrapper: Rc, expr: &GetExpr) -> Result { + let object = self.evaluate(expr.object.clone())?; + if let Object::Instance(inst) = object { + return inst.get(&expr.name); + } + Err(Error::runtime_error( + &expr.name, + "Only instances have properties.", + )) + } fn visit_logical_expr(&self, _: Rc, expr: &LogicalExpr) -> Result { let left = self.evaluate(expr.left.clone())?; @@ -124,11 +175,7 @@ impl ExprVisitor for Interpreter { Ok(expr.value.clone().unwrap()) } - fn visit_grouping_expr( - &self, - _: Rc, - expr: &GroupingExpr, - ) -> Result { + fn visit_grouping_expr(&self, _: Rc, expr: &GroupingExpr) -> Result { Ok(self.evaluate(expr.expression.clone())?) } @@ -247,6 +294,18 @@ impl ExprVisitor for Interpreter { )); } function.func.call(self, &arguments) + } else if let Object::Class(cls) = callee { + if arguments.len() != cls.arity() { + return Err(Error::runtime_error( + &expr.paren, + &format!( + "Expected {} arguments but got {}", + cls.arity(), + arguments.len() + ), + )); + } + cls.instantiate(self, arguments, Rc::clone(&cls)) } else { return Err(Error::runtime_error( &expr.paren, diff --git a/src/main.rs b/src/main.rs index 149d0f9..f6f4eb1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,8 @@ use resolver::*; use std::rc::Rc; +mod instance; + fn eval(source: &str) -> Result<(), Error> { let lexer = Lexer::new(source.to_string()); let mut tokens: Vec = lexer.collect(); diff --git a/src/parser.rs b/src/parser.rs index cd0faf8..7461796 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -27,7 +27,9 @@ impl Parser { } fn declaration(&mut self) -> Result, Error> { - let result = if self.match_token(vec![TokenKind::Fn]) { + let result = if self.match_token(vec![TokenKind::Class]){ + self.class_declaration() + }else if self.match_token(vec![TokenKind::Fn]) { self.function("function") } else if self.match_token(vec![TokenKind::Var]) { self.var_declaration() @@ -42,6 +44,23 @@ impl Parser { result } + fn class_declaration(&mut self) -> Result, Error> { + let name = self.consume(TokenKind::Identifier, "Expect class name.")?; + + let mut methods = Vec::new(); + + self.consume(TokenKind::LeftBrace, "Expect '{' before class body.")?; + while !self.check(TokenKind::RightBrace) && !self.is_at_end() { + methods.push(self.function("method")?); + } + self.consume(TokenKind::RightBrace, "Expect '}' after class body.")?; + + Ok(Rc::new(Stmt::Class(Rc::new(ClassStmt { + name, + methods: Rc::new(methods), + })))) + } + fn statement(&mut self) -> Result, Error> { if self.match_token(vec![TokenKind::For]) { return self.for_statement(); @@ -253,6 +272,13 @@ impl Parser { value: Rc::new(value), }))); } + Expr::Get(g) => { + return Ok(Expr::Set(Rc::new(SetExpr { + object: Rc::clone(&g.object), + name: g.name.clone(), + value: Rc::new(value), + }))); + } _ => { return Err(Error::parse_error( &equals, @@ -412,7 +438,13 @@ impl Parser { loop { if self.match_token(vec![TokenKind::LeftParen]) { expr = self.finish_call(expr)?; - } else { + }else if self.match_token(vec![TokenKind::Dot]) { + let name = self.consume(TokenKind::Identifier, "Expect property name after '.'.")?; + expr = Expr::Get(Rc::new(GetExpr { + object: Rc::new(expr), + name, + })); + }else { break; } } diff --git a/src/resolver.rs b/src/resolver.rs index 5d5a914..558ee9e 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -20,9 +20,26 @@ pub struct Resolver<'a>{ enum FunctionType { None, Function, + Method, } impl<'a> StmtVisitor<()> for Resolver<'a>{ + fn visit_class_stmt(&self , _: Rc, stmt: &ClassStmt) -> Result<(), Error> { + self.declare(&stmt.name); + self.define(&stmt.name); + + for method in stmt.methods.deref() { + let declaration = FunctionType::Method; + if let Stmt::Function(method) = method.deref() { + self.resolve_function(method, declaration); + }else{ + return Err(Error::runtime_error(&stmt.name, "Class method did not resolve to a function.")); + } + } + + Ok(()) + } + fn visit_return_stmt(&self, _: Rc ,stmt: &ReturnStmt) -> Result<(), Error> { if *self.current_function.borrow() == FunctionType::None { return Err(Error::runtime_error(&stmt.keyword, "Cannot return from top-level code.")); @@ -81,6 +98,18 @@ impl<'a> StmtVisitor<()> for Resolver<'a>{ } impl<'a> ExprVisitor<()> for Resolver<'a>{ + + fn visit_set_expr(&self, _: Rc, expr: &SetExpr) -> Result<(), Error> { + self.resolve_expr(expr.value.clone()); + self.resolve_expr(expr.object.clone()); + Ok(()) + } + + fn visit_get_expr(&self, _: Rc, expr: &GetExpr) -> Result<(), Error> { + self.resolve_expr(expr.object.clone()); + Ok(()) + } + fn visit_call_expr(&self, _: Rc, expr: &CallExpr) -> Result<(), Error> { self.resolve_expr(expr.callee.clone()); diff --git a/src/stmt.rs b/src/stmt.rs index c1bdc25..0b622dd 100644 --- a/src/stmt.rs +++ b/src/stmt.rs @@ -5,6 +5,7 @@ use std::rc::Rc; pub enum Stmt { Block(Rc), + Class(Rc), Expression(Rc), Function(Rc), If(Rc), @@ -18,6 +19,7 @@ impl PartialEq for Stmt { fn eq(&self, other: &Self) -> bool { match (self, other) { (Stmt::Block(a), Stmt::Block(b)) => Rc::ptr_eq(a, b), + (Stmt::Class(a), Stmt::Class(b)) => Rc::ptr_eq(a, b), (Stmt::Expression(a), Stmt::Expression(b)) => Rc::ptr_eq(a, b), (Stmt::Function(a), Stmt::Function(b)) => Rc::ptr_eq(a, b), (Stmt::If(a), Stmt::If(b)) => Rc::ptr_eq(a, b), @@ -38,6 +40,7 @@ impl Hash for Stmt { where H: Hasher, { match self { Stmt::Block(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } + Stmt::Class(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Stmt::Expression(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Stmt::Function(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } Stmt::If(a) => { hasher.write_usize(Rc::as_ptr(a) as usize); } @@ -53,6 +56,7 @@ impl Stmt { pub fn accept(&self, wrapper: Rc, stmt_visitor: &dyn StmtVisitor) -> Result { match self { Stmt::Block(v) => stmt_visitor.visit_block_stmt(wrapper, v), + Stmt::Class(v) => stmt_visitor.visit_class_stmt(wrapper, v), Stmt::Expression(v) => stmt_visitor.visit_expression_stmt(wrapper, v), Stmt::Function(v) => stmt_visitor.visit_function_stmt(wrapper, v), Stmt::If(v) => stmt_visitor.visit_if_stmt(wrapper, v), @@ -68,6 +72,11 @@ pub struct BlockStmt { pub statements: Rc>>, } +pub struct ClassStmt { + pub name: Token, + pub methods: Rc>>, +} + pub struct ExpressionStmt { pub expression: Rc, } @@ -105,6 +114,7 @@ pub struct WhileStmt { pub trait StmtVisitor { fn visit_block_stmt(&self, wrapper: Rc, stmt: &BlockStmt) -> Result; + fn visit_class_stmt(&self, wrapper: Rc, stmt: &ClassStmt) -> Result; fn visit_expression_stmt(&self, wrapper: Rc, stmt: &ExpressionStmt) -> Result; fn visit_function_stmt(&self, wrapper: Rc, stmt: &FunctionStmt) -> Result; fn visit_if_stmt(&self, wrapper: Rc, stmt: &IfStmt) -> Result; diff --git a/src/tokens.rs b/src/tokens.rs index 11650a6..2f9bd1e 100644 --- a/src/tokens.rs +++ b/src/tokens.rs @@ -1,5 +1,11 @@ use std::fmt; use crate::callable::Callable; +use crate::errors::*; +use crate::interpreter::Interpreter; +use crate::callable::*; +use std::rc::Rc; +use crate::instance::*; +use std::collections::HashMap; #[derive(Debug, Default, PartialEq, Clone)] pub enum TokenKind { @@ -64,6 +70,8 @@ pub enum Object { Str(String), Bool(bool), Function(Callable), + Class(Rc), + Instance(Rc), Nil, ArithmeticError, } @@ -77,6 +85,8 @@ impl fmt::Display for Object { Object::Bool(x) => write!(f, "{x}"), Object::ArithmeticError => write!(f, "Arithmetic Error"), Object::Function(_) => write!(f, ""), + Object::Class(c) => write!(f, "", c.name), + Object::Instance(i) => write!(f, "", i.class.name), } } } @@ -88,6 +98,43 @@ impl fmt::Display for Token { } } +#[derive(Debug, PartialEq, Clone)] +pub struct ClassStruct { + pub name: String, + methods: HashMap, +} + +impl ClassStruct { + pub fn new(name: String, methods: HashMap) -> Self { + ClassStruct { + name, + methods, + } + } + + pub fn instantiate(&self, _interpreter: &Interpreter, _arguments: Vec, cls: Rc) -> Result { + Ok(Object::Instance(Rc::new(InstanceStruct::new(cls)))) + } + + pub fn find_method(&self, name: String) -> Option { + self.methods.get(&name).cloned() + } +} + +impl CallableTrait for ClassStruct { + fn call(&self, interpreter: &Interpreter, arguments: &Vec) -> Result { + Ok(Object::Num(237.0)) + } + + fn arity(&self) -> usize { + 0 + } + + fn stringify(&self) -> String { + self.name.clone() + } +} + #[derive(Debug, PartialEq, Clone)] pub struct Token { pub kind: TokenKind,