环境
- Time 2023-07-11
- Java 17
前言
说明
参考:
- https://craftinginterpreters.com/contents.html
- https://github.com/GuoYaxiang/craftinginterpreters_zh
- https://space.bilibili.com/44550904
目标
接上一节,实现函数的定义和调用。
GenerateAst
package com.jiangbo.tool;import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;public class GenerateAst {public static void main(String[] args) throws IOException {// 根据实际情况,修改成对应的代码路径String outputDir = "C:\\work\\workspace\\demo\\src\\main\\java\\com\\jiangbo\\lox";defineAst(outputDir, "Expr", Arrays.asList("Assign   : Token name, Expr value","Binary   : Expr left, Token operator, Expr right","Call     : Expr callee, Token paren, List<Expr> arguments","Grouping : Expr expression","Literal  : Object value","Logical  : Expr left, Token operator, Expr right","Unary    : Token operator, Expr right","Variable : Token name"));defineAst(outputDir, "Stmt", Arrays.asList("Block      : List<Stmt> statements","Expression : Expr expression","Function   : Token name, List<Token> params," +" List<Stmt> body","If         : Expr condition, Stmt thenBranch," +" Stmt elseBranch","Print      : Expr expression","Return     : Token keyword, Expr value","Var        : Token name, Expr initializer","While      : Expr condition, Stmt body"));}private static void defineAst(String outputDir, String baseName, List<String> types)throws IOException {String path = outputDir + "/" + baseName + ".java";PrintWriter writer = new PrintWriter(path, StandardCharsets.UTF_8);writer.println("package com.jiangbo.lox;");writer.println();writer.println("import java.util.List;");writer.println();writer.println("abstract class " + baseName + " {");defineVisitor(writer, baseName, types);for (String type : types) {String className = type.split(":")[0].trim();String fields = type.split(":")[1].trim();defineType(writer, baseName, className, fields);}// The base accept() method.writer.println();writer.println("  abstract <R> R accept(Visitor<R> visitor);");writer.println("}");writer.close();}private static void defineVisitor(PrintWriter writer, String baseName, List<String> types) {writer.println("  interface Visitor<R> {");for (String type : types) {String typeName = type.split(":")[0].trim();writer.println("    R visit" + typeName + baseName + "(" +typeName + " " + baseName.toLowerCase() + ");");}writer.println("  }");}private static void defineType(PrintWriter writer, String baseName,String className, String fieldList) {writer.println("  static class " + className + " extends " +baseName + " {");// Constructor.writer.println("    " + className + "(" + fieldList + ") {");// Store parameters in fields.String[] fields = fieldList.split(", ");for (String field : fields) {String name = field.split(" ")[1];writer.println("      this." + name + " = " + name + ";");}writer.println("    }");// Visitor pattern.writer.println();writer.println("    @Override");writer.println("    <R> R accept(Visitor<R> visitor) {");writer.println("      return visitor.visit" +className + baseName + "(this);");writer.println("    }");// Fields.writer.println();for (String field : fields) {writer.println("    final " + field + ";");}writer.println("  }");}
}
Parser
package com.jiangbo.lox;import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;import static com.jiangbo.lox.TokenType.*;class Parser {private final List<Token> tokens;private int current = 0;Parser(List<Token> tokens) {this.tokens = tokens;}List<Stmt> parse() {List<Stmt> statements = new ArrayList<>();while (!isAtEnd()) {statements.add(declaration());}return statements;}private Stmt declaration() {try {if (match(FUN)) return function("function");if (match(VAR)) return varDeclaration();return statement();} catch (ParseError error) {synchronize();return null;}}private Stmt.Function function(String kind) {Token name = consume(IDENTIFIER, "Expect " + kind + " name.");consume(LEFT_PAREN, "Expect '(' after " + kind + " name.");List<Token> parameters = new ArrayList<>();if (!check(RIGHT_PAREN)) {do {if (parameters.size() >= 255) {error(peek(), "Can't have more than 255 parameters.");}parameters.add(consume(IDENTIFIER, "Expect parameter name."));} while (match(COMMA));}consume(RIGHT_PAREN, "Expect ')' after parameters.");consume(LEFT_BRACE, "Expect '{' before " + kind + " body.");List<Stmt> body = block();return new Stmt.Function(name, parameters, body);}private Stmt varDeclaration() {Token name = consume(IDENTIFIER, "Expect variable name.");Expr initializer = null;if (match(EQUAL)) {initializer = expression();}consume(SEMICOLON, "Expect ';' after variable declaration.");return new Stmt.Var(name, initializer);}private Stmt statement() {if (match(FOR)) return forStatement();if (match(IF)) return ifStatement();if (match(PRINT)) return printStatement();if (match(RETURN)) return returnStatement();if (match(WHILE)) return whileStatement();if (match(LEFT_BRACE)) return new Stmt.Block(block());return expressionStatement();}private Stmt forStatement() {consume(LEFT_PAREN, "Expect '(' after 'for'.");Stmt initializer;if (match(SEMICOLON)) {initializer = null;} else if (match(VAR)) {initializer = varDeclaration();} else {initializer = expressionStatement();}Expr condition = null;if (!check(SEMICOLON)) {condition = expression();}consume(SEMICOLON, "Expect ';' after loop condition.");Expr increment = null;if (!check(RIGHT_PAREN)) {increment = expression();}consume(RIGHT_PAREN, "Expect ')' after for clauses.");Stmt body = statement();if (increment != null) {body = new Stmt.Block(Arrays.asList( body, new Stmt.Expression(increment)));}if (condition == null) condition = new Expr.Literal(true);body = new Stmt.While(condition, body);if (initializer != null) {body = new Stmt.Block(Arrays.asList(initializer, body));}return body;}private Stmt whileStatement() {consume(LEFT_PAREN, "Expect '(' after 'while'.");Expr condition = expression();consume(RIGHT_PAREN, "Expect ')' after condition.");Stmt body = statement();return new Stmt.While(condition, body);}private Stmt ifStatement() {consume(LEFT_PAREN, "Expect '(' after 'if'.");Expr condition = expression();consume(RIGHT_PAREN, "Expect ')' after if condition.");Stmt thenBranch = statement();Stmt elseBranch = null;if (match(ELSE)) {elseBranch = statement();}return new Stmt.If(condition, thenBranch, elseBranch);}private List<Stmt> block() {List<Stmt> statements = new ArrayList<>();while (!check(RIGHT_BRACE) && !isAtEnd()) {statements.add(declaration());}consume(RIGHT_BRACE, "Expect '}' after block.");return statements;}private Stmt printStatement() {Expr value = expression();consume(SEMICOLON, "Expect ';' after value.");return new Stmt.Print(value);}private Stmt expressionStatement() {Expr expr = expression();consume(SEMICOLON, "Expect ';' after expression.");return new Stmt.Expression(expr);}private Stmt returnStatement() {Token keyword = previous();Expr value = null;if (!check(SEMICOLON)) {value = expression();}consume(SEMICOLON, "Expect ';' after return value.");return new Stmt.Return(keyword, value);}private Expr expression() {return assignment();}private Expr assignment() {Expr expr = or();if (match(EQUAL)) {Token equals = previous();Expr value = assignment();if (expr instanceof Expr.Variable) {Token name = ((Expr.Variable)expr).name;return new Expr.Assign(name, value);}error(equals, "Invalid assignment target.");}return expr;}private Expr or() {Expr expr = and();while (match(OR)) {Token operator = previous();Expr right = and();expr = new Expr.Logical(expr, operator, right);}return expr;}private Expr and() {Expr expr = equality();while (match(AND)) {Token operator = previous();Expr right = equality();expr = new Expr.Logical(expr, operator, right);}return expr;}private Expr equality() {Expr expr = comparison();while (match(BANG_EQUAL, EQUAL_EQUAL)) {Token operator = previous();Expr right = comparison();expr = new Expr.Binary(expr, operator, right);}return expr;}private Expr comparison() {Expr expr = term();while (match(GREATER, GREATER_EQUAL, LESS, LESS_EQUAL)) {Token operator = previous();Expr right = term();expr = new Expr.Binary(expr, operator, right);}return expr;}private Expr term() {Expr expr = factor();while (match(MINUS, PLUS)) {Token operator = previous();Expr right = factor();expr = new Expr.Binary(expr, operator, right);}return expr;}private Expr factor() {Expr expr = unary();while (match(SLASH, STAR)) {Token operator = previous();Expr right = unary();expr = new Expr.Binary(expr, operator, right);}return expr;}private Expr unary() {if (match(BANG, MINUS)) {Token operator = previous();Expr right = unary();return new Expr.Unary(operator, right);}return call();}private Expr call() {Expr expr = primary();while (true) {if (match(LEFT_PAREN)) {expr = finishCall(expr);} else {break;}}return expr;}private Expr finishCall(Expr callee) {List<Expr> arguments = new ArrayList<>();if (!check(RIGHT_PAREN)) {do {if (arguments.size() >= 255) {error(peek(), "Can't have more than 255 arguments.");}arguments.add(expression());} while (match(COMMA));}Token paren = consume(RIGHT_PAREN,"Expect ')' after arguments.");return new Expr.Call(callee, paren, arguments);}private Expr primary() {if (match(FALSE)) return new Expr.Literal(false);if (match(TRUE)) return new Expr.Literal(true);if (match(NIL)) return new Expr.Literal(null);if (match(NUMBER, STRING)) {return new Expr.Literal(previous().literal);}if (match(IDENTIFIER)) {return new Expr.Variable(previous());}if (match(LEFT_PAREN)) {Expr expr = expression();consume(RIGHT_PAREN, "Expect ')' after expression.");return new Expr.Grouping(expr);}throw error(peek(), "Expect expression.");}private boolean match(TokenType... types) {for (TokenType type : types) {if (check(type)) {advance();return true;}}return false;}private Token consume(TokenType type, String message) {if (check(type)) return advance();throw error(peek(), message);}private boolean check(TokenType type) {if (isAtEnd()) return false;return peek().type == type;}private Token advance() {if (!isAtEnd()) current++;return previous();}private boolean isAtEnd() {return peek().type == EOF;}private Token peek() {return tokens.get(current);}private Token previous() {return tokens.get(current - 1);}private ParseError error(Token token, String message) {Lox.error(token, message);return new ParseError();}private void synchronize() {advance();while (!isAtEnd()) {if (previous().type == SEMICOLON) return;switch (peek().type) {case CLASS:case FUN:case VAR:case FOR:case IF:case WHILE:case PRINT:case RETURN:return;}advance();}}private static class ParseError extends RuntimeException {}
}
Interpreter
package com.jiangbo.lox;import java.util.ArrayList;
import java.util.List;class Interpreter implements Expr.Visitor<Object>,Stmt.Visitor<Void> {final Environment globals = new Environment();private Environment environment = globals;Interpreter() {globals.define("clock", new LoxCallable() {@Overridepublic int arity() { return 0; }@Overridepublic Object call(Interpreter interpreter, List<Object> arguments) {return (double)System.currentTimeMillis() / 1000.0;}@Overridepublic String toString() { return "<native fn>"; }});}void interpret(List<Stmt> statements) {try {for (Stmt statement : statements) {execute(statement);}} catch (RuntimeError error) {Lox.runtimeError(error);}}private void execute(Stmt stmt) {stmt.accept(this);}@Overridepublic Void visitBlockStmt(Stmt.Block stmt) {executeBlock(stmt.statements, new Environment(environment));return null;}void executeBlock(List<Stmt> statements,Environment environment) {Environment previous = this.environment;try {this.environment = environment;for (Stmt statement : statements) {execute(statement);}} finally {this.environment = previous;}}@Overridepublic Object visitLiteralExpr(Expr.Literal expr) {return expr.value;}@Overridepublic Object visitLogicalExpr(Expr.Logical expr) {Object left = evaluate(expr.left);if (expr.operator.type == TokenType.OR) {if (isTruthy(left)) return left;} else {if (!isTruthy(left)) return left;}return evaluate(expr.right);}@Overridepublic Object visitGroupingExpr(Expr.Grouping expr) {return evaluate(expr.expression);}@Overridepublic Void visitVarStmt(Stmt.Var stmt) {Object value = null;if (stmt.initializer != null) {value = evaluate(stmt.initializer);}environment.define(stmt.name.lexeme, value);return null;}@Overridepublic Void visitWhileStmt(Stmt.While stmt) {while (isTruthy(evaluate(stmt.condition))) {execute(stmt.body);}return null;}@Overridepublic Object visitAssignExpr(Expr.Assign expr) {Object value = evaluate(expr.value);environment.assign(expr.name, value);return value;}@Overridepublic Object visitVariableExpr(Expr.Variable expr) {return environment.get(expr.name);}@Overridepublic Object visitUnaryExpr(Expr.Unary expr) {Object right = evaluate(expr.right);checkNumberOperand(expr.operator, right);return switch (expr.operator.type) {case BANG -> !isTruthy(right);case MINUS -> -(double) right;default -> null;};}@Overridepublic Object visitBinaryExpr(Expr.Binary expr) {Object left = evaluate(expr.left);Object right = evaluate(expr.right);switch (expr.operator.type) {case GREATER:checkNumberOperands(expr.operator, left, right);return (double) left > (double) right;case GREATER_EQUAL:checkNumberOperands(expr.operator, left, right);return (double) left >= (double) right;case LESS:checkNumberOperands(expr.operator, left, right);return (double) left < (double) right;case LESS_EQUAL:checkNumberOperands(expr.operator, left, right);return (double) left <= (double) right;case MINUS:checkNumberOperands(expr.operator, left, right);return (double) left - (double) right;case PLUS:if (left instanceof Double && right instanceof Double) {return (double) left + (double) right;}if (left instanceof String && right instanceof String) {return left + (String) right;}throw new RuntimeError(expr.operator,"Operands must be two numbers or two strings.");case SLASH:checkNumberOperands(expr.operator, left, right);return (double) left / (double) right;case STAR:checkNumberOperands(expr.operator, left, right);return (double) left * (double) right;case BANG_EQUAL:return !isEqual(left, right);case EQUAL_EQUAL:return isEqual(left, right);}// Unreachable.return null;}@Overridepublic Object visitCallExpr(Expr.Call expr) {Object callee = evaluate(expr.callee);List<Object> arguments = new ArrayList<>();for (Expr argument : expr.arguments) {arguments.add(evaluate(argument));}if (!(callee instanceof LoxCallable)) {throw new RuntimeError(expr.paren,"Can only call functions and classes.");}LoxCallable function = (LoxCallable)callee;if (arguments.size() != function.arity()) {throw new RuntimeError(expr.paren, "Expected " +function.arity() + " arguments but got " +arguments.size() + ".");}return function.call(this, arguments);}private Object evaluate(Expr expr) {return expr.accept(this);}@Overridepublic Void visitExpressionStmt(Stmt.Expression stmt) {evaluate(stmt.expression);return null;}@Overridepublic Void visitFunctionStmt(Stmt.Function stmt) {LoxFunction function = new LoxFunction(stmt, environment);environment.define(stmt.name.lexeme, function);return null;}@Overridepublic Void visitIfStmt(Stmt.If stmt) {if (isTruthy(evaluate(stmt.condition))) {execute(stmt.thenBranch);} else if (stmt.elseBranch != null) {execute(stmt.elseBranch);}return null;}@Overridepublic Void visitPrintStmt(Stmt.Print stmt) {Object value = evaluate(stmt.expression);System.out.println(stringify(value));return null;}@Overridepublic Void visitReturnStmt(Stmt.Return stmt) {Object value = null;if (stmt.value != null) value = evaluate(stmt.value);throw new Return(value);}private boolean isTruthy(Object object) {if (object == null) return false;if (object instanceof Boolean) return (boolean) object;return true;}private boolean isEqual(Object a, Object b) {if (a == null && b == null) return true;if (a == null) return false;return a.equals(b);}private void checkNumberOperand(Token operator, Object operand) {if (operand instanceof Double) return;throw new RuntimeError(operator, "Operand must be a number.");}private void checkNumberOperands(Token operator, Object left, Object right) {if (left instanceof Double && right instanceof Double) return;throw new RuntimeError(operator, "Operands must be numbers.");}private String stringify(Object object) {if (object == null) return "nil";if (object instanceof Double) {String text = object.toString();if (text.endsWith(".0")) {text = text.substring(0, text.length() - 2);}return text;}return object.toString();}
}
LoxCallable
package com.jiangbo.lox;import java.util.List;interface LoxCallable {int arity();Object call(Interpreter interpreter, List<Object> arguments);
}
LoxFunction
package com.jiangbo.lox;import java.util.List;class LoxFunction implements LoxCallable {private final Stmt.Function declaration;private final Environment closure;LoxFunction(Stmt.Function declaration, Environment closure) {this.closure = closure;this.declaration = declaration;}@Overridepublic int arity() {return declaration.params.size();}@Overridepublic Object call(Interpreter interpreter,List<Object> arguments) {Environment environment = new Environment(closure);for (int i = 0; i < declaration.params.size(); i++) {environment.define(declaration.params.get(i).lexeme,arguments.get(i));}try {interpreter.executeBlock(declaration.body, environment);} catch (Return returnValue) {return returnValue.value;}return null;}@Overridepublic String toString() {return "<fn " + declaration.name.lexeme + ">";}
}
Return
package com.jiangbo.lox;class Return extends RuntimeException {final Object value;Return(Object value) {super(null, null, false, false);this.value = value;}
}
总结
省略了其余的类,程序现在可以支持函数的定义和调用。