#include #include #include #include #include #include #include #include #include namespace pslang::interpreter { namespace { std::uint64_t get_array_index(value const & index, std::uint64_t size, ast::location const & location) { std::optional index_unsigned; std::optional index_signed; if (auto pvalue = std::get_if(&index)) { if (auto i8 = std::get_if(pvalue)) { index_signed = i8->value; } else if (auto u8 = std::get_if(pvalue)) { index_unsigned = u8->value; } else if (auto i16 = std::get_if(pvalue)) { index_signed = i16->value; } else if (auto u16 = std::get_if(pvalue)) { index_unsigned = u16->value; } else if (auto i32 = std::get_if(pvalue)) { index_signed = i32->value; } else if (auto u32 = std::get_if(pvalue)) { index_unsigned = u32->value; } else if (auto i64 = std::get_if(pvalue)) { index_signed = i64->value; } else if (auto u64 = std::get_if(pvalue)) { index_unsigned = u64->value; } } if (!index_signed && !index_unsigned) { std::ostringstream os; os << "Cannot index into an array with an expression of type "; types::print(os, type_of(index)); throw internal_error(os.str(), location); } if (index_unsigned) { if (*index_unsigned >= size) { std::ostringstream os; os << "Array index " << *index_unsigned << " out of bounds " << size; throw runtime_error(os.str(), location); } return *index_unsigned; } else // if (index_signed) { if (*index_signed < 0 || *index_signed >= size) { std::ostringstream os; os << "Array index " << *index_signed << " out of bounds " << size; throw runtime_error(os.str(), location); } return *index_signed; } } value eval_impl(context & context, ast::expression_ptr const & expression); template value eval_impl(context & context, ast::primitive_literal_base const & literal) { return primitive_value(primitive_value_base{literal.value}); } value eval_impl(context & context, ast::literal const & literal) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, literal); } value eval_impl(context & context, ast::identifier const & identifier) { // NB: cannot use identifier.level here because lexical scope stack // almost always differs from execution frame stack for (auto it = context.frame_stack.rbegin(); it != context.frame_stack.rend(); ++it) if (auto jt = it->variables.find(identifier.name); jt != it->variables.end()) return jt->second.value; throw internal_error("Identifier \"" + identifier.name + "\" is not defined", identifier.location); } template value unary_operation_impl(ast::unary_operation_type type, Value const & value, ast::location const & location) { std::ostringstream os; os << "Cannot apply " << type << " to a value of type "; types::print(os, type_of(value)); throw internal_error(os.str(), location); } template value unary_operation_impl(ast::unary_operation_type type, primitive_value_base const & arg1, ast::location const & location) { switch (type) { case ast::unary_operation_type::negation: if constexpr ((std::is_integral_v || std::is_floating_point_v) && !std::is_same_v) { return primitive_value(primitive_value_base{static_cast(-arg1.value)}); } break; case ast::unary_operation_type::logical_not: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(!arg1.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(~arg1.value)}); } break; } std::ostringstream os; os << "Cannot apply " << type << " to a value of type "; types::print(os, type_of(primitive_value(arg1))); throw internal_error(os.str(), location); } value unary_operation_impl(ast::unary_operation_type type, primitive_value const & arg1, ast::location const & location) { return std::visit([&](auto const & value){ return unary_operation_impl(type, value, location); }, arg1); } value eval_impl(context & context, ast::unary_operation const & unary_operation) { auto arg1 = eval_impl(context, unary_operation.arg1); return std::visit([&](auto const & value){ return unary_operation_impl(unary_operation.type, value, unary_operation.location); }, arg1); } bool requires_same_argument_type(ast::binary_operation_type) { // TODO: shift operators should return false return true; } bool is_short_circuiting(ast::binary_operation_type type) { switch (type) { case ast::binary_operation_type::logical_and: case ast::binary_operation_type::logical_or: return true; default: return false; } } template value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base const & arg1, value const & arg2_generic, ast::location const & location) { primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); switch (type) { case ast::binary_operation_type::addition: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value + arg2.value)}); } break; case ast::binary_operation_type::subtraction: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value - arg2.value)}); } break; case ast::binary_operation_type::multiplication: if constexpr (!std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value * arg2.value)}); } break; case ast::binary_operation_type::division: if constexpr (!std::is_same_v) { if constexpr (std::is_integral_v) { if (arg2.value == static_cast(0)) throw std::runtime_error("Division by zero"); } return primitive_value(primitive_value_base{static_cast(arg1.value / arg2.value)}); } break; case ast::binary_operation_type::remainder: if constexpr (!std::is_same_v && std::is_integral_v) { if constexpr (std::is_integral_v) { if (arg2.value == static_cast(0)) throw std::runtime_error("Division by zero"); } return primitive_value(primitive_value_base{static_cast(arg1.value % arg2.value)}); } break; case ast::binary_operation_type::binary_and: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value & arg2.value)}); } break; case ast::binary_operation_type::logical_and: throw internal_error("logical_and must be handled separately", location); case ast::binary_operation_type::binary_or: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value | arg2.value)}); } break; case ast::binary_operation_type::logical_or: throw internal_error("logical_or must be handled separately", location); case ast::binary_operation_type::logical_xor: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value ^ arg2.value)}); } else if constexpr (std::is_integral_v) { return primitive_value(primitive_value_base{static_cast(arg1.value ^ arg2.value)}); } break; case ast::binary_operation_type::equals: return primitive_value(primitive_value_base{arg1.value == arg2.value}); case ast::binary_operation_type::not_equals: return primitive_value(primitive_value_base{arg1.value != arg2.value}); case ast::binary_operation_type::less: return primitive_value(primitive_value_base{arg1.value < arg2.value}); case ast::binary_operation_type::greater: return primitive_value(primitive_value_base{arg1.value > arg2.value}); case ast::binary_operation_type::less_equals: return primitive_value(primitive_value_base{arg1.value <= arg2.value}); case ast::binary_operation_type::greater_equals: return primitive_value(primitive_value_base{arg1.value >= arg2.value}); } std::ostringstream os; os << "Cannot apply " << type << " to values of type "; types::print(os, type_of(primitive_value(arg1))); os << " and "; types::print(os, type_of(primitive_value(arg2))); throw internal_error(os.str(), location); } template value binary_operation_impl_same_type(ast::binary_operation_type type, Value const & arg1, value const & arg2, ast::location const & location) { std::ostringstream os; os << "Cannot apply " << type << " to values of type "; types::print(os, type_of(arg1)); os << " and "; types::print(os, type_of(arg2)); throw internal_error(os.str(), location); } value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, value const & arg2, ast::location const & location) { return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2, location); }, arg1); } template value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value_base const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) { switch (type) { case ast::binary_operation_type::logical_and: if constexpr (std::is_same_v) { if (!arg1.value) return primitive_value(primitive_value_base{false}); value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); } else if constexpr (std::is_integral_v) { if (arg1.value == T{}) return primitive_value(primitive_value_base{arg1.value}); value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); } break; case ast::binary_operation_type::logical_or: if constexpr (std::is_same_v) { if (arg1.value) return primitive_value(primitive_value_base{true}); value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); } else if constexpr (std::is_integral_v) { if (arg1.value == ~T{}) return primitive_value(primitive_value_base{arg1.value}); value const & arg2_generic = lazy_arg2(); primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); } break; default: throw internal_error("invalid operator type in short-circuiting branch", location); } std::ostringstream os; os << "Cannot apply " << type << " to values of type "; types::print(os, type_of(primitive_value(arg1))); os << " and "; types::print(os, type_of(lazy_arg2())); throw internal_error(os.str(), location); } template value short_circuiting_impl_same_type(ast::binary_operation_type type, Value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) { std::ostringstream os; os << "Cannot apply " << type << " to values of type "; types::print(os, type_of(arg1)); os << " and "; types::print(os, type_of(lazy_arg2())); throw internal_error(os.str(), location); } template value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) { return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(type, value, lazy_arg2, location); }, arg1); } value eval_impl(context & context, ast::binary_operation const & binary_operation) { auto type1 = ast::get_type(*binary_operation.arg1); auto type2 = ast::get_type(*binary_operation.arg2); bool const is_same_type = types::equal(*type1, *type2); if (requires_same_argument_type(binary_operation.type) && !is_same_type) { std::ostringstream os; os << "Cannot apply " << binary_operation.type << " to values of type "; types::print(os, *type1); os << " and "; types::print(os, *type2); throw internal_error(os.str(), binary_operation.location); } auto arg1 = eval_impl(context, binary_operation.arg1); if (is_short_circuiting(binary_operation.type)) { auto lazy_arg2 = [&]{ return eval_impl(context, binary_operation.arg2); }; return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(binary_operation.type, value, lazy_arg2, binary_operation.location); }, arg1); } if (is_same_type) { auto arg2 = eval_impl(context, binary_operation.arg2); return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2, binary_operation.location); }, arg1); } throw internal_error("eval(binary_operation) for different argument types not implemented", binary_operation.location); } value cast_impl(unit_value const & value, types::type const & type, ast::location const & location) { if (types::equal(type, types::unit_type{})) return value; throw internal_error("Cannot cast unit type to anything", location); } value cast_impl(array_value const & value, types::type const & type, ast::location const & location) { if (types::equal(type, type_of(value))) return value; throw internal_error("Cannot cast array type to anything", location); } template value cast_impl(primitive_value_base const & value, types::unit_type const &, ast::location const & location) { throw internal_error("Cannot cast anything to unit type", location); } template value cast_impl(primitive_value_base const & value, types::primitive_type_base const & type, ast::location const & location) { if constexpr (std::is_same_v) { return primitive_value(value); } else if constexpr (!std::is_same_v && !std::is_same_v) { if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(value.value.repr)}); } else if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{{static_cast(value.value)}}); } else { return primitive_value(primitive_value_base{static_cast(value.value)}); } } std::ostringstream os; os << "Cannot cast value of type "; types::print(os, type_of(primitive_value(value))); os << " to type "; types::print(os, types::primitive_type(type)); throw internal_error(os.str(), location); } template value cast_impl(primitive_value_base const & value, types::primitive_type const & type, ast::location const & location) { return std::visit([&](auto const & type){ return cast_impl(value, type, location); }, type); } template value cast_impl(primitive_value_base const &, types::array_type const &, ast::location const & location) { throw internal_error("Cannot cast anything to array type", location); } template value cast_impl(primitive_value_base const &, types::function_type const &, ast::location const & location) { throw internal_error("Cannot cast anything to function type", location); } template value cast_impl(primitive_value_base const & value, types::type const & type, ast::location const & location) { return std::visit([&](auto const & type){ return cast_impl(value, type, location); }, type); } value cast_impl(primitive_value const & value, types::type const & type, ast::location const & location) { return std::visit([&](auto const & value){ return cast_impl(value, type, location); }, value); } value cast_impl(struct_value const & value, types::type const & type, ast::location const & location) { if (types::equal(type, types::unit_type{})) return value; throw internal_error("Cannot cast struct type to anything", location); } value cast_impl(function_value const &, types::type const &, ast::location const & location) { throw internal_error("Cannot cast function type to anything", location); } value cast_impl(foreign_function_value const &, types::type const &, ast::location const & location) { throw internal_error("Cannot cast function type to anything", location); } value eval_impl(context & context, ast::cast_operation const & cast_operation) { auto arg = eval(context, cast_operation.expression); auto type = get_type(*cast_operation.type); return std::visit([&](auto const & value){ return cast_impl(value, *type, cast_operation.location); }, arg); } value eval_impl(context & context, ast::function_call const & function_call) { if (function_call.function) { auto lvalue = eval(context, function_call.function); auto fvalue = std::get_if(&lvalue); auto ffvalue = std::get_if(&lvalue); auto fcommon = fvalue ? static_cast(fvalue) : ffvalue; if (fcommon) { if (fcommon->arguments.size() != function_call.arguments.size()) { std::ostringstream os; os << "Cannot call function: expected " << fcommon->arguments.size() << " arguments, got " << function_call.arguments.size(); throw internal_error(os.str(), function_call.location); } std::vector args; for (auto const & expression : function_call.arguments) args.push_back(eval(context, expression)); for (std::size_t i = 0; i < args.size(); ++i) { auto actual_type = type_of(args[i]); if (!types::equal(actual_type, *fcommon->arguments[i].type)) { std::ostringstream os; os << "Cannot call function: argument #" << (i + 1) << " expects type "; types::print(os, *fcommon->arguments[i].type); os << " but actual type is "; types::print(os, actual_type); throw internal_error(os.str(), ast::get_location(*function_call.arguments[i])); } } if (fvalue) { auto & function_scope = context.frame_stack.emplace_back(); for (std::size_t i = 0; i < args.size(); ++i) function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; function_scope.expected_return_type = fvalue->return_type; exec(context, fvalue->statements); auto result = std::move(context.frame_stack.back().return_value); context.frame_stack.pop_back(); return result; } else if (ffvalue) { return exec_foreign(context, ffvalue->pointer, *fcommon->return_type, std::move(args)); } } std::ostringstream os; os << "Cannot call "; print(os, lvalue); os << ": not a function"; throw internal_error(os.str(), function_call.location); } else if (function_call.type) { auto type = get_type(*function_call.type); if (types::is_builtin_type(*type)) { if (function_call.arguments.empty()) return zero_value(context, *type); std::ostringstream os; os << "Cannot create built-in type "; types::print(os, *type); os << ": expected 0 arguments, but got " << function_call.arguments.size(); throw internal_error(os.str(), function_call.location); } else if (auto named_type = std::get_if(type.get())) { auto const & scope = context.frame_stack.at(named_type->level); auto const & data = scope.structs.at(named_type->name); if (function_call.arguments.empty()) return zero_value(context, *type); if (data.fields.size() != function_call.arguments.size()) { std::ostringstream os; os << "Cannot create struct \"" << named_type->name << "\": expected " << data.fields.size() << " arguments, got " << function_call.arguments.size(); throw internal_error(os.str(), function_call.location); } std::vector args; for (auto const & expression : function_call.arguments) args.push_back(eval(context, expression)); std::unordered_map fields; for (std::size_t i = 0; i < args.size(); ++i) fields[data.fields[i].name] = std::make_unique(std::move(args[i])); return struct_value{ .struct_type = type, .fields = std::move(fields), }; } else throw internal_error("Unknown type in constructor", function_call.location); } else throw internal_error("Function call node has neither function nor type", function_call.location); } value eval_impl(context & context, ast::array const & array) { if (array.elements.empty()) throw internal_error("Array node cannot have zero elements", array.location); types::type_ptr element_type; std::vector elements; for (std::size_t i = 0; i < array.elements.size(); ++i) { auto element = std::make_unique(eval(context, array.elements[i])); if (i == 0) element_type = std::make_unique(type_of(*element)); else { auto new_type = type_of(*element); if (!types::equal(*element_type, new_type)) { std::ostringstream os; os << "Error forming array: inferred element type is "; types::print(os, *element_type); os << " but element #" << i << " type is "; types::print(os, new_type); throw internal_error(os.str(), array.location); } } elements.push_back(std::move(element)); } return array_value{.element_type = std::move(element_type), .elements = std::move(elements)}; } value eval_impl(context & context, ast::array_access const & array_access) { auto array = eval(context, array_access.array); auto index = eval(context, array_access.index); if (auto avalue = std::get_if(&array)) return *avalue->elements[get_array_index(index, avalue->elements.size(), array_access.location)]; std::ostringstream os; os << "Cannot index into a non-array of type "; types::print(os, type_of(array)); throw internal_error(os.str(), array_access.location); } value eval_impl(context & context, ast::field_access const & field_access) { auto object = eval(context, field_access.object); if (auto value = std::get_if(&object)) { if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) return *it->second; std::ostringstream os; os << "Struct "; types::print(os, type_of(object)); os << " has no field named \"" << field_access.field_name << "\""; throw internal_error(os.str(), field_access.location); } std::ostringstream os; os << "Value of type "; types::print(os, type_of(object)); os << " is not a struct"; throw internal_error(os.str(), field_access.location); } value eval_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); } value * eval_ref_impl(context & context, ast::literal const & literal) { throw internal_error("Literal cannot be on the left-hand-side of assignment", ast::get_location(literal)); } value * eval_ref_impl(context & context, ast::identifier const & identifier) { if (identifier.level >= context.frame_stack.size()) throw internal_error("Bad identifier level", identifier.location); auto & scope = context.frame_stack[identifier.level]; auto it = scope.variables.find(identifier.name); if (it == scope.variables.end()) throw internal_error("Identifier \"" + identifier.name + "\" is not defined", identifier.location); return &it->second.value; } value * eval_ref_impl(context & context, ast::unary_operation const & unary_operation) { throw internal_error("Unary operation cannot be on the left-hand-side of assignment", unary_operation.location); } value * eval_ref_impl(context & context, ast::binary_operation const & binary_operation) { throw internal_error("Binary operation cannot be on the left-hand-side of assignment", binary_operation.location); } value * eval_ref_impl(context & context, ast::cast_operation const & cast_operation) { throw internal_error("Cast operation cannot be on the left-hand-side of assignment", cast_operation.location); } value * eval_ref_impl(context & context, ast::function_call const & function_call) { throw internal_error("Function call cannot be on the left-hand-side of assignment", function_call.location); } value * eval_ref_impl(context & context, ast::array const & array) { throw internal_error("Array cannot be on the left-hand-side of assignment", array.location); } value * eval_ref_impl(context & context, ast::array_access const & array_access) { auto index = eval(context, array_access.index); auto array_ref = eval_ref(context, array_access.array); if (auto avalue = std::get_if(array_ref)) { return avalue->elements[get_array_index(index, avalue->elements.size(), ast::get_location(*array_access.index))].get(); } std::ostringstream os; os << "Cannot index into a non-array of type "; types::print(os, type_of(*array_ref)); throw internal_error(os.str(), array_access.location); } value * eval_ref_impl(context & context, ast::field_access const & field_access) { auto object_ref = eval_ref(context, field_access.object); if (auto value = std::get_if(object_ref)) { if (auto it = value->fields.find(field_access.field_name); it != value->fields.end()) return it->second.get(); std::ostringstream os; os << "Struct "; types::print(os, type_of(*object_ref)); os << " has no field named \"" << field_access.field_name << "\""; throw internal_error(os.str(), field_access.location); } std::ostringstream os; os << "Value of type "; types::print(os, type_of(*object_ref)); os << " is not a struct"; throw internal_error(os.str(), field_access.location); } value * eval_ref_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression); } } value eval(context & context, ast::expression_ptr const & expression) { return eval_impl(context, expression); } value * eval_ref(context & context, ast::expression_ptr const & expression) { return eval_ref_impl(context, expression); } }