Struct types refactor v2: store AST node pointer in struct type

This commit is contained in:
Nikita Lisitsa 2026-03-22 13:34:35 +03:00
parent 78e6a5c0ee
commit 56c63d50ac
14 changed files with 366 additions and 413 deletions

View file

@ -15,6 +15,7 @@ namespace pslang::ast
std::size_t indent_level = 0;
};
void print(std::ostream & out, types::type const & type);
void print(std::ostream & out, type const & node);
void print(std::ostream & out, expression const & node, print_options const & options = {});
void print(std::ostream & out, statement const & node, print_options const & options = {});

View file

@ -3,7 +3,7 @@
#include <pslang/ast/type_visitor.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/types/print.hpp>
#include <pslang/types/type_visitor.hpp>
#include <pslang/types/type.hpp>
#include <iomanip>
@ -26,6 +26,112 @@ namespace pslang::ast
out << options.indent_string;
}
struct raw_type_print_visitor
: types::const_visitor<raw_type_print_visitor>
{
std::ostream & out;
using const_visitor::apply;
void apply(types::unit_type const & type)
{
out << "unit";
}
void apply(types::bool_type const &)
{
out << "bool";
}
void apply(types::i8_type const &)
{
out << "i8";
}
void apply(types::u8_type const &)
{
out << "u8";
}
void apply(types::i16_type const &)
{
out << "i16";
}
void apply(types::u16_type const &)
{
out << "u16";
}
void apply(types::i32_type const &)
{
out << "i32";
}
void apply(types::u32_type const &)
{
out << "u32";
}
void apply(types::i64_type const &)
{
out << "i64";
}
void apply(types::u64_type const &)
{
out << "u64";
}
void apply(types::f16_type const &)
{
out << "f16";
}
void apply(types::f32_type const &)
{
out << "f32";
}
void apply(types::f64_type const &)
{
out << "f64";
}
void apply(types::array_type const & type)
{
apply(*type.element_type);
out << "[" << type.size << "]";
}
void apply(types::function_type const & type)
{
if (type.arguments.size() == 1 && !is_function_type(*type.arguments[0]))
{
apply(*type.arguments.front());
out << " -> ";
apply(*type.result);
return;
}
out << '(';
bool first = true;
for (auto const & argument : type.arguments)
{
if (!first) out << ", ";
first = false;
apply(*argument);
}
out << ") -> ";
apply(*type.result);
}
void apply(types::struct_type const & type)
{
out << type.node->name;
}
};
struct type_print_visitor
: const_type_visitor<type_print_visitor>
{
@ -41,7 +147,7 @@ namespace pslang::ast
template <typename T>
void apply(types::primitive_type_base<T> const & type)
{
types::print(out, types::primitive_type{type});
print(out, types::type{types::primitive_type{type}});
}
void apply(array_type const & type)
@ -396,6 +502,11 @@ namespace pslang::ast
}
void print(std::ostream & out, types::type const & type)
{
raw_type_print_visitor{{}, out}.apply(type);
}
void print(std::ostream & out, type const & node)
{
type_print_visitor{{}, out}.apply(node);

View file

@ -3,7 +3,7 @@
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/error.hpp>
#include <pslang/types/print.hpp>
#include <pslang/ast/print.hpp>
#include <pslang/types/type.hpp>
#include <pslang/types/type_visitor.hpp>
@ -83,10 +83,10 @@ namespace pslang::ast
return {.size = 8, .alignment = 8};
}
size_and_alignment apply(types::named_type const & type)
size_and_alignment apply(types::struct_type const & type)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
if (auto jt = it->structs.find(type.name); jt != it->structs.end())
if (auto jt = it->structs.find(type.node->name); jt != it->structs.end())
{
// TODO: better error message (including the resursive inclusion path)
if (!jt->second.layout_ready && jt->second.layout_being_computed)
@ -98,7 +98,7 @@ namespace pslang::ast
return {.size = layout.size, .alignment = layout.alignment};
}
throw std::runtime_error("Unknown type \"" + type.name + "\"");
throw std::runtime_error("Unknown type \"" + type.node->name + "\"");
}
};
@ -162,10 +162,16 @@ namespace pslang::ast
void apply(ast::type_identifier & node)
{
auto type = types::named_type{};
type.name = node.name;
type.level = node.level;
node.inferred_type = std::make_unique<types::type>(std::move(type));
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->structs.find(node.name); jt != it->structs.end())
{
node.inferred_type = std::make_unique<types::type>(types::struct_type{jt->second.node});
return;
}
}
throw std::runtime_error(std::format("Unknown type \"{}\"", node.name));
}
};
@ -328,7 +334,7 @@ namespace pslang::ast
std::ostringstream os;
os << "Cannot apply " << node.type << " to a value of type ";
types::print(os, *arg1_type);
print(os, *arg1_type);
throw type_error(os.str(), node.location);
}
@ -460,9 +466,9 @@ namespace pslang::ast
std::ostringstream os;
os << "Cannot apply " << node.type << " to values of types ";
types::print(os, *arg1_type);
print(os, *arg1_type);
os << " and ";
types::print(os, *arg2_type);
print(os, *arg2_type);
throw type_error(os.str(), node.location);
}
@ -485,9 +491,9 @@ namespace pslang::ast
std::ostringstream os;
os << "Cannot cast a value of type ";
types::print(os, *source_type);
print(os, *source_type);
os << " to type ";
types::print(os, *target_type);
print(os, *target_type);
throw type_error(os.str(), node.location);
}
@ -514,7 +520,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Cannot call a value of a non-function type ";
types::print(os, *function_type);
print(os, *function_type);
throw type_error(os.str(), node.location);
}
@ -522,7 +528,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
print(os, *function_type);
os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
@ -534,11 +540,11 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
print(os, *function_type);
os << ": argument #" << i << " expected to have type ";
types::print(os, *ftype->arguments[i]);
print(os, *ftype->arguments[i]);
os << " but got type ";
types::print(os, *arg_type);
print(os, *arg_type);
throw type_error(os.str(), get_location(*node.arguments[i]));
}
}
@ -559,19 +565,18 @@ namespace pslang::ast
std::ostringstream os;
os << "Cannot create built-in type ";
types::print(os, *type);
print(os, *type);
os << ": expected 0 arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
else if (auto named_type = std::get_if<types::named_type>(type.get()))
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
auto const & scope = scopes.at(named_type->level);
auto const & struct_node = *scope.structs.at(named_type->name).node;
auto const & struct_node = *struct_type->node;
if (!node.arguments.empty())
{
if (node.arguments.size() != struct_node.fields.size())
throw type_error(std::format("Cannot create struct {}: expected {} arguments, but got {}", named_type->name, struct_node.fields.size(), node.arguments.size()), node.location);
throw type_error(std::format("Cannot create struct {}: expected {} arguments, but got {}", struct_node.name, struct_node.fields.size(), node.arguments.size()), node.location);
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
@ -579,10 +584,10 @@ namespace pslang::ast
if (!types::equal(*arg_type, *struct_node.fields[i].inferred_type))
{
std::ostringstream os;
os << "Cannot create struct " << named_type->name << ": argument #" << i << " expected to have type ";
types::print(os, *struct_node.fields[i].inferred_type);
os << "Cannot create struct " << struct_node.name << ": argument #" << i << " expected to have type ";
print(os, *struct_node.fields[i].inferred_type);
os << " but got type ";
types::print(os, *arg_type);
print(os, *arg_type);
throw type_error(os.str(), node.location);
}
}
@ -617,9 +622,9 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Failed to infer array type: element #0 has type ";
types::print(os, *element_type);
print(os, *element_type);
os << " but element #" << i << " has type ";
types::print(os, *current_type);
print(os, *current_type);
}
}
@ -641,7 +646,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Expected an integer type as index, but got ";
types::print(os, *index_type);
print(os, *index_type);
throw type_error(os.str(), get_location(*node.index));
}
@ -651,7 +656,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Expected an array to index, but got ";
types::print(os, *array_type);
print(os, *array_type);
throw type_error(os.str(), get_location(*node.array));
}
@ -663,17 +668,17 @@ namespace pslang::ast
apply(*node.object);
auto object_type = get_type(*node.object);
auto named_type = std::get_if<types::named_type>(object_type.get());
auto struct_type = std::get_if<types::struct_type>(object_type.get());
if (!named_type)
if (!struct_type)
{
std::ostringstream os;
os << "Expected a struct, but got ";
types::print(os, *object_type);
print(os, *object_type);
throw type_error(os.str(), get_location(*node.object));
}
auto const & struct_node = *scopes.at(named_type->level).structs.at(named_type->name).node;
auto const & struct_node = *struct_type->node;
for (auto const & field : struct_node.fields)
{
@ -684,7 +689,7 @@ namespace pslang::ast
}
}
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", named_type->name, node.field_name), node.location);
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
}
void apply(expression_ptr const & node)
@ -714,9 +719,9 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Cannot assign a value of type ";
types::print(os, *rtype);
print(os, *rtype);
os << " to an expression of type ";
types::print(os, *ltype);
print(os, *ltype);
throw type_error(os.str(), node.location);
};
}
@ -733,9 +738,9 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Cannot initialize a variable of type ";
types::print(os, *expected_type);
print(os, *expected_type);
os << " with an expression of type ";
types::print(os, *actual_type);
print(os, *actual_type);
throw type_error(os.str(), node.location);
}
}
@ -773,7 +778,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "if condition expects a bool type, but got ";
types::print(os, *actual_type);
print(os, *actual_type);
throw type_error(os.str(), get_location(*block.condition));
}
}
@ -792,7 +797,7 @@ namespace pslang::ast
{
std::ostringstream os;
os << "while condition expects a bool type, but got ";
types::print(os, *actual_type);
print(os, *actual_type);
throw type_error(os.str(), get_location(*node.condition));
}
@ -844,9 +849,9 @@ namespace pslang::ast
{
std::ostringstream os;
os << "Returning value of type ";
types::print(os, *actual_type);
print(os, *actual_type);
os << " from a function returning ";
types::print(os, *return_scope.expected_return_type);
print(os, *return_scope.expected_return_type);
throw type_error(os.str(), node.location);
}
}

