From 41d3bb0f3d696ce7dffc5ff1b04381d23ff8cf93 Mon Sep 17 00:00:00 2001 From: lisyarus Date: Tue, 16 Dec 2025 19:19:46 +0300 Subject: [PATCH] Implement functions in parser & interpreter --- apps/interpreter/source/main.cpp | 16 ++++ examples/test.psl | 9 +- libs/ast/include/pslang/ast/expression.hpp | 4 +- .../ast/include/pslang/ast/expression_fwd.hpp | 2 +- libs/ast/include/pslang/ast/function.hpp | 33 +++++++ libs/ast/include/pslang/ast/print.hpp | 3 + libs/ast/include/pslang/ast/statement.hpp | 10 +- libs/ast/include/pslang/ast/statement_fwd.hpp | 4 +- libs/ast/source/print.cpp | 38 ++++++++ .../include/pslang/interpreter/context.hpp | 12 +++ .../include/pslang/interpreter/value.hpp | 4 + libs/interpreter/source/eval.cpp | 96 ++++++++++++++++++- libs/interpreter/source/interpreter.cpp | 77 +++++++++++++-- libs/interpreter/source/value.cpp | 10 ++ libs/parser/rules/pslang.l | 4 + libs/parser/rules/pslang.y | 37 +++++++ libs/parser/source/finilize.cpp | 11 +++ libs/type/include/pslang/type/type_fwd.hpp | 2 +- spec.txt | 4 - 19 files changed, 355 insertions(+), 21 deletions(-) create mode 100644 libs/ast/include/pslang/ast/function.hpp diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 11286da..7d63da0 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,7 @@ int main(int argc, char ** argv) try std::cout << "Available options:\n"; std::cout << " -t, --trace Trace each line of execution\n"; std::cout << " -d, --dump Dump all variables after processing each file\n"; + std::cout << " -p, --print Print the AST after parsing each file\n"; return 0; } @@ -22,6 +24,7 @@ int main(int argc, char ** argv) try auto context = interpreter::empty_context(); bool dump = false; + bool dump_ast = false; for (int arg = 1; arg < argc; ++arg) { @@ -37,7 +40,20 @@ int main(int argc, char ** argv) try continue; } + if (std::strcmp(argv[arg], "-p") == 0 || std::strcmp(argv[arg], "--print") == 0) + { + dump_ast = true; + continue; + } + auto ast = parser::parse(argv[arg]); + + if (dump_ast) + { + ast::print(std::cout, ast); + std::cout << std::flush; + } + interpreter::execute(context, ast); if (dump) diff --git a/examples/test.psl b/examples/test.psl index 9f30fa9..09b5dce 100644 --- a/examples/test.psl +++ b/examples/test.psl @@ -1,3 +1,6 @@ -let n = 10 + (1u as i32) -let x = (n as f32) / 3.0 -x = 15.2 +func factorial(n : u32) -> u32: + if n == 0u: + return 1u + return n * factorial(n - 1u) + +let x = factorial(10u) diff --git a/libs/ast/include/pslang/ast/expression.hpp b/libs/ast/include/pslang/ast/expression.hpp index a4de813..b23db1e 100644 --- a/libs/ast/include/pslang/ast/expression.hpp +++ b/libs/ast/include/pslang/ast/expression.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace pslang::ast @@ -27,7 +28,8 @@ namespace pslang::ast identifier, unary_operation, binary_operation, - cast_operation + cast_operation, + function_call >; struct expression diff --git a/libs/ast/include/pslang/ast/expression_fwd.hpp b/libs/ast/include/pslang/ast/expression_fwd.hpp index af492fd..f23ef2b 100644 --- a/libs/ast/include/pslang/ast/expression_fwd.hpp +++ b/libs/ast/include/pslang/ast/expression_fwd.hpp @@ -7,6 +7,6 @@ namespace pslang::ast struct expression; - using expression_ptr = std::unique_ptr; + using expression_ptr = std::shared_ptr; } diff --git a/libs/ast/include/pslang/ast/function.hpp b/libs/ast/include/pslang/ast/function.hpp new file mode 100644 index 0000000..f77a276 --- /dev/null +++ b/libs/ast/include/pslang/ast/function.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace pslang::ast +{ + + struct function_definition + { + struct argument + { + std::string name; + type::type_ptr type; + }; + + std::string name; + std::vector arguments; + type::type_ptr return_type; + statement_list_ptr statements; + }; + + struct function_call + { + std::string name; + std::vector arguments; + }; + +} diff --git a/libs/ast/include/pslang/ast/print.hpp b/libs/ast/include/pslang/ast/print.hpp index 6556fef..2cfbc8b 100644 --- a/libs/ast/include/pslang/ast/print.hpp +++ b/libs/ast/include/pslang/ast/print.hpp @@ -29,6 +29,7 @@ namespace pslang::ast void print(std::ostream & out, unary_operation const & node, print_options const & options = {}); void print(std::ostream & out, binary_operation const & node, print_options const & options = {}); void print(std::ostream & out, cast_operation const & node, print_options const & options = {}); + void print(std::ostream & out, function_call const & node, print_options const & options = {}); void print(std::ostream & out, expression_ptr const & node, print_options const & options = {}); void print(std::ostream & out, assignment const & node, print_options const & options = {}); void print(std::ostream & out, variable_declaration const & node, print_options const & options = {}); @@ -37,6 +38,8 @@ namespace pslang::ast void print(std::ostream & out, else_if_block const & node, print_options const & options = {}); void print(std::ostream & out, if_chain const & node, print_options const & options = {}); void print(std::ostream & out, while_block const & node, print_options const & options = {}); + void print(std::ostream & out, function_definition const & node, print_options const & options = {}); + void print(std::ostream & out, return_statement const & node, print_options const & options = {}); void print(std::ostream & out, statement_ptr const & node, print_options const & options = {}); void print(std::ostream & out, statement_list_ptr const & node, print_options const & options = {}); diff --git a/libs/ast/include/pslang/ast/statement.hpp b/libs/ast/include/pslang/ast/statement.hpp index ad02dbc..a5c3649 100644 --- a/libs/ast/include/pslang/ast/statement.hpp +++ b/libs/ast/include/pslang/ast/statement.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -25,6 +26,11 @@ namespace pslang::ast expression_ptr rhs; }; + struct return_statement + { + expression_ptr value; + }; + using statement_impl = std::variant< expression_ptr, assignment, @@ -33,7 +39,9 @@ namespace pslang::ast else_block, else_if_block, if_chain, - while_block + while_block, + function_definition, + return_statement >; struct statement diff --git a/libs/ast/include/pslang/ast/statement_fwd.hpp b/libs/ast/include/pslang/ast/statement_fwd.hpp index d2f4856..134aaa9 100644 --- a/libs/ast/include/pslang/ast/statement_fwd.hpp +++ b/libs/ast/include/pslang/ast/statement_fwd.hpp @@ -8,13 +8,13 @@ namespace pslang::ast struct statement; - using statement_ptr = std::unique_ptr; + using statement_ptr = std::shared_ptr; struct statement_list { std::vector statements; }; - using statement_list_ptr = std::unique_ptr; + using statement_list_ptr = std::shared_ptr; } diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 392f878..3e50c46 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -143,6 +143,15 @@ namespace pslang::ast print(out, node.expression, child(options)); } + void print(std::ostream & out, function_call const & node, print_options const & options) + { + put_indent(out, options); + out << "call " << node.name; + newline(out); + for (auto const & argument : node.arguments) + print(out, argument, child(options)); + } + void print(std::ostream & out, expression_ptr const & node, print_options const & options) { std::visit([&](auto const & value){ print(out, value, options); }, *node); @@ -232,6 +241,35 @@ namespace pslang::ast print(out, node.statements, child(options)); } + void print(std::ostream & out, function_definition const & node, print_options const & options) + { + put_indent(out, options); + out << "function { name = \"" << node.name << "\", return type = "; + type::print(out, *node.return_type); + out << " }"; + newline(out); + for (auto const & arg : node.arguments) + { + put_indent(out, child(options)); + out << "argument { name = \"" << arg.name << "\", type = "; + type::print(out, *arg.type); + out << " }"; + newline(out); + } + put_indent(out, child(options)); + out << "body"; + newline(out); + print(out, node.statements, child(child(options))); + } + + void print(std::ostream & out, return_statement const & node, print_options const & options) + { + put_indent(out, options); + out << "return"; + newline(out); + print(out, node.value, child(options)); + } + void print(std::ostream & out, statement_ptr const & node, print_options const & options) { std::visit([&](auto const & value){ print(out, value, options); }, *node); diff --git a/libs/interpreter/include/pslang/interpreter/context.hpp b/libs/interpreter/include/pslang/interpreter/context.hpp index bdb6a20..9a554d2 100644 --- a/libs/interpreter/include/pslang/interpreter/context.hpp +++ b/libs/interpreter/include/pslang/interpreter/context.hpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -17,9 +18,20 @@ namespace pslang::interpreter interpreter::value value; }; + struct function_data + { + std::vector arguments; + type::type_ptr return_type; + ast::statement_list_ptr statements; + }; + struct scope { std::unordered_map variables; + std::unordered_map functions; + + bool is_function_scope = false; + value return_value = unit_value{}; }; struct context diff --git a/libs/interpreter/include/pslang/interpreter/value.hpp b/libs/interpreter/include/pslang/interpreter/value.hpp index 3369672..a980aa5 100644 --- a/libs/interpreter/include/pslang/interpreter/value.hpp +++ b/libs/interpreter/include/pslang/interpreter/value.hpp @@ -9,6 +9,9 @@ namespace pslang::interpreter { + struct unit_value + {}; + template struct primitive_value_base { @@ -51,6 +54,7 @@ namespace pslang::interpreter }; using value_impl = std::variant< + unit_value, primitive_value >; diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index 8079cc4..25d6a43 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -83,7 +83,7 @@ namespace pslang::interpreter template value eval_impl(context & context, ast::numeric_literal_base const & literal) { - return primitive_value(primitive_value_base{literal.value});; + return primitive_value(primitive_value_base{literal.value}); } value eval_impl(context & context, ast::literal const & literal) @@ -102,6 +102,15 @@ namespace pslang::interpreter throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined"); } + value unary_operation_impl(ast::unary_operation_type type, unit_value const &) + { + std::ostringstream os; + os << "Cannot apply unary operator \""; + print(os, type); + os << "\" to a value of type unit"; + throw std::runtime_error(os.str()); + } + template value unary_operation_impl(ast::unary_operation_type type, primitive_value_base const & arg1) { @@ -251,6 +260,16 @@ namespace pslang::interpreter throw std::runtime_error(os.str()); } + value binary_operation_impl_same_type(ast::binary_operation_type type, unit_value const &, value const & arg2) + { + std::ostringstream os; + os << "Cannot apply binary operator \""; + print(os, type); + os << "\" to values of type unit and "; + type::print(os, type_of(arg2)); + throw std::runtime_error(os.str()); + } + value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, value const & arg2) { return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2); }, arg1); @@ -284,6 +303,20 @@ namespace pslang::interpreter throw std::runtime_error("eval(binary_operation) for different argument types not implemented"); } + value cast_impl(unit_value const & value, type::type const & type) + { + if (type::equal(type, type::unit_type{})) + return value; + + throw std::runtime_error("Cannot cast unit type to anything"); + } + + template + value cast_impl(primitive_value_base const & value, type::unit_type const &) + { + throw std::runtime_error("Cannot cast anything to unit type"); + } + template value cast_impl(primitive_value_base const & value, type::primitive_type_base const & type) { @@ -327,6 +360,67 @@ namespace pslang::interpreter return std::visit([&](auto const & value){ return cast_impl(value, *cast_operation.type); }, arg); } + value eval_impl(context & context, ast::function_call const & function_call) + { + for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) + { + if (auto jt = it->functions.find(function_call.name); jt != it->functions.end()) + { + if (jt->second.arguments.size() != function_call.arguments.size()) + { + std::ostringstream os; + os << "Cannot call function \"" << function_call.name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); + throw std::runtime_error(os.str()); + } + + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); + + for (std::size_t i = 0; i < args.size(); ++i) + { + auto actual_type = type_of(args[i]); + if (!type::equal(actual_type, *jt->second.arguments[i].type)) + { + std::ostringstream os; + os << "Cannot call function \"" << function_call.name << "\": argument #" << (i + 1) << " expects type "; + type::print(os, *jt->second.arguments[i].type); + os << " but actual type is "; + type::print(os, actual_type); + throw std::runtime_error(os.str()); + } + } + + auto & function_scope = context.scope_stack.emplace_back(); + function_scope.is_function_scope = true; + + for (std::size_t i = 0; i < args.size(); ++i) + function_scope.variables[jt->second.arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; + + auto expected_return_type = jt->second.return_type; + + execute(context, jt->second.statements); + + auto actual_return_type = type_of(context.scope_stack.back().return_value); + if (!type::equal(actual_return_type, *expected_return_type)) + { + std::ostringstream os; + os << "Error returning from function \"" << function_call.name << "\": expected return type is "; + type::print(os, *expected_return_type); + os << " but actual type is "; + type::print(os, actual_return_type); + throw std::runtime_error(os.str()); + } + + auto result = std::move(context.scope_stack.back().return_value); + context.scope_stack.pop_back(); + return result; + } + } + + throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); + } + value eval_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); diff --git a/libs/interpreter/source/interpreter.cpp b/libs/interpreter/source/interpreter.cpp index 6579a96..19b7f13 100644 --- a/libs/interpreter/source/interpreter.cpp +++ b/libs/interpreter/source/interpreter.cpp @@ -12,6 +12,24 @@ namespace pslang::interpreter namespace { + struct return_exception + {}; + + struct stack_pop_guard + { + stack_pop_guard(context & context) + : context_(context) + {} + + ~stack_pop_guard() + { + context_.scope_stack.pop_back(); + } + + private: + context & context_; + }; + void execute_impl(context & context, ast::expression_ptr const & expression) { eval(context, expression); @@ -72,7 +90,7 @@ namespace pslang::interpreter throw std::runtime_error(os.str()); } } - scope.variables[variable_declaration.name] = {.category = variable_declaration.category, .value = value}; + context.scope_stack.back().variables[variable_declaration.name] = {.category = variable_declaration.category, .value = value}; } void execute_impl(context & context, ast::if_block const &) @@ -94,6 +112,8 @@ namespace pslang::interpreter { for (auto const & block : if_chain.blocks) { + bool do_execute = true; + if (block.condition) { auto value = eval(context, block.condition); @@ -107,12 +127,16 @@ namespace pslang::interpreter throw std::runtime_error(os.str()); } - if (std::get(std::get(value)).value) - { - execute(context, block.statements); - break; - } + do_execute = std::get(std::get(value)).value; } + + if (!do_execute) + continue; + + context.scope_stack.emplace_back(); + stack_pop_guard guard(context); + execute(context, block.statements); + break; } } @@ -133,6 +157,8 @@ namespace pslang::interpreter if (std::get(std::get(value)).value) { + context.scope_stack.emplace_back(); + stack_pop_guard guard(context); execute(context, while_block.statements); } else @@ -140,11 +166,48 @@ namespace pslang::interpreter } } + void execute_impl(context & context, ast::function_definition const & function_definition) + { + auto & scope = context.scope_stack.back(); + if (scope.functions.count(function_definition.name) > 0) + throw std::runtime_error("Function \"" + function_definition.name + "\" is already defined"); + + auto & function = scope.functions[function_definition.name]; + function.arguments = function_definition.arguments; + function.return_type = function_definition.return_type; + function.statements = function_definition.statements; + } + + void execute_impl(context & context, ast::return_statement const & return_statement) + { + auto value = eval(context, return_statement.value); + for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) + { + if (it->is_function_scope) + { + it->return_value = std::move(value); + throw return_exception{}; + } + } + + throw std::runtime_error("Cannot return outside of a function"); + } + void execute_impl(context & context, ast::statement_list_ptr const & statements) { for (auto const & statement : statements->statements) { - std::visit([&](auto const & statement){ execute_impl(context, statement); }, *statement); + try + { + std::visit([&](auto const & statement){ execute_impl(context, statement); }, *statement); + } + catch (return_exception const &) + { + if (context.scope_stack.back().is_function_scope) + break; + else + throw; + } } } diff --git a/libs/interpreter/source/value.cpp b/libs/interpreter/source/value.cpp index c89570e..dad9e92 100644 --- a/libs/interpreter/source/value.cpp +++ b/libs/interpreter/source/value.cpp @@ -8,6 +8,11 @@ namespace pslang::interpreter namespace { + type::type type_of_impl(unit_value const &) + { + return type::unit_type{}; + } + template type::type type_of_impl(primitive_value_base const &) { @@ -24,6 +29,11 @@ namespace pslang::interpreter return std::visit([](auto const & value){ return type_of_impl(value); }, value); } + void print_impl(std::ostream & out, unit_value const &) + { + out << "unit"; + } + template void print_impl(std::ostream & out, primitive_value_base const & value) { diff --git a/libs/parser/rules/pslang.l b/libs/parser/rules/pslang.l index 0a19f0b..8d0206b 100644 --- a/libs/parser/rules/pslang.l +++ b/libs/parser/rules/pslang.l @@ -28,6 +28,8 @@ if { return bp::make_if(ctx.location); } else { return bp::make_else(ctx.location); } while { return bp::make_while(ctx.location); } as { return bp::make_as(ctx.location); } +func { return bp::make_func(ctx.location); } +return { return bp::make_return(ctx.location); } true { return bp::make_true(ctx.location); } false { return bp::make_false(ctx.location); } @@ -50,6 +52,7 @@ f64 { return bp::make_f64(ctx.location); } "\t" { return bp::make_indent(ctx.location); } "=" { return bp::make_assignment(ctx.location); } ":" { return bp::make_colon(ctx.location); } +"," { return bp::make_comma(ctx.location); } "(" { return bp::make_lparen(ctx.location); } ")" { return bp::make_rparen(ctx.location); } "+" { return bp::make_plus(ctx.location); } @@ -67,6 +70,7 @@ f64 { return bp::make_f64(ctx.location); } ">" { return bp::make_greater(ctx.location); } "<=" { return bp::make_less_equals(ctx.location); } ">=" { return bp::make_greater_equals(ctx.location); } +"->" { return bp::make_arrow(ctx.location); } [0-9]+b { return bp::make_lit_i8(yytext, ctx.location); } [0-9]+ub { return bp::make_lit_u8(yytext, ctx.location); } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index 9874d9e..f1514d5 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -75,6 +75,7 @@ template %token indent "indentation" %token assignment "=" %token colon ":" +%token comma "," %token lparen "(" %token rparen ")" %token plus "+" @@ -92,6 +93,7 @@ template %token greater ">" %token less_equals "<=" %token greater_equals ">=" +%token arrow "->" %token lit_i8 %token lit_u8 @@ -116,6 +118,8 @@ template %token else %token while %token as +%token func +%token return %token true %token false @@ -137,6 +141,9 @@ template %type indented_statement_list %type indentation %type statement +%type > function_definition_argument_list +%type > nonempty_function_definition_argument_list +%type function_return_type %type variable_declaration %type variable_keyword %type type_expression @@ -150,6 +157,8 @@ template %type mult_expression %type not_expression %type base_expression +%type > function_call_argument_list +%type > nonempty_function_call_argument_list %type literal %% @@ -177,6 +186,23 @@ statement | else colon { $$ = ast::else_block{{}}; } | else if expression colon { $$ = ast::else_if_block{std::make_unique($3), {}}; } | while expression colon { $$ = ast::while_block{std::make_unique($2), {}}; } +| func name lparen function_definition_argument_list rparen function_return_type colon { $$ = ast::function_definition{$2, $4, $6, {}}; } +| return expression { $$ = ast::return_statement{std::make_unique($2)}; } +; + +function_definition_argument_list +: %empty { std::vector tmp; $$ = std::move(tmp); } +| nonempty_function_definition_argument_list { $$ = $1; } +; + +nonempty_function_definition_argument_list +: name colon type_expression { std::vector tmp; tmp.push_back({.name = $1, .type = std::make_unique($3)}); $$ = std::move(tmp); } +| nonempty_function_definition_argument_list comma name colon type_expression { auto tmp = $1; tmp.push_back({.name = $3, .type = std::make_unique($5)}); $$ = std::move(tmp); } +; + +function_return_type +: arrow type_expression { $$ = std::make_unique($2); } +| %empty { $$ = std::make_unique(type::unit_type{}); } ; variable_declaration @@ -262,6 +288,17 @@ base_expression : literal | name { $$ = ast::identifier{$1}; } | lparen expression rparen { $$ = $2; } +| name lparen function_call_argument_list rparen { $$ = ast::function_call{$1, $3}; } +; + +function_call_argument_list +: %empty { std::vector tmp; $$ = std::move(tmp); } +| nonempty_function_call_argument_list +; + +nonempty_function_call_argument_list +: expression { std::vector tmp; tmp.push_back(std::make_unique($1)); $$ = std::move(tmp); } +| nonempty_function_call_argument_list comma expression { auto tmp = $1; tmp.push_back(std::make_unique($3)); $$ = std::move(tmp); } ; literal diff --git a/libs/parser/source/finilize.cpp b/libs/parser/source/finilize.cpp index 895f61f..da5f755 100644 --- a/libs/parser/source/finilize.cpp +++ b/libs/parser/source/finilize.cpp @@ -55,6 +55,17 @@ namespace pslang::parser return node.statements.get(); } + ast::statement_list * get_statement_list(ast::function_definition & node) + { + node.statements = std::make_unique(); + return node.statements.get(); + } + + ast::statement_list * get_statement_list(ast::return_statement &) + { + return nullptr; + } + ast::statement_list * get_statement_list(ast::statement & statement) { return std::visit([](auto & value){ return get_statement_list(value); }, statement); diff --git a/libs/type/include/pslang/type/type_fwd.hpp b/libs/type/include/pslang/type/type_fwd.hpp index 26c647e..07b7084 100644 --- a/libs/type/include/pslang/type/type_fwd.hpp +++ b/libs/type/include/pslang/type/type_fwd.hpp @@ -7,6 +7,6 @@ namespace pslang::type struct type; - using type_ptr = std::unique_ptr; + using type_ptr = std::shared_ptr; } diff --git a/spec.txt b/spec.txt index a3963a2..2e42b97 100644 --- a/spec.txt +++ b/spec.txt @@ -146,10 +146,6 @@ Function types: (T1, T2, T3) -> i32 (T1, T2) -> unit // no return value -Function declaration (required for e.g. loops in call graph): - func foo(x: i32, y: i32) -> i32 - func bar(x: f32) // same as -> unit - Function definition: func foo(x: i32, y: i32) -> i32: return x * y