Add inferred type to AST nodes & implement type checking & inference
This commit is contained in:
parent
97fb066e38
commit
a36ba2610b
15 changed files with 891 additions and 0 deletions
|
|
@ -111,6 +111,7 @@ int main(int argc, char ** argv)
|
|||
filenames.push_back(argv[arg]);
|
||||
auto ast = parser::parse(filenames.back());
|
||||
ast::resolve_identifiers(ast);
|
||||
ast::check_and_infer_types(ast);
|
||||
parsed.push_back(std::move(ast));
|
||||
}
|
||||
catch (pslang::ast::parse_error const & error)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -12,6 +13,7 @@ namespace pslang::ast
|
|||
{
|
||||
std::vector<expression_ptr> elements;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
struct array_access
|
||||
|
|
@ -19,6 +21,7 @@ namespace pslang::ast
|
|||
expression_ptr array;
|
||||
expression_ptr index;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/ast/type.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
|
@ -12,6 +13,7 @@ namespace pslang::ast
|
|||
expression_ptr expression;
|
||||
type_ptr type;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include <pslang/ast/array.hpp>
|
||||
#include <pslang/ast/struct.hpp>
|
||||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
|
@ -18,6 +19,7 @@ namespace pslang::ast
|
|||
unary_operation_type type;
|
||||
expression_ptr arg1;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
struct binary_operation
|
||||
|
|
@ -26,6 +28,7 @@ namespace pslang::ast
|
|||
expression_ptr arg1;
|
||||
expression_ptr arg2;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
using expression_impl = std::variant<
|
||||
|
|
@ -47,5 +50,6 @@ namespace pslang::ast
|
|||
};
|
||||
|
||||
location get_location(expression const & expression);
|
||||
types::type_ptr get_type(expression const & expression);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/ast/type_fwd.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -32,6 +33,7 @@ namespace pslang::ast
|
|||
expression_ptr function;
|
||||
std::vector<expression_ptr> arguments;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
#include <string>
|
||||
|
||||
|
|
@ -12,6 +13,7 @@ namespace pslang::ast
|
|||
std::string name;
|
||||
ast::location location;
|
||||
std::size_t level = 0;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,5 +6,6 @@ namespace pslang::ast
|
|||
{
|
||||
|
||||
void resolve_identifiers(statement_list_ptr & statements);
|
||||
void check_and_infer_types(statement_list_ptr & statements);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/ast/type_fwd.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -29,6 +30,7 @@ namespace pslang::ast
|
|||
expression_ptr object;
|
||||
std::string field_name;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ namespace pslang::ast
|
|||
{
|
||||
type_ptr element_type;
|
||||
std::uint64_t size;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
struct function_type
|
||||
{
|
||||
std::vector<type_ptr> arguments;
|
||||
type_ptr result;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
struct type_identifier
|
||||
|
|
@ -30,6 +32,7 @@ namespace pslang::ast
|
|||
std::string name;
|
||||
ast::location location;
|
||||
std::size_t level = 0;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
using type_impl = std::variant<
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace pslang::ast
|
||||
|
|
@ -9,4 +11,6 @@ namespace pslang::ast
|
|||
|
||||
using type_ptr = std::shared_ptr<type>;
|
||||
|
||||
types::type_ptr get_type(type const & type);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include <pslang/ast/expression.hpp>
|
||||
#include <pslang/ast/expression_visitor.hpp>
|
||||
#include <pslang/types/type.hpp>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
|
@ -59,6 +60,58 @@ namespace pslang::ast
|
|||
}
|
||||
};
|
||||
|
||||
struct get_type_visitor
|
||||
: const_expression_visitor<get_type_visitor>
|
||||
{
|
||||
using const_expression_visitor::apply;
|
||||
|
||||
template <typename T>
|
||||
types::type_ptr apply(primitive_literal_base<T> const & node)
|
||||
{
|
||||
return std::make_unique<types::type>(types::primitive_type{types::primitive_type_base<T>{}});
|
||||
}
|
||||
|
||||
types::type_ptr apply(identifier const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(unary_operation const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(binary_operation const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(cast_operation const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(function_call const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(array const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(array_access const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(field_access const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
location get_location(expression const & expression)
|
||||
|
|
@ -66,4 +119,9 @@ namespace pslang::ast
|
|||
return get_location_visitor{}.apply(expression);
|
||||
}
|
||||
|
||||
types::type_ptr get_type(expression const & expression)
|
||||
{
|
||||
return get_type_visitor{}.apply(expression);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
49
libs/ast/source/type.cpp
Normal file
49
libs/ast/source/type.cpp
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#include <pslang/ast/type.hpp>
|
||||
#include <pslang/ast/type_visitor.hpp>
|
||||
#include <pslang/types/type.hpp>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct get_type_visitor
|
||||
: const_type_visitor<get_type_visitor>
|
||||
{
|
||||
using const_type_visitor::apply;
|
||||
|
||||
types::type_ptr apply(types::unit_type const & type)
|
||||
{
|
||||
return std::make_unique<types::type>(type);
|
||||
}
|
||||
|
||||
types::type_ptr apply(types::primitive_type const & type)
|
||||
{
|
||||
return std::make_unique<types::type>(type);
|
||||
}
|
||||
|
||||
types::type_ptr apply(ast::array_type const & type)
|
||||
{
|
||||
return type.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(ast::function_type const & type)
|
||||
{
|
||||
return type.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(ast::type_identifier const & type)
|
||||
{
|
||||
return type.inferred_type;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
types::type_ptr get_type(type const & type)
|
||||
{
|
||||
return get_type_visitor{}.apply(type);
|
||||
}
|
||||
|
||||
}
|
||||
677
libs/ast/source/type_check.cpp
Normal file
677
libs/ast/source/type_check.cpp
Normal file
|
|
@ -0,0 +1,677 @@
|
|||
#include <pslang/ast/preprocess.hpp>
|
||||
#include <pslang/ast/type_visitor.hpp>
|
||||
#include <pslang/ast/expression_visitor.hpp>
|
||||
#include <pslang/ast/statement_visitor.hpp>
|
||||
#include <pslang/ast/error.hpp>
|
||||
#include <pslang/types/print.hpp>
|
||||
#include <pslang/types/type.hpp>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct variable_data
|
||||
{
|
||||
value_category category;
|
||||
types::type_ptr type;
|
||||
};
|
||||
|
||||
struct function_data
|
||||
{
|
||||
std::vector<types::type_ptr> arguments;
|
||||
types::type_ptr result_type;
|
||||
};
|
||||
|
||||
struct struct_data
|
||||
{
|
||||
struct field_data
|
||||
{
|
||||
std::string name;
|
||||
types::type_ptr type;
|
||||
};
|
||||
|
||||
std::vector<field_data> fields;
|
||||
};
|
||||
|
||||
struct scope
|
||||
{
|
||||
std::unordered_map<std::string, variable_data> variables;
|
||||
std::unordered_map<std::string, function_data> functions;
|
||||
std::unordered_map<std::string, struct_data> structs;
|
||||
|
||||
bool is_function_scope = false;
|
||||
bool is_global_scope = false;
|
||||
|
||||
types::type_ptr expected_return_type = nullptr;
|
||||
};
|
||||
|
||||
struct check_visitor
|
||||
: type_visitor<check_visitor>
|
||||
, expression_visitor<check_visitor>
|
||||
, const_statement_visitor<check_visitor>
|
||||
{
|
||||
std::vector<scope> scopes;
|
||||
|
||||
using type_visitor::apply;
|
||||
using expression_visitor::apply;
|
||||
using const_statement_visitor::apply;
|
||||
|
||||
void apply(types::unit_type const &)
|
||||
{}
|
||||
|
||||
void apply(types::primitive_type const &)
|
||||
{}
|
||||
|
||||
void apply(ast::array_type & node)
|
||||
{
|
||||
apply(*node.element_type);
|
||||
|
||||
auto type = types::array_type{};
|
||||
type.element_type = get_type(*node.element_type);
|
||||
type.size = node.size;
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
|
||||
void apply(ast::function_type & node)
|
||||
{
|
||||
for (auto const & argument : node.arguments)
|
||||
apply(*argument);
|
||||
apply(*node.result);
|
||||
|
||||
auto type = types::function_type{};
|
||||
for (auto const & argument : node.arguments)
|
||||
type.arguments.push_back(get_type(*argument));
|
||||
type.result = get_type(*node.result);
|
||||
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
|
||||
void apply(ast::type_identifier & node)
|
||||
{
|
||||
auto type = types::named_type{};
|
||||
type.name = node.name;
|
||||
type.level = node.level;
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
|
||||
void apply(literal &)
|
||||
{}
|
||||
|
||||
void apply(identifier & node)
|
||||
{
|
||||
if (auto type = types::builtin_type(node.name))
|
||||
{
|
||||
node.inferred_type = type;
|
||||
return;
|
||||
}
|
||||
|
||||
auto & scope = scopes.at(node.level);
|
||||
if (auto it = scope.variables.find(node.name); it != scope.variables.end())
|
||||
{
|
||||
node.inferred_type = it->second.type;
|
||||
}
|
||||
else if (auto it = scope.functions.find(node.name); it != scope.functions.end())
|
||||
{
|
||||
auto type = types::function_type{};
|
||||
for (auto const & argument : it->second.arguments)
|
||||
type.arguments.push_back(argument);
|
||||
type.result = it->second.result_type;
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
}
|
||||
|
||||
void apply(unary_operation & node)
|
||||
{
|
||||
apply(*node.arg1);
|
||||
auto arg1_type = get_type(*node.arg1);
|
||||
|
||||
bool good = false;
|
||||
|
||||
switch (node.type)
|
||||
{
|
||||
case unary_operation_type::negation:
|
||||
if (types::is_integer_type(*arg1_type) || types::is_floating_point_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case unary_operation_type::logical_not:
|
||||
if (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Cannot apply " << node.type << " to a value of type ";
|
||||
types::print(os, *arg1_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
void apply(binary_operation & node)
|
||||
{
|
||||
apply(*node.arg1);
|
||||
apply(*node.arg2);
|
||||
|
||||
auto arg1_type = get_type(*node.arg1);
|
||||
auto arg2_type = get_type(*node.arg2);
|
||||
|
||||
bool equal = types::equal(*arg1_type, *arg2_type);
|
||||
|
||||
switch (node.type)
|
||||
{
|
||||
case binary_operation_type::addition:
|
||||
if (equal && types::is_numeric_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::subtraction:
|
||||
if (equal && types::is_numeric_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::multiplication:
|
||||
if (equal && types::is_numeric_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::division:
|
||||
if (equal && types::is_numeric_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::remainder:
|
||||
if (equal && types::is_integer_type(*arg1_type))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::logical_and:
|
||||
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::logical_or:
|
||||
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::logical_xor:
|
||||
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
|
||||
{
|
||||
node.inferred_type = arg1_type;
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::equals:
|
||||
if (equal)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::not_equals:
|
||||
if (equal)
|
||||
return;
|
||||
break;
|
||||
case binary_operation_type::less:
|
||||
if (equal)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::greater:
|
||||
if (equal)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::less_equals:
|
||||
if (equal)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::greater_equals:
|
||||
if (equal)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Cannot apply " << node.type << " to values of types ";
|
||||
types::print(os, *arg1_type);
|
||||
os << " and ";
|
||||
types::print(os, *arg2_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
void apply(cast_operation & node)
|
||||
{
|
||||
apply(*node.expression);
|
||||
apply(*node.type);
|
||||
|
||||
auto source_type = get_type(*node.expression);
|
||||
auto target_type = get_type(*node.type);
|
||||
|
||||
node.inferred_type = target_type;
|
||||
|
||||
if (types::equal(*source_type, *target_type))
|
||||
return;
|
||||
|
||||
if (std::get_if<types::primitive_type>(target_type.get()))
|
||||
if (!types::is_bool_type(*source_type) && !types::is_bool_type(*target_type))
|
||||
return;
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Cannot cast a value of type ";
|
||||
types::print(os, *source_type);
|
||||
os << " to type ";
|
||||
types::print(os, *target_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
void apply(function_call & node)
|
||||
{
|
||||
apply(*node.function);
|
||||
for (auto const & argument : node.arguments)
|
||||
apply(*argument);
|
||||
|
||||
std::string function_name;
|
||||
|
||||
if (auto identifier = std::get_if<ast::identifier>(node.function.get()))
|
||||
{
|
||||
if (types::type_ptr type = types::builtin_type(identifier->name))
|
||||
{
|
||||
if (node.arguments.empty())
|
||||
{
|
||||
node.inferred_type = type;
|
||||
return;
|
||||
}
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Cannot create built-in type ";
|
||||
types::print(os, *type);
|
||||
os << ": expected 0 arguments, but got " << node.arguments.size();
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
|
||||
auto & scope = scopes.at(identifier->level);
|
||||
if (auto it = scope.structs.find(identifier->name); it != scope.structs.end())
|
||||
{
|
||||
if (!node.arguments.empty())
|
||||
{
|
||||
if (node.arguments.size() != it->second.fields.size())
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot create struct " << identifier->name << ": expected " << it->second.fields.size() << " arguments, but got " << node.arguments.size();
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < node.arguments.size(); ++i)
|
||||
{
|
||||
auto arg_type = get_type(*node.arguments[i]);
|
||||
if (!types::equal(*arg_type, *it->second.fields[i].type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot create struct " << identifier->name << ": argument #" << i << " expected to have type ";
|
||||
types::print(os, *it->second.fields[i].type);
|
||||
os << " but got type ";
|
||||
types::print(os, *arg_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
types::named_type type;
|
||||
type.name = identifier->name;
|
||||
type.level = identifier->level;
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
return;
|
||||
}
|
||||
|
||||
function_name = identifier->name + " ";
|
||||
}
|
||||
|
||||
auto function_type = get_type(*node.function);
|
||||
|
||||
auto ftype = std::get_if<types::function_type>(function_type.get());
|
||||
if (!ftype)
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot call a value of a non-function type ";
|
||||
types::print(os, *function_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
if (ftype->arguments.size() != node.arguments.size())
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot call function " << function_name << " of type ";
|
||||
types::print(os, *function_type);
|
||||
os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size();
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < node.arguments.size(); ++i)
|
||||
{
|
||||
auto arg_type = get_type(*node.arguments[i]);
|
||||
if (!types::equal(*arg_type, *ftype->arguments[i]))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot call function " << function_name << " of type ";
|
||||
types::print(os, *function_type);
|
||||
os << ": argument #" << i << " expected to have type ";
|
||||
types::print(os, *ftype->arguments[i]);
|
||||
os << " but got type ";
|
||||
types::print(os, *arg_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
}
|
||||
|
||||
node.inferred_type = ftype->result;
|
||||
}
|
||||
|
||||
void apply(array & node)
|
||||
{
|
||||
if (node.elements.empty())
|
||||
throw invalid_ast_error("Empty array", node.location);
|
||||
|
||||
for (auto const & element : node.elements)
|
||||
apply(*element);
|
||||
|
||||
types::type_ptr element_type = nullptr;
|
||||
for (std::size_t i = 0; i < node.elements.size(); ++i)
|
||||
{
|
||||
auto current_type = get_type(*node.elements[0]);
|
||||
|
||||
if (i == 0)
|
||||
{
|
||||
element_type = current_type;
|
||||
}
|
||||
else if (!types::equal(*element_type, *current_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Failed to infer array type: element #0 has type ";
|
||||
types::print(os, *element_type);
|
||||
os << " but element #" << i << " has type ";
|
||||
types::print(os, *current_type);
|
||||
}
|
||||
}
|
||||
|
||||
types::array_type type;
|
||||
type.element_type = element_type;
|
||||
type.size = node.elements.size();
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
|
||||
void apply(array_access & node)
|
||||
{
|
||||
apply(*node.array);
|
||||
apply(*node.index);
|
||||
|
||||
auto array_type = get_type(*node.array);
|
||||
auto index_type = get_type(*node.index);
|
||||
|
||||
if (!types::is_integer_type(*index_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Expected an integer type as index, but got ";
|
||||
types::print(os, *index_type);
|
||||
throw type_error(os.str(), get_location(*node.index));
|
||||
}
|
||||
|
||||
auto atype = std::get_if<types::array_type>(array_type.get());
|
||||
|
||||
if (!atype)
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Expected an array to index, but got ";
|
||||
types::print(os, *array_type);
|
||||
throw type_error(os.str(), get_location(*node.array));
|
||||
}
|
||||
|
||||
node.inferred_type = atype->element_type;
|
||||
}
|
||||
|
||||
void apply(field_access & node)
|
||||
{
|
||||
apply(*node.object);
|
||||
|
||||
auto object_type = get_type(*node.object);
|
||||
auto named_type = std::get_if<types::named_type>(object_type.get());
|
||||
|
||||
if (!named_type)
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Expected a struct, but got ";
|
||||
types::print(os, *object_type);
|
||||
throw type_error(os.str(), get_location(*node.object));
|
||||
}
|
||||
|
||||
auto const & struct_data = scopes.at(named_type->level).structs.at(named_type->name);
|
||||
|
||||
for (auto const & field : struct_data.fields)
|
||||
{
|
||||
if (field.name == node.field_name)
|
||||
{
|
||||
node.inferred_type = field.type;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream os;
|
||||
os << "Struct \"" << named_type->name << "\" has no field named \"" << node.field_name << "\"";
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
void apply(expression_ptr const & node)
|
||||
{
|
||||
apply(*node);
|
||||
}
|
||||
|
||||
void apply(assignment const & node)
|
||||
{
|
||||
apply(node.lhs);
|
||||
apply(node.rhs);
|
||||
auto ltype = get_type(*node.lhs);
|
||||
auto rtype = get_type(*node.rhs);
|
||||
// TODO: check lvalue
|
||||
if (!types::equal(*ltype, *rtype))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot assign a value of type ";
|
||||
types::print(os, *rtype);
|
||||
os << " to an expression of type ";
|
||||
types::print(os, *ltype);
|
||||
throw type_error(os.str(), node.location);
|
||||
};
|
||||
}
|
||||
|
||||
void apply(variable_declaration const & node)
|
||||
{
|
||||
apply(node.initializer);
|
||||
auto actual_type = get_type(*node.initializer);
|
||||
if (node.type)
|
||||
{
|
||||
apply(*node.type);
|
||||
auto expected_type = get_type(*node.type);
|
||||
if (!types::equal(*expected_type, *actual_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Cannot initialize a variable of type ";
|
||||
types::print(os, *expected_type);
|
||||
os << " with an expression of type ";
|
||||
types::print(os, *actual_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
}
|
||||
|
||||
scopes.back().variables[node.name] = {
|
||||
.category = node.category,
|
||||
.type = actual_type,
|
||||
};
|
||||
}
|
||||
|
||||
void apply(if_block const & node)
|
||||
{
|
||||
throw invalid_ast_error("if blocks cannot be present in the final AST", node.location);
|
||||
}
|
||||
|
||||
void apply(else_if_block const & node)
|
||||
{
|
||||
throw invalid_ast_error("else if blocks cannot be present in the final AST", node.location);
|
||||
}
|
||||
|
||||
void apply(else_block const & node)
|
||||
{
|
||||
throw invalid_ast_error("else blocks cannot be present in the final AST", node.location);
|
||||
}
|
||||
|
||||
void apply(if_chain const & node)
|
||||
{
|
||||
for (auto const & block : node.blocks)
|
||||
{
|
||||
if (block.condition)
|
||||
{
|
||||
apply(block.condition);
|
||||
auto actual_type = get_type(*block.condition);
|
||||
if (!types::is_bool_type(*actual_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "if condition expects a bool type, but got ";
|
||||
types::print(os, *actual_type);
|
||||
throw type_error(os.str(), get_location(*block.condition));
|
||||
}
|
||||
}
|
||||
|
||||
scopes.emplace_back();
|
||||
apply(*block.statements);
|
||||
scopes.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
void apply(while_block const & node)
|
||||
{
|
||||
apply(node.condition);
|
||||
auto actual_type = get_type(*node.condition);
|
||||
if (!types::is_bool_type(*actual_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "while condition expects a bool type, but got ";
|
||||
types::print(os, *actual_type);
|
||||
throw type_error(os.str(), get_location(*node.condition));
|
||||
}
|
||||
|
||||
scopes.emplace_back();
|
||||
apply(*node.statements);
|
||||
scopes.pop_back();
|
||||
}
|
||||
|
||||
void apply(function_definition const & node)
|
||||
{
|
||||
for (auto const & argument : node.arguments)
|
||||
apply(*argument.type);
|
||||
apply(*node.return_type);
|
||||
|
||||
auto & data = scopes.back().functions[node.name];
|
||||
for (auto const & argument : node.arguments)
|
||||
data.arguments.push_back(get_type(*argument.type));
|
||||
data.result_type = get_type(*node.return_type);
|
||||
|
||||
scopes.emplace_back().is_function_scope = true;
|
||||
scopes.back().expected_return_type = get_type(*node.return_type);
|
||||
|
||||
for (auto const & argument : node.arguments)
|
||||
{
|
||||
scopes.back().variables[argument.name] = {
|
||||
.category = value_category::constant,
|
||||
.type = get_type(*argument.type),
|
||||
};
|
||||
}
|
||||
|
||||
apply(*node.statements);
|
||||
scopes.pop_back();
|
||||
}
|
||||
|
||||
void apply(return_statement const & node)
|
||||
{
|
||||
types::type_ptr actual_type;
|
||||
if (node.value)
|
||||
{
|
||||
apply(node.value);
|
||||
actual_type = get_type(*node.value);
|
||||
}
|
||||
else
|
||||
{
|
||||
actual_type = std::make_unique<types::type>(types::unit_type{});
|
||||
}
|
||||
|
||||
auto & return_scope = scopes.at(node.level);
|
||||
if (!return_scope.expected_return_type)
|
||||
throw invalid_ast_error("Unexpected return level", node.location);
|
||||
|
||||
if (!types::equal(*return_scope.expected_return_type, *actual_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Returning value of type ";
|
||||
types::print(os, *actual_type);
|
||||
os << " from a function returning ";
|
||||
types::print(os, *return_scope.expected_return_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
}
|
||||
|
||||
void apply(field_definition const & node)
|
||||
{
|
||||
apply(*node.type);
|
||||
}
|
||||
|
||||
void apply(struct_definition const & node)
|
||||
{
|
||||
for (auto const & field : node.fields)
|
||||
apply(field);
|
||||
|
||||
auto & data = scopes.back().structs[node.name];
|
||||
for (auto const & field : node.fields)
|
||||
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
void check_and_infer_types(statement_list_ptr & statements)
|
||||
{
|
||||
check_visitor visitor;
|
||||
visitor.scopes.emplace_back().is_global_scope = true;
|
||||
visitor.apply(*statements);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -14,4 +14,12 @@ namespace pslang::types
|
|||
|
||||
type_ptr builtin_type(std::string const & name);
|
||||
|
||||
bool is_unit_type(type const & type);
|
||||
bool is_bool_type(type const & type);
|
||||
bool is_integer_type(type const & type);
|
||||
bool is_signed_integer_type(type const & type);
|
||||
bool is_unsigned_integer_type(type const & type);
|
||||
bool is_floating_point_type(type const & type);
|
||||
bool is_numeric_type(type const & type);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,4 +38,79 @@ namespace pslang::types
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
bool is_unit_type(type const & type)
|
||||
{
|
||||
return equal(type, unit_type{});
|
||||
}
|
||||
|
||||
bool is_bool_type(type const & type)
|
||||
{
|
||||
return equal(type, bool_type{});
|
||||
}
|
||||
|
||||
bool is_integer_type(type const & type)
|
||||
{
|
||||
if (auto ptype = std::get_if<primitive_type>(&type))
|
||||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return std::is_integral_v<T> && !std::is_same_v<T, bool>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_signed_integer_type(type const & type)
|
||||
{
|
||||
if (auto ptype = std::get_if<primitive_type>(&type))
|
||||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return std::is_integral_v<T> && std::is_signed_v<T> && !std::is_same_v<T, bool>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_unsigned_integer_type(type const & type)
|
||||
{
|
||||
if (auto ptype = std::get_if<primitive_type>(&type))
|
||||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return std::is_integral_v<T> && std::is_unsigned_v<T> && !std::is_same_v<T, bool>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_floating_point_type(type const & type)
|
||||
{
|
||||
if (auto ptype = std::get_if<primitive_type>(&type))
|
||||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return std::is_floating_point_v<T>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_numeric_type(type const & type)
|
||||
{
|
||||
if (auto ptype = std::get_if<primitive_type>(&type))
|
||||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return !std::is_same_v<T, bool>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue