diff --git a/Expr.hpp b/Expr.hpp index 50402fd..6868df2 100644 --- a/Expr.hpp +++ b/Expr.hpp @@ -38,6 +38,11 @@ class Logical; template class Call; +template +class Get; + +template +class Set; template class Visitor { @@ -51,7 +56,8 @@ class Visitor { virtual R visitVariableExpr(shared_ptr> expr) = 0; virtual R visitLogicalExpr(shared_ptr> expr) = 0; virtual R visitCallExpr(shared_ptr> expr) = 0; - // virtual string visitSetExpr(Set& expr) = 0; + virtual R visitGetExpr(shared_ptr> expr) = 0; + virtual R visitSetExpr(shared_ptr> expr) = 0; // virtual string visitSuperExpr(Super& expr) = 0; // virtual string visitThisExpr(This& expr) = 0; }; @@ -194,28 +200,36 @@ class Call: vector>> arguments; }; -// template -// class Get: public Expr { -// public: -// Get(shared_ptr> object, Token name); -// R accept(shared_ptr> visitor) override { -// return visitor->visitGetExpr(*this); -// } -// Token name; -// shared_ptr> object; -// }; +template +class Get: + public Expr, + public std::enable_shared_from_this> +{ + public: + Get(shared_ptr> object_, Token name_): + object(object_), name(name_) { } + R accept(shared_ptr> visitor) override { + return visitor->visitGetExpr(this->shared_from_this()); + } + shared_ptr> object; + Token name; +}; -// template -// class Set: public Expr { -// public: -// Set(shared_ptr> object, Token name, shared_ptr> value); -// R accept(shared_ptr> visitor) override { -// return visitor->visitSetExpr(*this); -// }; -// shared_ptr> object; -// Token name; -// shared_ptr> value; -// }; +template +class Set: + public Expr, + public std::enable_shared_from_this> +{ + public: + Set(shared_ptr> object_, Token name_, shared_ptr> value_): + object(object_), name(name_), value(value_) { } + R accept(shared_ptr> visitor) override { + return visitor->visitSetExpr(this->shared_from_this()); + }; + shared_ptr> object; + Token name; + shared_ptr> value; +}; // template // class Super: public Expr { diff --git a/Interpreter.cpp b/Interpreter.cpp index f8798c5..59bd116 100644 --- a/Interpreter.cpp +++ b/Interpreter.cpp @@ -13,9 +13,10 @@ #include "./RuntimeError.hpp" #include "./Environment.hpp" #include "./LoxCallable.hpp" +#include "./LoxInstance.hpp" +#include "./LoxClass.hpp" #include "./LoxFunction.hpp" #include "./ReturnError.hpp" -#include "./LoxClass.hpp" #include "./lox.hpp" using std::stod; @@ -168,7 +169,8 @@ Object Interpreter::visitVariableExpr(shared_ptr> expr) { return lookUpVariable(expr->name, expr); } -Object Interpreter::lookUpVariable(Token name, shared_ptr> expr) { +Object Interpreter::lookUpVariable( + Token name, shared_ptr> expr) { auto distance = locals.find(expr); if (distance != locals.end()) { return environment->getAt(distance->second, name.lexeme); @@ -185,7 +187,8 @@ Object Interpreter::visitCallExpr(shared_ptr> expr) { arguments.push_back(evaluate(argument)); } - if (callee.type != Object::Object_fun && callee.type != Object::Object_class) { + if (callee.type != Object::Object_fun && + callee.type != Object::Object_class) { throw RuntimeError(expr->paren, "Can only call functions and classes."); } @@ -205,6 +208,29 @@ Object Interpreter::visitCallExpr(shared_ptr> expr) { return callable->call(shared_from_this(), arguments); } +Object Interpreter::visitGetExpr(shared_ptr> expr) { + Object object = evaluate(expr->object); + if (object.type == Object::Object_instance) { + return (object.instance)->get(expr->name); + } + + throw RuntimeError(expr->name, + "Only instances have properties."); +} + +Object Interpreter::visitSetExpr(shared_ptr> expr) { + Object object = evaluate(expr->object); + + if (object.type != Object::Object_instance) { + throw RuntimeError(expr->name, "Only instances have fields."); + } + + Object value = evaluate(expr->value); + value.instance->set(expr->name, value); + return value; +} + + void Interpreter::visitExpressionStmt(const Expression& stmt) { evaluate(stmt.expression); } @@ -235,7 +261,14 @@ void Interpreter::visitBlockStmt(const Block& stmt) { void Interpreter::visitClassStmt(const Class& stmt) { environment->define(stmt.name.lexeme, Object::make_nil_obj()); - auto klass = shared_ptr(new LoxClass(stmt.name.lexeme)); + + map> methods; + for (auto method : stmt.methods) { + shared_ptr function(new LoxFunction(method, environment)); + methods[method->name.lexeme] = function; + } + + auto klass = shared_ptr(new LoxClass(stmt.name.lexeme, methods)); environment->assign(stmt.name, Object::make_class_obj(klass)); } @@ -254,10 +287,10 @@ void Interpreter::visitIfStmt(const If& stmt) { } } -void Interpreter::visitFunctionStmt(const Function& stmt) { +void Interpreter::visitFunctionStmt(shared_ptr stmt) { shared_ptr function(new LoxFunction(stmt, environment)); Object obj = Object::make_fun_obj(function); - environment->define(stmt.name.lexeme, obj); + environment->define(stmt->name.lexeme, obj); } Object Interpreter::evaluate(shared_ptr> expr) { diff --git a/Interpreter.hpp b/Interpreter.hpp index a282aee..7085fae 100644 --- a/Interpreter.hpp +++ b/Interpreter.hpp @@ -34,6 +34,8 @@ class Interpreter: Object visitVariableExpr(shared_ptr> expr); Object visitLogicalExpr(shared_ptr> expr); Object visitCallExpr(shared_ptr> expr); + Object visitGetExpr(shared_ptr> expr); + Object visitSetExpr(shared_ptr> expr); void visitExpressionStmt(const Expression& stmt); void visitPrintStmt(const Print& stmt); void visitVarStmt(const Var& stmt); @@ -41,7 +43,7 @@ class Interpreter: void visitClassStmt(const Class& stmt); void visitIfStmt(const If& stmt); void visitWhileStmt(const While& stmt); - void visitFunctionStmt(const Function& stmt); + void visitFunctionStmt(shared_ptr stmt); void visitReturnStmt(const Return& stmt); void executeBlock( vector> statements, diff --git a/LoxClass.cpp b/LoxClass.cpp index 604ce2d..636709e 100644 --- a/LoxClass.cpp +++ b/LoxClass.cpp @@ -7,6 +7,7 @@ #include "./LoxClass.hpp" #include "./LoxInstance.hpp" #include "./Interpreter.hpp" +#include "./LoxFunction.hpp" #include "./Token.hpp" using std::string; @@ -14,7 +15,8 @@ using std::shared_ptr; using std::vector; -LoxClass::LoxClass(string name_): name(name_) {} +LoxClass::LoxClass(string name_, map> methods_): + name(name_), methods(methods_) {} Object LoxClass::call( shared_ptr interpreter, @@ -29,3 +31,11 @@ int LoxClass::arity() { string LoxClass::toString() { return name; } + +shared_ptr LoxClass::findMethod(string name) { + auto searched = methods.find(name); + if (searched != methods.end()) { + return searched->second; + } + return nullptr; +} \ No newline at end of file diff --git a/LoxClass.hpp b/LoxClass.hpp index 9dfa1bc..d1e144f 100644 --- a/LoxClass.hpp +++ b/LoxClass.hpp @@ -6,18 +6,26 @@ #include #include #include +#include #include "./LoxCallable.hpp" #include "./Interpreter.hpp" +#include "./LoxFunction.hpp" #include "./Token.hpp" using std::string; using std::shared_ptr; using std::vector; +using std::map; class LoxClass: public LoxCallable { public: string name; - explicit LoxClass(string name_); + map> methods; + shared_ptr findMethod(string name); + explicit LoxClass( + string name_, + map> methods_); + Object call( shared_ptr interpreter, vector arguments); diff --git a/LoxFunction.cpp b/LoxFunction.cpp index 16db7aa..7f73d9c 100644 --- a/LoxFunction.cpp +++ b/LoxFunction.cpp @@ -16,12 +16,12 @@ using std::vector; using std::string; LoxFunction::LoxFunction( - Function declaration_, + shared_ptr declaration_, shared_ptr closure_ ): declaration(declaration_), closure(closure_) {} int LoxFunction::arity() { - return declaration.params.size(); + return declaration->params.size(); } Object LoxFunction::call( @@ -30,14 +30,14 @@ Object LoxFunction::call( ) { shared_ptr environment(new Environment(closure)); - for (int i = 0; i < declaration.params.size(); i++) { + for (int i = 0; i < declaration->params.size(); i++) { environment->define( - declaration.params[i].lexeme, + declaration->params[i].lexeme, arguments[i]); } try { - interpreter->executeBlock(declaration.body, environment); + interpreter->executeBlock(declaration->body, environment); } catch (ReturnError returnValue) { return returnValue.value; } @@ -45,5 +45,5 @@ Object LoxFunction::call( } string LoxFunction::toString() { - return ""; + return "name.lexeme + ">"; } diff --git a/LoxFunction.hpp b/LoxFunction.hpp index 391548b..d3551fc 100644 --- a/LoxFunction.hpp +++ b/LoxFunction.hpp @@ -18,9 +18,12 @@ using std::string; class LoxFunction: public LoxCallable { public: - Function declaration; + shared_ptr declaration; shared_ptr closure; - explicit LoxFunction(Function declaration_, shared_ptr closure_); + explicit LoxFunction( + shared_ptr declaration_, + shared_ptr closure_ + ); int arity(); diff --git a/LoxInstance.cpp b/LoxInstance.cpp index c2e0d07..663a210 100644 --- a/LoxInstance.cpp +++ b/LoxInstance.cpp @@ -4,6 +4,8 @@ #include "./LoxInstance.hpp" #include "./LoxClass.hpp" +#include "./RuntimeError.hpp" +#include "./Token.hpp" using std::string; @@ -12,3 +14,22 @@ LoxInstance::LoxInstance(LoxClass klass_): klass(klass_) {} string LoxInstance::toString() { return klass.name + " instance"; } + +Object LoxInstance::get(Token name) { + auto searched = fields.find(name.lexeme); + if (searched != fields.end()) { + return searched->second; + } + + shared_ptr method = klass.findMethod(name.lexeme); + if (method != nullptr) { + return Object::make_fun_obj(method); + } + + throw RuntimeError(name, + "Undefined property '" + name.lexeme + "'."); +} + +void LoxInstance::set(Token name, Object value) { + fields[name.lexeme] = value; +} \ No newline at end of file diff --git a/LoxInstance.hpp b/LoxInstance.hpp index 749dcc6..7f836de 100644 --- a/LoxInstance.hpp +++ b/LoxInstance.hpp @@ -4,13 +4,19 @@ #define LOXINSTANCE_HPP_ #include +#include #include "./LoxClass.hpp" +#include "./Token.hpp" using std::string; +using std::map; class LoxInstance { public: LoxClass klass; + map fields; + Object get(Token name); + void set(Token name, Object value); explicit LoxInstance(LoxClass klass_); string toString(); }; diff --git a/Makefile b/Makefile index 081b0e2..2f8b49f 100644 --- a/Makefile +++ b/Makefile @@ -3,10 +3,10 @@ CXX = g++ all: main +Token.o: Token.cpp Token.hpp LoxInstance.o: LoxInstance.cpp LoxInstance.hpp LoxClass.o: LoxClass.cpp LoxClass.hpp Resolver.o: Resolver.cpp Resolver.hpp -Token.o: Token.cpp Token.hpp LoxFunction.o: LoxFunction.cpp LoxFunction.hpp Environment.o: Environment.cpp Environment.hpp Interpreter.o: Interpreter.cpp Interpreter.hpp diff --git a/Parser.cpp b/Parser.cpp index 083e512..6ffa81b 100644 --- a/Parser.cpp +++ b/Parser.cpp @@ -7,15 +7,15 @@ comparison → addition ( ( ">" | ">=" | "<" | "<=" ) addition )* ; addition → multiplication ( ( "-" | "+" ) multiplication )* ; multiplication → unary ( ( "/" | "*" ) unary )* ; unary → ( "!" | "-" ) unary | call ; -call → primary ( "(" arguments? ")" )* ; +call → primary ( "(" arguments? ")" | "." IDENTIFIER )* ; arguments → expression ( "," expression )* ; primary → "false" | "true" | "nil" | NUMBER | STRING | "(" expression ")" | IDENTIFIER ; expression → assignment ; -assignment → identifier "=" assignment - | logic_or ; +assignment → ( call "." )? IDENTIFIER "=" assignment + | logic_or; logic_or → logic_and ( "or" logic_and )* ; logic_and → equality ( "and" equality )* ; */ @@ -49,9 +49,15 @@ shared_ptr> Parser::assignment() { shared_ptr> value = assignment(); auto variable = dynamic_cast*>(expr.get()); + auto get = dynamic_cast*>(expr.get()); if (variable != nullptr) { Token name = variable->name; - return shared_ptr>(new Assign(name, value)); + return shared_ptr>( + new Assign(name, value)); + } + if (get != nullptr) { + return shared_ptr>( + new Set(get->object, get->name, value)); } error(equals, "Invalid assignment target."); } @@ -193,7 +199,7 @@ shared_ptr Parser::expressionStatement() { return expression; } -shared_ptr Parser::function(string kind) { +shared_ptr Parser::function(string kind) { Token name = consume(IDENTIFIER, "Expect " + kind + " name."); consume(LEFT_PAREN, "Expect '(' after " + kind + " name."); vector parameters; @@ -208,7 +214,7 @@ shared_ptr Parser::function(string kind) { consume(RIGHT_PAREN, "Expect ')' after parameters."); consume(LEFT_BRACE, "Expect '{' before " + kind + " body."); vector> body = block(); - return shared_ptr(new Function(name, parameters, body)); + return shared_ptr(new Function(name, parameters, body)); } vector> Parser::block() { @@ -297,6 +303,9 @@ shared_ptr> Parser::call() { while (true) { if (match({ LEFT_PAREN })) { expr = finishCall(expr); + } else if (match({ DOT })) { + Token name = consume(IDENTIFIER, "Expect property name after '.'."); + expr = shared_ptr>(new Get(expr, name)); } else { break; } @@ -437,7 +446,7 @@ shared_ptr Parser::classDeclaration() { Token name = consume(IDENTIFIER, "Expect class name."); consume(LEFT_BRACE, "Expect '{' before class body."); - vector> methods; + vector> methods; while (!check(RIGHT_BRACE) && !isAtEnd()) { methods.push_back(function("method")); } diff --git a/Parser.hpp b/Parser.hpp index 8e9d800..4360e49 100644 --- a/Parser.hpp +++ b/Parser.hpp @@ -54,7 +54,7 @@ class Parser { shared_ptr printStatement(); shared_ptr returnStatement(); shared_ptr expressionStatement(); - shared_ptr function(string kind); + shared_ptr function(string kind); shared_ptr declaration(); shared_ptr classDeclaration(); shared_ptr varDeclaration(); diff --git a/Resolver.cpp b/Resolver.cpp index 2bf4222..9820eb5 100644 --- a/Resolver.cpp +++ b/Resolver.cpp @@ -72,6 +72,16 @@ Object Resolver::visitCallExpr(shared_ptr> expr) { } return Object::make_nil_obj(); } +Object Resolver::visitGetExpr(shared_ptr> expr) { + resolve(expr->object); + return Object::make_nil_obj(); +} + +Object Resolver::visitSetExpr(shared_ptr> expr) { + resolve(expr->value); + resolve(expr->object); + return Object::make_nil_obj(); +} void Resolver::visitExpressionStmt(const Expression& stmt) { resolve(stmt.expression); @@ -98,6 +108,11 @@ void Resolver::visitBlockStmt(const Block& stmt) { void Resolver::visitClassStmt(const Class& stmt) { declare(stmt.name); define(stmt.name); + + for (auto method : stmt.methods) { + FunctionType declaration = METHOD; + resolveFunction(method, declaration); + } } void Resolver::visitIfStmt(const If& stmt) { @@ -113,9 +128,9 @@ void Resolver::visitWhileStmt(const While& stmt) { resolve(stmt.body); } -void Resolver::visitFunctionStmt(const Function& stmt) { - declare(stmt.name); - define(stmt.name); +void Resolver::visitFunctionStmt(shared_ptr stmt) { + declare(stmt->name); + define(stmt->name); resolveFunction(stmt, FUNCTION); } @@ -185,18 +200,18 @@ void Resolver::resolveLocal(shared_ptr> expr, Token name) { } void Resolver::resolveFunction( - Function function, + shared_ptr function, FunctionType type ) { FunctionType enclosingFunction = currentFunction; currentFunction = type; beginScope(); - for (Token param : function.params) { + for (Token param : function->params) { declare(param); define(param); } - resolve(function.body); + resolve(function->body); endScope(); currentFunction = enclosingFunction; } \ No newline at end of file diff --git a/Resolver.hpp b/Resolver.hpp index 02de04b..550a7f0 100644 --- a/Resolver.hpp +++ b/Resolver.hpp @@ -33,6 +33,8 @@ class Resolver: Object visitVariableExpr(shared_ptr> expr); Object visitLogicalExpr(shared_ptr> expr); Object visitCallExpr(shared_ptr> expr); + Object visitGetExpr(shared_ptr> expr); + Object visitSetExpr(shared_ptr> expr); void visitExpressionStmt(const Expression& stmt); void visitPrintStmt(const Print& stmt); void visitVarStmt(const Var& stmt); @@ -40,7 +42,7 @@ class Resolver: void visitClassStmt(const Class& stmt); void visitIfStmt(const If& stmt); void visitWhileStmt(const While& stmt); - void visitFunctionStmt(const Function& stmt); + void visitFunctionStmt(shared_ptr stmt); void visitReturnStmt(const Return& stmt); void resolve(vector> statements); void resolve(shared_ptr stmt); @@ -48,7 +50,8 @@ class Resolver: private: enum FunctionType { NONE, - FUNCTION + FUNCTION, + METHOD }; FunctionType currentFunction = NONE; void beginScope(); @@ -56,7 +59,7 @@ class Resolver: void declare(Token name); void define(Token name); void resolveLocal(shared_ptr>, Token name); - void resolveFunction(Function function, FunctionType type); + void resolveFunction(shared_ptr function, FunctionType type); }; #endif // RESOLVER_HPP_ diff --git a/Stmt.hpp b/Stmt.hpp index 2f4d3ec..95cd9da 100644 --- a/Stmt.hpp +++ b/Stmt.hpp @@ -73,7 +73,7 @@ class Visitor_Stmt { virtual void visitBlockStmt(const Block& stmt) = 0; virtual void visitIfStmt(const If& stmt) = 0; virtual void visitWhileStmt(const While& stmt) = 0; - virtual void visitFunctionStmt(const Function& stmt) = 0; + virtual void visitFunctionStmt(shared_ptr stmt) = 0; virtual void visitReturnStmt(const Return& stmt) = 0; virtual void visitClassStmt(const Class& stmt) = 0; }; @@ -158,7 +158,10 @@ class While: public Stmt { shared_ptr body; }; -class Function: public Stmt { +class Function: + public Stmt, + public std::enable_shared_from_this +{ public: Function( Token name_, @@ -168,7 +171,7 @@ class Function: public Stmt { params(params_), body(body_) { } void accept(shared_ptr visitor) override { - visitor->visitFunctionStmt(*this); + visitor->visitFunctionStmt(this->shared_from_this()); } Token name; vector params; @@ -188,13 +191,13 @@ class Return: public Stmt { class Class: public Stmt { public: - Class(Token name_, vector> function_): - name(name_), function(function_) { } + Class(Token name_, vector> methods_): + name(name_), methods(methods_) { } void accept(shared_ptr visitor) override { visitor->visitClassStmt(*this); } Token name; - vector> function; + vector> methods; }; #endif // STMT_HPP_ diff --git a/token.cpp b/Token.cpp similarity index 100% rename from token.cpp rename to Token.cpp diff --git a/Token.hpp b/Token.hpp index ff5e174..2c96e09 100644 --- a/Token.hpp +++ b/Token.hpp @@ -10,8 +10,8 @@ using std::string; using std::shared_ptr; class LoxCallable; -class LoxInstance; class LoxClass; +class LoxInstance; typedef enum { LEFT_PAREN, RIGHT_PAREN, LEFT_BRACE, RIGHT_BRACE, diff --git a/test.lox b/test.lox index 2b8ac60..2d7df01 100644 --- a/test.lox +++ b/test.lox @@ -1,3 +1,7 @@ -class Bagel {} -var bagel = Bagel(); -print bagel; \ No newline at end of file +class Bacon { + eat() { + print "Crunch crunch crunch!"; + } +} + +Bacon().eat(); \ No newline at end of file