View file

@ -1,6 +1,6 @@
#include <pslang/interpreter/context.hpp>
#include <pslang/types/type_visitor.hpp>
#include <pslang/types/print.hpp>
#include <pslang/ast/print.hpp>
namespace pslang::interpreter
{
@ -39,13 +39,11 @@ namespace pslang::interpreter
throw std::runtime_error("Cannot zero-initialize a function type");
}
value apply(types::named_type const & named_type)
value apply(types::struct_type const & struct_type)
{
auto const & struct_data = context.frame_stack.at(named_type.level).structs.at(named_type.name);
struct_value result{.struct_type = std::make_unique<types::type>(named_type)};
for (auto const & field : struct_data.fields)
result.fields[field.name] = std::make_unique<value>(apply(*field.type));
struct_value result{.struct_type = std::make_unique<types::type>(struct_type)};
for (auto const & field : struct_type.node->fields)
result.fields[field.name] = std::make_unique<value>(apply(*field.inferred_type));
return std::move(result);
}
};
@ -68,7 +66,7 @@ namespace pslang::interpreter
out << variable.first << " = ";
print(out, variable.second.value);
out << " (";
types::print(out, type_of(variable.second.value));
ast::print(out, type_of(variable.second.value));
out << ")\n";
}
}

View file

@ -4,7 +4,7 @@
#include <pslang/interpreter/error.hpp>
#include <pslang/interpreter/foreign.hpp>
#include <pslang/ast/expression.hpp>
#include <pslang/types/print.hpp>
#include <pslang/ast/print.hpp>
#include <sstream>
#include <optional>
@ -61,7 +61,7 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot index into an array with an expression of type ";
types::print(os, type_of(index));
ast::print(os, type_of(index));
throw internal_error(os.str(), location);
}
@ -108,7 +108,7 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot apply " << type << " to a value of type ";
types::print(os, type_of(value));
ast::print(os, type_of(value));
throw internal_error(os.str(), location);
}
@ -137,7 +137,7 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot apply " << type << " to a value of type ";
types::print(os, type_of(primitive_value(arg1)));
ast::print(os, type_of(primitive_value(arg1)));
throw internal_error(os.str(), location);
}
@ -267,9 +267,9 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(primitive_value(arg1)));
ast::print(os, type_of(primitive_value(arg1)));
os << " and ";
types::print(os, type_of(primitive_value(arg2)));
ast::print(os, type_of(primitive_value(arg2)));
throw internal_error(os.str(), location);
}
@ -278,9 +278,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(arg1));
ast::print(os, type_of(arg1));
os << " and ";
types::print(os, type_of(arg2));
ast::print(os, type_of(arg2));
throw internal_error(os.str(), location);
}
@ -299,7 +299,7 @@ namespace pslang::interpreter
{
if (!arg1.value)
return primitive_value(primitive_value_base<T>{false});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
@ -308,7 +308,7 @@ namespace pslang::interpreter
{
if (arg1.value == T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
@ -319,7 +319,7 @@ namespace pslang::interpreter
{
if (arg1.value)
return primitive_value(primitive_value_base<T>{true});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
@ -328,7 +328,7 @@ namespace pslang::interpreter
{
if (arg1.value == ~T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
@ -340,9 +340,9 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(primitive_value(arg1)));
ast::print(os, type_of(primitive_value(arg1)));
os << " and ";
types::print(os, type_of(lazy_arg2()));
ast::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
@ -351,9 +351,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(arg1));
ast::print(os, type_of(arg1));
os << " and ";
types::print(os, type_of(lazy_arg2()));
ast::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
@ -374,9 +374,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot apply " << binary_operation.type << " to values of type ";
types::print(os, *type1);
ast::print(os, *type1);
os << " and ";
types::print(os, *type2);
ast::print(os, *type2);
throw internal_error(os.str(), binary_operation.location);
}
@ -444,9 +444,9 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot cast value of type ";
types::print(os, type_of(primitive_value(value)));
ast::print(os, type_of(primitive_value(value)));
os << " to type ";
types::print(os, types::primitive_type(type));
ast::print(os, types::type{types::primitive_type(type)});
throw internal_error(os.str(), location);
}
@ -528,9 +528,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type ";
types::print(os, *fcommon->arguments[i].type);
ast::print(os, *fcommon->arguments[i].type);
os << " but actual type is ";
types::print(os, actual_type);
ast::print(os, actual_type);
throw internal_error(os.str(), ast::get_location(*function_call.arguments[i]));
}
}
@ -572,22 +572,21 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot create built-in type ";
types::print(os, *type);
ast::print(os, *type);
os << ": expected 0 arguments, but got " << function_call.arguments.size();
throw internal_error(os.str(), function_call.location);
}
else if (auto named_type = std::get_if<types::named_type>(type.get()))
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
auto const & scope = context.frame_stack.at(named_type->level);
auto const & data = scope.structs.at(named_type->name);
auto const & struct_node = *struct_type->node;
if (function_call.arguments.empty())
return zero_value(context, *type);
if (data.fields.size() != function_call.arguments.size())
if (struct_node.fields.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot create struct \"" << named_type->name << "\": expected " << data.fields.size() << " arguments, got " << function_call.arguments.size();
os << "Cannot create struct \"" << struct_node.name << "\": expected " << struct_node.fields.size() << " arguments, got " << function_call.arguments.size();
throw internal_error(os.str(), function_call.location);
}
@ -598,7 +597,7 @@ namespace pslang::interpreter
std::unordered_map<std::string, value_ptr> fields;
for (std::size_t i = 0; i < args.size(); ++i)
fields[data.fields[i].name] = std::make_unique<value>(std::move(args[i]));
fields[struct_node.fields[i].name] = std::make_unique<value>(std::move(args[i]));
return struct_value{
.struct_type = type,
@ -631,9 +630,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Error forming array: inferred element type is ";
types::print(os, *element_type);
ast::print(os, *element_type);
os << " but element #" << i << " type is ";
types::print(os, new_type);
ast::print(os, new_type);
throw internal_error(os.str(), array.location);
}
}
@ -654,7 +653,7 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot index into a non-array of type ";
types::print(os, type_of(array));
ast::print(os, type_of(array));
throw internal_error(os.str(), array_access.location);
}
@ -668,14 +667,14 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Struct ";
types::print(os, type_of(object));
ast::print(os, type_of(object));
os << " has no field named \"" << field_access.field_name << "\"";
throw internal_error(os.str(), field_access.location);
}
std::ostringstream os;
os << "Value of type ";
types::print(os, type_of(object));
ast::print(os, type_of(object));
os << " is not a struct";
throw internal_error(os.str(), field_access.location);
}
@ -741,7 +740,7 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Cannot index into a non-array of type ";
types::print(os, type_of(*array_ref));
ast::print(os, type_of(*array_ref));
throw internal_error(os.str(), array_access.location);
}
@ -755,14 +754,14 @@ namespace pslang::interpreter
std::ostringstream os;
os << "Struct ";
types::print(os, type_of(*object_ref));
ast::print(os, type_of(*object_ref));
os << " has no field named \"" << field_access.field_name << "\"";
throw internal_error(os.str(), field_access.location);
}
std::ostringstream os;
os << "Value of type ";
types::print(os, type_of(*object_ref));
ast::print(os, type_of(*object_ref));
os << " is not a struct";
throw internal_error(os.str(), field_access.location);
}

