Add support for function types

This commit is contained in:
Nikita Lisitsa 2025-12-18 15:37:38 +03:00
parent 53d4e12a09
commit 61f1a9c079
5 changed files with 73 additions and 1 deletions

View file

@ -30,6 +30,15 @@ namespace pslang::interpreter
return type::array_type{std::make_unique<type::type>(resolve_type(context, *type.element_type)), type.size}; return type::array_type{std::make_unique<type::type>(resolve_type(context, *type.element_type)), type.size};
} }
type::type resolve_type_impl(context & context, type::function_type const & type)
{
type::function_type result;
for (auto const & argument : type.arguments)
result.arguments.push_back(std::make_unique<type::type>(resolve_type(context, *argument)));
result.result = std::make_unique<type::type>(resolve_type(context, *type.result));
return result;
}
type::type resolve_type_impl(context & context, type::identifier const & type) type::type resolve_type_impl(context & context, type::identifier const & type)
{ {
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)

View file

@ -131,6 +131,7 @@ template <typename T>
%token end 0 %token end 0
%right arrow
%left ampersand vertical_bar circumflex %left ampersand vertical_bar circumflex
%left equals not_equals less greater less_equals greater_equals %left equals not_equals less greater less_equals greater_equals
%nonassoc as %nonassoc as
@ -138,6 +139,7 @@ template <typename T>
%left plus minus %left plus minus
%left asterisk slash percent %left asterisk slash percent
%precedence NOT %precedence NOT
%precedence lbracket
%type <indented_statement_list> indented_statement_list %type <indented_statement_list> indented_statement_list
%type <std::size_t> indentation %type <std::size_t> indentation
@ -149,6 +151,8 @@ template <typename T>
%type <ast::value_category> variable_keyword %type <ast::value_category> variable_keyword
%type <type::type> type_expression %type <type::type> type_expression
%type <type::primitive_type> primitive_type %type <type::primitive_type> primitive_type
%type <std::vector<type::type_ptr>> function_paren_type_list
%type <std::vector<type::type_ptr>> two_or_more_type_list
%type <ast::expression> expression %type <ast::expression> expression
%type <ast::expression> postfix_expression %type <ast::expression> postfix_expression
%type <ast::expression> base_expression %type <ast::expression> base_expression
@ -219,6 +223,9 @@ type_expression
| primitive_type { $$ = type::type($1); } | primitive_type { $$ = type::type($1); }
| name { $$ = type::identifier{$1}; } | name { $$ = type::identifier{$1}; }
| type_expression lbracket lit_i32 rbracket { $$ = type::array_type{std::make_unique<type::type>($1), std::stoull($3)}; } | type_expression lbracket lit_i32 rbracket { $$ = type::array_type{std::make_unique<type::type>($1), std::stoull($3)}; }
| type_expression arrow type_expression { std::vector<type::type_ptr> args; args.push_back(std::make_unique<type::type>($1)); $$ = type::function_type{std::move(args), std::make_unique<type::type>($3)}; }
| lparen function_paren_type_list rparen arrow type_expression { $$ = type::function_type{$2, std::make_unique<type::type>($5)}; }
| lparen type_expression rparen { $$ = $2; }
; ;
primitive_type primitive_type
@ -235,6 +242,16 @@ primitive_type
| f64 { $$ = type::f64_type{}; } | f64 { $$ = type::f64_type{}; }
; ;
function_paren_type_list
: %empty { std::vector<type::type_ptr> tmp; $$ = std::move(tmp); }
| two_or_more_type_list { $$ = $1; }
;
two_or_more_type_list
: type_expression comma type_expression { std::vector<type::type_ptr> tmp; tmp.push_back(std::make_unique<type::type>($1)); tmp.push_back(std::make_unique<type::type>($3)); $$ = std::move(tmp); }
| two_or_more_type_list comma type_expression { auto tmp = $1; tmp.push_back(std::make_unique<type::type>($3)); $$ = std::move(tmp); }
;
expression expression
: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } : expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } | expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
@ -260,7 +277,6 @@ postfix_expression
: base_expression : base_expression
| postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; } | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; }
| postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; } | postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; }
//| postfix_expression dot name lparen comma_separated_expression_list rparen { auto args = $5; args.insert(args.begin(), std::make_unique<ast::expression>($1)); $$ = ast::function_call{$3, std::move(args), @$}; }
| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), $3, @$}; } | postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), $3, @$}; }
; ;

View file

@ -0,0 +1,31 @@
#pragma once
#include <pslang/type/type_fwd.hpp>
#include <vector>
namespace pslang::type
{
struct function_type
{
std::vector<type_ptr> arguments;
type_ptr result;
};
inline bool operator == (function_type const & f1, function_type const & f2)
{
if (f1.arguments.size() != f2.arguments.size())
return false;
for (std::size_t i = 0; i < f1.arguments.size(); ++i)
if (!equal(*f1.arguments[i], *f2.arguments[i]))
return false;
if (!equal(*f1.result, *f2.result))
return false;
return true;
}
}

View file

@ -3,6 +3,7 @@
#include <pslang/type/unit.hpp> #include <pslang/type/unit.hpp>
#include <pslang/type/primitive.hpp> #include <pslang/type/primitive.hpp>
#include <pslang/type/array.hpp> #include <pslang/type/array.hpp>
#include <pslang/type/function.hpp>
#include <pslang/type/identifier.hpp> #include <pslang/type/identifier.hpp>
#include <pslang/type/type_fwd.hpp> #include <pslang/type/type_fwd.hpp>
@ -15,6 +16,7 @@ namespace pslang::type
unit_type, unit_type,
primitive_type, primitive_type,
array_type, array_type,
function_type,
identifier identifier
>; >;

View file

@ -78,6 +78,20 @@ namespace pslang::type
out << "[" << type.size << "]"; out << "[" << type.size << "]";
} }
void print_impl(std::ostream & out, function_type const & type)
{
out << '(';
bool first = true;
for (auto const & argument : type.arguments)
{
if (!first) out << ", ";
first = false;
print(out, *argument);
}
out << ") -> ";
print(out, *type.result);
}
void print_impl(std::ostream & out, identifier const & type) void print_impl(std::ostream & out, identifier const & type)
{ {
out << type.name; out << type.name;