diff --git a/examples/test.psl b/examples/test.psl index 09b5dce..4749d3d 100644 --- a/examples/test.psl +++ b/examples/test.psl @@ -1,6 +1,10 @@ -func factorial(n : u32) -> u32: - if n == 0u: - return 1u - return n * factorial(n - 1u) +struct vec2: + x : f32 + y : f32 -let x = factorial(10u) +func add(a : vec2, b : vec2) -> vec2: + return vec2(a.x + b.x, a.y + b.y) + +mut v = add(vec2(1.0, 2.0), vec2(3.0, 4.0)) +v.x = -v.x +v.y = -v.y diff --git a/libs/ast/include/pslang/ast/expression.hpp b/libs/ast/include/pslang/ast/expression.hpp index 3b82150..fb70dde 100644 --- a/libs/ast/include/pslang/ast/expression.hpp +++ b/libs/ast/include/pslang/ast/expression.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace pslang::ast @@ -32,7 +33,8 @@ namespace pslang::ast cast_operation, function_call, array, - array_access + array_access, + field_access >; struct expression diff --git a/libs/ast/include/pslang/ast/statement.hpp b/libs/ast/include/pslang/ast/statement.hpp index f968469..e5198df 100644 --- a/libs/ast/include/pslang/ast/statement.hpp +++ b/libs/ast/include/pslang/ast/statement.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -42,7 +43,9 @@ namespace pslang::ast if_chain, while_block, function_definition, - return_statement + return_statement, + field_definition, + struct_definition >; struct statement diff --git a/libs/ast/include/pslang/ast/struct.hpp b/libs/ast/include/pslang/ast/struct.hpp new file mode 100644 index 0000000..adfa6cc --- /dev/null +++ b/libs/ast/include/pslang/ast/struct.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include +#include + +namespace pslang::ast +{ + + struct field_definition + { + std::string name; + type::type_ptr type; + }; + + struct struct_definition + { + std::string name; + std::vector fields; + }; + + struct field_access + { + expression_ptr object; + std::string field_name; + }; + +} diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 575e860..d9eb338 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -146,6 +146,13 @@ namespace pslang::ast print(out, node.index, child(options)); } + void print_impl(std::ostream & out, field_access const & node, print_options const & options) + { + put_indent(out, options); + out << "field access { name = \"" << node.field_name << "\" }\n"; + print(out, node.object, child(options)); + } + void print_impl(std::ostream & out, expression_ptr const & node, print_options const & options) { std::visit([&](auto const & value){ print_impl(out, value, options); }, *node); @@ -251,6 +258,22 @@ namespace pslang::ast print(out, node.value, child(options)); } + void print_impl(std::ostream & out, field_definition const & node, print_options const & options) + { + put_indent(out, options); + out << "field { name = \"" << node.name << "\", type = "; + type::print(out, *node.type); + out << " }\n"; + } + + void print_impl(std::ostream & out, struct_definition const & node, print_options const & options) + { + put_indent(out, options); + out << "struct { name = \"" << node.name << "\" }\n"; + for (auto const & field : node.fields) + print_impl(out, field, child(options)); + } + void print_impl(std::ostream & out, statement_ptr const & node, print_options const & options) { std::visit([&](auto const & value){ print_impl(out, value, options); }, *node); diff --git a/libs/interpreter/include/pslang/interpreter/context.hpp b/libs/interpreter/include/pslang/interpreter/context.hpp index 9a554d2..42d515e 100644 --- a/libs/interpreter/include/pslang/interpreter/context.hpp +++ b/libs/interpreter/include/pslang/interpreter/context.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -25,10 +26,21 @@ namespace pslang::interpreter ast::statement_list_ptr statements; }; + struct struct_data + { + std::vector fields; + }; + struct scope { std::unordered_map variables; std::unordered_map functions; + std::unordered_map structs; + + bool contains(std::string const & name) + { + return variables.count(name) > 0 || functions.count(name) > 0 || structs.count(name) > 0; + } bool is_function_scope = false; value return_value = unit_value{}; diff --git a/libs/interpreter/include/pslang/interpreter/value.hpp b/libs/interpreter/include/pslang/interpreter/value.hpp index 33bf59f..f737f2c 100644 --- a/libs/interpreter/include/pslang/interpreter/value.hpp +++ b/libs/interpreter/include/pslang/interpreter/value.hpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace pslang::interpreter { @@ -62,10 +63,17 @@ namespace pslang::interpreter std::vector elements; }; + struct struct_value + { + type::type_ptr struct_type; + std::unordered_map fields; + }; + using value_impl = std::variant< unit_value, primitive_value, - array_value + array_value, + struct_value >; struct value diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index d937310..4db25e0 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -439,6 +439,14 @@ namespace pslang::interpreter return std::visit([&](auto const & value){ return cast_impl(value, type); }, value); } + value cast_impl(struct_value const & value, type::type const & type) + { + if (type::equal(type, type::unit_type{})) + return value; + + throw std::runtime_error("Cannot cast struct type to anything"); + } + value eval_impl(context & context, ast::cast_operation const & cast_operation) { auto arg = eval(context, cast_operation.expression); @@ -501,6 +509,43 @@ namespace pslang::interpreter context.scope_stack.pop_back(); return result; } + + if (auto jt = it->structs.find(function_call.name); jt != it->structs.end()) + { + if (jt->second.fields.size() != function_call.arguments.size()) + { + std::ostringstream os; + os << "Cannot create struct \"" << function_call.name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); + throw std::runtime_error(os.str()); + } + + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); + + std::unordered_map fields; + + for (std::size_t i = 0; i < args.size(); ++i) + { + auto actual_type = type_of(args[i]); + if (!type::equal(actual_type, *jt->second.fields[i].type)) + { + std::ostringstream os; + os << "Cannot create struct \"" << function_call.name << "\": field " << jt->second.fields[i].name << " expects type "; + type::print(os, *jt->second.fields[i].type); + os << " but actual type is "; + type::print(os, actual_type); + throw std::runtime_error(os.str()); + } + + fields[jt->second.fields[i].name] = std::make_unique(std::move(args[i])); + } + + return struct_value{ + .struct_type = std::make_unique(type::identifier{function_call.name}), + .fields = std::move(fields), + }; + } } throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); @@ -554,6 +599,28 @@ namespace pslang::interpreter throw std::runtime_error(os.str()); } + value eval_impl(context & context, ast::field_access const & field_access) + { + auto object = eval(context, field_access.object); + if (auto value = std::get_if(&object)) + { + if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) + return *it->second; + + std::ostringstream os; + os << "Struct "; + type::print(os, type_of(object)); + os << " has no field named \"" << field_access.field_name << "\""; + throw std::runtime_error(os.str()); + } + + std::ostringstream os; + os << "Value of type "; + type::print(os, type_of(object)); + os << " is not a struct"; + throw std::runtime_error(os.str()); + } + value eval_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); @@ -621,6 +688,28 @@ namespace pslang::interpreter throw std::runtime_error(os.str()); } + value * eval_ref_impl(context & context, ast::field_access const & field_access) + { + auto object_ref = eval_ref(context, field_access.object); + if (auto value = std::get_if(object_ref)) + { + if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) + return it->second.get(); + + std::ostringstream os; + os << "Struct "; + type::print(os, type_of(*object_ref)); + os << " has no field named \"" << field_access.field_name << "\""; + throw std::runtime_error(os.str()); + } + + std::ostringstream os; + os << "Value of type "; + type::print(os, type_of(*object_ref)); + os << " is not a struct"; + throw std::runtime_error(os.str()); + } + value * eval_ref_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression); diff --git a/libs/interpreter/source/exec.cpp b/libs/interpreter/source/exec.cpp index 1348703..d0fd8b9 100644 --- a/libs/interpreter/source/exec.cpp +++ b/libs/interpreter/source/exec.cpp @@ -30,6 +30,37 @@ namespace pslang::interpreter context & context_; }; + type::type resolve_type(context & context, type::type const & type); + + type::type resolve_type(context &, type::unit_type const & type) + { + return type; + } + + type::type resolve_type(context &, type::primitive_type const & type) + { + return type; + } + + type::type resolve_type(context & context, type::array_type const & type) + { + return type::array_type{std::make_unique(resolve_type(context, *type.element_type)), type.size}; + } + + type::type resolve_type(context & context, type::identifier const & type) + { + for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) + if (it->structs.count(type.name)) + return type; + + throw std::runtime_error("Type \"" + type.name + "\" is not defined"); + } + + type::type resolve_type(context & context, type::type const & type) + { + return std::visit([&](auto const & type){ return resolve_type(context, type); }, type); + } + void exec_impl(context & context, ast::expression_ptr const & expression) { eval(context, expression); @@ -59,18 +90,19 @@ namespace pslang::interpreter void exec_impl(context & context, ast::variable_declaration const & variable_declaration) { auto & scope = context.scope_stack.back(); - if (scope.variables.count(variable_declaration.name) > 0) - throw std::runtime_error("Error: variable \"" + variable_declaration.name + "\" is already declared"); + if (scope.contains(variable_declaration.name)) + throw std::runtime_error("Identifier \"" + variable_declaration.name + "\" is already defined in this scope"); auto value = eval(context, variable_declaration.initializer); if (variable_declaration.type) { + auto expected_type = resolve_type(context, *variable_declaration.type); auto actual_type = type_of(value); - if (!type::equal(*variable_declaration.type, actual_type)) + if (!type::equal(expected_type, actual_type)) { std::ostringstream os; os << "Cannot initialize a variable of type "; - type::print(os, *variable_declaration.type); + type::print(os, expected_type); os << " with an expression of type "; type::print(os, actual_type); throw std::runtime_error(os.str()); @@ -155,12 +187,15 @@ namespace pslang::interpreter void exec_impl(context & context, ast::function_definition const & function_definition) { auto & scope = context.scope_stack.back(); - if (scope.functions.count(function_definition.name) > 0) - throw std::runtime_error("Function \"" + function_definition.name + "\" is already defined"); + if (scope.contains(function_definition.name)) + throw std::runtime_error("Identifier \"" + function_definition.name + "\" is already defined in this scope"); auto & function = scope.functions[function_definition.name]; - function.arguments = function_definition.arguments; - function.return_type = function_definition.return_type; + + for (auto const & argument : function_definition.arguments) + function.arguments.push_back({.name = argument.name, .type = std::make_unique(resolve_type(context, *argument.type))}); + + function.return_type = std::make_unique(resolve_type(context, *function_definition.return_type)); function.statements = function_definition.statements; } @@ -179,6 +214,21 @@ namespace pslang::interpreter throw std::runtime_error("Cannot return outside of a function"); } + void exec_impl(context & context, ast::field_definition const & field_definition) + { + throw std::runtime_error("Internal interpreter error: field definitions cannot be present in the final AST outside of struct definitions"); + } + + void exec_impl(context & context, ast::struct_definition const & struct_definition) + { + auto & scope = context.scope_stack.back(); + if (scope.contains(struct_definition.name)) + throw std::runtime_error("Identifier \"" + struct_definition.name + "\" is already defined in this scope"); + + auto & result = scope.structs[struct_definition.name]; + result.fields = struct_definition.fields; + } + void exec_impl(context & context, ast::statement_list_ptr const & statements) { for (auto const & statement : statements->statements) diff --git a/libs/interpreter/source/value.cpp b/libs/interpreter/source/value.cpp index 8a566b1..fb8f65c 100644 --- a/libs/interpreter/source/value.cpp +++ b/libs/interpreter/source/value.cpp @@ -29,6 +29,11 @@ namespace pslang::interpreter return type::array_type{.element_type = value.element_type, .size = value.elements.size()}; } + type::type type_of_impl(struct_value const & value) + { + return *value.struct_type; + } + type::type type_of_impl(value const & value) { return std::visit([](auto const & value){ return type_of_impl(value); }, value); @@ -85,6 +90,21 @@ namespace pslang::interpreter out << "]"; } + void print_impl(std::ostream & out, struct_value const & value) + { + out << "{"; + bool first = true; + for (auto const & field : value.fields) + { + if (!first) + out << ", "; + first = false; + out << field.first << " = "; + print(out, *field.second); + } + out << "}"; + } + void print_impl(std::ostream & out, value const & value) { std::visit([&](auto const & value){ return print_impl(out, value); }, value); diff --git a/libs/parser/rules/pslang.l b/libs/parser/rules/pslang.l index a154a27..21e53ca 100644 --- a/libs/parser/rules/pslang.l +++ b/libs/parser/rules/pslang.l @@ -30,6 +30,7 @@ while { return bp::make_while(ctx.location); } as { return bp::make_as(ctx.location); } func { return bp::make_func(ctx.location); } return { return bp::make_return(ctx.location); } +struct { return bp::make_struct(ctx.location); } true { return bp::make_true(ctx.location); } false { return bp::make_false(ctx.location); } @@ -53,6 +54,7 @@ f64 { return bp::make_f64(ctx.location); } "=" { return bp::make_assignment(ctx.location); } ":" { return bp::make_colon(ctx.location); } "," { return bp::make_comma(ctx.location); } +"." { return bp::make_dot(ctx.location); } "(" { return bp::make_lparen(ctx.location); } ")" { return bp::make_rparen(ctx.location); } "[" { return bp::make_lbracket(ctx.location); } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index 7dc449d..6c3a9c4 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -76,6 +76,7 @@ template %token assignment "=" %token colon ":" %token comma "," +%token dot "." %token lparen "(" %token rparen ")" %token lbracket "[" @@ -122,6 +123,7 @@ template %token as %token func %token return +%token struct %token true %token false @@ -192,6 +194,8 @@ statement | func name lparen function_definition_argument_list rparen function_return_type colon { $$ = ast::function_definition{$2, $4, $6, {}}; } | return expression { $$ = ast::return_statement{std::make_unique($2)}; } | return { $$ = ast::return_statement{nullptr}; } +| struct name colon { $$ = ast::struct_definition{$2, {}}; } +| name colon type_expression { $$ = ast::field_definition{$1, std::make_unique($3)}; } ; function_definition_argument_list @@ -223,6 +227,7 @@ variable_keyword type_expression : unit { $$ = type::type(type::unit_type{}); } | primitive_type { $$ = type::type($1); } +| name { $$ = type::identifier{$1}; } | type_expression lbracket lit_i32 rbracket { $$ = type::array_type{std::make_unique($1), std::stoull($3)}; } ; @@ -292,6 +297,7 @@ not_expression postfix_expression : base_expression | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique($1), std::make_unique($3)}; } +| postfix_expression dot name { $$ = ast::field_access{std::make_unique($1), $3}; } ; base_expression diff --git a/libs/parser/source/finilize.cpp b/libs/parser/source/finilize.cpp index da5f755..5a39a74 100644 --- a/libs/parser/source/finilize.cpp +++ b/libs/parser/source/finilize.cpp @@ -66,6 +66,16 @@ namespace pslang::parser return nullptr; } + ast::statement_list * get_statement_list(ast::field_definition &) + { + return nullptr; + } + + ast::statement_list * get_statement_list(ast::struct_definition &) + { + return nullptr; + } + ast::statement_list * get_statement_list(ast::statement & statement) { return std::visit([](auto & value){ return get_statement_list(value); }, statement); @@ -77,10 +87,34 @@ namespace pslang::parser { ast::statement_list_ptr result = std::make_unique(); - std::vector stack; + using stack_entry = std::variant; + + std::vector stack; stack.push_back(result.get()); std::size_t current_indent = 0; + auto current_statement_list = [&]() -> ast::statement_list * + { + if (stack.empty()) + throw std::runtime_error("Internal error: empty finilization stack"); + + if (auto list = std::get_if(&stack.back())) + return *list; + + throw std::runtime_error("Unexpected statement inside struct definition"); + }; + + auto current_struct_definition = [&]() -> ast::struct_definition * + { + if (stack.empty()) + throw std::runtime_error("Internal error: empty finilization stack"); + + if (auto list = std::get_if(&stack.back())) + return *list; + + throw std::runtime_error("Unexpected statement outside struct definition"); + }; + for (auto & statement : statements.statements) { if (statement.indentation > current_indent) @@ -102,13 +136,13 @@ namespace pslang::parser { ast::if_chain chain; chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::move(if_block->statements)}); - stack.back()->statements.push_back(std::make_unique(std::move(chain))); + current_statement_list()->statements.push_back(std::make_unique(std::move(chain))); } else if (auto else_block = std::get_if(statement.statement.get())) { - if (stack.back()->statements.empty()) + if (current_statement_list()->statements.empty()) throw std::runtime_error("Unexpected else block"); - auto chain = std::get_if(stack.back()->statements.back().get()); + auto chain = std::get_if(current_statement_list()->statements.back().get()); if (!chain || chain->blocks.empty() || !chain->blocks.back().condition) throw std::runtime_error("Unexpected else block"); @@ -116,17 +150,27 @@ namespace pslang::parser } else if (auto else_if_block = std::get_if(statement.statement.get())) { - if (stack.back()->statements.empty()) + if (current_statement_list()->statements.empty()) throw std::runtime_error("Unexpected else if block"); - auto chain = std::get_if(stack.back()->statements.back().get()); + auto chain = std::get_if(current_statement_list()->statements.back().get()); if (!chain || chain->blocks.empty() || !chain->blocks.back().condition) throw std::runtime_error("Unexpected else if block"); chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::move(else_if_block->statements)}); } + else if (auto field_definition = std::get_if(statement.statement.get())) + { + current_struct_definition()->fields.push_back(*field_definition); + } + else if (std::get_if(statement.statement.get())) + { + current_statement_list()->statements.push_back(std::move(statement.statement)); + stack.push_back(std::get_if(current_statement_list()->statements.back().get())); + ++current_indent; + } else { - stack.back()->statements.push_back(std::move(statement.statement)); + current_statement_list()->statements.push_back(std::move(statement.statement)); } if (list) diff --git a/libs/type/include/pslang/type/identifier.hpp b/libs/type/include/pslang/type/identifier.hpp new file mode 100644 index 0000000..d02790d --- /dev/null +++ b/libs/type/include/pslang/type/identifier.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace pslang::type +{ + + struct identifier + { + std::string name; + }; + + inline bool operator == (identifier const & t1, identifier const & t2) + { + return t1.name == t2.name; + } + +} diff --git a/libs/type/include/pslang/type/type.hpp b/libs/type/include/pslang/type/type.hpp index 3232bdf..c605e80 100644 --- a/libs/type/include/pslang/type/type.hpp +++ b/libs/type/include/pslang/type/type.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -13,7 +14,8 @@ namespace pslang::type using type_impl = std::variant< unit_type, primitive_type, - array_type + array_type, + identifier >; struct type diff --git a/libs/type/source/print.cpp b/libs/type/source/print.cpp index f08e164..6ec2aae 100644 --- a/libs/type/source/print.cpp +++ b/libs/type/source/print.cpp @@ -78,6 +78,11 @@ namespace pslang::type out << "[" << type.size << "]"; } + void print_impl(std::ostream & out, identifier const & type) + { + out << type.name; + } + void print_impl(std::ostream & out, type const & type) { std::visit([&](auto const & value){ print_impl(out, value); }, type);