View file

@ -2,7 +2,7 @@
#include <pslang/interpreter/eval.hpp>
#include <pslang/interpreter/error.hpp>
#include <pslang/ast/statement.hpp>
#include <pslang/types/print.hpp>
#include <pslang/ast/print.hpp>
#include <stdexcept>
#include <sstream>
@ -48,9 +48,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot assign a value of type ";
types::print(os, new_type);
ast::print(os, new_type);
os << " to a variable of type ";
types::print(os, existing_type);
ast::print(os, existing_type);
throw internal_error(os.str(), assignment.location);
}
@ -72,9 +72,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Cannot initialize a variable of type ";
types::print(os, *expected_type);
ast::print(os, *expected_type);
os << " with an expression of type ";
types::print(os, actual_type);
ast::print(os, actual_type);
throw internal_error(os.str(), variable_declaration.location);
}
}
@ -110,7 +110,7 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Expected type bool, got type ";
types::print(os, actual_type);
ast::print(os, actual_type);
os << " in if block condition";
throw internal_error(os.str(), get_location(*block.condition));
}
@ -138,7 +138,7 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Expected type bool, got type ";
types::print(os, actual_type);
ast::print(os, actual_type);
os << " in while block condition";
throw internal_error(os.str(), get_location(*while_block.condition));
}
@ -212,9 +212,9 @@ namespace pslang::interpreter
{
std::ostringstream os;
os << "Returning value of type ";
types::print(os, actual_type);
ast::print(os, actual_type);
os << " from a function returning ";
types::print(os, *frame.expected_return_type);
ast::print(os, *frame.expected_return_type);
throw internal_error(os.str(), return_statement.location);
}

