Add inferred type to AST nodes & implement type checking & inference

This commit is contained in:
Nikita Lisitsa 2025-12-23 12:17:42 +03:00
parent 97fb066e38
commit a36ba2610b
15 changed files with 891 additions and 0 deletions

View file

@ -111,6 +111,7 @@ int main(int argc, char ** argv)
filenames.push_back(argv[arg]);
auto ast = parser::parse(filenames.back());
ast::resolve_identifiers(ast);
ast::check_and_infer_types(ast);
parsed.push_back(std::move(ast));
}
catch (pslang::ast::parse_error const & error)

View file

@ -2,6 +2,7 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/types/type_fwd.hpp>
#include <vector>
@ -12,6 +13,7 @@ namespace pslang::ast
{
std::vector<expression_ptr> elements;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
struct array_access
@ -19,6 +21,7 @@ namespace pslang::ast
expression_ptr array;
expression_ptr index;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
}

View file

@ -3,6 +3,7 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/ast/type.hpp>
#include <pslang/types/type_fwd.hpp>
namespace pslang::ast
{
@ -12,6 +13,7 @@ namespace pslang::ast
expression_ptr expression;
type_ptr type;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
}

View file

@ -9,6 +9,7 @@
#include <pslang/ast/array.hpp>
#include <pslang/ast/struct.hpp>
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/types/type_fwd.hpp>
namespace pslang::ast
{
@ -18,6 +19,7 @@ namespace pslang::ast
unary_operation_type type;
expression_ptr arg1;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
struct binary_operation
@ -26,6 +28,7 @@ namespace pslang::ast
expression_ptr arg1;
expression_ptr arg2;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
using expression_impl = std::variant<
@ -47,5 +50,6 @@ namespace pslang::ast
};
location get_location(expression const & expression);
types::type_ptr get_type(expression const & expression);
}

View file

@ -4,6 +4,7 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/ast/type_fwd.hpp>
#include <pslang/types/type_fwd.hpp>
#include <string>
#include <vector>
@ -32,6 +33,7 @@ namespace pslang::ast
expression_ptr function;
std::vector<expression_ptr> arguments;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
}

View file

@ -1,6 +1,7 @@
#pragma once
#include <pslang/ast/location.hpp>
#include <pslang/types/type_fwd.hpp>
#include <string>
@ -12,6 +13,7 @@ namespace pslang::ast
std::string name;
ast::location location;
std::size_t level = 0;
types::type_ptr inferred_type = nullptr;
};
}

View file

@ -6,5 +6,6 @@ namespace pslang::ast
{
void resolve_identifiers(statement_list_ptr & statements);
void check_and_infer_types(statement_list_ptr & statements);
}

View file

@ -3,6 +3,7 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/type_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/types/type_fwd.hpp>
#include <string>
#include <vector>
@ -29,6 +30,7 @@ namespace pslang::ast
expression_ptr object;
std::string field_name;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
}

View file

@ -17,12 +17,14 @@ namespace pslang::ast
{
type_ptr element_type;
std::uint64_t size;
types::type_ptr inferred_type = nullptr;
};
struct function_type
{
std::vector<type_ptr> arguments;
type_ptr result;
types::type_ptr inferred_type = nullptr;
};
struct type_identifier
@ -30,6 +32,7 @@ namespace pslang::ast
std::string name;
ast::location location;
std::size_t level = 0;
types::type_ptr inferred_type = nullptr;
};
using type_impl = std::variant<

View file

@ -1,5 +1,7 @@
#pragma once
#include <pslang/types/type_fwd.hpp>
#include <memory>
namespace pslang::ast
@ -9,4 +11,6 @@ namespace pslang::ast
using type_ptr = std::shared_ptr<type>;
types::type_ptr get_type(type const & type);
}

View file

