pslang/libs/ir/source/print.cpp

307 lines
6.5 KiB
C++

#include <pslang/ir/print.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ir/compiler.hpp>
#include <pslang/ast/print.hpp>
#include <unordered_map>
#include <iomanip>
namespace pslang::ir
{
namespace
{
void print(std::ostream & out, ast::unary_operation_type type)
{
switch (type)
{
case ast::unary_operation_type::negation:
out << "neg";
break;
case ast::unary_operation_type::logical_not:
out << "not";
break;
}
}
void print(std::ostream & out, ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::addition:
out << "add";
break;
case ast::binary_operation_type::subtraction:
out << "sub";
break;
case ast::binary_operation_type::multiplication:
out << "mul";
break;
case ast::binary_operation_type::division:
out << "div";
break;
case ast::binary_operation_type::remainder:
out << "rem";
break;
case ast::binary_operation_type::binary_and:
out << "and";
break;
case ast::binary_operation_type::logical_and:
out << "and sc";
break;
case ast::binary_operation_type::binary_or:
out << "or";
break;
case ast::binary_operation_type::logical_or:
out << "or sc";
break;
case ast::binary_operation_type::logical_xor:
out << "xor";
break;
case ast::binary_operation_type::equals:
out << "eq";
break;
case ast::binary_operation_type::not_equals:
out << "neq";
break;
case ast::binary_operation_type::less:
out << "lt";
break;
case ast::binary_operation_type::greater:
out << "gt";
break;
case ast::binary_operation_type::less_equals:
out << "leq";
break;
case ast::binary_operation_type::greater_equals:
out << "geq";
break;
}
}
struct print_literal_visitor
{
std::ostream & out;
void operator()(ast::bool_literal const & literal)
{
out << (literal.value ? "true" : "false");
}
void operator()(ast::i8_literal const & literal)
{
out << int(literal.value);
}
void operator()(ast::i16_literal const & literal)
{
out << literal.value;
}
void operator()(ast::i32_literal const & literal)
{
out << literal.value;
}
void operator()(ast::i64_literal const & literal)
{
out << literal.value;
}
void operator()(ast::u8_literal const & literal)
{
out << int(literal.value);
}
void operator()(ast::u16_literal const & literal)
{
out << literal.value;
}
void operator()(ast::u32_literal const & literal)
{
out << literal.value;
}
void operator()(ast::u64_literal const & literal)
{
out << literal.value;
}
void operator()(ast::f16_literal const & literal)
{
out << literal.value.repr;
}
void operator()(ast::f32_literal const & literal)
{
out << literal.value;
}
void operator()(ast::f64_literal const & literal)
{
out << literal.value;
}
};
struct print_visitor
{
std::ostream & out;
std::unordered_map<node const *, std::size_t> node_index;
std::size_t current_index = 0;
std::size_t indent = 0;
void fill_index(node_list const & nodes)
{
std::size_t index = 0;
for (auto const & node : nodes)
node_index[&node] = index++;
if (!node_index.empty())
{
std::size_t max_index = node_index.size() - 1;
while (max_index > 0)
{
++indent;
max_index /= 10;
}
}
}
std::size_t get_index(node_ref ref)
{
return node_index.at(&(*ref));
}
void prelude()
{
out << std::right << std::setfill(' ') << std::setw(indent);
}
void operator()(nop const &)
{
out << "nop";
}
void operator()(literal const & instruction)
{
out << "lit ";
std::visit(print_literal_visitor{out}, instruction.value);
}
void operator()(copy const & instruction)
{
out << "copy $" << get_index(instruction.source);
}
void operator()(unary_operation const & instruction)
{
print(out, instruction.type);
out << " $" << get_index(instruction.arg1);
}
void operator()(binary_operation const & instruction)
{
print(out, instruction.type);
out << " $" << get_index(instruction.arg1) << " $" << get_index(instruction.arg2);
}
void operator()(cast_operation const & instruction)
{
out << "cast $" << get_index(instruction.arg1);
}
void operator()(argument const & instruction)
{
out << "arg #" << instruction.index;
}
void operator()(address const & instruction)
{
out << "addr $" << get_index(instruction.target);
}
void operator()(extern_symbol const & instruction)
{
out << "extern \"" << instruction.name << "\"";
}
void operator()(assignment const & instruction)
{
out << "assign $" << get_index(instruction.lhs) << " $" << get_index(instruction.rhs);
}
void operator()(jump const & instruction)
{
out << "jump $" << get_index(instruction.target);
}
void operator()(jump_if_zero const & instruction)
{
out << "jz $" << get_index(instruction.condition) << " $" << get_index(instruction.target);
}
void operator()(call const & instruction)
{
out << "call $" << get_index(instruction.target);
for (auto const & arument : instruction.arguments)
out << " $" << get_index(arument);
}
void operator()(call_pointer const & instruction)
{
out << "callp $" << get_index(instruction.pointer);
for (auto const & arument : instruction.arguments)
out << " $" << get_index(arument);
}
void operator()(return_value const & instruction)
{
out << "ret ";
if (instruction.value)
out << "$" << get_index(*instruction.value);
}
};
void print_impl(std::ostream & out, node_list const & nodes, module_context const * context)
{
print_visitor visitor{out};
visitor.fill_index(nodes);
std::size_t index = 0;
for (auto const & node : nodes)
{
if (context && context->labels.contains(&node))
{
if (index > 0) out << '\n';
out << context->labels.at(&node) << ":\n";
}
if (context)
out << " ";
visitor.prelude();
out << index << ": ";
++index;
std::visit(visitor, node.instruction);
if (node.inferred_type)
{
out << " : ";
ast::print(out, *node.inferred_type);
}
out << '\n';
}
}
}
void print(std::ostream & out, node_list const & nodes)
{
print_impl(out, nodes, nullptr);
}
void print(std::ostream & out, module_context const & context)
{
print_impl(out, context.nodes, &context);
}
}