View file

@ -1,6 +1,5 @@
#include <pslang/interpreter/value.hpp>
#include <pslang/interpreter/value_visitor.hpp>
#include <pslang/types/print.hpp>
#include <iomanip>

View file

@ -1,7 +1,7 @@
#include <pslang/ir/print.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ir/compiler.hpp>
#include <pslang/types/print.hpp>
#include <pslang/ast/print.hpp>
#include <unordered_map>
#include <iomanip>
@ -24,7 +24,7 @@ namespace pslang::ir
break;
}
}
void print(std::ostream & out, ast::binary_operation_type type)
{
switch (type)
@ -79,7 +79,7 @@ namespace pslang::ir
break;
}
}
struct print_literal_visitor
{
std::ostream & out;
@ -268,7 +268,7 @@ namespace pslang::ir
{
print_visitor visitor{out};
visitor.fill_index(nodes);
std::size_t index = 0;
for (auto const & node : nodes)
{
@ -286,7 +286,7 @@ namespace pslang::ir
if (node.inferred_type)
{
out << " : ";
types::print(out, *node.inferred_type);
ast::print(out, *node.inferred_type);
}
out << '\n';
}
@ -304,4 +304,4 @@ namespace pslang::ir
print_impl(out, context.nodes, &context);
}
}
}

