Implement array types & array values

This commit is contained in:
Nikita Lisitsa 2025-12-17 15:46:51 +03:00
parent bc7d24ad62
commit 435aa61fe4
15 changed files with 273 additions and 18 deletions

View file

@ -0,0 +1,21 @@
#pragma once
#include <pslang/ast/expression_fwd.hpp>
#include <vector>
namespace pslang::ast
{
struct array
{
std::vector<expression_ptr> elements;
};
struct array_access
{
expression_ptr array;
expression_ptr index;
};
}

View file

@ -5,6 +5,7 @@
#include <pslang/ast/operation.hpp> #include <pslang/ast/operation.hpp>
#include <pslang/ast/cast.hpp> #include <pslang/ast/cast.hpp>
#include <pslang/ast/function.hpp> #include <pslang/ast/function.hpp>
#include <pslang/ast/array.hpp>
#include <pslang/ast/expression_fwd.hpp> #include <pslang/ast/expression_fwd.hpp>
namespace pslang::ast namespace pslang::ast
@ -29,7 +30,9 @@ namespace pslang::ast
unary_operation, unary_operation,
binary_operation, binary_operation,
cast_operation, cast_operation,
function_call function_call,
array,
array_access
>; >;
struct expression struct expression

View file

@ -30,6 +30,8 @@ namespace pslang::ast
void print(std::ostream & out, binary_operation const & node, print_options const & options = {}); void print(std::ostream & out, binary_operation const & node, print_options const & options = {});
void print(std::ostream & out, cast_operation const & node, print_options const & options = {}); void print(std::ostream & out, cast_operation const & node, print_options const & options = {});
void print(std::ostream & out, function_call const & node, print_options const & options = {}); void print(std::ostream & out, function_call const & node, print_options const & options = {});
void print(std::ostream & out, array const & node, print_options const & options = {});
void print(std::ostream & out, array_access const & node, print_options const & options = {});
void print(std::ostream & out, expression_ptr const & node, print_options const & options = {}); void print(std::ostream & out, expression_ptr const & node, print_options const & options = {});
void print(std::ostream & out, assignment const & node, print_options const & options = {}); void print(std::ostream & out, assignment const & node, print_options const & options = {});
void print(std::ostream & out, variable_declaration const & node, print_options const & options = {}); void print(std::ostream & out, variable_declaration const & node, print_options const & options = {});

View file