@ -1,5 +1,6 @@
#include <pslang/ast/expression.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/types/type.hpp>
namespace pslang::ast
{
@ -59,6 +60,58 @@ namespace pslang::ast
}
};
struct get_type_visitor
: const_expression_visitor<get_type_visitor>
{
using const_expression_visitor::apply;
template <typename T>
types::type_ptr apply(primitive_literal_base<T> const & node)
{
return std::make_unique<types::type>(types::primitive_type{types::primitive_type_base<T>{}});
}
types::type_ptr apply(identifier const & node)
{
return node.inferred_type;
}
types::type_ptr apply(unary_operation const & node)
{
return node.inferred_type;
}
types::type_ptr apply(binary_operation const & node)
{
return node.inferred_type;
}
types::type_ptr apply(cast_operation const & node)
{
return node.inferred_type;
}
types::type_ptr apply(function_call const & node)
{
return node.inferred_type;
}
types::type_ptr apply(array const & node)
{
return node.inferred_type;
}
types::type_ptr apply(array_access const & node)
{
return node.inferred_type;
}
types::type_ptr apply(field_access const & node)
{
return node.inferred_type;
}
};
}
location get_location(expression const & expression)
@ -66,4 +119,9 @@ namespace pslang::ast
return get_location_visitor{}.apply(expression);
}
types::type_ptr get_type(expression const & expression)
{
return get_type_visitor{}.apply(expression);
}
}

49
libs/ast/source/type.cpp Normal file
View file

@ -0,0 +1,49 @@
#include <pslang/ast/type.hpp>
#include <pslang/ast/type_visitor.hpp>
#include <pslang/types/type.hpp>
namespace pslang::ast
{
namespace
{
struct get_type_visitor
: const_type_visitor<get_type_visitor>
{
using const_type_visitor::apply;
types::type_ptr apply(types::unit_type const & type)
{
return std::make_unique<types::type>(type);
}
types::type_ptr apply(types::primitive_type const & type)
{
return std::make_unique<types::type>(type);
}
types::type_ptr apply(ast::array_type const & type)
{
return type.inferred_type;
}
types::type_ptr apply(ast::function_type const & type)
{
return type.inferred_type;
}
types::type_ptr apply(ast::type_identifier const & type)
{
return type.inferred_type;
}
};
}
types::type_ptr get_type(type const & type)
{
return get_type_visitor{}.apply(type);
}
}

View file