View file

@ -318,24 +318,19 @@ namespace pslang::jit::aarch64
type = field.inferred_type;
++count;
}
else if (auto named_type = std::get_if<types::named_type>(field.inferred_type.get()))
else if (auto struct_type = std::get_if<types::struct_type>(field.inferred_type.get()))
{
// NB: recursion must be impossible due to prior checks in type checker
if (auto struct_node = lcontext.is_struct(named_type->name))
if (auto subdata = get_hfa_data(*struct_type->node, lcontext))
{
if (auto subdata = get_hfa_data(*struct_node, lcontext))
{
if (type && !types::equal(*type, *subdata->type))
return std::nullopt;
type = subdata->type;
count += subdata->count;
}
else
if (type && !types::equal(*type, *subdata->type))
return std::nullopt;
type = subdata->type;
count += subdata->count;
}
else
throw std::runtime_error("Unknown named type: \"" + named_type->name + "\"");
return std::nullopt;
}
else
return std::nullopt;
@ -477,7 +472,7 @@ namespace pslang::jit::aarch64
};
std::vector<scope> scopes;
template <typename Node>
void apply(Node const &)
{
@ -548,25 +543,20 @@ namespace pslang::jit::aarch64
{
if (auto jt = it->variables.find(node.name); jt != it->variables.end())
{
if (auto named_type = std::get_if<types::named_type>(node.inferred_type.get()))
if (auto struct_type = std::get_if<types::struct_type>(node.inferred_type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
std::size_t stack_size = ((struct_type->node->layout.size + 15) / 16) * 16;
builder.sub_imm(31, 31, stack_size);
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
std::size_t variable_offset = stack_offset - jt->second.frame_offset;
for (std::size_t offset = 0; offset < stack_size; offset += 16)
{
std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16;
builder.sub_imm(31, 31, stack_size);
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
std::size_t variable_offset = stack_offset - jt->second.frame_offset;
for (std::size_t offset = 0; offset < stack_size; offset += 16)
{
builder.ldr(0, 31, (variable_offset + offset) / 8);
builder.ldr(1, 31, (variable_offset + offset) / 8 + 1);
builder.str(0, 31, offset / 8);
builder.str(1, 31, offset / 8 + 1);
}
builder.ldr(0, 31, (variable_offset + offset) / 8);
builder.ldr(1, 31, (variable_offset + offset) / 8 + 1);
builder.str(0, 31, offset / 8);
builder.str(1, 31, offset / 8 + 1);
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
else if (types::is_unit_type(*node.inferred_type))
{}
@ -944,119 +934,110 @@ namespace pslang::jit::aarch64
builder.xor_reg(0, 0, 0);
builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1);
}
else if (auto named_type = std::get_if<types::named_type>(node.inferred_type.get()))
else if (auto struct_type = std::get_if<types::struct_type>(node.inferred_type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
auto & struct_node = *struct_type->node;
// Allocate stack space for the struct
std::size_t stack_size = ((struct_node.layout.size + 15) / 16) * 16;
auto offset = stack_offset;
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
builder.sub_imm(31, 31, stack_size);
// Evaluate each field of the struct (i.e. each constructor argument)
// and copy it to the corresponding place in the struct
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
// Allocate stack space for the struct
std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16;
auto offset = stack_offset;
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
builder.sub_imm(31, 31, stack_size);
auto type = ast::get_type(*node.arguments[i]);
apply(*node.arguments[i]);
// Evaluate each field of the struct (i.e. each constructor argument)
// and copy it to the corresponding place in the struct
for (std::size_t i = 0; i < node.arguments.size(); ++i)
if (std::get_if<types::struct_type>(type.get()))
{
auto type = ast::get_type(*node.arguments[i]);
apply(*node.arguments[i]);
// TODO: struct field
throw std::runtime_error("Not implemented");
}
else if (types::is_floating_point_type(*type))
{
builder.stur_fp(0, fp_mode_for(*type), 31, struct_node.fields[i].layout.offset);
}
else
{
auto size = types::type_size(*type);
if (std::get_if<types::named_type>(type.get()))
{
// TODO: struct field
throw std::runtime_error("Not implemented");
}
else if (types::is_floating_point_type(*type))
{
builder.stur_fp(0, fp_mode_for(*type), 31, struct_node->fields[i].layout.offset);
}
else
{
auto size = types::type_size(*type);
if (size == 1)
builder.sturb(0, 31, struct_node->fields[i].layout.offset);
else if (size == 2)
builder.sturh(0, 31, struct_node->fields[i].layout.offset);
else if (size == 4)
builder.sturw(0, 31, struct_node->fields[i].layout.offset);
else if (size == 8)
builder.stur(0, 31, struct_node->fields[i].layout.offset);
}
if (size == 1)
builder.sturb(0, 31, struct_node.fields[i].layout.offset);
else if (size == 2)
builder.sturh(0, 31, struct_node.fields[i].layout.offset);
else if (size == 4)
builder.sturw(0, 31, struct_node.fields[i].layout.offset);
else if (size == 8)
builder.stur(0, 31, struct_node.fields[i].layout.offset);
}
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
}
}
void apply(ast::field_access const & node)
{
auto struct_type = get_type(*node.object);
if (auto named_type = std::get_if<types::named_type>(struct_type.get()))
auto object_type = get_type(*node.object);
if (auto struct_type = std::get_if<types::struct_type>(object_type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
auto & struct_node = *struct_type->node;
std::optional<std::size_t> field_id;
for (std::size_t i = 0; i < struct_node.fields.size(); ++i)
{
std::size_t field_id = -1;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
if (struct_node.fields[i].name == node.field_name)
{
if (struct_node->fields[i].name == node.field_name)
{
field_id = i;
break;
}
field_id = i;
break;
}
if (field_id == -1)
throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + named_type->name + "\"");
apply(*node.object);
auto stack_size = ((struct_node->layout.size + 15) / 16) * 16;
auto const & field = struct_node->fields[field_id];
if (types::is_unit_type(*field.inferred_type))
{}
else if (types::is_floating_point_type(*field.inferred_type))
{
builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type))
{
auto size = types::type_size(*field.inferred_type);
if (size == 1)
builder.ldurb(0, 31, field.layout.offset);
else if (size == 2)
builder.ldurh(0, 31, field.layout.offset);
else if (size == 4)
builder.ldurw(0, 31, field.layout.offset);
else if (size == 8)
builder.ldur(0, 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (auto named_type = std::get_if<types::named_type>(field.inferred_type.get()))
{
if (auto field_struct_node = lcontext.is_struct(named_type->name))
{
// TODO: copy the struct-typed field on stack, overriding
// the struct itself, and update the stack offset
throw std::runtime_error("Not implemented");
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
return;
}
if (!field_id)
throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + struct_node.name + "\"");
apply(*node.object);
auto stack_size = ((struct_node.layout.size + 15) / 16) * 16;
auto const & field = struct_node.fields[*field_id];
if (types::is_unit_type(*field.inferred_type))
{}
else if (types::is_floating_point_type(*field.inferred_type))
{
builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type))
{
auto size = types::type_size(*field.inferred_type);
if (size == 1)
builder.ldurb(0, 31, field.layout.offset);
else if (size == 2)
builder.ldurh(0, 31, field.layout.offset);
else if (size == 4)
builder.ldurw(0, 31, field.layout.offset);
else if (size == 8)
builder.ldur(0, 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (auto struct_type = std::get_if<types::struct_type>(field.inferred_type.get()))
{
// TODO: copy the struct-typed field on stack, overriding
// the struct itself, and update the stack offset
throw std::runtime_error("Not implemented");
}
return;
}
throw std::runtime_error("Unknown object in field access");
@ -1100,15 +1081,10 @@ namespace pslang::jit::aarch64
else if (size == 8)
builder.stur(0, 31, stack_offset - frame_offset);
}
else if (auto named_type = std::get_if<types::named_type>(type.get()))
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{
// TODO: whole-struct assignment
throw std::runtime_error("Not implemented");
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
// TODO: whole-struct assignment
throw std::runtime_error("Not implemented");
}
}
@ -1116,7 +1092,7 @@ namespace pslang::jit::aarch64
{
apply(*node.initializer);
auto type = ast::get_type(*node.initializer);
if (std::get_if<types::named_type>(type.get()))
if (std::get_if<types::struct_type>(type.get()))
{
// Nothing to be done: the struct is already on the stack
// Just record the stack offset as variable location
@ -1180,7 +1156,7 @@ namespace pslang::jit::aarch64
apply(*node.condition);
std::int32_t skip = pcontext.code.size();
builder.cbz(0, 0);
scopes.emplace_back();
apply(*node.statements);
scope_cleanup();
@ -1344,25 +1320,21 @@ namespace pslang::jit::aarch64
{
auto base_offset = lvalue_offset(field_access->object);
auto type = ast::get_type(*field_access->object);
if (auto named_type = std::get_if<types::named_type>(type.get()))
if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{
std::size_t field_id = -1;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
if (struct_node->fields[i].name == field_access->field_name)
{
field_id = i;
break;
}
auto & struct_node = *struct_type->node;
std::optional<std::size_t> field_id;
for (std::size_t i = 0; i < struct_node.fields.size(); ++i)
if (struct_node.fields[i].name == field_access->field_name)
{
field_id = i;
break;
}
if (field_id == -1)
throw std::runtime_error("Invalid field \"" + field_access->field_name + "\"");
if (!field_id)
throw std::runtime_error("Invalid field \"" + field_access->field_name + "\"");
return base_offset - struct_node->fields[field_id].layout.offset;
}
else
throw std::runtime_error("Invalid struct \"" + named_type->name + "\"");
return base_offset - struct_node.fields[*field_id].layout.offset;
}
else
throw std::runtime_error("Invalid field access node");
@ -1472,4 +1444,4 @@ namespace pslang::jit::aarch64
pcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}
}
}

View file

@ -1,19 +0,0 @@
#pragma once
#include <string>
namespace pslang::types
{
struct named_type
{
std::string name;
std::size_t level;
};
inline bool operator == (named_type const & t1, named_type const & t2)
{
return (t1.level == t2.level) && (t1.name == t2.name);
}
}

View file

@ -1,12 +0,0 @@
#pragma once
#include <pslang/types/type_fwd.hpp>
#include <iostream>
namespace pslang::types
{
void print(std::ostream & out, type const & type);
}

View file

@ -0,0 +1,22 @@
#pragma once
#include <pslang/types/type_fwd.hpp>
namespace pslang::ast
{
struct struct_definition;
}
namespace pslang::types
{
struct struct_type
{
ast::struct_definition * node;
friend bool operator == (struct_type const &, struct_type const &) = default;
};
}

View file

@ -4,7 +4,7 @@
#include <pslang/types/primitive.hpp>
#include <pslang/types/array.hpp>
#include <pslang/types/function.hpp>
#include <pslang/types/named.hpp>
#include <pslang/types/struct.hpp>
#include <pslang/types/type_fwd.hpp>
#include <variant>
@ -17,7 +17,7 @@ namespace pslang::types
primitive_type,
array_type,
function_type,
named_type
struct_type
>;
struct type

View file

@ -1,123 +0,0 @@
#include <pslang/types/print.hpp>
#include <pslang/types/type_visitor.hpp>
namespace pslang::types
{
namespace
{
struct print_visitor
: const_visitor<print_visitor>
{
std::ostream & out;
using const_visitor::apply;
void apply(unit_type const & type)
{
out << "unit";
}
void apply(bool_type const &)
{
out << "bool";
}
void apply(i8_type const &)
{
out << "i8";
}
void apply(u8_type const &)
{
out << "u8";
}
void apply(i16_type const &)
{
out << "i16";
}
void apply(u16_type const &)
{
out << "u16";
}
void apply(i32_type const &)
{
out << "i32";
}
void apply(u32_type const &)
{
out << "u32";
}
void apply(i64_type const &)
{
out << "i64";
}
void apply(u64_type const &)
{
out << "u64";
}
void apply(f16_type const &)
{
out << "f16";
}
void apply(f32_type const &)
{
out << "f32";
}
void apply(f64_type const &)
{
out << "f64";
}
void apply(array_type const & type)
{
apply(*type.element_type);
out << "[" << type.size << "]";
}
void apply(function_type const & type)
{
if (type.arguments.size() == 1 && !is_function_type(*type.arguments[0]))
{
apply(*type.arguments.front());
out << " -> ";
apply(*type.result);
return;
}
out << '(';
bool first = true;
for (auto const & argument : type.arguments)
{
if (!first) out << ", ";
first = false;
apply(*argument);
}
out << ") -> ";
apply(*type.result);
}
void apply(named_type const & type)
{
out << type.name;
}
};
}
void print(std::ostream & out, type const & type)
{
print_visitor{{}, out}.apply(type);
}
}