diff --git a/libs/ast/include/pslang/ast/print.hpp b/libs/ast/include/pslang/ast/print.hpp index db01a8f..841e2f4 100644 --- a/libs/ast/include/pslang/ast/print.hpp +++ b/libs/ast/include/pslang/ast/print.hpp @@ -15,6 +15,7 @@ namespace pslang::ast std::size_t indent_level = 0; }; + void print(std::ostream & out, types::type const & type); void print(std::ostream & out, type const & node); void print(std::ostream & out, expression const & node, print_options const & options = {}); void print(std::ostream & out, statement const & node, print_options const & options = {}); diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 3cb639b..5a7fbbc 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include @@ -26,6 +26,112 @@ namespace pslang::ast out << options.indent_string; } + struct raw_type_print_visitor + : types::const_visitor + { + std::ostream & out; + + using const_visitor::apply; + + void apply(types::unit_type const & type) + { + out << "unit"; + } + + void apply(types::bool_type const &) + { + out << "bool"; + } + + void apply(types::i8_type const &) + { + out << "i8"; + } + + void apply(types::u8_type const &) + { + out << "u8"; + } + + void apply(types::i16_type const &) + { + out << "i16"; + } + + void apply(types::u16_type const &) + { + out << "u16"; + } + + void apply(types::i32_type const &) + { + out << "i32"; + } + + void apply(types::u32_type const &) + { + out << "u32"; + } + + void apply(types::i64_type const &) + { + out << "i64"; + } + + void apply(types::u64_type const &) + { + out << "u64"; + } + + void apply(types::f16_type const &) + { + out << "f16"; + } + + void apply(types::f32_type const &) + { + out << "f32"; + } + + void apply(types::f64_type const &) + { + out << "f64"; + } + + void apply(types::array_type const & type) + { + apply(*type.element_type); + out << "[" << type.size << "]"; + } + + void apply(types::function_type const & type) + { + if (type.arguments.size() == 1 && !is_function_type(*type.arguments[0])) + { + apply(*type.arguments.front()); + out << " -> "; + apply(*type.result); + return; + } + + out << '('; + bool first = true; + for (auto const & argument : type.arguments) + { + if (!first) out << ", "; + first = false; + apply(*argument); + } + out << ") -> "; + apply(*type.result); + } + + void apply(types::struct_type const & type) + { + out << type.node->name; + } + }; + struct type_print_visitor : const_type_visitor { @@ -41,7 +147,7 @@ namespace pslang::ast template void apply(types::primitive_type_base const & type) { - types::print(out, types::primitive_type{type}); + print(out, types::type{types::primitive_type{type}}); } void apply(array_type const & type) @@ -396,6 +502,11 @@ namespace pslang::ast } + void print(std::ostream & out, types::type const & type) + { + raw_type_print_visitor{{}, out}.apply(type); + } + void print(std::ostream & out, type const & node) { type_print_visitor{{}, out}.apply(node); diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 5dca68b..e8c370d 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include @@ -83,10 +83,10 @@ namespace pslang::ast return {.size = 8, .alignment = 8}; } - size_and_alignment apply(types::named_type const & type) + size_and_alignment apply(types::struct_type const & type) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) - if (auto jt = it->structs.find(type.name); jt != it->structs.end()) + if (auto jt = it->structs.find(type.node->name); jt != it->structs.end()) { // TODO: better error message (including the resursive inclusion path) if (!jt->second.layout_ready && jt->second.layout_being_computed) @@ -98,7 +98,7 @@ namespace pslang::ast return {.size = layout.size, .alignment = layout.alignment}; } - throw std::runtime_error("Unknown type \"" + type.name + "\""); + throw std::runtime_error("Unknown type \"" + type.node->name + "\""); } }; @@ -162,10 +162,16 @@ namespace pslang::ast void apply(ast::type_identifier & node) { - auto type = types::named_type{}; - type.name = node.name; - type.level = node.level; - node.inferred_type = std::make_unique(std::move(type)); + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + if (auto jt = it->structs.find(node.name); jt != it->structs.end()) + { + node.inferred_type = std::make_unique(types::struct_type{jt->second.node}); + return; + } + } + + throw std::runtime_error(std::format("Unknown type \"{}\"", node.name)); } }; @@ -328,7 +334,7 @@ namespace pslang::ast std::ostringstream os; os << "Cannot apply " << node.type << " to a value of type "; - types::print(os, *arg1_type); + print(os, *arg1_type); throw type_error(os.str(), node.location); } @@ -460,9 +466,9 @@ namespace pslang::ast std::ostringstream os; os << "Cannot apply " << node.type << " to values of types "; - types::print(os, *arg1_type); + print(os, *arg1_type); os << " and "; - types::print(os, *arg2_type); + print(os, *arg2_type); throw type_error(os.str(), node.location); } @@ -485,9 +491,9 @@ namespace pslang::ast std::ostringstream os; os << "Cannot cast a value of type "; - types::print(os, *source_type); + print(os, *source_type); os << " to type "; - types::print(os, *target_type); + print(os, *target_type); throw type_error(os.str(), node.location); } @@ -514,7 +520,7 @@ namespace pslang::ast { std::ostringstream os; os << "Cannot call a value of a non-function type "; - types::print(os, *function_type); + print(os, *function_type); throw type_error(os.str(), node.location); } @@ -522,7 +528,7 @@ namespace pslang::ast { std::ostringstream os; os << "Cannot call function " << function_name << " of type "; - types::print(os, *function_type); + print(os, *function_type); os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size(); throw type_error(os.str(), node.location); } @@ -534,11 +540,11 @@ namespace pslang::ast { std::ostringstream os; os << "Cannot call function " << function_name << " of type "; - types::print(os, *function_type); + print(os, *function_type); os << ": argument #" << i << " expected to have type "; - types::print(os, *ftype->arguments[i]); + print(os, *ftype->arguments[i]); os << " but got type "; - types::print(os, *arg_type); + print(os, *arg_type); throw type_error(os.str(), get_location(*node.arguments[i])); } } @@ -559,19 +565,18 @@ namespace pslang::ast std::ostringstream os; os << "Cannot create built-in type "; - types::print(os, *type); + print(os, *type); os << ": expected 0 arguments, but got " << node.arguments.size(); throw type_error(os.str(), node.location); } - else if (auto named_type = std::get_if(type.get())) + else if (auto struct_type = std::get_if(type.get())) { - auto const & scope = scopes.at(named_type->level); - auto const & struct_node = *scope.structs.at(named_type->name).node; + auto const & struct_node = *struct_type->node; if (!node.arguments.empty()) { if (node.arguments.size() != struct_node.fields.size()) - throw type_error(std::format("Cannot create struct {}: expected {} arguments, but got {}", named_type->name, struct_node.fields.size(), node.arguments.size()), node.location); + throw type_error(std::format("Cannot create struct {}: expected {} arguments, but got {}", struct_node.name, struct_node.fields.size(), node.arguments.size()), node.location); for (std::size_t i = 0; i < node.arguments.size(); ++i) { @@ -579,10 +584,10 @@ namespace pslang::ast if (!types::equal(*arg_type, *struct_node.fields[i].inferred_type)) { std::ostringstream os; - os << "Cannot create struct " << named_type->name << ": argument #" << i << " expected to have type "; - types::print(os, *struct_node.fields[i].inferred_type); + os << "Cannot create struct " << struct_node.name << ": argument #" << i << " expected to have type "; + print(os, *struct_node.fields[i].inferred_type); os << " but got type "; - types::print(os, *arg_type); + print(os, *arg_type); throw type_error(os.str(), node.location); } } @@ -617,9 +622,9 @@ namespace pslang::ast { std::ostringstream os; os << "Failed to infer array type: element #0 has type "; - types::print(os, *element_type); + print(os, *element_type); os << " but element #" << i << " has type "; - types::print(os, *current_type); + print(os, *current_type); } } @@ -641,7 +646,7 @@ namespace pslang::ast { std::ostringstream os; os << "Expected an integer type as index, but got "; - types::print(os, *index_type); + print(os, *index_type); throw type_error(os.str(), get_location(*node.index)); } @@ -651,7 +656,7 @@ namespace pslang::ast { std::ostringstream os; os << "Expected an array to index, but got "; - types::print(os, *array_type); + print(os, *array_type); throw type_error(os.str(), get_location(*node.array)); } @@ -663,17 +668,17 @@ namespace pslang::ast apply(*node.object); auto object_type = get_type(*node.object); - auto named_type = std::get_if(object_type.get()); + auto struct_type = std::get_if(object_type.get()); - if (!named_type) + if (!struct_type) { std::ostringstream os; os << "Expected a struct, but got "; - types::print(os, *object_type); + print(os, *object_type); throw type_error(os.str(), get_location(*node.object)); } - auto const & struct_node = *scopes.at(named_type->level).structs.at(named_type->name).node; + auto const & struct_node = *struct_type->node; for (auto const & field : struct_node.fields) { @@ -684,7 +689,7 @@ namespace pslang::ast } } - throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", named_type->name, node.field_name), node.location); + throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location); } void apply(expression_ptr const & node) @@ -714,9 +719,9 @@ namespace pslang::ast { std::ostringstream os; os << "Cannot assign a value of type "; - types::print(os, *rtype); + print(os, *rtype); os << " to an expression of type "; - types::print(os, *ltype); + print(os, *ltype); throw type_error(os.str(), node.location); }; } @@ -733,9 +738,9 @@ namespace pslang::ast { std::ostringstream os; os << "Cannot initialize a variable of type "; - types::print(os, *expected_type); + print(os, *expected_type); os << " with an expression of type "; - types::print(os, *actual_type); + print(os, *actual_type); throw type_error(os.str(), node.location); } } @@ -773,7 +778,7 @@ namespace pslang::ast { std::ostringstream os; os << "if condition expects a bool type, but got "; - types::print(os, *actual_type); + print(os, *actual_type); throw type_error(os.str(), get_location(*block.condition)); } } @@ -792,7 +797,7 @@ namespace pslang::ast { std::ostringstream os; os << "while condition expects a bool type, but got "; - types::print(os, *actual_type); + print(os, *actual_type); throw type_error(os.str(), get_location(*node.condition)); } @@ -844,9 +849,9 @@ namespace pslang::ast { std::ostringstream os; os << "Returning value of type "; - types::print(os, *actual_type); + print(os, *actual_type); os << " from a function returning "; - types::print(os, *return_scope.expected_return_type); + print(os, *return_scope.expected_return_type); throw type_error(os.str(), node.location); } } diff --git a/libs/interpreter/source/context.cpp b/libs/interpreter/source/context.cpp index 53d4743..09b5123 100644 --- a/libs/interpreter/source/context.cpp +++ b/libs/interpreter/source/context.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include namespace pslang::interpreter { @@ -39,13 +39,11 @@ namespace pslang::interpreter throw std::runtime_error("Cannot zero-initialize a function type"); } - value apply(types::named_type const & named_type) + value apply(types::struct_type const & struct_type) { - auto const & struct_data = context.frame_stack.at(named_type.level).structs.at(named_type.name); - - struct_value result{.struct_type = std::make_unique(named_type)}; - for (auto const & field : struct_data.fields) - result.fields[field.name] = std::make_unique(apply(*field.type)); + struct_value result{.struct_type = std::make_unique(struct_type)}; + for (auto const & field : struct_type.node->fields) + result.fields[field.name] = std::make_unique(apply(*field.inferred_type)); return std::move(result); } }; @@ -68,7 +66,7 @@ namespace pslang::interpreter out << variable.first << " = "; print(out, variable.second.value); out << " ("; - types::print(out, type_of(variable.second.value)); + ast::print(out, type_of(variable.second.value)); out << ")\n"; } } diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index 52d7f12..284ecf2 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include @@ -61,7 +61,7 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot index into an array with an expression of type "; - types::print(os, type_of(index)); + ast::print(os, type_of(index)); throw internal_error(os.str(), location); } @@ -108,7 +108,7 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot apply " << type << " to a value of type "; - types::print(os, type_of(value)); + ast::print(os, type_of(value)); throw internal_error(os.str(), location); } @@ -137,7 +137,7 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot apply " << type << " to a value of type "; - types::print(os, type_of(primitive_value(arg1))); + ast::print(os, type_of(primitive_value(arg1))); throw internal_error(os.str(), location); } @@ -267,9 +267,9 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot apply " << type << " to values of type "; - types::print(os, type_of(primitive_value(arg1))); + ast::print(os, type_of(primitive_value(arg1))); os << " and "; - types::print(os, type_of(primitive_value(arg2))); + ast::print(os, type_of(primitive_value(arg2))); throw internal_error(os.str(), location); } @@ -278,9 +278,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot apply " << type << " to values of type "; - types::print(os, type_of(arg1)); + ast::print(os, type_of(arg1)); os << " and "; - types::print(os, type_of(arg2)); + ast::print(os, type_of(arg2)); throw internal_error(os.str(), location); } @@ -299,7 +299,7 @@ namespace pslang::interpreter { if (!arg1.value) return primitive_value(primitive_value_base{false}); - + value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); @@ -308,7 +308,7 @@ namespace pslang::interpreter { if (arg1.value == T{}) return primitive_value(primitive_value_base{arg1.value}); - + value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); @@ -319,7 +319,7 @@ namespace pslang::interpreter { if (arg1.value) return primitive_value(primitive_value_base{true}); - + value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); @@ -328,7 +328,7 @@ namespace pslang::interpreter { if (arg1.value == ~T{}) return primitive_value(primitive_value_base{arg1.value}); - + value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); @@ -340,9 +340,9 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot apply " << type << " to values of type "; - types::print(os, type_of(primitive_value(arg1))); + ast::print(os, type_of(primitive_value(arg1))); os << " and "; - types::print(os, type_of(lazy_arg2())); + ast::print(os, type_of(lazy_arg2())); throw internal_error(os.str(), location); } @@ -351,9 +351,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot apply " << type << " to values of type "; - types::print(os, type_of(arg1)); + ast::print(os, type_of(arg1)); os << " and "; - types::print(os, type_of(lazy_arg2())); + ast::print(os, type_of(lazy_arg2())); throw internal_error(os.str(), location); } @@ -374,9 +374,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot apply " << binary_operation.type << " to values of type "; - types::print(os, *type1); + ast::print(os, *type1); os << " and "; - types::print(os, *type2); + ast::print(os, *type2); throw internal_error(os.str(), binary_operation.location); } @@ -444,9 +444,9 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot cast value of type "; - types::print(os, type_of(primitive_value(value))); + ast::print(os, type_of(primitive_value(value))); os << " to type "; - types::print(os, types::primitive_type(type)); + ast::print(os, types::type{types::primitive_type(type)}); throw internal_error(os.str(), location); } @@ -528,9 +528,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot call function: argument #" << (i + 1) << " expects type "; - types::print(os, *fcommon->arguments[i].type); + ast::print(os, *fcommon->arguments[i].type); os << " but actual type is "; - types::print(os, actual_type); + ast::print(os, actual_type); throw internal_error(os.str(), ast::get_location(*function_call.arguments[i])); } } @@ -572,22 +572,21 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot create built-in type "; - types::print(os, *type); + ast::print(os, *type); os << ": expected 0 arguments, but got " << function_call.arguments.size(); throw internal_error(os.str(), function_call.location); } - else if (auto named_type = std::get_if(type.get())) + else if (auto struct_type = std::get_if(type.get())) { - auto const & scope = context.frame_stack.at(named_type->level); - auto const & data = scope.structs.at(named_type->name); + auto const & struct_node = *struct_type->node; if (function_call.arguments.empty()) return zero_value(context, *type); - if (data.fields.size() != function_call.arguments.size()) + if (struct_node.fields.size() != function_call.arguments.size()) { std::ostringstream os; - os << "Cannot create struct \"" << named_type->name << "\": expected " << data.fields.size() << " arguments, got " << function_call.arguments.size(); + os << "Cannot create struct \"" << struct_node.name << "\": expected " << struct_node.fields.size() << " arguments, got " << function_call.arguments.size(); throw internal_error(os.str(), function_call.location); } @@ -598,7 +597,7 @@ namespace pslang::interpreter std::unordered_map fields; for (std::size_t i = 0; i < args.size(); ++i) - fields[data.fields[i].name] = std::make_unique(std::move(args[i])); + fields[struct_node.fields[i].name] = std::make_unique(std::move(args[i])); return struct_value{ .struct_type = type, @@ -631,9 +630,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Error forming array: inferred element type is "; - types::print(os, *element_type); + ast::print(os, *element_type); os << " but element #" << i << " type is "; - types::print(os, new_type); + ast::print(os, new_type); throw internal_error(os.str(), array.location); } } @@ -654,7 +653,7 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot index into a non-array of type "; - types::print(os, type_of(array)); + ast::print(os, type_of(array)); throw internal_error(os.str(), array_access.location); } @@ -668,14 +667,14 @@ namespace pslang::interpreter std::ostringstream os; os << "Struct "; - types::print(os, type_of(object)); + ast::print(os, type_of(object)); os << " has no field named \"" << field_access.field_name << "\""; throw internal_error(os.str(), field_access.location); } std::ostringstream os; os << "Value of type "; - types::print(os, type_of(object)); + ast::print(os, type_of(object)); os << " is not a struct"; throw internal_error(os.str(), field_access.location); } @@ -741,7 +740,7 @@ namespace pslang::interpreter std::ostringstream os; os << "Cannot index into a non-array of type "; - types::print(os, type_of(*array_ref)); + ast::print(os, type_of(*array_ref)); throw internal_error(os.str(), array_access.location); } @@ -755,14 +754,14 @@ namespace pslang::interpreter std::ostringstream os; os << "Struct "; - types::print(os, type_of(*object_ref)); + ast::print(os, type_of(*object_ref)); os << " has no field named \"" << field_access.field_name << "\""; throw internal_error(os.str(), field_access.location); } std::ostringstream os; os << "Value of type "; - types::print(os, type_of(*object_ref)); + ast::print(os, type_of(*object_ref)); os << " is not a struct"; throw internal_error(os.str(), field_access.location); } diff --git a/libs/interpreter/source/exec.cpp b/libs/interpreter/source/exec.cpp index 2978ae4..561dae7 100644 --- a/libs/interpreter/source/exec.cpp +++ b/libs/interpreter/source/exec.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include @@ -48,9 +48,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot assign a value of type "; - types::print(os, new_type); + ast::print(os, new_type); os << " to a variable of type "; - types::print(os, existing_type); + ast::print(os, existing_type); throw internal_error(os.str(), assignment.location); } @@ -72,9 +72,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Cannot initialize a variable of type "; - types::print(os, *expected_type); + ast::print(os, *expected_type); os << " with an expression of type "; - types::print(os, actual_type); + ast::print(os, actual_type); throw internal_error(os.str(), variable_declaration.location); } } @@ -110,7 +110,7 @@ namespace pslang::interpreter { std::ostringstream os; os << "Expected type bool, got type "; - types::print(os, actual_type); + ast::print(os, actual_type); os << " in if block condition"; throw internal_error(os.str(), get_location(*block.condition)); } @@ -138,7 +138,7 @@ namespace pslang::interpreter { std::ostringstream os; os << "Expected type bool, got type "; - types::print(os, actual_type); + ast::print(os, actual_type); os << " in while block condition"; throw internal_error(os.str(), get_location(*while_block.condition)); } @@ -212,9 +212,9 @@ namespace pslang::interpreter { std::ostringstream os; os << "Returning value of type "; - types::print(os, actual_type); + ast::print(os, actual_type); os << " from a function returning "; - types::print(os, *frame.expected_return_type); + ast::print(os, *frame.expected_return_type); throw internal_error(os.str(), return_statement.location); } diff --git a/libs/interpreter/source/value.cpp b/libs/interpreter/source/value.cpp index 712f8ee..29a784e 100644 --- a/libs/interpreter/source/value.cpp +++ b/libs/interpreter/source/value.cpp @@ -1,6 +1,5 @@ #include #include -#include #include diff --git a/libs/ir/source/print.cpp b/libs/ir/source/print.cpp index 665febc..b3f025f 100644 --- a/libs/ir/source/print.cpp +++ b/libs/ir/source/print.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include @@ -24,7 +24,7 @@ namespace pslang::ir break; } } - + void print(std::ostream & out, ast::binary_operation_type type) { switch (type) @@ -79,7 +79,7 @@ namespace pslang::ir break; } } - + struct print_literal_visitor { std::ostream & out; @@ -268,7 +268,7 @@ namespace pslang::ir { print_visitor visitor{out}; visitor.fill_index(nodes); - + std::size_t index = 0; for (auto const & node : nodes) { @@ -286,7 +286,7 @@ namespace pslang::ir if (node.inferred_type) { out << " : "; - types::print(out, *node.inferred_type); + ast::print(out, *node.inferred_type); } out << '\n'; } @@ -304,4 +304,4 @@ namespace pslang::ir print_impl(out, context.nodes, &context); } -} \ No newline at end of file +} diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 403bd7f..da5be1a 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -318,24 +318,19 @@ namespace pslang::jit::aarch64 type = field.inferred_type; ++count; } - else if (auto named_type = std::get_if(field.inferred_type.get())) + else if (auto struct_type = std::get_if(field.inferred_type.get())) { // NB: recursion must be impossible due to prior checks in type checker - if (auto struct_node = lcontext.is_struct(named_type->name)) + if (auto subdata = get_hfa_data(*struct_type->node, lcontext)) { - if (auto subdata = get_hfa_data(*struct_node, lcontext)) - { - if (type && !types::equal(*type, *subdata->type)) - return std::nullopt; - - type = subdata->type; - count += subdata->count; - } - else + if (type && !types::equal(*type, *subdata->type)) return std::nullopt; + + type = subdata->type; + count += subdata->count; } else - throw std::runtime_error("Unknown named type: \"" + named_type->name + "\""); + return std::nullopt; } else return std::nullopt; @@ -477,7 +472,7 @@ namespace pslang::jit::aarch64 }; std::vector scopes; - + template void apply(Node const &) { @@ -548,25 +543,20 @@ namespace pslang::jit::aarch64 { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) { - if (auto named_type = std::get_if(node.inferred_type.get())) + if (auto struct_type = std::get_if(node.inferred_type.get())) { - if (auto struct_node = lcontext.is_struct(named_type->name)) + std::size_t stack_size = ((struct_type->node->layout.size + 15) / 16) * 16; + builder.sub_imm(31, 31, stack_size); + stack_offset += stack_size; + scopes.back().stack_offset += stack_size; + std::size_t variable_offset = stack_offset - jt->second.frame_offset; + for (std::size_t offset = 0; offset < stack_size; offset += 16) { - std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16; - builder.sub_imm(31, 31, stack_size); - stack_offset += stack_size; - scopes.back().stack_offset += stack_size; - std::size_t variable_offset = stack_offset - jt->second.frame_offset; - for (std::size_t offset = 0; offset < stack_size; offset += 16) - { - builder.ldr(0, 31, (variable_offset + offset) / 8); - builder.ldr(1, 31, (variable_offset + offset) / 8 + 1); - builder.str(0, 31, offset / 8); - builder.str(1, 31, offset / 8 + 1); - } + builder.ldr(0, 31, (variable_offset + offset) / 8); + builder.ldr(1, 31, (variable_offset + offset) / 8 + 1); + builder.str(0, 31, offset / 8); + builder.str(1, 31, offset / 8 + 1); } - else - throw std::runtime_error("Unknown type \"" + named_type->name + "\""); } else if (types::is_unit_type(*node.inferred_type)) {} @@ -944,119 +934,110 @@ namespace pslang::jit::aarch64 builder.xor_reg(0, 0, 0); builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1); } - else if (auto named_type = std::get_if(node.inferred_type.get())) + else if (auto struct_type = std::get_if(node.inferred_type.get())) { - if (auto struct_node = lcontext.is_struct(named_type->name)) + auto & struct_node = *struct_type->node; + + // Allocate stack space for the struct + std::size_t stack_size = ((struct_node.layout.size + 15) / 16) * 16; + auto offset = stack_offset; + stack_offset += stack_size; + scopes.back().stack_offset += stack_size; + builder.sub_imm(31, 31, stack_size); + + // Evaluate each field of the struct (i.e. each constructor argument) + // and copy it to the corresponding place in the struct + for (std::size_t i = 0; i < node.arguments.size(); ++i) { - // Allocate stack space for the struct - std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16; - auto offset = stack_offset; - stack_offset += stack_size; - scopes.back().stack_offset += stack_size; - builder.sub_imm(31, 31, stack_size); + auto type = ast::get_type(*node.arguments[i]); + apply(*node.arguments[i]); - // Evaluate each field of the struct (i.e. each constructor argument) - // and copy it to the corresponding place in the struct - for (std::size_t i = 0; i < node.arguments.size(); ++i) + if (std::get_if(type.get())) { - auto type = ast::get_type(*node.arguments[i]); - apply(*node.arguments[i]); + // TODO: struct field + throw std::runtime_error("Not implemented"); + } + else if (types::is_floating_point_type(*type)) + { + builder.stur_fp(0, fp_mode_for(*type), 31, struct_node.fields[i].layout.offset); + } + else + { + auto size = types::type_size(*type); - if (std::get_if(type.get())) - { - // TODO: struct field - throw std::runtime_error("Not implemented"); - } - else if (types::is_floating_point_type(*type)) - { - builder.stur_fp(0, fp_mode_for(*type), 31, struct_node->fields[i].layout.offset); - } - else - { - auto size = types::type_size(*type); - - if (size == 1) - builder.sturb(0, 31, struct_node->fields[i].layout.offset); - else if (size == 2) - builder.sturh(0, 31, struct_node->fields[i].layout.offset); - else if (size == 4) - builder.sturw(0, 31, struct_node->fields[i].layout.offset); - else if (size == 8) - builder.stur(0, 31, struct_node->fields[i].layout.offset); - } + if (size == 1) + builder.sturb(0, 31, struct_node.fields[i].layout.offset); + else if (size == 2) + builder.sturh(0, 31, struct_node.fields[i].layout.offset); + else if (size == 4) + builder.sturw(0, 31, struct_node.fields[i].layout.offset); + else if (size == 8) + builder.stur(0, 31, struct_node.fields[i].layout.offset); } } - else - throw std::runtime_error("Unknown type \"" + named_type->name + "\""); } } } void apply(ast::field_access const & node) { - auto struct_type = get_type(*node.object); - if (auto named_type = std::get_if(struct_type.get())) + auto object_type = get_type(*node.object); + if (auto struct_type = std::get_if(object_type.get())) { - if (auto struct_node = lcontext.is_struct(named_type->name)) + auto & struct_node = *struct_type->node; + + std::optional field_id; + for (std::size_t i = 0; i < struct_node.fields.size(); ++i) { - std::size_t field_id = -1; - for (std::size_t i = 0; i < struct_node->fields.size(); ++i) + if (struct_node.fields[i].name == node.field_name) { - if (struct_node->fields[i].name == node.field_name) - { - field_id = i; - break; - } + field_id = i; + break; } - - if (field_id == -1) - throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + named_type->name + "\""); - - apply(*node.object); - - auto stack_size = ((struct_node->layout.size + 15) / 16) * 16; - - auto const & field = struct_node->fields[field_id]; - if (types::is_unit_type(*field.inferred_type)) - {} - else if (types::is_floating_point_type(*field.inferred_type)) - { - builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset); - - builder.add_imm(31, 31, stack_size); - stack_offset -= stack_size; - scopes.back().stack_offset -= stack_size; - } - else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type)) - { - auto size = types::type_size(*field.inferred_type); - if (size == 1) - builder.ldurb(0, 31, field.layout.offset); - else if (size == 2) - builder.ldurh(0, 31, field.layout.offset); - else if (size == 4) - builder.ldurw(0, 31, field.layout.offset); - else if (size == 8) - builder.ldur(0, 31, field.layout.offset); - - builder.add_imm(31, 31, stack_size); - stack_offset -= stack_size; - scopes.back().stack_offset -= stack_size; - } - else if (auto named_type = std::get_if(field.inferred_type.get())) - { - if (auto field_struct_node = lcontext.is_struct(named_type->name)) - { - // TODO: copy the struct-typed field on stack, overriding - // the struct itself, and update the stack offset - throw std::runtime_error("Not implemented"); - } - else - throw std::runtime_error("Unknown type \"" + named_type->name + "\""); - } - - return; } + + if (!field_id) + throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + struct_node.name + "\""); + + apply(*node.object); + + auto stack_size = ((struct_node.layout.size + 15) / 16) * 16; + + auto const & field = struct_node.fields[*field_id]; + if (types::is_unit_type(*field.inferred_type)) + {} + else if (types::is_floating_point_type(*field.inferred_type)) + { + builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset); + + builder.add_imm(31, 31, stack_size); + stack_offset -= stack_size; + scopes.back().stack_offset -= stack_size; + } + else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type)) + { + auto size = types::type_size(*field.inferred_type); + if (size == 1) + builder.ldurb(0, 31, field.layout.offset); + else if (size == 2) + builder.ldurh(0, 31, field.layout.offset); + else if (size == 4) + builder.ldurw(0, 31, field.layout.offset); + else if (size == 8) + builder.ldur(0, 31, field.layout.offset); + + builder.add_imm(31, 31, stack_size); + stack_offset -= stack_size; + scopes.back().stack_offset -= stack_size; + } + else if (auto struct_type = std::get_if(field.inferred_type.get())) + { + // TODO: copy the struct-typed field on stack, overriding + // the struct itself, and update the stack offset + throw std::runtime_error("Not implemented"); + } + + return; } throw std::runtime_error("Unknown object in field access"); @@ -1100,15 +1081,10 @@ namespace pslang::jit::aarch64 else if (size == 8) builder.stur(0, 31, stack_offset - frame_offset); } - else if (auto named_type = std::get_if(type.get())) + else if (auto struct_type = std::get_if(type.get())) { - if (auto struct_node = lcontext.is_struct(named_type->name)) - { - // TODO: whole-struct assignment - throw std::runtime_error("Not implemented"); - } - else - throw std::runtime_error("Unknown type \"" + named_type->name + "\""); + // TODO: whole-struct assignment + throw std::runtime_error("Not implemented"); } } @@ -1116,7 +1092,7 @@ namespace pslang::jit::aarch64 { apply(*node.initializer); auto type = ast::get_type(*node.initializer); - if (std::get_if(type.get())) + if (std::get_if(type.get())) { // Nothing to be done: the struct is already on the stack // Just record the stack offset as variable location @@ -1180,7 +1156,7 @@ namespace pslang::jit::aarch64 apply(*node.condition); std::int32_t skip = pcontext.code.size(); builder.cbz(0, 0); - + scopes.emplace_back(); apply(*node.statements); scope_cleanup(); @@ -1344,25 +1320,21 @@ namespace pslang::jit::aarch64 { auto base_offset = lvalue_offset(field_access->object); auto type = ast::get_type(*field_access->object); - if (auto named_type = std::get_if(type.get())) + if (auto struct_type = std::get_if(type.get())) { - if (auto struct_node = lcontext.is_struct(named_type->name)) - { - std::size_t field_id = -1; - for (std::size_t i = 0; i < struct_node->fields.size(); ++i) - if (struct_node->fields[i].name == field_access->field_name) - { - field_id = i; - break; - } + auto & struct_node = *struct_type->node; + std::optional field_id; + for (std::size_t i = 0; i < struct_node.fields.size(); ++i) + if (struct_node.fields[i].name == field_access->field_name) + { + field_id = i; + break; + } - if (field_id == -1) - throw std::runtime_error("Invalid field \"" + field_access->field_name + "\""); + if (!field_id) + throw std::runtime_error("Invalid field \"" + field_access->field_name + "\""); - return base_offset - struct_node->fields[field_id].layout.offset; - } - else - throw std::runtime_error("Invalid struct \"" + named_type->name + "\""); + return base_offset - struct_node.fields[*field_id].layout.offset; } else throw std::runtime_error("Invalid field access node"); @@ -1472,4 +1444,4 @@ namespace pslang::jit::aarch64 pcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); } -} \ No newline at end of file +} diff --git a/libs/types/include/pslang/types/named.hpp b/libs/types/include/pslang/types/named.hpp deleted file mode 100644 index 1e825f3..0000000 --- a/libs/types/include/pslang/types/named.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -namespace pslang::types -{ - - struct named_type - { - std::string name; - std::size_t level; - }; - - inline bool operator == (named_type const & t1, named_type const & t2) - { - return (t1.level == t2.level) && (t1.name == t2.name); - } - -} diff --git a/libs/types/include/pslang/types/print.hpp b/libs/types/include/pslang/types/print.hpp deleted file mode 100644 index 7e3c3f1..0000000 --- a/libs/types/include/pslang/types/print.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -#include - -namespace pslang::types -{ - - void print(std::ostream & out, type const & type); - -} diff --git a/libs/types/include/pslang/types/struct.hpp b/libs/types/include/pslang/types/struct.hpp new file mode 100644 index 0000000..9c38394 --- /dev/null +++ b/libs/types/include/pslang/types/struct.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +namespace pslang::ast +{ + + struct struct_definition; + +} + +namespace pslang::types +{ + + struct struct_type + { + ast::struct_definition * node; + + friend bool operator == (struct_type const &, struct_type const &) = default; + }; + +} diff --git a/libs/types/include/pslang/types/type.hpp b/libs/types/include/pslang/types/type.hpp index 767a477..02b9c0f 100644 --- a/libs/types/include/pslang/types/type.hpp +++ b/libs/types/include/pslang/types/type.hpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include @@ -17,7 +17,7 @@ namespace pslang::types primitive_type, array_type, function_type, - named_type + struct_type >; struct type diff --git a/libs/types/source/print.cpp b/libs/types/source/print.cpp deleted file mode 100644 index 774e997..0000000 --- a/libs/types/source/print.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include -#include - -namespace pslang::types -{ - - namespace - { - - struct print_visitor - : const_visitor - { - std::ostream & out; - - using const_visitor::apply; - - void apply(unit_type const & type) - { - out << "unit"; - } - - void apply(bool_type const &) - { - out << "bool"; - } - - void apply(i8_type const &) - { - out << "i8"; - } - - void apply(u8_type const &) - { - out << "u8"; - } - - void apply(i16_type const &) - { - out << "i16"; - } - - void apply(u16_type const &) - { - out << "u16"; - } - - void apply(i32_type const &) - { - out << "i32"; - } - - void apply(u32_type const &) - { - out << "u32"; - } - - void apply(i64_type const &) - { - out << "i64"; - } - - void apply(u64_type const &) - { - out << "u64"; - } - - void apply(f16_type const &) - { - out << "f16"; - } - - void apply(f32_type const &) - { - out << "f32"; - } - - void apply(f64_type const &) - { - out << "f64"; - } - - void apply(array_type const & type) - { - apply(*type.element_type); - out << "[" << type.size << "]"; - } - - void apply(function_type const & type) - { - if (type.arguments.size() == 1 && !is_function_type(*type.arguments[0])) - { - apply(*type.arguments.front()); - out << " -> "; - apply(*type.result); - return; - } - - out << '('; - bool first = true; - for (auto const & argument : type.arguments) - { - if (!first) out << ", "; - first = false; - apply(*argument); - } - out << ") -> "; - apply(*type.result); - } - - void apply(named_type const & type) - { - out << type.name; - } - }; - - } - - void print(std::ostream & out, type const & type) - { - print_visitor{{}, out}.apply(type); - } - -}