@ -152,6 +152,24 @@ namespace pslang::ast
print(out, argument, child(options)); print(out, argument, child(options));
} }
void print(std::ostream & out, array const & node, print_options const & options)
{
put_indent(out, options);
out << "array";
newline(out);
for (auto const & element : node.elements)
print(out, element, child(options));
}
void print(std::ostream & out, array_access const & node, print_options const & options)
{
put_indent(out, options);
out << "array access";
newline(out);
print(out, node.array, child(options));
print(out, node.index, child(options));
}
void print(std::ostream & out, expression_ptr const & node, print_options const & options) void print(std::ostream & out, expression_ptr const & node, print_options const & options)
{ {
std::visit([&](auto const & value){ print(out, value, options); }, *node); std::visit([&](auto const & value){ print(out, value, options); }, *node);

View file

@ -1,10 +1,12 @@
#pragma once #pragma once
#include <pslang/type/type.hpp> #include <pslang/type/type.hpp>
#include <pslang/interpreter/value_fwd.hpp>
#include <cstdint> #include <cstdint>
#include <variant> #include <variant>
#include <iostream> #include <iostream>
#include <vector>
namespace pslang::interpreter namespace pslang::interpreter
{ {
@ -53,9 +55,17 @@ namespace pslang::interpreter
using primitive_value_impl::primitive_value_impl; using primitive_value_impl::primitive_value_impl;
}; };
struct array_value
{
// Can't infer type from elements in case of zero-sized array
type::type_ptr element_type;
std::vector<value_ptr> elements;
};
using value_impl = std::variant< using value_impl = std::variant<
unit_value, unit_value,
primitive_value primitive_value,
array_value
>; >;
struct value struct value

View file

@ -0,0 +1,12 @@
#pragma once
#include <memory>
namespace pslang::interpreter
{
struct value;
using value_ptr = std::shared_ptr<value>;
}

View file

@ -5,6 +5,7 @@
#include <pslang/type/print.hpp> #include <pslang/type/print.hpp>
#include <sstream> #include <sstream>
#include <optional>
namespace pslang::interpreter namespace pslang::interpreter
{ {
@ -102,12 +103,14 @@ namespace pslang::interpreter
throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined"); throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined");
} }
value unary_operation_impl(ast::unary_operation_type type, unit_value const &) template <typename Value>
value unary_operation_impl(ast::unary_operation_type type, Value const & value)
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot apply unary operator \""; os << "Cannot apply unary operator \"";
print(os, type); print(os, type);
os << "\" to a value of type unit"; os << "\" to a value of type ";
type::print(os, type_of(value));
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
@ -260,12 +263,15 @@ namespace pslang::interpreter
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
value binary_operation_impl_same_type(ast::binary_operation_type type, unit_value const &, value const & arg2) template <typename Value>
value binary_operation_impl_same_type(ast::binary_operation_type type, Value const & arg1, value const & arg2)
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot apply binary operator \""; os << "Cannot apply binary operator \"";
print(os, type); print(os, type);
os << "\" to values of type unit and "; os << "\" to values of type ";
type::print(os, type_of(arg1));
os << " and ";
type::print(os, type_of(arg2)); type::print(os, type_of(arg2));
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
@ -311,6 +317,14 @@ namespace pslang::interpreter
throw std::runtime_error("Cannot cast unit type to anything"); throw std::runtime_error("Cannot cast unit type to anything");
} }
value cast_impl(array_value const & value, type::type const & type)
{
if (type::equal(type, type_of(value)))
return value;
throw std::runtime_error("Cannot cast array type to anything");
}
template <typename T> template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::unit_type const &) value cast_impl(primitive_value_base<T> const & value, type::unit_type const &)
{ {
@ -421,6 +435,124 @@ namespace pslang::interpreter
throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); throw std::runtime_error("Function \"" + function_call.name + "\" is not defined");
} }
value eval_impl(context & context, ast::array const & array)
{
if (array.elements.empty())
throw std::runtime_error("Internal error: array ast node cannot have zero elements");
type::type_ptr element_type;
std::vector<value_ptr> elements;
for (std::size_t i = 0; i < array.elements.size(); ++i)
{
auto element = std::make_unique<value>(eval(context, array.elements[i]));
if (i == 0)
element_type = std::make_unique<type::type>(type_of(*element));
else
{
auto new_type = type_of(*element);
if (!type::equal(*element_type, new_type))
{
std::ostringstream os;
os << "Error forming array: inferred element type is ";
type::print(os, *element_type);
os << " but element #" << i << " type is ";
type::print(os, new_type);
throw std::runtime_error(os.str());
}
}
elements.push_back(std::move(element));
}
return array_value{.element_type = std::move(element_type), .elements = std::move(elements)};
}
value eval_impl(context & context, ast::array_access const & array_access)
{
auto array = eval(context, array_access.array);
auto index = eval(context, array_access.index);
if (auto avalue = std::get_if<array_value>(&array))
{
std::optional<std::uint64_t> index_unsigned;
std::optional<std::int64_t> index_signed;
if (auto pvalue = std::get_if<primitive_value>(&index))
{
if (auto i8 = std::get_if<i8_value>(pvalue))
{
index_signed = i8->value;
}
else if (auto u8 = std::get_if<u8_value>(pvalue))
{
index_unsigned = u8->value;
}
else if (auto i16 = std::get_if<i16_value>(pvalue))
{
index_signed = i16->value;
}
else if (auto u16 = std::get_if<u16_value>(pvalue))
{
index_unsigned = u16->value;
}
else if (auto i32 = std::get_if<i32_value>(pvalue))
{
index_signed = i32->value;
}
else if (auto u32 = std::get_if<u32_value>(pvalue))
{
index_unsigned = u32->value;
}
else if (auto i64 = std::get_if<i64_value>(pvalue))
{
index_signed = i64->value;
}
else if (auto u64 = std::get_if<u64_value>(pvalue))
{
index_unsigned = u64->value;
}
}
if (!index_signed && !index_unsigned)
{
std::ostringstream os;
os << "Cannot index into an array with an expression of type ";
type::print(os, type_of(index));
throw std::runtime_error(os.str());
}
std::uint64_t final_index;
if (index_unsigned)
{
if (*index_unsigned >= avalue->elements.size())
{
std::ostringstream os;
os << "Array index " << *index_unsigned << " out of bounds " << avalue->elements.size();
throw std::runtime_error(os.str());
}
final_index = *index_unsigned;
}
else // if (index_signed)
{
if (*index_signed < 0 || *index_signed >= avalue->elements.size())
{
std::ostringstream os;
os << "Array index " << *index_signed << " out of bounds " << avalue->elements.size();
throw std::runtime_error(os.str());
}
final_index = *index_signed;
}
return *avalue->elements[final_index];
}
std::ostringstream os;
os << "Cannot index into a non-array of type ";
type::print(os, type_of(array));
throw std::runtime_error(os.str());
}
value eval_impl(context & context, ast::expression_ptr const & expression) value eval_impl(context & context, ast::expression_ptr const & expression)
{ {
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);

View file

@ -24,6 +24,11 @@ namespace pslang::interpreter
return std::visit([](auto const & value){ return type_of_impl(value); }, value); return std::visit([](auto const & value){ return type_of_impl(value); }, value);
} }
type::type type_of_impl(array_value const & value)
{
return type::array_type{.element_type = value.element_type, .size = value.elements.size()};
}
type::type type_of_impl(value const & value) type::type type_of_impl(value const & value)
{ {
return std::visit([](auto const & value){ return type_of_impl(value); }, value); return std::visit([](auto const & value){ return type_of_impl(value); }, value);
@ -68,6 +73,18 @@ namespace pslang::interpreter
std::visit([&](auto const & value){ return print_impl(out, value); }, value); std::visit([&](auto const & value){ return print_impl(out, value); }, value);
} }
void print_impl(std::ostream & out, array_value const & value)
{
out << "[";
for (std::size_t i = 0; i < value.elements.size(); ++i)
{
if (i > 0)
out << ", ";
print(out, *value.elements[i]);
}
out << "]";
}
void print_impl(std::ostream & out, value const & value) void print_impl(std::ostream & out, value const & value)
{ {
std::visit([&](auto const & value){ return print_impl(out, value); }, value); std::visit([&](auto const & value){ return print_impl(out, value); }, value);

View file

@ -24,7 +24,7 @@ bison_target(
${PSLANG_PARSER_RULES_FILE} ${PSLANG_PARSER_RULES_FILE}
${PSLANG_PARSER_SOURCE_FILE} ${PSLANG_PARSER_SOURCE_FILE}
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE} DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
COMPILE_FLAGS -Wcounterexamples # COMPILE_FLAGS -Wcounterexamples
) )
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser) add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)

View file

@ -55,6 +55,8 @@ f64 { return bp::make_f64(ctx.location); }
"," { return bp::make_comma(ctx.location); } "," { return bp::make_comma(ctx.location); }
"(" { return bp::make_lparen(ctx.location); } "(" { return bp::make_lparen(ctx.location); }
")" { return bp::make_rparen(ctx.location); } ")" { return bp::make_rparen(ctx.location); }
"[" { return bp::make_lbracket(ctx.location); }
"]" { return bp::make_rbracket(ctx.location); }
"+" { return bp::make_plus(ctx.location); } "+" { return bp::make_plus(ctx.location); }
"-" { return bp::make_minus(ctx.location); } "-" { return bp::make_minus(ctx.location); }
"*" { return bp::make_asterisk(ctx.location); } "*" { return bp::make_asterisk(ctx.location); }

View file

@ -78,6 +78,8 @@ template <typename T>
%token comma "," %token comma ","
%token lparen "(" %token lparen "("
%token rparen ")" %token rparen ")"
%token lbracket "["
%token rbracket "]"
%token plus "+" %token plus "+"
%token minus "-" %token minus "-"
%token asterisk "*" %token asterisk "*"
@ -156,9 +158,10 @@ template <typename T>
%type <ast::expression> sum_expression %type <ast::expression> sum_expression
%type <ast::expression> mult_expression %type <ast::expression> mult_expression
%type <ast::expression> not_expression %type <ast::expression> not_expression
%type <ast::expression> postfix_expression
%type <ast::expression> base_expression %type <ast::expression> base_expression
%type <std::vector<ast::expression_ptr>> function_call_argument_list %type <std::vector<ast::expression_ptr>> comma_separated_expression_list
%type <std::vector<ast::expression_ptr>> nonempty_function_call_argument_list %type <std::vector<ast::expression_ptr>> nonempty_comma_separated_expression_list
%type <ast::expression> literal %type <ast::expression> literal
%% %%
@ -220,6 +223,7 @@ variable_keyword
type_expression type_expression
: unit { $$ = type::type(type::unit_type{}); } : unit { $$ = type::type(type::unit_type{}); }
| primitive_type { $$ = type::type($1); } | primitive_type { $$ = type::type($1); }
| type_expression lbracket lit_i32 rbracket { $$ = type::array_type{std::make_unique<type::type>($1), std::stoull($3)}; }
; ;
primitive_type primitive_type
@ -281,25 +285,31 @@ mult_expression
; ;
not_expression not_expression
: postfix_expression
| exclamation postfix_expression { $$ = ast::unary_operation{ast::unary_operation_type::logical_not, std::make_unique<ast::expression>($2) }; }
;
postfix_expression
: base_expression : base_expression
| exclamation base_expression { $$ = ast::unary_operation{ast::unary_operation_type::logical_not, std::make_unique<ast::expression>($2) }; } | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3)}; }
; ;
base_expression base_expression
: literal : literal
| name { $$ = ast::identifier{$1}; } | name { $$ = ast::identifier{$1}; }
| lparen expression rparen { $$ = $2; } | lparen expression rparen { $$ = $2; }
| name lparen function_call_argument_list rparen { $$ = ast::function_call{$1, $3}; } | name lparen comma_separated_expression_list rparen { $$ = ast::function_call{$1, $3}; }
| lbracket nonempty_comma_separated_expression_list rbracket { $$ = ast::array{$2}; }
; ;
function_call_argument_list comma_separated_expression_list
: %empty { std::vector<ast::expression_ptr> tmp; $$ = std::move(tmp); } : %empty { std::vector<ast::expression_ptr> tmp; $$ = std::move(tmp); }
| nonempty_function_call_argument_list | nonempty_comma_separated_expression_list
; ;
nonempty_function_call_argument_list nonempty_comma_separated_expression_list
: expression { std::vector<ast::expression_ptr> tmp; tmp.push_back(std::make_unique<ast::expression>($1)); $$ = std::move(tmp); } : expression { std::vector<ast::expression_ptr> tmp; tmp.push_back(std::make_unique<ast::expression>($1)); $$ = std::move(tmp); }
| nonempty_function_call_argument_list comma expression { auto tmp = $1; tmp.push_back(std::make_unique<ast::expression>($3)); $$ = std::move(tmp); } | nonempty_comma_separated_expression_list comma expression { auto tmp = $1; tmp.push_back(std::make_unique<ast::expression>($3)); $$ = std::move(tmp); }
; ;
literal literal

