diff --git a/examples/ir_test.psl b/examples/ir_test.psl index 0780ec9..1d4969d 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -1,8 +1,13 @@ -func print(c: u8): - foreign func putchar(c: i32) -> i32 - putchar(c as i32) +func alloc(size: u64) -> unit mut*: + foreign func malloc(size: u64) -> unit mut* + return malloc(size) -mut a = 10ub -let b = a -a = 11ub -print(b) +foreign func free(ptr: unit*) + +let array = alloc(400ul) as i32 mut* +*array = 10 +*(array + 1) = 20 +let q = array + 10 +let n = q - array +array[5] = 50 +free(array as unit*) diff --git a/libs/ast/include/pslang/ast/operation.hpp b/libs/ast/include/pslang/ast/operation.hpp index 4e6bee2..abc8f76 100644 --- a/libs/ast/include/pslang/ast/operation.hpp +++ b/libs/ast/include/pslang/ast/operation.hpp @@ -7,6 +7,9 @@ namespace pslang::ast { negation, logical_not, + address_of, + mutable_address_of, + dereference, }; enum class binary_operation_type @@ -40,6 +43,15 @@ namespace pslang::ast case unary_operation_type::logical_not: out << "not"; break; + case unary_operation_type::address_of: + out << "address of"; + break; + case unary_operation_type::mutable_address_of: + out << "mutable address of"; + break; + case unary_operation_type::dereference: + out << "dereference"; + break; } return out; } diff --git a/libs/ast/include/pslang/ast/type.hpp b/libs/ast/include/pslang/ast/type.hpp index a09859f..789db70 100644 --- a/libs/ast/include/pslang/ast/type.hpp +++ b/libs/ast/include/pslang/ast/type.hpp @@ -29,6 +29,13 @@ namespace pslang::ast types::type_ptr inferred_type = nullptr; }; + struct pointer_type + { + type_ptr referenced_type; + bool is_mutable; + types::type_ptr inferred_type = nullptr; + }; + struct type_identifier { std::string name; @@ -42,6 +49,7 @@ namespace pslang::ast types::primitive_type, array_type, function_type, + pointer_type, type_identifier >; diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index bae351d..cd60b6b 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -130,6 +130,14 @@ namespace pslang::ast { out << type.node->name; } + + void apply(types::pointer_type const & type) + { + apply(*type.referenced_type); + if (type.is_mutable) + out << " mut"; + out << "*"; + } }; struct type_print_visitor @@ -178,6 +186,14 @@ namespace pslang::ast apply(*type.result); } + void apply(pointer_type const & type) + { + apply(*type.referenced_type); + if (type.is_mutable) + out << " mut"; + out << "*"; + } + void apply(type_identifier const & type) { out << type.name; diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index a738b6b..879d010 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -138,6 +138,11 @@ namespace pslang::ast apply(*function_type.result); } + void apply(pointer_type const & pointer_type) + { + apply(*pointer_type.referenced_type); + } + void apply(type_identifier & identifier) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) diff --git a/libs/ast/source/type.cpp b/libs/ast/source/type.cpp index 2107fca..b25fe1c 100644 --- a/libs/ast/source/type.cpp +++ b/libs/ast/source/type.cpp @@ -33,6 +33,11 @@ namespace pslang::ast return type.inferred_type; } + types::type_ptr apply(ast::pointer_type const & type) + { + return type.inferred_type; + } + types::type_ptr apply(ast::type_identifier const & type) { return type.inferred_type; diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index bbb5d40..8d876f8 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -64,6 +64,11 @@ namespace pslang::ast return {.size = 8, .alignment = 8}; } + size_and_alignment apply(types::pointer_type const & type) + { + return {.size = 8, .alignment = 8}; + } + size_and_alignment apply(types::struct_type const & type) { if (auto jt = lcontext.structs.find(type.node); jt != lcontext.structs.end()) @@ -136,6 +141,12 @@ namespace pslang::ast node.inferred_type = std::make_unique(std::move(type)); } + void apply(ast::pointer_type & node) + { + apply(*node.referenced_type); + node.inferred_type = std::make_unique(types::pointer_type{get_type(*node.referenced_type), node.is_mutable}); + } + void apply(ast::type_identifier & node) { node.inferred_type = std::make_unique(types::struct_type{node.node}); @@ -272,6 +283,46 @@ namespace pslang::ast return; } break; + case unary_operation_type::address_of: + if (auto lvalue = classify_lvalue(node.arg1)) + { + switch (*lvalue) + { + case ast::value_category::compile_time: + throw type_error("Cannot take address of a compile-time value", node.location); + case ast::value_category::constant: + case ast::value_category::_mutable: + node.inferred_type = std::make_unique(types::pointer_type{arg1_type, false}); + } + return; + } + else + throw type_error("Cannot take address of a non-lvalue", node.location); + break; + case unary_operation_type::mutable_address_of: + if (auto lvalue = classify_lvalue(node.arg1)) + { + switch (*lvalue) + { + case ast::value_category::compile_time: + throw type_error("Cannot take address of a compile-time value", node.location); + case ast::value_category::constant: + throw type_error("Cannot take mutable address of an immutable value", node.location); + case ast::value_category::_mutable: + node.inferred_type = std::make_unique(types::pointer_type{arg1_type, true}); + } + return; + } + else + throw type_error("Cannot take address of a non-lvalue", node.location); + break; + case unary_operation_type::dereference: + if (auto pointer_type = std::get_if(arg1_type.get())) + { + node.inferred_type = pointer_type->referenced_type; + return; + } + break; } std::ostringstream os; @@ -298,6 +349,16 @@ namespace pslang::ast node.inferred_type = arg1_type; return; } + if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type)) + { + node.inferred_type = arg1_type; + return; + } + if (types::is_integer_type(*arg1_type) && types::is_pointer_type(*arg2_type)) + { + node.inferred_type = arg1_type; + return; + } break; case binary_operation_type::subtraction: if (equal && types::is_numeric_type(*arg1_type)) @@ -305,6 +366,16 @@ namespace pslang::ast node.inferred_type = arg1_type; return; } + if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type)) + { + node.inferred_type = arg1_type; + return; + } + if (types::is_pointer_type(*arg1_type) && types::is_pointer_type(*arg2_type)) + { + node.inferred_type = std::make_unique(types::primitive_type{types::i64_type{}}); + return; + } break; case binary_operation_type::multiplication: if (equal && types::is_numeric_type(*arg1_type)) @@ -431,6 +502,16 @@ namespace pslang::ast if (!types::is_bool_type(*source_type) && !types::is_bool_type(*target_type)) return; + if (auto source_pointer_type = std::get_if(source_type.get())) + { + if (auto target_pointer_type = std::get_if(target_type.get())) + { + if (!source_pointer_type->is_mutable && target_pointer_type->is_mutable) + throw type_error("Cannot cast an immutable pointer to a mutable pointer", node.location); + return; + } + } + std::ostringstream os; os << "Cannot cast a value of type "; print(os, *source_type); @@ -592,17 +673,22 @@ namespace pslang::ast throw type_error(os.str(), get_location(*node.index)); } - auto atype = std::get_if(array_type.get()); - - if (!atype) + if (auto atype = std::get_if(array_type.get())) { - std::ostringstream os; - os << "Expected an array to index, but got "; - print(os, *array_type); - throw type_error(os.str(), get_location(*node.array)); + node.inferred_type = atype->element_type; + return; } - node.inferred_type = atype->element_type; + if (auto ptype = std::get_if(array_type.get())) + { + node.inferred_type = ptype->referenced_type; + return; + } + + std::ostringstream os; + os << "Expected an array or a pointer, but got "; + print(os, *array_type); + throw type_error(os.str(), get_location(*node.array)); } void apply(field_access & node) @@ -785,7 +871,30 @@ namespace pslang::ast } else if (auto array_access = std::get_if(node.get())) { - return classify_lvalue(array_access->array); + auto array_type = get_type(*array_access->array); + if (std::get_if(array_type.get())) + return classify_lvalue(array_access->array); + if (auto pointer_type = std::get_if(array_type.get())) + { + if (pointer_type->is_mutable) + return ast::value_category::_mutable; + else + return ast::value_category::constant; + } + } + else if (auto unary_operation = std::get_if(node.get())) + { + if (unary_operation->type == ast::unary_operation_type::dereference) + { + auto arg_type = get_type(*unary_operation->arg1); + if (auto pointer_type = std::get_if(arg_type.get())) + { + if (pointer_type->is_mutable) + return ast::value_category::_mutable; + else + return ast::value_category::constant; + } + } } return std::nullopt; diff --git a/libs/interpreter/source/context.cpp b/libs/interpreter/source/context.cpp index 09b5123..f34a184 100644 --- a/libs/interpreter/source/context.cpp +++ b/libs/interpreter/source/context.cpp @@ -39,6 +39,11 @@ namespace pslang::interpreter throw std::runtime_error("Cannot zero-initialize a function type"); } + value apply(types::pointer_type const &) + { + throw std::runtime_error("Not implemented"); + } + value apply(types::struct_type const & struct_type) { struct_value result{.struct_type = std::make_unique(struct_type)}; diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index 25bc77a..95e5ef1 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -133,6 +133,10 @@ namespace pslang::interpreter return primitive_value(primitive_value_base{static_cast(~arg1.value)}); } break; + case ast::unary_operation_type::address_of: + case ast::unary_operation_type::mutable_address_of: + case ast::unary_operation_type::dereference: + throw std::runtime_error("Not implemented"); } std::ostringstream os; diff --git a/libs/ir/source/print.cpp b/libs/ir/source/print.cpp index b3f025f..58db033 100644 --- a/libs/ir/source/print.cpp +++ b/libs/ir/source/print.cpp @@ -22,6 +22,15 @@ namespace pslang::ir case ast::unary_operation_type::logical_not: out << "not"; break; + case ast::unary_operation_type::address_of: + out << "address"; + break; + case ast::unary_operation_type::mutable_address_of: + out << "address"; + break; + case ast::unary_operation_type::dereference: + out << "deref"; + break; } } diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index e9b83dd..4b273af 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -586,6 +586,10 @@ namespace pslang::jit::aarch64 if (types::is_integer_type(*node.inferred_type)) extend(0, node.inferred_type); break; + case ast::unary_operation_type::address_of: + case ast::unary_operation_type::mutable_address_of: + case ast::unary_operation_type::dereference: + throw std::runtime_error("Not implemented"); } } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index b19b81f..ff9c914 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -160,6 +160,7 @@ template %precedence UMINUS %left asterisk slash percent %precedence NOT +%precedence ADDRESSOF DEREFERENCE %precedence lbracket %type indented_statement_list @@ -268,6 +269,8 @@ type_expression | type_expression arrow type_expression { std::vector args; args.push_back(std::make_unique($1)); $$ = ast::function_type{std::move(args), std::make_unique($3)}; } | lparen function_paren_type_list rparen arrow type_expression { $$ = ast::function_type{$2, std::make_unique($5)}; } | lparen type_expression rparen { $$ = $2; } +| type_expression asterisk { $$ = ast::pointer_type{std::make_unique($1), false}; } +| type_expression mut asterisk { $$ = ast::pointer_type{std::make_unique($1), true}; } ; primitive_type @@ -315,6 +318,9 @@ expression | expression slash expression { $$ = ast::binary_operation{ast::binary_operation_type::division, std::make_unique($1), std::make_unique($3), @$ }; } | expression percent expression { $$ = ast::binary_operation{ast::binary_operation_type::remainder, std::make_unique($1), std::make_unique($3), @$ }; } | exclamation expression %prec NOT { $$ = ast::unary_operation{ast::unary_operation_type::logical_not, std::make_unique($2), @$ }; } +| ampersand expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::address_of, std::make_unique($2), @$ }; } +| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique($3), @$ }; } +| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique($2), @$ }; } | postfix_expression { $$ = $1; } ; diff --git a/libs/types/include/pslang/types/pointer.hpp b/libs/types/include/pslang/types/pointer.hpp new file mode 100644 index 0000000..c5210d9 --- /dev/null +++ b/libs/types/include/pslang/types/pointer.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace pslang::types +{ + + struct pointer_type + { + type_ptr referenced_type; + bool is_mutable; + + friend bool operator == (pointer_type const & t1, pointer_type const & t2) + { + return equal(*t1.referenced_type, *t2.referenced_type); + } + }; + +} diff --git a/libs/types/include/pslang/types/type.hpp b/libs/types/include/pslang/types/type.hpp index 02b9c0f..9d6edca 100644 --- a/libs/types/include/pslang/types/type.hpp +++ b/libs/types/include/pslang/types/type.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -17,7 +18,8 @@ namespace pslang::types primitive_type, array_type, function_type, - struct_type + struct_type, + pointer_type >; struct type diff --git a/libs/types/include/pslang/types/type_fwd.hpp b/libs/types/include/pslang/types/type_fwd.hpp index 2a1a519..06c8c3b 100644 --- a/libs/types/include/pslang/types/type_fwd.hpp +++ b/libs/types/include/pslang/types/type_fwd.hpp @@ -23,6 +23,7 @@ namespace pslang::types bool is_numeric_type(type const & type); bool is_builtin_type(type const & type); bool is_function_type(type const & type); + bool is_pointer_type(type const & type); std::size_t type_size(type const & type); diff --git a/libs/types/source/type.cpp b/libs/types/source/type.cpp index f7b343b..4b020c8 100644 --- a/libs/types/source/type.cpp +++ b/libs/types/source/type.cpp @@ -130,6 +130,13 @@ namespace pslang::types return false; } + bool is_pointer_type(type const & type) + { + if (std::get_if(&type)) + return true; + return false; + } + std::size_t type_size(type const & type) { if (std::get_if(&type))