715 lines
24 KiB
C++
715 lines
24 KiB
C++
#include <pslang/ir/compiler.hpp>
|
|
#include <pslang/ir/node.hpp>
|
|
#include <pslang/ast/statement_visitor.hpp>
|
|
#include <pslang/ast/expression_visitor.hpp>
|
|
#include <pslang/types/type_visitor.hpp>
|
|
|
|
#include <unordered_map>
|
|
|
|
namespace pslang::ir
|
|
{
|
|
|
|
namespace
|
|
{
|
|
|
|
struct local_context
|
|
{
|
|
std::unordered_map<ast::function_definition const *, std::pair<node_ref, node_ref>> functions;
|
|
std::unordered_map<ast::variable_base const *, node_ref> variables;
|
|
|
|
struct scope
|
|
{
|
|
std::string label_prefix;
|
|
};
|
|
|
|
std::vector<scope> scopes;
|
|
|
|
struct resolve_address_data
|
|
{
|
|
node_ref address;
|
|
ast::function_definition const * target;
|
|
};
|
|
|
|
struct resolve_call_data
|
|
{
|
|
node_ref call;
|
|
ast::function_definition const * target;
|
|
};
|
|
|
|
std::vector<resolve_address_data> resolve_address;
|
|
std::vector<resolve_call_data> resolve_call;
|
|
};
|
|
|
|
struct zero_literal_visitor
|
|
: types::const_visitor<zero_literal_visitor>
|
|
{
|
|
module_context & mcontext;
|
|
|
|
using const_visitor::apply;
|
|
|
|
template <typename T>
|
|
void apply(types::primitive_type_base<T>)
|
|
{
|
|
mcontext.nodes->emplace_back(literal{ast::literal{ast::primitive_literal_base<T>{.value = {}}}},
|
|
std::make_shared<types::type>(types::primitive_type{types::primitive_type_base<T>{}}));
|
|
}
|
|
|
|
template <typename T>
|
|
void apply(T const &)
|
|
{
|
|
throw std::runtime_error("Invalid type for zero_literal_visitor");
|
|
}
|
|
};
|
|
|
|
// Compile a single function and store the entry point node_ref
|
|
// in local_context
|
|
struct compile_function_visitor
|
|
: ast::const_statement_visitor<compile_function_visitor>
|
|
, ast::const_expression_visitor<compile_function_visitor>
|
|
{
|
|
using const_statement_visitor::apply;
|
|
using const_expression_visitor::apply;
|
|
|
|
module_context & mcontext;
|
|
local_context & lcontext;
|
|
|
|
template <typename Node>
|
|
node_ref apply(Node const & node)
|
|
{
|
|
throw std::runtime_error(std::string("IR compile visitor not implemented for ") + typeid(Node).name());
|
|
}
|
|
|
|
// Expressions
|
|
|
|
template <typename T>
|
|
node_ref apply(ast::primitive_literal_base<T> const & node)
|
|
{
|
|
mcontext.nodes->emplace_back(literal{node}, ast::get_type(node));
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::identifier const & node)
|
|
{
|
|
if (node.variable_node)
|
|
{
|
|
return lcontext.variables.at(node.variable_node);
|
|
}
|
|
else if (node.function_node)
|
|
{
|
|
mcontext.nodes->emplace_back(instruction_address{}, node.inferred_type);
|
|
lcontext.resolve_address.emplace_back(last(), node.function_node);
|
|
return last();
|
|
}
|
|
else if (node.foreign_function_node)
|
|
{
|
|
mcontext.nodes->emplace_back(extern_symbol{node.name}, node.inferred_type);
|
|
return last();
|
|
}
|
|
else
|
|
throw std::runtime_error("Unknown identifier \"" + node.name + "\"");
|
|
}
|
|
|
|
node_ref apply(ast::unary_operation const & node)
|
|
{
|
|
auto arg1 = apply(*node.arg1);
|
|
|
|
if (node.type == ast::unary_operation_type::dereference)
|
|
{
|
|
mcontext.nodes->emplace_back(load{arg1}, node.inferred_type);
|
|
return last();
|
|
}
|
|
|
|
mcontext.nodes->emplace_back(unary_operation{node.type, arg1}, node.inferred_type);
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::binary_operation const & node)
|
|
{
|
|
auto arg1 = apply(*node.arg1);
|
|
|
|
// Short-circuit operators
|
|
|
|
if (node.type == ast::binary_operation_type::logical_and)
|
|
{
|
|
mcontext.nodes->emplace_back(jump_if_zero{arg1});
|
|
auto jump = last();
|
|
auto arg2 = apply(*node.arg2);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_and, arg1, arg2}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(assignment{arg1, last()});
|
|
mcontext.nodes->emplace_back(label{});
|
|
std::get<jump_if_zero>(jump->instruction).target = last();
|
|
return arg1;
|
|
}
|
|
|
|
if (node.type == ast::binary_operation_type::logical_or)
|
|
{
|
|
mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::logical_not, arg1}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_zero{last()});
|
|
auto jump = last();
|
|
auto arg2 = apply(*node.arg2);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_or, arg1, arg2}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(assignment{arg1, last()});
|
|
mcontext.nodes->emplace_back(label{});
|
|
std::get<jump_if_zero>(jump->instruction).target = last();
|
|
return arg1;
|
|
}
|
|
|
|
// Pointer arithmetic
|
|
|
|
auto arg2 = apply(*node.arg2);
|
|
|
|
auto arg1_type = get_type(*node.arg1);
|
|
auto arg2_type = get_type(*node.arg2);
|
|
|
|
auto arg1_is_pointer = types::is_pointer_type(*arg1_type);
|
|
auto arg2_is_pointer = types::is_pointer_type(*arg2_type);
|
|
|
|
if ((node.type == ast::binary_operation_type::addition || node.type == ast::binary_operation_type::subtraction)
|
|
&& (arg1_is_pointer || arg2_is_pointer))
|
|
{
|
|
// Pointer types are equal and referenced types are non-empty - guaranteed by type checker
|
|
std::int64_t element_size = 0;
|
|
if (arg1_is_pointer)
|
|
element_size = ast::type_size(*std::get<types::pointer_type>(*arg1_type).referenced_type);
|
|
else
|
|
element_size = ast::type_size(*std::get<types::pointer_type>(*arg2_type).referenced_type);
|
|
|
|
auto i64_type = std::make_shared<types::type>(types::primitive_type{types::i64_type{}});
|
|
|
|
mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type);
|
|
auto element_size_node = last();
|
|
|
|
if (node.type == ast::binary_operation_type::addition)
|
|
{
|
|
if (arg1_is_pointer)
|
|
{
|
|
mcontext.nodes->emplace_back(cast_operation{arg2, i64_type}, i64_type);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg1, last()}, node.inferred_type);
|
|
}
|
|
else // if (arg2_is_pointer)
|
|
{
|
|
mcontext.nodes->emplace_back(cast_operation{arg1, i64_type}, i64_type);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg2, last()}, node.inferred_type);
|
|
}
|
|
}
|
|
else if (node.type == ast::binary_operation_type::subtraction)
|
|
{
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::subtraction, arg1, arg2}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::division, last(), element_size_node}, i64_type);
|
|
}
|
|
|
|
return last();
|
|
}
|
|
|
|
// Different-type integer comparison
|
|
|
|
if (!types::equal(*arg1_type, *arg2_type)
|
|
&& types::is_integer_type(*arg1_type)
|
|
&& types::is_integer_type(*arg2_type)
|
|
&& ast::is_comparison(node.type))
|
|
{
|
|
bool arg1_unsigned = types::is_unsigned_integer_type(*arg1_type);
|
|
bool arg2_unsigned = types::is_unsigned_integer_type(*arg2_type);
|
|
std::size_t arg1_size = types::type_size(*arg1_type);
|
|
std::size_t arg2_size = types::type_size(*arg2_type);
|
|
std::size_t max_size = std::max(arg1_size, arg2_size);
|
|
|
|
types::type_ptr max_type = (arg1_size > arg2_size) ? arg1_type : arg2_type;
|
|
|
|
if ((arg1_unsigned && arg2_unsigned) || (!arg1_unsigned && !arg2_unsigned))
|
|
{
|
|
// Both signed or both unsigned: just cast the smaller one to the larger type
|
|
if (arg1_size < arg2_size)
|
|
{
|
|
mcontext.nodes->emplace_back(cast_operation{arg1, max_type}, max_type);
|
|
mcontext.nodes->emplace_back(binary_operation{node.type, last(), arg2}, node.inferred_type);
|
|
return last();
|
|
}
|
|
else
|
|
{
|
|
mcontext.nodes->emplace_back(cast_operation{arg2, max_type}, max_type);
|
|
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, last()}, node.inferred_type);
|
|
return last();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
// Different signedness
|
|
|
|
// Swap arg1 and arg2 if arg1 is unsigned, and reverse the operation
|
|
auto type = node.type;
|
|
if (arg1_unsigned)
|
|
{
|
|
std::swap(arg1, arg2);
|
|
std::swap(arg1_type, arg2_type);
|
|
std::swap(arg1_size, arg2_size);
|
|
std::swap(arg1_is_pointer, arg2_is_pointer);
|
|
|
|
if (type == ast::binary_operation_type::less)
|
|
type = ast::binary_operation_type::greater;
|
|
else if (type == ast::binary_operation_type::greater)
|
|
type = ast::binary_operation_type::less;
|
|
else if (type == ast::binary_operation_type::less_equals)
|
|
type = ast::binary_operation_type::greater_equals;
|
|
else if (type == ast::binary_operation_type::greater_equals)
|
|
type = ast::binary_operation_type::less_equals;
|
|
}
|
|
|
|
// Compare with zero first
|
|
zero_literal_visitor{{}, mcontext}.apply(*arg1_type);
|
|
switch (type)
|
|
{
|
|
case ast::binary_operation_type::equals:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
|
break;
|
|
case ast::binary_operation_type::not_equals:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
|
break;
|
|
case ast::binary_operation_type::less:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
|
break;
|
|
case ast::binary_operation_type::greater:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
|
break;
|
|
case ast::binary_operation_type::less_equals:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less_equals, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
|
break;
|
|
case ast::binary_operation_type::greater_equals:
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
auto result = std::prev(last());
|
|
auto jump_node = last();
|
|
|
|
// Less-than-zero case handled, here arg1 is nonnegative - cast both to the largest unsigned type
|
|
types::type_ptr max_unsigned_type;
|
|
if (max_size == 1)
|
|
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u8_type{}});
|
|
else if (max_size == 2)
|
|
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u16_type{}});
|
|
else if (max_size == 4)
|
|
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u32_type{}});
|
|
else if (max_size == 8)
|
|
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u64_type{}});
|
|
|
|
mcontext.nodes->emplace_back(cast_operation{arg1, max_unsigned_type}, max_unsigned_type);
|
|
auto new_arg1 = last();
|
|
mcontext.nodes->emplace_back(cast_operation{arg2, max_unsigned_type}, max_unsigned_type);
|
|
auto new_arg2 = last();
|
|
mcontext.nodes->emplace_back(binary_operation{type, new_arg1, new_arg2}, node.inferred_type);
|
|
mcontext.nodes->emplace_back(assignment{result, last()});
|
|
mcontext.nodes->emplace_back(label{});
|
|
if (auto jump_if_zero = std::get_if<ir::jump_if_zero>(&jump_node->instruction))
|
|
jump_if_zero->target = last();
|
|
else if (auto jump_if_nonzero = std::get_if<ir::jump_if_nonzero>(&jump_node->instruction))
|
|
jump_if_nonzero->target = last();
|
|
return result;
|
|
}
|
|
}
|
|
|
|
// General case
|
|
|
|
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type);
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::cast_operation const & node)
|
|
{
|
|
auto arg = apply(*node.expression);
|
|
mcontext.nodes->emplace_back(cast_operation{arg, node.inferred_type}, node.inferred_type);
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::function_call const & node)
|
|
{
|
|
if (node.function)
|
|
{
|
|
std::vector<node_ref> arguments;
|
|
|
|
for (auto const & argument : node.arguments)
|
|
arguments.push_back(apply(*argument));
|
|
|
|
if (auto identifier = std::get_if<ast::identifier>(node.function.get()); identifier && identifier->function_node)
|
|
{
|
|
mcontext.nodes->emplace_back(call{{}, std::move(arguments)}, node.inferred_type);
|
|
lcontext.resolve_call.emplace_back(last(), identifier->function_node);
|
|
return last();
|
|
}
|
|
|
|
auto function = apply(*node.function);
|
|
mcontext.nodes->emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type);
|
|
return last();
|
|
}
|
|
else // if (node.type)
|
|
{
|
|
auto type = ast::get_type(*node.type);
|
|
if (auto struct_type = std::get_if<types::struct_type>(type.get()))
|
|
{
|
|
mcontext.nodes->emplace_back(alloc{}, node.inferred_type);
|
|
auto result = last();
|
|
for (std::size_t i = 0; i < node.arguments.size(); ++i)
|
|
{
|
|
auto const & field = struct_type->node->fields[i];
|
|
auto arg = apply(*node.arguments[i]);
|
|
mcontext.nodes->emplace_back(assignment{result, arg, {i}}, field.inferred_type);
|
|
}
|
|
return result;
|
|
}
|
|
throw std::runtime_error("Type constructors are not implemented for non-struct types");
|
|
}
|
|
}
|
|
|
|
node_ref apply(ast::array_access const & node)
|
|
{
|
|
auto array_type = ast::get_type(*node.array);
|
|
if (types::is_pointer_type(*array_type))
|
|
{
|
|
auto new_ptr = apply(ast::binary_operation{ast::binary_operation_type::addition, node.array, node.index, {}, array_type});
|
|
mcontext.nodes->emplace_back(load{new_ptr}, node.inferred_type);
|
|
return last();
|
|
}
|
|
throw std::runtime_error("Unknown array access left-hand side");
|
|
}
|
|
|
|
node_ref apply(ast::field_access const & node)
|
|
{
|
|
auto object = apply(*node.object);
|
|
auto object_type = ast::get_type(*node.object);
|
|
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
|
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
|
{
|
|
if (struct_node->fields[i].name == node.field_name)
|
|
{
|
|
mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type);
|
|
return last();
|
|
}
|
|
}
|
|
throw std::runtime_error("Unknown field name");
|
|
}
|
|
|
|
// TODO: array, array_access, field_access
|
|
|
|
// Statements
|
|
|
|
node_ref apply(ast::expression_ptr const & node)
|
|
{
|
|
return apply(*node);
|
|
}
|
|
|
|
std::optional<node_ref> apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector<std::size_t> path)
|
|
{
|
|
if (auto identifier = std::get_if<ast::identifier>(lhs_node.get()))
|
|
{
|
|
auto lhs = apply(*lhs_node);
|
|
mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type);
|
|
return last();
|
|
}
|
|
else if (auto field_access = std::get_if<ast::field_access>(lhs_node.get()))
|
|
{
|
|
auto object_type = ast::get_type(*field_access->object);
|
|
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
|
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
|
{
|
|
auto const & field = struct_node->fields[i];
|
|
if (field.name == field_access->field_name)
|
|
{
|
|
path.push_back(i);
|
|
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
|
|
}
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<node_ref> apply_get_address(ast::expression_ptr const & node)
|
|
{
|
|
auto result_type = std::make_shared<types::type>(types::pointer_type{ast::get_type(*node), true});
|
|
|
|
if (auto identifier = std::get_if<ast::identifier>(node.get()))
|
|
{
|
|
auto object = apply(*node);
|
|
mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type);
|
|
return last();
|
|
}
|
|
|
|
if (auto unary_operation = std::get_if<ast::unary_operation>(node.get()))
|
|
{
|
|
if (unary_operation->type == ast::unary_operation_type::dereference)
|
|
{
|
|
return apply(*unary_operation->arg1);
|
|
}
|
|
}
|
|
|
|
if (auto array_access = std::get_if<ast::array_access>(node.get()))
|
|
{
|
|
auto array_type = get_type(*array_access->array);
|
|
if (types::is_pointer_type(*array_type))
|
|
{
|
|
return apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
|
|
}
|
|
}
|
|
|
|
if (auto field_access = std::get_if<ast::field_access>(node.get()))
|
|
{
|
|
auto object_type = ast::get_type(*field_access->object);
|
|
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
|
if (auto object_ptr = apply_get_address(field_access->object))
|
|
{
|
|
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
|
{
|
|
auto const & field = struct_node->fields[i];
|
|
if (field.name == field_access->field_name)
|
|
{
|
|
mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}},
|
|
std::make_shared<types::type>(types::primitive_type{types::u64_type{}}));
|
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()},
|
|
result_type);
|
|
return last();
|
|
}
|
|
}
|
|
throw std::runtime_error("Unknown field name");
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
node_ref apply(ast::assignment const & node)
|
|
{
|
|
auto rhs = apply(*node.rhs);
|
|
|
|
// Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes
|
|
// terminating in an identifier node.
|
|
if (auto result = apply_field_chain_assignment(node.lhs, rhs, {}))
|
|
return *result;
|
|
|
|
// Otherwise, compile into explicit memory store
|
|
if (auto lhs_ptr = apply_get_address(node.lhs))
|
|
{
|
|
mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs));
|
|
return last();
|
|
}
|
|
|
|
throw std::runtime_error("Unknown assignment left-hand side");
|
|
}
|
|
|
|
node_ref apply(ast::variable_declaration const & node)
|
|
{
|
|
auto before = last();
|
|
auto result = apply(*node.initializer);
|
|
if (result == before)
|
|
{
|
|
// Evaluating variable initializer didn't produce any nodes
|
|
// It must have been just a reference to another variable or smth like that
|
|
// Introduce a copy node to prevent accidental variable coalescing
|
|
mcontext.nodes->emplace_back(copy{result});
|
|
result = last();
|
|
}
|
|
lcontext.variables[&node] = result;
|
|
return result;
|
|
}
|
|
|
|
node_ref apply(ast::if_chain const & node)
|
|
{
|
|
std::vector<node_ref> jumps_to_end;
|
|
for (auto const & block : node.blocks)
|
|
{
|
|
std::optional<node_ref> jump_to_next;
|
|
|
|
if (block.condition)
|
|
{
|
|
auto condition = apply(*block.condition);
|
|
mcontext.nodes->emplace_back(jump_if_zero{condition, {}});
|
|
jump_to_next = last();
|
|
}
|
|
|
|
lcontext.scopes.emplace_back();
|
|
apply(*block.statements);
|
|
lcontext.scopes.pop_back();
|
|
|
|
mcontext.nodes->emplace_back(jump{});
|
|
jumps_to_end.push_back(last());
|
|
|
|
mcontext.nodes->emplace_back(label{});
|
|
if (jump_to_next)
|
|
std::get<jump_if_zero>((*jump_to_next)->instruction).target = last();
|
|
}
|
|
|
|
auto end = last();
|
|
for (auto const & jump_to_end : jumps_to_end)
|
|
std::get<jump>(jump_to_end->instruction).target = end;
|
|
|
|
return end;
|
|
}
|
|
|
|
node_ref apply(ast::while_block const & node)
|
|
{
|
|
mcontext.nodes->emplace_back(label{});
|
|
auto begin = last();
|
|
auto condition = apply(*node.condition);
|
|
mcontext.nodes->emplace_back(jump_if_zero{condition, {}});
|
|
auto jump1 = last();
|
|
|
|
lcontext.scopes.emplace_back();
|
|
apply(*node.statements);
|
|
lcontext.scopes.pop_back();
|
|
|
|
mcontext.nodes->emplace_back(jump{begin});
|
|
mcontext.nodes->emplace_back(label{});
|
|
std::get<jump_if_zero>(jump1->instruction).target = last();
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::function_definition const & node)
|
|
{
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::foreign_function_declaration const & node)
|
|
{
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::return_statement const & node)
|
|
{
|
|
if (node.value)
|
|
{
|
|
auto value = apply(*node.value);
|
|
mcontext.nodes->emplace_back(return_value{value}, value->inferred_type);
|
|
}
|
|
else
|
|
mcontext.nodes->emplace_back(return_value{}, std::make_shared<types::type>(types::unit_type{}));
|
|
return last();
|
|
}
|
|
|
|
node_ref apply(ast::struct_definition const & node)
|
|
{
|
|
return last();
|
|
}
|
|
|
|
// Statement list
|
|
|
|
void do_apply(ast::function_definition const & node)
|
|
{
|
|
mcontext.nodes->emplace_back(label{});
|
|
auto begin = last();
|
|
|
|
lcontext.scopes.emplace_back();
|
|
for (std::size_t i = 0; i < node.arguments.size(); ++i)
|
|
{
|
|
mcontext.nodes->emplace_back(argument{i}, ast::get_type(*node.arguments[i].type));
|
|
lcontext.variables[&node.arguments[i]] = last();
|
|
}
|
|
apply(*node.statements);
|
|
|
|
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
|
|
mcontext.nodes->emplace_back(return_value{}, std::make_shared<types::type>(types::unit_type{}));
|
|
|
|
auto end = last();
|
|
|
|
lcontext.functions[&node] = {begin, end};
|
|
mcontext.labels[&(*begin)] = lcontext.scopes.back().label_prefix + node.name;
|
|
|
|
lcontext.scopes.pop_back();
|
|
}
|
|
|
|
private:
|
|
node_ref last()
|
|
{
|
|
return std::prev(mcontext.nodes->end());
|
|
}
|
|
};
|
|
|
|
// Main compilation visitor
|
|
struct compile_visitor
|
|
: ast::const_statement_visitor<compile_visitor>
|
|
{
|
|
module_context & mcontext;
|
|
local_context & lcontext;
|
|
|
|
using const_statement_visitor::apply;
|
|
|
|
void apply(ast::expression_ptr const &) {}
|
|
|
|
void apply(ast::assignment const &) {}
|
|
|
|
void apply(ast::variable_declaration const &) {}
|
|
|
|
void apply(ast::if_chain const & node)
|
|
{
|
|
for (auto const & block : node.blocks)
|
|
{
|
|
lcontext.scopes.emplace_back();
|
|
apply(*block.statements);
|
|
lcontext.scopes.pop_back();
|
|
}
|
|
}
|
|
|
|
void apply(ast::while_block const & node)
|
|
{
|
|
lcontext.scopes.emplace_back();
|
|
apply(*node.statements);
|
|
lcontext.scopes.pop_back();
|
|
}
|
|
|
|
void apply(ast::function_definition const & node)
|
|
{
|
|
compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node);
|
|
std::string label_prefix;
|
|
if (!lcontext.scopes.empty())
|
|
label_prefix = lcontext.scopes.back().label_prefix + node.name + ".";
|
|
lcontext.scopes.emplace_back(std::move(label_prefix));
|
|
apply(*node.statements);
|
|
|
|
// Don't pop_back entry point scope
|
|
if (lcontext.scopes.size() > 1)
|
|
lcontext.scopes.pop_back();
|
|
}
|
|
|
|
void apply(ast::foreign_function_declaration const &) {}
|
|
|
|
void apply(ast::return_statement const &) {}
|
|
|
|
void apply(ast::struct_definition const &) {}
|
|
};
|
|
|
|
}
|
|
|
|
void compile(module_context & mcontext, ast::statement_ptr const & root)
|
|
{
|
|
if (!mcontext.nodes)
|
|
mcontext.nodes = std::make_shared<node_list>();
|
|
|
|
local_context lcontext;
|
|
|
|
mcontext.nodes->emplace_back(label{});
|
|
auto extra_label = std::prev(mcontext.nodes->end());
|
|
compile_visitor{{}, mcontext, lcontext}.apply(*root);
|
|
mcontext.nodes->erase(extra_label);
|
|
|
|
for (auto & symbol : lcontext.functions)
|
|
symbol.second.second++;
|
|
|
|
for (auto const & resolve : lcontext.resolve_address)
|
|
std::get<instruction_address>(resolve.address->instruction).target = lcontext.functions.at(resolve.target).first;
|
|
|
|
for (auto const & resolve : lcontext.resolve_call)
|
|
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target).first;
|
|
|
|
for (auto const & symbol : lcontext.functions)
|
|
mcontext.symbols[symbol.first] = {.begin = symbol.second.first, .end = symbol.second.second};
|
|
|
|
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get())).first;
|
|
}
|
|
|
|
}
|