pslang/libs/interpreter/source/eval.cpp

783 lines
24 KiB
C++

#include <pslang/interpreter/eval.hpp>
#include <pslang/interpreter/exec.hpp>
#include <pslang/interpreter/value.hpp>
#include <pslang/ast/expression.hpp>
#include <pslang/type/print.hpp>
#include <pslang/parser/error.hpp>
#include <sstream>
#include <optional>
namespace pslang::interpreter
{
namespace
{
type::type resolve_type_impl(context &, type::unit_type const & type)
{
return type;
}
type::type resolve_type_impl(context &, type::primitive_type const & type)
{
return type;
}
type::type resolve_type_impl(context & context, type::array_type const & type)
{
return type::array_type{std::make_unique<type::type>(resolve_type(context, *type.element_type)), type.size};
}
type::type resolve_type_impl(context & context, type::function_type const & type)
{
type::function_type result;
for (auto const & argument : type.arguments)
result.arguments.push_back(std::make_unique<type::type>(resolve_type(context, *argument)));
result.result = std::make_unique<type::type>(resolve_type(context, *type.result));
return result;
}
type::type resolve_type_impl(context & context, type::identifier const & type)
{
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (it->structs.count(type.name))
{
return type::identifier{std::string(it.base() - context.scope_stack.begin() - 1, '/') + type.name};
}
}
throw std::runtime_error("Type \"" + type.name + "\" is not defined");
}
type::type resolve_type_impl(context & context, type::type const & type)
{
return std::visit([&](auto const & type){ return resolve_type_impl(context, type); }, type);
}
void print(std::ostream & out, ast::unary_operation_type type)
{
switch (type)
{
case ast::unary_operation_type::negation:
out << "-";
return;
case ast::unary_operation_type::logical_not:
out << "!";
return;
}
out << "(unknown)";
}
void print(std::ostream & out, ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::addition:
out << "+";
return;
case ast::binary_operation_type::subtraction:
out << "-";
return;
case ast::binary_operation_type::multiplication:
out << "*";
return;
case ast::binary_operation_type::division:
out << "/";
return;
case ast::binary_operation_type::remainder:
out << "%";
return;
case ast::binary_operation_type::logical_and:
out << "&";
return;
case ast::binary_operation_type::logical_or:
out << "|";
return;
case ast::binary_operation_type::logical_xor:
out << "^";
return;
case ast::binary_operation_type::equals:
out << "==";
return;
case ast::binary_operation_type::not_equals:
out << "!=";
return;
case ast::binary_operation_type::less:
out << "<";
return;
case ast::binary_operation_type::greater:
out << ">";
return;
case ast::binary_operation_type::less_equals:
out << "<=";
return;
case ast::binary_operation_type::greater_equals:
out << ">=";
return;
}
out << "(unknown)";
}
std::uint64_t get_array_index(value const & index, std::uint64_t size)
{
std::optional<std::uint64_t> index_unsigned;
std::optional<std::int64_t> index_signed;
if (auto pvalue = std::get_if<primitive_value>(&index))
{
if (auto i8 = std::get_if<i8_value>(pvalue))
{
index_signed = i8->value;
}
else if (auto u8 = std::get_if<u8_value>(pvalue))
{
index_unsigned = u8->value;
}
else if (auto i16 = std::get_if<i16_value>(pvalue))
{
index_signed = i16->value;
}
else if (auto u16 = std::get_if<u16_value>(pvalue))
{
index_unsigned = u16->value;
}
else if (auto i32 = std::get_if<i32_value>(pvalue))
{
index_signed = i32->value;
}
else if (auto u32 = std::get_if<u32_value>(pvalue))
{
index_unsigned = u32->value;
}
else if (auto i64 = std::get_if<i64_value>(pvalue))
{
index_signed = i64->value;
}
else if (auto u64 = std::get_if<u64_value>(pvalue))
{
index_unsigned = u64->value;
}
}
if (!index_signed && !index_unsigned)
{
std::ostringstream os;
os << "Cannot index into an array with an expression of type ";
type::print(os, type_of(index));
throw std::runtime_error(os.str());
}
if (index_unsigned)
{
if (*index_unsigned >= size)
{
std::ostringstream os;
os << "Array index " << *index_unsigned << " out of bounds " << size;
throw std::runtime_error(os.str());
}
return *index_unsigned;
}
else // if (index_signed)
{
if (*index_signed < 0 || *index_signed >= size)
{
std::ostringstream os;
os << "Array index " << *index_signed << " out of bounds " << size;
throw std::runtime_error(os.str());
}
return *index_signed;
}
}
value eval_impl(context & context, ast::expression_ptr const & expression);
template <typename T>
value eval_impl(context & context, ast::numeric_literal_base<T> const & literal)
{
return primitive_value(primitive_value_base<T>{literal.value});
}
value eval_impl(context & context, ast::literal const & literal)
{
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, literal);
}
value eval_impl(context & context, ast::identifier const & identifier)
{
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (auto jt = it->variables.find(identifier.name); jt != it->variables.end())
return jt->second.value;
}
throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined");
}
template <typename Value>
value unary_operation_impl(ast::unary_operation_type type, Value const & value)
{
std::ostringstream os;
os << "Cannot apply unary operator \"";
print(os, type);
os << "\" to a value of type ";
type::print(os, type_of(value));
throw std::runtime_error(os.str());
}
template <typename T>
value unary_operation_impl(ast::unary_operation_type type, primitive_value_base<T> const & arg1)
{
switch (type)
{
case ast::unary_operation_type::negation:
if constexpr ((std::is_integral_v<T> || std::is_floating_point_v<T>) && !std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(-arg1.value)});
}
break;
case ast::unary_operation_type::logical_not:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(!arg1.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(~arg1.value)});
}
break;
}
std::ostringstream os;
os << "Cannot apply unary operator \"";
print(os, type);
os << "\" to a value of type ";
type::print(os, type_of(primitive_value(arg1)));
throw std::runtime_error(os.str());
}
value unary_operation_impl(ast::unary_operation_type type, primitive_value const & arg1)
{
return std::visit([&](auto const & value){ return unary_operation_impl(type, value); }, arg1);
}
value eval_impl(context & context, ast::unary_operation const & unary_operation)
{
auto arg1 = eval_impl(context, unary_operation.arg1);
return std::visit([&](auto const & value){ return unary_operation_impl(unary_operation.type, value); }, arg1);
}
bool requires_same_argument_type(ast::binary_operation_type)
{
// TODO: shift operators should return false
return true;
}
template <typename T>
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, value const & arg2_generic)
{
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
switch (type)
{
case ast::binary_operation_type::addition:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value + arg2.value)});
}
break;
case ast::binary_operation_type::subtraction:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value - arg2.value)});
}
break;
case ast::binary_operation_type::multiplication:
if constexpr (!std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value * arg2.value)});
}
break;
case ast::binary_operation_type::division:
if constexpr (!std::is_same_v<T, bool>)
{
if constexpr (std::is_integral_v<T>)
{
if (arg2.value == static_cast<T>(0))
throw std::runtime_error("Division by zero");
}
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value / arg2.value)});
}
break;
case ast::binary_operation_type::remainder:
if constexpr (!std::is_same_v<T, bool> && std::is_integral_v<T>)
{
if constexpr (std::is_integral_v<T>)
{
if (arg2.value == static_cast<T>(0))
throw std::runtime_error("Division by zero");
}
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value % arg2.value)});
}
break;
case ast::binary_operation_type::logical_and:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value & arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value | arg2.value)});
}
break;
case ast::binary_operation_type::logical_xor:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value ^ arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value ^ arg2.value)});
}
break;
case ast::binary_operation_type::equals:
return primitive_value(primitive_value_base<bool>{arg1.value == arg2.value});
case ast::binary_operation_type::not_equals:
return primitive_value(primitive_value_base<bool>{arg1.value != arg2.value});
case ast::binary_operation_type::less:
return primitive_value(primitive_value_base<bool>{arg1.value < arg2.value});
case ast::binary_operation_type::greater:
return primitive_value(primitive_value_base<bool>{arg1.value > arg2.value});
case ast::binary_operation_type::less_equals:
return primitive_value(primitive_value_base<bool>{arg1.value <= arg2.value});
case ast::binary_operation_type::greater_equals:
return primitive_value(primitive_value_base<bool>{arg1.value >= arg2.value});
}
std::ostringstream os;
os << "Cannot apply binary operator \"";
print(os, type);
os << "\" to values of type ";
type::print(os, type_of(primitive_value(arg1)));
os << " and ";
type::print(os, type_of(primitive_value(arg2)));
throw std::runtime_error(os.str());
}
template <typename Value>
value binary_operation_impl_same_type(ast::binary_operation_type type, Value const & arg1, value const & arg2)
{
std::ostringstream os;
os << "Cannot apply binary operator \"";
print(os, type);
os << "\" to values of type ";
type::print(os, type_of(arg1));
os << " and ";
type::print(os, type_of(arg2));
throw std::runtime_error(os.str());
}
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, value const & arg2)
{
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2); }, arg1);
}
value eval_impl(context & context, ast::binary_operation const & binary_operation)
{
auto arg1 = eval_impl(context, binary_operation.arg1);
auto arg2 = eval_impl(context, binary_operation.arg2);
if (requires_same_argument_type(binary_operation.type))
{
auto type1 = type_of(arg1);
auto type2 = type_of(arg2);
if (!type::equal(type1, type2))
{
std::ostringstream os;
os << "Cannot apply binary operator \"";
print(os, binary_operation.type);
os << "\" to values of type ";
type::print(os, type1);
os << " and ";
type::print(os, type2);
throw std::runtime_error(os.str());
}
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2); }, arg1);
}
throw std::runtime_error("eval(binary_operation) for different argument types not implemented");
}
value cast_impl(unit_value const & value, type::type const & type)
{
if (type::equal(type, type::unit_type{}))
return value;
throw std::runtime_error("Cannot cast unit type to anything");
}
value cast_impl(array_value const & value, type::type const & type)
{
if (type::equal(type, type_of(value)))
return value;
throw std::runtime_error("Cannot cast array type to anything");
}
template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::unit_type const &)
{
throw std::runtime_error("Cannot cast anything to unit type");
}
template <typename T, typename H>
value cast_impl(primitive_value_base<T> const & value, type::primitive_type_base<H> const & type)
{
if constexpr (std::is_same_v<T, H>)
{
return primitive_value(value);
}
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
{
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
}
std::ostringstream os;
os << "Cannot cast value of type ";
type::print(os, type_of(primitive_value(value)));
os << " to type ";
type::print(os, type::primitive_type(type));
throw std::runtime_error(os.str());
}
template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::primitive_type const & type)
{
return std::visit([&](auto const & type){ return cast_impl(value, type); }, type);
}
template <typename T>
value cast_impl(primitive_value_base<T> const & value, type::type const & type)
{
return std::visit([&](auto const & type){ return cast_impl(value, type); }, type);
}
value cast_impl(primitive_value const & value, type::type const & type)
{
return std::visit([&](auto const & value){ return cast_impl(value, type); }, value);
}
value cast_impl(struct_value const & value, type::type const & type)
{
if (type::equal(type, type::unit_type{}))
return value;
throw std::runtime_error("Cannot cast struct type to anything");
}
value eval_impl(context & context, ast::cast_operation const & cast_operation)
{
auto arg = eval(context, cast_operation.expression);
return std::visit([&](auto const & value){ return cast_impl(value, *cast_operation.type); }, arg);
}
value eval_impl(context & context, ast::function_call const & function_call)
{
auto identifier = std::get_if<ast::identifier>(function_call.function.get());
if (!identifier)
throw std::runtime_error("Calling non-identifier functions is not implemented");
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (auto jt = it->functions.find(identifier->name); jt != it->functions.end())
{
if (jt->second.arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *jt->second.arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type ";
type::print(os, *jt->second.arguments[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str());
}
}
auto & function_scope = context.scope_stack.emplace_back();
function_scope.is_function_scope = true;
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[jt->second.arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
auto expected_return_type = jt->second.return_type;
exec(context, jt->second.statements);
auto actual_return_type = type_of(context.scope_stack.back().return_value);
if (!type::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function \"" << identifier->name << "\": expected return type is ";
type::print(os, *expected_return_type);
os << " but actual type is ";
type::print(os, actual_return_type);
throw std::runtime_error(os.str());
}
auto result = std::move(context.scope_stack.back().return_value);
context.scope_stack.pop_back();
return result;
}
if (auto jt = it->structs.find(identifier->name); jt != it->structs.end())
{
if (jt->second.fields.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
std::unordered_map<std::string, value_ptr> fields;
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *jt->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
type::print(os, *jt->second.fields[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str());
}
fields[jt->second.fields[i].name] = std::make_unique<value>(std::move(args[i]));
}
return struct_value{
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{identifier->name})),
.fields = std::move(fields),
};
}
}
throw std::runtime_error("Function \"" + identifier->name + "\" is not defined");
}
value eval_impl(context & context, ast::array const & array)
{
if (array.elements.empty())
throw std::runtime_error("Internal error: array ast node cannot have zero elements");
type::type_ptr element_type;
std::vector<value_ptr> elements;
for (std::size_t i = 0; i < array.elements.size(); ++i)
{
auto element = std::make_unique<value>(eval(context, array.elements[i]));
if (i == 0)
element_type = std::make_unique<type::type>(type_of(*element));
else
{
auto new_type = type_of(*element);
if (!type::equal(*element_type, new_type))
{
std::ostringstream os;
os << "Error forming array: inferred element type is ";
type::print(os, *element_type);
os << " but element #" << i << " type is ";
type::print(os, new_type);
throw std::runtime_error(os.str());
}
}
elements.push_back(std::move(element));
}
return array_value{.element_type = std::move(element_type), .elements = std::move(elements)};
}
value eval_impl(context & context, ast::array_access const & array_access)
{
auto array = eval(context, array_access.array);
auto index = eval(context, array_access.index);
if (auto avalue = std::get_if<array_value>(&array))
{
return *avalue->elements[get_array_index(index, avalue->elements.size())];
}
std::ostringstream os;
os << "Cannot index into a non-array of type ";
type::print(os, type_of(array));
throw std::runtime_error(os.str());
}
value eval_impl(context & context, ast::field_access const & field_access)
{
auto object = eval(context, field_access.object);
if (auto value = std::get_if<struct_value>(&object))
{
if (auto it = value->fields.find(field_access.field_name); it != value->fields.end())
return *it->second;
std::ostringstream os;
os << "Struct ";
type::print(os, type_of(object));
os << " has no field named \"" << field_access.field_name << "\"";
throw std::runtime_error(os.str());
}
std::ostringstream os;
os << "Value of type ";
type::print(os, type_of(object));
os << " is not a struct";
throw std::runtime_error(os.str());
}
value eval_impl(context & context, ast::expression_ptr const & expression)
{
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
}
value * eval_ref_impl(context & context, ast::literal const &)
{
throw std::runtime_error("Literal cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::identifier const & identifier)
{
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (auto jt = it->variables.find(identifier.name); jt != it->variables.end())
{
if (jt->second.category != ast::value_category::_mutable)
throw std::runtime_error("Cannot assign a value to a non-mutable variable");
return &jt->second.value;
}
}
throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined");
}
value * eval_ref_impl(context & context, ast::unary_operation const &)
{
throw std::runtime_error("Unary operation cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::binary_operation const &)
{
throw std::runtime_error("Binary operation cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::cast_operation const &)
{
throw std::runtime_error("Cast operation cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::function_call const &)
{
throw std::runtime_error("Function call cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::array const &)
{
throw std::runtime_error("Array cannot be on the left-hand-side of assignment");
}
value * eval_ref_impl(context & context, ast::array_access const & array_access)
{
auto index = eval(context, array_access.index);
auto array_ref = eval_ref(context, array_access.array);
if (auto avalue = std::get_if<array_value>(array_ref))
{
return avalue->elements[get_array_index(index, avalue->elements.size())].get();
}
std::ostringstream os;
os << "Cannot index into a non-array of type ";
type::print(os, type_of(*array_ref));
throw std::runtime_error(os.str());
}
value * eval_ref_impl(context & context, ast::field_access const & field_access)
{
auto object_ref = eval_ref(context, field_access.object);
if (auto value = std::get_if<struct_value>(object_ref))
{
if (auto it = value->fields.find(field_access.field_name); it != value->fields.end())
return it->second.get();
std::ostringstream os;
os << "Struct ";
type::print(os, type_of(*object_ref));
os << " has no field named \"" << field_access.field_name << "\"";
throw std::runtime_error(os.str());
}
std::ostringstream os;
os << "Value of type ";
type::print(os, type_of(*object_ref));
os << " is not a struct";
throw std::runtime_error(os.str());
}
value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
{
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);
}
}
type::type resolve_type(context & context, type::type const & type)
{
return resolve_type_impl(context, type);
}
value eval(context & context, ast::expression_ptr const & expression)
{
return eval_impl(context, expression);
}
value * eval_ref(context & context, ast::expression_ptr const & expression)
{
return eval_ref_impl(context, expression);
}
}