View file

@ -0,0 +1,21 @@
#pragma once
#include <pslang/type/type_fwd.hpp>
#include <cstdint>
namespace pslang::type
{
struct array_type
{
type_ptr element_type;
std::uint64_t size;
};
inline bool operator == (array_type const & t1, array_type const & t2)
{
return equal(*t1.element_type, *t2.element_type) && (t1.size == t2.size);
}
}

View file

@ -2,6 +2,7 @@
#include <pslang/type/unit.hpp> #include <pslang/type/unit.hpp>
#include <pslang/type/primitive.hpp> #include <pslang/type/primitive.hpp>
#include <pslang/type/array.hpp>
#include <pslang/type/type_fwd.hpp> #include <pslang/type/type_fwd.hpp>
#include <variant> #include <variant>
@ -11,7 +12,8 @@ namespace pslang::type
using type_impl = std::variant< using type_impl = std::variant<
unit_type, unit_type,
primitive_type primitive_type,
array_type
>; >;
struct type struct type
@ -20,6 +22,4 @@ namespace pslang::type
using type_impl::type_impl; using type_impl::type_impl;
}; };
bool equal(type const & t1, type const & t2);
} }

View file

@ -9,4 +9,5 @@ namespace pslang::type
using type_ptr = std::shared_ptr<type>; using type_ptr = std::shared_ptr<type>;
bool equal(type const & t1, type const & t2);
} }

View file

@ -72,6 +72,12 @@ namespace pslang::type
std::visit([&](auto const & value){ print_impl(out, value); }, type); std::visit([&](auto const & value){ print_impl(out, value); }, type);
} }
void print_impl(std::ostream & out, array_type const & type)
{
print(out, *type.element_type);
out << "[" << type.size << "]";
}
void print_impl(std::ostream & out, type const & type) void print_impl(std::ostream & out, type const & type)
{ {
std::visit([&](auto const & value){ print_impl(out, value); }, type); std::visit([&](auto const & value){ print_impl(out, value); }, type);