#include #include #include #include #include #include #include namespace pslang::interpreter { namespace { type::type resolve_type_impl(context &, type::unit_type const & type) { return type; } type::type resolve_type_impl(context &, type::primitive_type const & type) { return type; } type::type resolve_type_impl(context & context, type::array_type const & type) { return type::array_type{std::make_unique(resolve_type(context, *type.element_type)), type.size}; } type::type resolve_type_impl(context & context, type::identifier const & type) { for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { if (it->structs.count(type.name)) { return type::identifier{std::string(it.base() - context.scope_stack.begin() - 1, '/') + type.name}; } } throw std::runtime_error("Type \"" + type.name + "\" is not defined"); } type::type resolve_type_impl(context & context, type::type const & type) { return std::visit([&](auto const & type){ return resolve_type_impl(context, type); }, type); } void print(std::ostream & out, ast::unary_operation_type type) { switch (type) { case ast::unary_operation_type::negation: out << "-"; return; case ast::unary_operation_type::logical_not: out << "!"; return; } out << "(unknown)"; } void print(std::ostream & out, ast::binary_operation_type type) { switch (type) { case ast::binary_operation_type::addition: out << "+"; return; case ast::binary_operation_type::subtraction: out << "-"; return; case ast::binary_operation_type::multiplication: out << "*"; return; case ast::binary_operation_type::division: out << "/"; return; case ast::binary_operation_type::remainder: out << "%"; return; case ast::binary_operation_type::logical_and: out << "&"; return; case ast::binary_operation_type::logical_or: out << "|"; return; case ast::binary_operation_type::logical_xor: out << "^"; return; case ast::binary_operation_type::equals: out << "=="; return; case ast::binary_operation_type::not_equals: out << "!="; return; case ast::binary_operation_type::less: out << "<"; return; case ast::binary_operation_type::greater: out << ">"; return; case ast::binary_operation_type::less_equals: out << "<="; return; case ast::binary_operation_type::greater_equals: out << ">="; return; } out << "(unknown)"; } std::uint64_t get_array_index(value const & index, std::uint64_t size) { std::optional index_unsigned; std::optional index_signed; if (auto pvalue = std::get_if(&index)) { if (auto i8 = std::get_if(pvalue)) { index_signed = i8->value; } else if (auto u8 = std::get_if(pvalue)) { index_unsigned = u8->value; } else if (auto i16 = std::get_if(pvalue)) { index_signed = i16->value; } else if (auto u16 = std::get_if(pvalue)) { index_unsigned = u16->value; } else if (auto i32 = std::get_if(pvalue)) { index_signed = i32->value; } else if (auto u32 = std::get_if(pvalue)) { index_unsigned = u32->value; } else if (auto i64 = std::get_if(pvalue)) { index_signed = i64->value; } else if (auto u64 = std::get_if(pvalue)) { index_unsigned = u64->value; } } if (!index_signed && !index_unsigned) { std::ostringstream os; os << "Cannot index into an array with an expression of type "; type::print(os, type_of(index)); throw std::runtime_error(os.str()); } if (index_unsigned) { if (*index_unsigned >= size) { std::ostringstream os; os << "Array index " << *index_unsigned << " out of bounds " << size; throw std::runtime_error(os.str()); } return *index_unsigned; } else // if (index_signed) { if (*index_signed < 0 || *index_signed >= size) { std::ostringstream os; os << "Array index " << *index_signed << " out of bounds " << size; throw std::runtime_error(os.str()); } return *index_signed; } } value eval_impl(context & context, ast::expression_ptr const & expression); template value eval_impl(context & context, ast::numeric_literal_base const & literal) { return primitive_value(primitive_value_base{literal.value}); } value eval_impl(context & context, ast::literal const & literal) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, literal); } value eval_impl(context & context, ast::identifier const & identifier) { for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { if (auto jt = it->variables.find(identifier.name); jt != it->variables.end()) return jt->second.value; } throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined"); } template value unary_operation_impl(ast::unary_operation_type type, Value const & value) { std::ostringstream os; os << "Cannot apply unary operator \""; print(os, type); os << "\" to a value of type "; type::print(os, type_of(value)); throw std::runtime_error(os.str()); } template value unary_operation_impl(ast::unary_operation_type type, primitive_value_base const & arg1) { switch (type) { case ast::unary_operation_type::negation: if constexpr ((std::is_integral_v || std::is_floating_point_v) && !std::is_same_v) { return primitive_value(primitive_value_base{static_cast(-arg1.value)}); } break; case ast::unary_operation_type::logical_not: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(!arg1.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(~arg1.value)}); } break; } std::ostringstream os; os << "Cannot apply unary operator \""; print(os, type); os << "\" to a value of type "; type::print(os, type_of(primitive_value(arg1))); throw std::runtime_error(os.str()); } value unary_operation_impl(ast::unary_operation_type type, primitive_value const & arg1) { return std::visit([&](auto const & value){ return unary_operation_impl(type, value); }, arg1); } value eval_impl(context & context, ast::unary_operation const & unary_operation) { auto arg1 = eval_impl(context, unary_operation.arg1); return std::visit([&](auto const & value){ return unary_operation_impl(unary_operation.type, value); }, arg1); } bool requires_same_argument_type(ast::binary_operation_type) { // TODO: shift operators should return false return true; } template value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base const & arg1, value const & arg2_generic) { primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); switch (type) { case ast::binary_operation_type::addition: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value + arg2.value)}); } break; case ast::binary_operation_type::subtraction: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value - arg2.value)}); } break; case ast::binary_operation_type::multiplication: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value * arg2.value)}); } break; case ast::binary_operation_type::division: if constexpr (!std::is_same_v) { if constexpr (std::is_integral_v) { if (arg2.value == static_cast(0)) throw std::runtime_error("Division by zero"); } return primitive_value(primitive_value_base{static_cast(arg1.value / arg2.value)}); } break; case ast::binary_operation_type::remainder: if constexpr (!std::is_same_v && std::is_integral_v) { if constexpr (std::is_integral_v) { if (arg2.value == static_cast(0)) throw std::runtime_error("Division by zero"); } return primitive_value(primitive_value_base{static_cast(arg1.value % arg2.value)}); } break; case ast::binary_operation_type::logical_and: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value & arg2.value)}); } break; case ast::binary_operation_type::logical_or: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value | arg2.value)}); } break; case ast::binary_operation_type::logical_xor: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value ^ arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value ^ arg2.value)}); } break; case ast::binary_operation_type::equals: return primitive_value(primitive_value_base{arg1.value == arg2.value}); case ast::binary_operation_type::not_equals: return primitive_value(primitive_value_base{arg1.value != arg2.value}); case ast::binary_operation_type::less: return primitive_value(primitive_value_base{arg1.value < arg2.value}); case ast::binary_operation_type::greater: return primitive_value(primitive_value_base{arg1.value > arg2.value}); case ast::binary_operation_type::less_equals: return primitive_value(primitive_value_base{arg1.value <= arg2.value}); case ast::binary_operation_type::greater_equals: return primitive_value(primitive_value_base{arg1.value >= arg2.value}); } std::ostringstream os; os << "Cannot apply binary operator \""; print(os, type); os << "\" to values of type "; type::print(os, type_of(primitive_value(arg1))); os << " and "; type::print(os, type_of(primitive_value(arg2))); throw std::runtime_error(os.str()); } template value binary_operation_impl_same_type(ast::binary_operation_type type, Value const & arg1, value const & arg2) { std::ostringstream os; os << "Cannot apply binary operator \""; print(os, type); os << "\" to values of type "; type::print(os, type_of(arg1)); os << " and "; type::print(os, type_of(arg2)); throw std::runtime_error(os.str()); } value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, value const & arg2) { return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2); }, arg1); } value eval_impl(context & context, ast::binary_operation const & binary_operation) { auto arg1 = eval_impl(context, binary_operation.arg1); auto arg2 = eval_impl(context, binary_operation.arg2); if (requires_same_argument_type(binary_operation.type)) { auto type1 = type_of(arg1); auto type2 = type_of(arg2); if (!type::equal(type1, type2)) { std::ostringstream os; os << "Cannot apply binary operator \""; print(os, binary_operation.type); os << "\" to values of type "; type::print(os, type1); os << " and "; type::print(os, type2); throw std::runtime_error(os.str()); } return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2); }, arg1); } throw std::runtime_error("eval(binary_operation) for different argument types not implemented"); } value cast_impl(unit_value const & value, type::type const & type) { if (type::equal(type, type::unit_type{})) return value; throw std::runtime_error("Cannot cast unit type to anything"); } value cast_impl(array_value const & value, type::type const & type) { if (type::equal(type, type_of(value))) return value; throw std::runtime_error("Cannot cast array type to anything"); } template value cast_impl(primitive_value_base const & value, type::unit_type const &) { throw std::runtime_error("Cannot cast anything to unit type"); } template value cast_impl(primitive_value_base const & value, type::primitive_type_base const & type) { if constexpr (std::is_same_v) { return primitive_value(value); } else if constexpr (!std::is_same_v && !std::is_same_v) { return primitive_value(primitive_value_base{static_cast(value.value)}); } std::ostringstream os; os << "Cannot cast value of type "; type::print(os, type_of(primitive_value(value))); os << " to type "; type::print(os, type::primitive_type(type)); throw std::runtime_error(os.str()); } template value cast_impl(primitive_value_base const & value, type::primitive_type const & type) { return std::visit([&](auto const & type){ return cast_impl(value, type); }, type); } template value cast_impl(primitive_value_base const & value, type::type const & type) { return std::visit([&](auto const & type){ return cast_impl(value, type); }, type); } value cast_impl(primitive_value const & value, type::type const & type) { return std::visit([&](auto const & value){ return cast_impl(value, type); }, value); } value cast_impl(struct_value const & value, type::type const & type) { if (type::equal(type, type::unit_type{})) return value; throw std::runtime_error("Cannot cast struct type to anything"); } value eval_impl(context & context, ast::cast_operation const & cast_operation) { auto arg = eval(context, cast_operation.expression); return std::visit([&](auto const & value){ return cast_impl(value, *cast_operation.type); }, arg); } value eval_impl(context & context, ast::function_call const & function_call) { for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { if (auto jt = it->functions.find(function_call.name); jt != it->functions.end()) { if (jt->second.arguments.size() != function_call.arguments.size()) { std::ostringstream os; os << "Cannot call function \"" << function_call.name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } std::vector args; for (auto const & expression : function_call.arguments) args.push_back(eval(context, expression)); for (std::size_t i = 0; i < args.size(); ++i) { auto actual_type = type_of(args[i]); if (!type::equal(actual_type, *jt->second.arguments[i].type)) { std::ostringstream os; os << "Cannot call function \"" << function_call.name << "\": argument #" << (i + 1) << " expects type "; type::print(os, *jt->second.arguments[i].type); os << " but actual type is "; type::print(os, actual_type); throw std::runtime_error(os.str()); } } auto & function_scope = context.scope_stack.emplace_back(); function_scope.is_function_scope = true; for (std::size_t i = 0; i < args.size(); ++i) function_scope.variables[jt->second.arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; auto expected_return_type = jt->second.return_type; exec(context, jt->second.statements); auto actual_return_type = type_of(context.scope_stack.back().return_value); if (!type::equal(actual_return_type, *expected_return_type)) { std::ostringstream os; os << "Error returning from function \"" << function_call.name << "\": expected return type is "; type::print(os, *expected_return_type); os << " but actual type is "; type::print(os, actual_return_type); throw std::runtime_error(os.str()); } auto result = std::move(context.scope_stack.back().return_value); context.scope_stack.pop_back(); return result; } if (auto jt = it->structs.find(function_call.name); jt != it->structs.end()) { if (jt->second.fields.size() != function_call.arguments.size()) { std::ostringstream os; os << "Cannot create struct \"" << function_call.name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } std::vector args; for (auto const & expression : function_call.arguments) args.push_back(eval(context, expression)); std::unordered_map fields; for (std::size_t i = 0; i < args.size(); ++i) { auto actual_type = type_of(args[i]); if (!type::equal(actual_type, *jt->second.fields[i].type)) { std::ostringstream os; os << "Cannot create struct \"" << function_call.name << "\": field " << jt->second.fields[i].name << " expects type "; type::print(os, *jt->second.fields[i].type); os << " but actual type is "; type::print(os, actual_type); throw std::runtime_error(os.str()); } fields[jt->second.fields[i].name] = std::make_unique(std::move(args[i])); } return struct_value{ .struct_type = std::make_unique(resolve_type(context, type::identifier{function_call.name})), .fields = std::move(fields), }; } } throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); } value eval_impl(context & context, ast::array const & array) { if (array.elements.empty()) throw std::runtime_error("Internal error: array ast node cannot have zero elements"); type::type_ptr element_type; std::vector elements; for (std::size_t i = 0; i < array.elements.size(); ++i) { auto element = std::make_unique(eval(context, array.elements[i])); if (i == 0) element_type = std::make_unique(type_of(*element)); else { auto new_type = type_of(*element); if (!type::equal(*element_type, new_type)) { std::ostringstream os; os << "Error forming array: inferred element type is "; type::print(os, *element_type); os << " but element #" << i << " type is "; type::print(os, new_type); throw std::runtime_error(os.str()); } } elements.push_back(std::move(element)); } return array_value{.element_type = std::move(element_type), .elements = std::move(elements)}; } value eval_impl(context & context, ast::array_access const & array_access) { auto array = eval(context, array_access.array); auto index = eval(context, array_access.index); if (auto avalue = std::get_if(&array)) { return *avalue->elements[get_array_index(index, avalue->elements.size())]; } std::ostringstream os; os << "Cannot index into a non-array of type "; type::print(os, type_of(array)); throw std::runtime_error(os.str()); } value eval_impl(context & context, ast::field_access const & field_access) { auto object = eval(context, field_access.object); if (auto value = std::get_if(&object)) { if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) return *it->second; std::ostringstream os; os << "Struct "; type::print(os, type_of(object)); os << " has no field named \"" << field_access.field_name << "\""; throw std::runtime_error(os.str()); } std::ostringstream os; os << "Value of type "; type::print(os, type_of(object)); os << " is not a struct"; throw std::runtime_error(os.str()); } value eval_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); } value * eval_ref_impl(context & context, ast::literal const &) { throw std::runtime_error("Literal cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::identifier const & identifier) { for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { if (auto jt = it->variables.find(identifier.name); jt != it->variables.end()) { if (jt->second.category != ast::value_category::_mutable) throw std::runtime_error("Cannot assign a value to a non-mutable variable"); return &jt->second.value; } } throw std::runtime_error("Identifier \"" + identifier.name + "\" is not defined"); } value * eval_ref_impl(context & context, ast::unary_operation const &) { throw std::runtime_error("Unary operation cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::binary_operation const &) { throw std::runtime_error("Binary operation cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::cast_operation const &) { throw std::runtime_error("Cast operation cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::function_call const &) { throw std::runtime_error("Function call cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::array const &) { throw std::runtime_error("Array cannot be on the left-hand-side of assignment"); } value * eval_ref_impl(context & context, ast::array_access const & array_access) { auto index = eval(context, array_access.index); auto array_ref = eval_ref(context, array_access.array); if (auto avalue = std::get_if(array_ref)) { return avalue->elements[get_array_index(index, avalue->elements.size())].get(); } std::ostringstream os; os << "Cannot index into a non-array of type "; type::print(os, type_of(*array_ref)); throw std::runtime_error(os.str()); } value * eval_ref_impl(context & context, ast::field_access const & field_access) { auto object_ref = eval_ref(context, field_access.object); if (auto value = std::get_if(object_ref)) { if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) return it->second.get(); std::ostringstream os; os << "Struct "; type::print(os, type_of(*object_ref)); os << " has no field named \"" << field_access.field_name << "\""; throw std::runtime_error(os.str()); } std::ostringstream os; os << "Value of type "; type::print(os, type_of(*object_ref)); os << " is not a struct"; throw std::runtime_error(os.str()); } value * eval_ref_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression); } } type::type resolve_type(context & context, type::type const & type) { return resolve_type_impl(context, type); } value eval(context & context, ast::expression_ptr const & expression) { return eval_impl(context, expression); } value * eval_ref(context & context, ast::expression_ptr const & expression) { return eval_ref_impl(context, expression); } }