@ -0,0 +1,677 @@
#include <pslang/ast/preprocess.hpp>
#include <pslang/ast/type_visitor.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/error.hpp>
#include <pslang/types/print.hpp>
#include <pslang/types/type.hpp>
#include <unordered_map>
#include <sstream>
namespace pslang::ast
{
namespace
{
struct variable_data
{
value_category category;
types::type_ptr type;
};
struct function_data
{
std::vector<types::type_ptr> arguments;
types::type_ptr result_type;
};
struct struct_data
{
struct field_data
{
std::string name;
types::type_ptr type;
};
std::vector<field_data> fields;
};
struct scope
{
std::unordered_map<std::string, variable_data> variables;
std::unordered_map<std::string, function_data> functions;
std::unordered_map<std::string, struct_data> structs;
bool is_function_scope = false;
bool is_global_scope = false;
types::type_ptr expected_return_type = nullptr;
};
struct check_visitor
: type_visitor<check_visitor>
, expression_visitor<check_visitor>
, const_statement_visitor<check_visitor>
{
std::vector<scope> scopes;
using type_visitor::apply;
using expression_visitor::apply;
using const_statement_visitor::apply;
void apply(types::unit_type const &)
{}
void apply(types::primitive_type const &)
{}
void apply(ast::array_type & node)
{
apply(*node.element_type);
auto type = types::array_type{};
type.element_type = get_type(*node.element_type);
type.size = node.size;
node.inferred_type = std::make_unique<types::type>(std::move(type));
}
void apply(ast::function_type & node)
{
for (auto const & argument : node.arguments)
apply(*argument);
apply(*node.result);
auto type = types::function_type{};
for (auto const & argument : node.arguments)
type.arguments.push_back(get_type(*argument));
type.result = get_type(*node.result);
node.inferred_type = std::make_unique<types::type>(std::move(type));
}
void apply(ast::type_identifier & node)
{
auto type = types::named_type{};
type.name = node.name;
type.level = node.level;
node.inferred_type = std::make_unique<types::type>(std::move(type));
}
void apply(literal &)
{}
void apply(identifier & node)
{
if (auto type = types::builtin_type(node.name))
{
node.inferred_type = type;
return;
}
auto & scope = scopes.at(node.level);
if (auto it = scope.variables.find(node.name); it != scope.variables.end())
{
node.inferred_type = it->second.type;
}
else if (auto it = scope.functions.find(node.name); it != scope.functions.end())
{
auto type = types::function_type{};
for (auto const & argument : it->second.arguments)
type.arguments.push_back(argument);
type.result = it->second.result_type;
node.inferred_type = std::make_unique<types::type>(std::move(type));
}
}
void apply(unary_operation & node)
{
apply(*node.arg1);
auto arg1_type = get_type(*node.arg1);
bool good = false;
switch (node.type)
{
case unary_operation_type::negation:
if (types::is_integer_type(*arg1_type) || types::is_floating_point_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case unary_operation_type::logical_not:
if (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
}
std::ostringstream os;
os << "Cannot apply " << node.type << " to a value of type ";
types::print(os, *arg1_type);
throw type_error(os.str(), node.location);
}
void apply(binary_operation & node)
{
apply(*node.arg1);
apply(*node.arg2);
auto arg1_type = get_type(*node.arg1);
auto arg2_type = get_type(*node.arg2);
bool equal = types::equal(*arg1_type, *arg2_type);
switch (node.type)
{
case binary_operation_type::addition:
if (equal && types::is_numeric_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::subtraction:
if (equal && types::is_numeric_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::multiplication:
if (equal && types::is_numeric_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::division:
if (equal && types::is_numeric_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::remainder:
if (equal && types::is_integer_type(*arg1_type))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_and:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_or:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_xor:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::equals:
if (equal)
{
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return;
}
break;
case binary_operation_type::not_equals:
if (equal)
return;
break;
case binary_operation_type::less:
if (equal)
{
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return;
}
break;
case binary_operation_type::greater:
if (equal)
{
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return;
}
break;
case binary_operation_type::less_equals:
if (equal)
{
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return;
}
break;
case binary_operation_type::greater_equals:
if (equal)
{
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return;
}
break;
}
std::ostringstream os;
os << "Cannot apply " << node.type << " to values of types ";
types::print(os, *arg1_type);
os << " and ";
types::print(os, *arg2_type);
throw type_error(os.str(), node.location);
}
void apply(cast_operation & node)
{
apply(*node.expression);
apply(*node.type);
auto source_type = get_type(*node.expression);
auto target_type = get_type(*node.type);
node.inferred_type = target_type;
if (types::equal(*source_type, *target_type))
return;
if (std::get_if<types::primitive_type>(target_type.get()))
if (!types::is_bool_type(*source_type) && !types::is_bool_type(*target_type))
return;
std::ostringstream os;
os << "Cannot cast a value of type ";
types::print(os, *source_type);
os << " to type ";
types::print(os, *target_type);
throw type_error(os.str(), node.location);
}
void apply(function_call & node)
{
apply(*node.function);
for (auto const & argument : node.arguments)
apply(*argument);
std::string function_name;
if (auto identifier = std::get_if<ast::identifier>(node.function.get()))
{
if (types::type_ptr type = types::builtin_type(identifier->name))
{
if (node.arguments.empty())
{
node.inferred_type = type;
return;
}
std::ostringstream os;
os << "Cannot create built-in type ";
types::print(os, *type);
os << ": expected 0 arguments, but got " << node.arguments.size();
throw std::runtime_error(os.str());
}
auto & scope = scopes.at(identifier->level);
if (auto it = scope.structs.find(identifier->name); it != scope.structs.end())
{
if (!node.arguments.empty())
{
if (node.arguments.size() != it->second.fields.size())
{
std::ostringstream os;
os << "Cannot create struct " << identifier->name << ": expected " << it->second.fields.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto arg_type = get_type(*node.arguments[i]);
if (!types::equal(*arg_type, *it->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct " << identifier->name << ": argument #" << i << " expected to have type ";
types::print(os, *it->second.fields[i].type);
os << " but got type ";
types::print(os, *arg_type);
throw type_error(os.str(), node.location);
}
}
}
types::named_type type;
type.name = identifier->name;
type.level = identifier->level;
node.inferred_type = std::make_unique<types::type>(std::move(type));
return;
}
function_name = identifier->name + " ";
}
auto function_type = get_type(*node.function);
auto ftype = std::get_if<types::function_type>(function_type.get());
if (!ftype)
{
std::ostringstream os;
os << "Cannot call a value of a non-function type ";
types::print(os, *function_type);
throw type_error(os.str(), node.location);
}
if (ftype->arguments.size() != node.arguments.size())
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto arg_type = get_type(*node.arguments[i]);
if (!types::equal(*arg_type, *ftype->arguments[i]))
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": argument #" << i << " expected to have type ";
types::print(os, *ftype->arguments[i]);
os << " but got type ";
types::print(os, *arg_type);
throw type_error(os.str(), node.location);
}
}
node.inferred_type = ftype->result;
}
void apply(array & node)
{
if (node.elements.empty())
throw invalid_ast_error("Empty array", node.location);
for (auto const & element : node.elements)
apply(*element);
types::type_ptr element_type = nullptr;
for (std::size_t i = 0; i < node.elements.size(); ++i)
{
auto current_type = get_type(*node.elements[0]);
if (i == 0)
{
element_type = current_type;
}
else if (!types::equal(*element_type, *current_type))
{
std::ostringstream os;
os << "Failed to infer array type: element #0 has type ";
types::print(os, *element_type);
os << " but element #" << i << " has type ";
types::print(os, *current_type);
}
}
types::array_type type;
type.element_type = element_type;
type.size = node.elements.size();
node.inferred_type = std::make_unique<types::type>(std::move(type));
}
void apply(array_access & node)
{
apply(*node.array);
apply(*node.index);
auto array_type = get_type(*node.array);
auto index_type = get_type(*node.index);
if (!types::is_integer_type(*index_type))
{
std::ostringstream os;
os << "Expected an integer type as index, but got ";
types::print(os, *index_type);
throw type_error(os.str(), get_location(*node.index));
}
auto atype = std::get_if<types::array_type>(array_type.get());
if (!atype)
{
std::ostringstream os;
os << "Expected an array to index, but got ";
types::print(os, *array_type);
throw type_error(os.str(), get_location(*node.array));
}
node.inferred_type = atype->element_type;
}
void apply(field_access & node)
{
apply(*node.object);
auto object_type = get_type(*node.object);
auto named_type = std::get_if<types::named_type>(object_type.get());
if (!named_type)
{
std::ostringstream os;
os << "Expected a struct, but got ";
types::print(os, *object_type);
throw type_error(os.str(), get_location(*node.object));
}
auto const & struct_data = scopes.at(named_type->level).structs.at(named_type->name);
for (auto const & field : struct_data.fields)
{
if (field.name == node.field_name)
{
node.inferred_type = field.type;
return;
}
}
std::ostringstream os;
os << "Struct \"" << named_type->name << "\" has no field named \"" << node.field_name << "\"";
throw type_error(os.str(), node.location);
}
void apply(expression_ptr const & node)
{
apply(*node);
}
void apply(assignment const & node)
{
apply(node.lhs);
apply(node.rhs);
auto ltype = get_type(*node.lhs);
auto rtype = get_type(*node.rhs);
// TODO: check lvalue
if (!types::equal(*ltype, *rtype))
{
std::ostringstream os;
os << "Cannot assign a value of type ";
types::print(os, *rtype);
os << " to an expression of type ";
types::print(os, *ltype);
throw type_error(os.str(), node.location);
};
}
void apply(variable_declaration const & node)
{
apply(node.initializer);
auto actual_type = get_type(*node.initializer);
if (node.type)
{
apply(*node.type);
auto expected_type = get_type(*node.type);
if (!types::equal(*expected_type, *actual_type))
{
std::ostringstream os;
os << "Cannot initialize a variable of type ";
types::print(os, *expected_type);
os << " with an expression of type ";
types::print(os, *actual_type);
throw type_error(os.str(), node.location);
}
}
scopes.back().variables[node.name] = {
.category = node.category,
.type = actual_type,
};
}
void apply(if_block const & node)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", node.location);
}
void apply(else_if_block const & node)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", node.location);
}
void apply(else_block const & node)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", node.location);
}
void apply(if_chain const & node)
{
for (auto const & block : node.blocks)
{
if (block.condition)
{
apply(block.condition);
auto actual_type = get_type(*block.condition);
if (!types::is_bool_type(*actual_type))
{
std::ostringstream os;
os << "if condition expects a bool type, but got ";
types::print(os, *actual_type);
throw type_error(os.str(), get_location(*block.condition));
}
}
scopes.emplace_back();
apply(*block.statements);
scopes.pop_back();
}
}
void apply(while_block const & node)
{
apply(node.condition);
auto actual_type = get_type(*node.condition);
if (!types::is_bool_type(*actual_type))
{
std::ostringstream os;
os << "while condition expects a bool type, but got ";
types::print(os, *actual_type);
throw type_error(os.str(), get_location(*node.condition));
}
scopes.emplace_back();
apply(*node.statements);
scopes.pop_back();
}
void apply(function_definition const & node)
{
for (auto const & argument : node.arguments)
apply(*argument.type);
apply(*node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = get_type(*node.return_type);
for (auto const & argument : node.arguments)
{
scopes.back().variables[argument.name] = {
.category = value_category::constant,
.type = get_type(*argument.type),
};
}
apply(*node.statements);
scopes.pop_back();
}
void apply(return_statement const & node)
{
types::type_ptr actual_type;
if (node.value)
{
apply(node.value);
actual_type = get_type(*node.value);
}
else
{
actual_type = std::make_unique<types::type>(types::unit_type{});
}
auto & return_scope = scopes.at(node.level);
if (!return_scope.expected_return_type)
throw invalid_ast_error("Unexpected return level", node.location);
if (!types::equal(*return_scope.expected_return_type, *actual_type))
{
std::ostringstream os;
os << "Returning value of type ";
types::print(os, *actual_type);
os << " from a function returning ";
types::print(os, *return_scope.expected_return_type);
throw type_error(os.str(), node.location);
}
}
void apply(field_definition const & node)
{
apply(*node.type);
}
void apply(struct_definition const & node)
{
for (auto const & field : node.fields)
apply(field);
auto & data = scopes.back().structs[node.name];
for (auto const & field : node.fields)
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
}
};
}
void check_and_infer_types(statement_list_ptr & statements)
{
check_visitor visitor;
visitor.scopes.emplace_back().is_global_scope = true;
visitor.apply(*statements);
}
}

View file

@ -14,4 +14,12 @@ namespace pslang::types
type_ptr builtin_type(std::string const & name);
bool is_unit_type(type const & type);
bool is_bool_type(type const & type);
bool is_integer_type(type const & type);
bool is_signed_integer_type(type const & type);
bool is_unsigned_integer_type(type const & type);
bool is_floating_point_type(type const & type);
bool is_numeric_type(type const & type);
}

View file

@ -38,4 +38,79 @@ namespace pslang::types
return nullptr;
}
bool is_unit_type(type const & type)
{
return equal(type, unit_type{});
}
bool is_bool_type(type const & type)
{
return equal(type, bool_type{});
}
bool is_integer_type(type const & type)
{
if (auto ptype = std::get_if<primitive_type>(&type))
{
return std::visit([]<typename T>(primitive_type_base<T> const &)
{
return std::is_integral_v<T> && !std::is_same_v<T, bool>;
}, *ptype);
}
return false;
}
bool is_signed_integer_type(type const & type)
{
if (auto ptype = std::get_if<primitive_type>(&type))
{
return std::visit([]<typename T>(primitive_type_base<T> const &)
{
return std::is_integral_v<T> && std::is_signed_v<T> && !std::is_same_v<T, bool>;
}, *ptype);
}
return false;
}
bool is_unsigned_integer_type(type const & type)
{
if (auto ptype = std::get_if<primitive_type>(&type))
{
return std::visit([]<typename T>(primitive_type_base<T> const &)
{
return std::is_integral_v<T> && std::is_unsigned_v<T> && !std::is_same_v<T, bool>;
}, *ptype);
}
return false;
}
bool is_floating_point_type(type const & type)
{
if (auto ptype = std::get_if<primitive_type>(&type))
{
return std::visit([]<typename T>(primitive_type_base<T> const &)
{
return std::is_floating_point_v<T>;
}, *ptype);
}
return false;
}
bool is_numeric_type(type const & type)
{
if (auto ptype = std::get_if<primitive_type>(&type))
{
return std::visit([]<typename T>(primitive_type_base<T> const &)
{
return !std::is_same_v<T, bool>;
}, *ptype);
}
return false;
}
}