pslang/libs/interpreter/source/eval.cpp

332 lines
10 KiB
C++

#include <pslang/interpreter/eval.hpp>
#include <pslang/interpreter/interpreter.hpp>
#include <pslang/interpreter/value.hpp>
#include <pslang/ast/expression.hpp>
#include <pslang/type/print.hpp>
#include <sstream>
namespace pslang::interpreter
{
namespace
{
void print(std::ostream & out, ast::unary_operation_type type)
{
switch (type)
{
case ast::unary_operation_type::negation:
out << "-";
return;
case ast::unary_operation_type::logical_not:
out << "!";
return;
}
out << "(unknown)";
}
void print(std::ostream & out, ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::addition:
out << "+";
return;
case ast::binary_operation_type::subtraction:
out << "-";
return;
case ast::binary_operation_type::multiplication:
out << "*";
return;
case ast::binary_operation_type::division:
out << "/";
return;
case ast::binary_operation_type::remainder:
out << "%";
return;
case ast::binary_operation_type::logical_and:
out << "&";
return;
case ast::binary_operation_type::logical_or:
out << "|";
return;
case ast::binary_operation_type::logical_xor:
out << "^";
return;
case ast::binary_operation_type::equals:
out << "==";
return;
case ast::binary_operation_type::not_equals:
out << "!=";
return;
case ast::binary_operation_type::less:
out << "<";
return;
case ast::binary_operation_type::greater:
out << ">";
return;
case ast::binary_operation_type::less_equals:
out << "<=";
return;
case ast::binary_operation_type::greater_equals:
out << ">=";
return;
}
out << "(unknown)";
}
value eval_impl(context & context, ast::expression_ptr const & expression);
template <typename T>
value eval_impl(context & context, ast::numeric_literal_base<T> const & literal)
{
return primitive_value(primitive_value_base<T>{literal.value});;
}
value eval_impl(context & context, ast::literal const & literal)
{
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, literal);
}
value eval_impl(context & context, ast::identifier const & identifier)
{
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (auto jt = it->variables.find(identifier.name); jt != it->variables.end())
return jt->second.value;
}
throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined");
}
template <typename T>
value unary_operation_impl(ast::unary_operation_type type, primitive_value_base<T> const & arg1)
{
switch (type)
{
case ast::unary_operation_type::negation:
if constexpr ((std::is_integral_v<T> || std::is_floating_point_v<T>) && !std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(-arg1.value)});
}
break;
case ast::unary_operation_type::logical_not:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(!arg1.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(~arg1.value)});
}
break;
}
std::ostringstream os;
os << "Cannot apply unary operator \"";
print(os, type);
os << "\" to a value of type ";
type::print(os, type_of(primitive_value(arg1)));
throw std::runtime_error(os.str());
}
value unary_operation_impl(ast::unary_operation_type type, primitive_value const & arg1)
{
return std::visit([&](auto const & value){ return unary_operation_impl(type, value); }, arg1);
}
value eval_impl(context & context, ast::unary_operation const & unary_operation)
{
auto arg1 = eval_impl(context, unary_operation.arg1);
return std::visit([&](auto const & value){ return unary_operation_impl(unary_operation.type, value); }, arg1);
}
bool requires_same_argument_type(ast::binary_operation_type)
{
// TODO: shift operators should return false
return true;
}
template <typename T>
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, value const & arg2_generic)
{
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
switch (type)
{
case ast::binary_operation_type::addition:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value + arg2.value)});
}
break;
case ast::binary_operation_type::subtraction:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value - arg2.value)});
}
break;
case ast::binary_operation_type::multiplication:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value * arg2.value)});
}
break;
case ast::binary_operation_type::division:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value / arg2.value)});
}
break;
case ast::binary_operation_type::remainder:
if constexpr (!std::is_same_v<T, bool> && std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value % arg2.value)});
}
break;
case ast::binary_operation_type::logical_and:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value & arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value | arg2.value)});
}
break;
case ast::binary_operation_type::logical_xor:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value ^ arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value ^ arg2.value)});
}
break;
case ast::binary_operation_type::equals:
return primitive_value(primitive_value_base<bool>{arg1.value == arg2.value});
case ast::binary_operation_type::not_equals:
return primitive_value(primitive_value_base<bool>{arg1.value != arg2.value});
case ast::binary_operation_type::less:
return primitive_value(primitive_value_base<bool>{arg1.value < arg2.value});
case ast::binary_operation_type::greater:
return primitive_value(primitive_value_base<bool>{arg1.value > arg2.value});
case ast::binary_operation_type::less_equals:
return primitive_value(primitive_value_base<bool>{arg1.value <= arg2.value});
case ast::binary_operation_type::greater_equals:
return primitive_value(primitive_value_base<bool>{arg1.value >= arg2.value});
}
std::ostringstream os;
os << "Cannot apply binary operator \"";
print(os, type);
os << "\" to values of type ";
type::print(os, type_of(primitive_value(arg1)));
os << " and ";
type::print(os, type_of(primitive_value(arg2)));
throw std::runtime_error(os.str());
}
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, value const & arg2)
{
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2); }, arg1);
}
value eval_impl(context & context, ast::binary_operation const & binary_operation)
{
auto arg1 = eval_impl(context, binary_operation.arg1);
auto arg2 = eval_impl(context, binary_operation.arg2);
if (requires_same_argument_type(binary_operation.type))
{
auto type1 = type_of(arg1);
auto type2 = type_of(arg2);
if (!type::equal(type1, type2))
{
std::ostringstream os;
os << "Cannot apply binary operator \"";
print(os, binary_operation.type);
os << "\" to values of type ";
type::print(os, type1);
os << " and ";
type::print(os, type2);
throw std::runtime_error(os.str());
}
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2); }, arg1);
}
throw std::runtime_error("eval(binary_operation) for different argument types not implemented");
}
template <typename T, typename H>
value cast_impl(primitive_value_base<T> const & value, type::primitive_type_base<H> const & type)
{
if constexpr (std::is_same_v<T, H>)
{
return primitive_value(value);
}
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
{
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
}
std::ostringstream os;
os << "Cannot cast value of type ";
type::print(os, type_of(primitive_value(value)));
os << " to type ";
type::print(os, type::primitive_type(type));
throw std::runtime_error(os.str());
}
template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::primitive_type const & type)
{
return std::visit([&](auto const & type){ return cast_impl(value, type); }, type);
}
template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::type const & type)
{
return std::visit([&](auto const & type){ return cast_impl(value, type); }, type);
}
value cast_impl(primitive_value const & value, type::type const & type)
{
return std::visit([&](auto const & value){ return cast_impl(value, type); }, value);
}
value eval_impl(context & context, ast::cast_operation const & cast_operation)
{
auto arg = eval(context, cast_operation.expression);
return std::visit([&](auto const & value){ return cast_impl(value, *cast_operation.type); }, arg);
}
value eval_impl(context & context, ast::expression_ptr const & expression)
{
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
}
}
value eval(context & context, ast::expression_ptr const & expression)
{
return eval_impl(context, expression);
}
}