Add foreign functions (stub in interpreter, implemented in aarch64 compiler)

This commit is contained in:
Nikita Lisitsa 2026-03-13 11:01:53 +03:00
parent cb433d87bf
commit c8cd86dd0b
24 changed files with 269 additions and 27 deletions

View file

@ -7,6 +7,7 @@
#include <pslang/ast/print.hpp> #include <pslang/ast/print.hpp>
#include <pslang/jit/jit.hpp> #include <pslang/jit/jit.hpp>
#include <pslang/jit/executable.hpp> #include <pslang/jit/executable.hpp>
#include <pslang/jit/foreign.hpp>
#include <filesystem> #include <filesystem>
#include <iostream> #include <iostream>
@ -74,6 +75,7 @@ int main(int argc, char ** argv)
using namespace pslang; using namespace pslang;
auto context = interpreter::empty_context(); auto context = interpreter::empty_context();
context.foreign_resolver = &jit::load_foreign;
bool dump = false; bool dump = false;
bool dump_ast = false; bool dump_ast = false;
@ -169,13 +171,19 @@ int main(int argc, char ** argv)
for (auto const & ast : parsed) for (auto const & ast : parsed)
jit::compile(pcontext, ast); jit::compile(pcontext, ast);
for (auto const & resolve : pcontext.foreign_resolve)
{
auto fptr = jit::load_foreign(resolve.name);
std::copy_n((std::uint8_t const *)(&fptr), 8, pcontext.code.data() + resolve.offset);
}
auto executable = jit::make_host_executable(pcontext.code); auto executable = jit::make_host_executable(pcontext.code);
{ {
// TODO: remove, testing-only code; should execute entry point instead // TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("test"); auto offset = pcontext.symbols.at("test");
auto fptr = (std::uint32_t(*)())(executable.data.get() + offset); auto fptr = (double(*)(double))(executable.data.get() + offset);
auto x = fptr(); auto x = fptr(3.14159265358979323846 / 6.0);
std::cout << "Result: " << std::boolalpha << x << std::endl; std::cout << "Result: " << std::boolalpha << x << std::endl;
} }
} }

7
examples/foreign.psl Normal file
View file

@ -0,0 +1,7 @@
foreign func sin(x: f64) -> f64
foreign func cos(x: f64) -> f64
func test(x: f64) -> f64:
let s = sin(x)
let c = cos(x)
return s * s + c * c

View file

@ -33,6 +33,10 @@ namespace pslang::ast
statement_list_ptr statements; statement_list_ptr statements;
}; };
struct foreign_function_declaration
: function_declaration
{};
struct return_statement struct return_statement
{ {
// can be null, which means "return unit" // can be null, which means "return unit"

View file

@ -40,6 +40,7 @@ namespace pslang::ast
if_chain, if_chain,
while_block, while_block,
function_definition, function_definition,
foreign_function_declaration,
return_statement, return_statement,
field_definition, field_definition,
struct_definition struct_definition

View file

@ -352,6 +352,23 @@ namespace pslang::ast
--options.indent_level; --options.indent_level;
} }
void apply(foreign_function_declaration const & node)
{
put_indent(out, options);
out << "foreign function { name = \"" << node.name << "\", return type = ";
print(out, *node.return_type);
out << " }\n";
++options.indent_level;
for (auto const & arg : node.arguments)
{
put_indent(out, options);
out << "argument { name = \"" << arg.name << "\", type = ";
print(out, *arg.type);
out << " }\n";
}
--options.indent_level;
}
void apply(return_statement const & node) void apply(return_statement const & node)
{ {
put_indent(out, options); put_indent(out, options);

View file

@ -95,6 +95,14 @@ namespace pslang::ast
scopes.back().functions.insert(function_definition.name); scopes.back().functions.insert(function_definition.name);
} }
void apply(foreign_function_declaration const foreign_function_declaration)
{
if (scopes.back().contains(foreign_function_declaration.name))
throw parse_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined at this scope", foreign_function_declaration.location);
scopes.back().functions.insert(foreign_function_declaration.name);
}
void apply(return_statement const &) void apply(return_statement const &)
{} {}
@ -322,6 +330,22 @@ namespace pslang::ast
scopes.pop_back(); scopes.pop_back();
} }
void apply(foreign_function_declaration const & foreign_function_declaration)
{
// Already added to scope by populate_globals_visitor
std::unordered_set<std::string> argument_names;
for (auto const & argument : foreign_function_declaration.arguments)
{
if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + foreign_function_declaration.name + "\"", argument.location);
argument_names.insert(argument.name);
apply(*argument.type);
}
apply(*foreign_function_declaration.return_type);
}
void apply(return_statement const & return_statement) void apply(return_statement const & return_statement)
{ {
if (return_statement.value) if (return_statement.value)

View file

@ -57,6 +57,11 @@ namespace pslang::ast
return node.location; return node.location;
} }
location apply(foreign_function_declaration const & node)
{
return node.location;
}
location apply(return_statement const & node) location apply(return_statement const & node)
{ {
return node.location; return node.location;

View file

@ -159,6 +159,18 @@ namespace pslang::ast
scopes.pop_back(); scopes.pop_back();
} }
void apply(foreign_function_declaration const & node)
{
for (auto const & argument : node.arguments)
resolve_types(scopes, *argument.type);
resolve_types(scopes, *node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
}
void apply(return_statement const &) void apply(return_statement const &)
{} {}
@ -731,6 +743,9 @@ namespace pslang::ast
scopes.pop_back(); scopes.pop_back();
} }
void apply(foreign_function_declaration const &)
{}
void apply(return_statement const & node) void apply(return_statement const & node)
{ {
types::type_ptr actual_type; types::type_ptr actual_type;

View file

@ -9,6 +9,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#include <functional>
namespace pslang::interpreter namespace pslang::interpreter
{ {
@ -44,9 +45,12 @@ namespace pslang::interpreter
value return_value = unit_value{}; value return_value = unit_value{};
}; };
using foreign_resolver = std::function<void*(std::string const &)>;
struct context struct context
{ {
bool trace = false; bool trace = false;
foreign_resolver foreign_resolver;
std::vector<frame> frame_stack; std::vector<frame> frame_stack;
}; };

View file

@ -0,0 +1,11 @@
#pragma once
#include <pslang/interpreter/value.hpp>
#include <pslang/interpreter/context.hpp>
namespace pslang::interpreter
{
value exec_foreign(context & context, void * pointer, types::type const & return_type, std::vector<value> args);
}

View file

@ -72,7 +72,7 @@ namespace pslang::interpreter
std::unordered_map<std::string, value_ptr> fields; std::unordered_map<std::string, value_ptr> fields;
}; };
struct function_value struct function_common
{ {
struct argument struct argument
{ {
@ -82,15 +82,27 @@ namespace pslang::interpreter
std::vector<argument> arguments; std::vector<argument> arguments;
types::type_ptr return_type; types::type_ptr return_type;
};
struct function_value
: function_common
{
ast::statement_list_ptr statements; ast::statement_list_ptr statements;
}; };
struct foreign_function_value
: function_common
{
void * pointer;
};
using value_impl = std::variant< using value_impl = std::variant<
unit_value, unit_value,
primitive_value, primitive_value,
array_value, array_value,
struct_value, struct_value,
function_value function_value,
foreign_function_value
>; >;
struct value struct value

View file

@ -2,6 +2,7 @@
#include <pslang/interpreter/exec.hpp> #include <pslang/interpreter/exec.hpp>
#include <pslang/interpreter/value.hpp> #include <pslang/interpreter/value.hpp>
#include <pslang/interpreter/error.hpp> #include <pslang/interpreter/error.hpp>
#include <pslang/interpreter/foreign.hpp>
#include <pslang/ast/expression.hpp> #include <pslang/ast/expression.hpp>
#include <pslang/types/print.hpp> #include <pslang/types/print.hpp>
@ -498,6 +499,11 @@ namespace pslang::interpreter
throw internal_error("Cannot cast function type to anything", location); throw internal_error("Cannot cast function type to anything", location);
} }
value cast_impl(foreign_function_value const &, types::type const &, ast::location const & location)
{
throw internal_error("Cannot cast function type to anything", location);
}
value eval_impl(context & context, ast::cast_operation const & cast_operation) value eval_impl(context & context, ast::cast_operation const & cast_operation)
{ {
auto arg = eval(context, cast_operation.expression); auto arg = eval(context, cast_operation.expression);
@ -511,12 +517,14 @@ namespace pslang::interpreter
{ {
auto lvalue = eval(context, function_call.function); auto lvalue = eval(context, function_call.function);
auto fvalue = std::get_if<function_value>(&lvalue); auto fvalue = std::get_if<function_value>(&lvalue);
if (fvalue) auto ffvalue = std::get_if<foreign_function_value>(&lvalue);
auto fcommon = fvalue ? static_cast<function_common *>(fvalue) : ffvalue;
if (fcommon)
{ {
if (fvalue->arguments.size() != function_call.arguments.size()) if (fcommon->arguments.size() != function_call.arguments.size())
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size(); os << "Cannot call function: expected " << fcommon->arguments.size() << " arguments, got " << function_call.arguments.size();
throw internal_error(os.str(), function_call.location); throw internal_error(os.str(), function_call.location);
} }
@ -527,29 +535,36 @@ namespace pslang::interpreter
for (std::size_t i = 0; i < args.size(); ++i) for (std::size_t i = 0; i < args.size(); ++i)
{ {
auto actual_type = type_of(args[i]); auto actual_type = type_of(args[i]);
if (!types::equal(actual_type, *fvalue->arguments[i].type)) if (!types::equal(actual_type, *fcommon->arguments[i].type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type "; os << "Cannot call function: argument #" << (i + 1) << " expects type ";
types::print(os, *fvalue->arguments[i].type); types::print(os, *fcommon->arguments[i].type);
os << " but actual type is "; os << " but actual type is ";
types::print(os, actual_type); types::print(os, actual_type);
throw internal_error(os.str(), ast::get_location(*function_call.arguments[i])); throw internal_error(os.str(), ast::get_location(*function_call.arguments[i]));
} }
} }
auto & function_scope = context.frame_stack.emplace_back(); if (fvalue)
{
auto & function_scope = context.frame_stack.emplace_back();
for (std::size_t i = 0; i < args.size(); ++i) for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
function_scope.expected_return_type = fvalue->return_type; function_scope.expected_return_type = fvalue->return_type;
exec(context, fvalue->statements); exec(context, fvalue->statements);
auto result = std::move(context.frame_stack.back().return_value); auto result = std::move(context.frame_stack.back().return_value);
context.frame_stack.pop_back(); context.frame_stack.pop_back();
return result; return result;
}
else if (ffvalue)
{
return exec_foreign(context, ffvalue->pointer, *fcommon->return_type, std::move(args));
}
} }
std::ostringstream os; std::ostringstream os;

View file

@ -170,6 +170,23 @@ namespace pslang::interpreter
frame.variables[function_definition.name] = {.category = ast::value_category::constant, .value = std::move(value)}; frame.variables[function_definition.name] = {.category = ast::value_category::constant, .value = std::move(value)};
} }
void exec_impl(context & context, ast::foreign_function_declaration const & foreign_function_declaration)
{
auto & frame = context.frame_stack.back();
if (frame.contains(foreign_function_declaration.name))
throw std::runtime_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined in this scope");
foreign_function_value value;
for (auto const & argument : foreign_function_declaration.arguments)
value.arguments.push_back({.name = argument.name, .type = get_type(*argument.type)});
value.return_type = get_type(*foreign_function_declaration.return_type);
value.pointer = context.foreign_resolver(foreign_function_declaration.name);
frame.variables[foreign_function_declaration.name] = {.category = ast::value_category::constant, .value = std::move(value)};
}
void exec_impl(context & context, ast::return_statement const & return_statement) void exec_impl(context & context, ast::return_statement const & return_statement)
{ {
// NB: cannot use return_statement.level here because lexical scope stack // NB: cannot use return_statement.level here because lexical scope stack

View file

@ -0,0 +1,12 @@
#include <pslang/interpreter/foreign.hpp>
#include <pslang/interpreter/error.hpp>
namespace pslang::interpreter
{
value exec_foreign(context & context, void * pointer, types::type const & return_type, std::vector<value> args)
{
throw std::runtime_error("Not implemented");
}
}

View file

@ -41,6 +41,15 @@ namespace pslang::interpreter
result.result = value.return_type; result.result = value.return_type;
return result; return result;
} }
types::type operator()(foreign_function_value const & value)
{
types::function_type result;
for (auto const & argument : value.arguments)
result.arguments.push_back(argument.type);
result.result = value.return_type;
return result;
}
}; };
struct print_visitor struct print_visitor
@ -112,6 +121,11 @@ namespace pslang::interpreter
{ {
out << "func"; out << "func";
} }
void operator()(foreign_function_value const & value)
{
out << "foreign func";
}
}; };
} }

View file

@ -42,6 +42,10 @@ namespace pslang::jit::aarch64
// Store the address plus a signed 9-bit offset in @reg_addr // Store the address plus a signed 9-bit offset in @reg_addr
void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the value at address specified by the value of the program counter (PC)
// plus a signed 19-bit @offset multiplied by 4, and store it into register @reg_dst
void ldr_pc(std::uint8_t reg_dst, std::int32_t offset);
// Add a 12-bit @value to the register @reg_src and store the result in @reg_dst // Add a 12-bit @value to the register @reg_src and store the result in @reg_dst
void add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value); void add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value);

View file

@ -0,0 +1,10 @@
#pragma once
#include <string>
namespace pslang::jit
{
void * load_foreign(std::string const & name);
}

View file

@ -12,9 +12,18 @@ namespace pslang::jit
struct program_context struct program_context
{ {
struct foreign_resolve_info
{
std::string name;
// Offset in bytes to the place in code blob
// containing the 64-bit address of the foreign symbol
std::int32_t offset;
};
jit::abi abi; jit::abi abi;
std::vector<std::uint8_t> code = {}; std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {}; std::unordered_map<std::string, std::int32_t> symbols = {};
std::vector<foreign_resolve_info> foreign_resolve = {};
}; };
} }

View file

@ -10,8 +10,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <iostream>
namespace pslang::jit::aarch64 namespace pslang::jit::aarch64
{ {
@ -32,6 +30,8 @@ namespace pslang::jit::aarch64
}; };
std::vector<resolve_info> resolve; std::vector<resolve_info> resolve;
std::unordered_map<std::string, std::int32_t> foreign_address;
}; };
std::uint8_t fp_mode_for(types::type const & type) std::uint8_t fp_mode_for(types::type const & type)
@ -186,6 +186,15 @@ namespace pslang::jit::aarch64
apply(*node.statements); apply(*node.statements);
} }
void apply(ast::foreign_function_declaration const & foreign_function_declaration)
{
if (!lcontext.foreign_address.contains(foreign_function_declaration.name))
{
lcontext.foreign_address[foreign_function_declaration.name] = pcontext.code.size();
push_bytes<void *>(nullptr);
}
}
void apply(ast::return_statement const & node) void apply(ast::return_statement const & node)
{ {
apply(*node.value); apply(*node.value);
@ -365,8 +374,15 @@ namespace pslang::jit::aarch64
} }
// Not a variable - must be a function! // Not a variable - must be a function!
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()}); if (lcontext.foreign_address.contains(node.name))
builder.adr(0, 0); {
builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4);
}
else
{
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
}
} }
void apply(ast::unary_operation const & node) void apply(ast::unary_operation const & node)
@ -789,9 +805,14 @@ namespace pslang::jit::aarch64
builder.ret(); builder.ret();
} }
void apply(ast::function_definition const & node) void apply(ast::function_definition const &)
{ {
// Don't handle internal functions // Must be handled prior to that
}
void apply(ast::foreign_function_declaration const &)
{
// Must be handled prior to that
} }
void apply(ast::statement_list const & node) void apply(ast::statement_list const & node)
@ -909,6 +930,11 @@ namespace pslang::jit::aarch64
compile_function_visitor visitor{{}, {}, pcontext, lcontext}; compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node); visitor.do_apply(node);
} }
void apply(ast::foreign_function_declaration const &)
{
// Already handled by populate_globals
}
}; };
} }
@ -924,6 +950,9 @@ namespace pslang::jit::aarch64
{ {
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset); builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset);
} }
for (auto const & foreign : lcontext.foreign_address)
pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
} }
} }

View file

@ -40,6 +40,11 @@ namespace pslang::jit::aarch64
do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
} }
void instruction_builder::ldr_pc(std::uint8_t reg_dst, std::int32_t offset)
{
do_push(0x58000000u | (reg_dst & REG_MASK) | ((((std::uint32_t)offset) & 0x7ffffu) << 5));
}
void instruction_builder::add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value) void instruction_builder::add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value)
{ {
do_push(0x91000000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((value & 0xfffu) << 10)); do_push(0x91000000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((value & 0xfffu) << 10));

View file

@ -0,0 +1,13 @@
#include <pslang/jit/foreign.hpp>
#include <dlfcn.h>
namespace pslang::jit
{
void * load_foreign(std::string const & name)
{
return dlsym(RTLD_DEFAULT, name.data());
}
}

View file

@ -31,6 +31,7 @@ else { return bp::make_else(ctx.location); }
while { return bp::make_while(ctx.location); } while { return bp::make_while(ctx.location); }
as { return bp::make_as(ctx.location); } as { return bp::make_as(ctx.location); }
func { return bp::make_func(ctx.location); } func { return bp::make_func(ctx.location); }
foreign { return bp::make_foreign(ctx.location); }
return { return bp::make_return(ctx.location); } return { return bp::make_return(ctx.location); }
struct { return bp::make_struct(ctx.location); } struct { return bp::make_struct(ctx.location); }
true { return bp::make_true(ctx.location); } true { return bp::make_true(ctx.location); }

View file

@ -130,6 +130,7 @@ template <typename T>
%token while %token while
%token as %token as
%token func %token func
%token foreign
%token return %token return
%token struct %token struct
%token true %token true
@ -212,6 +213,7 @@ statement
| else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), @$}; } | else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), @$}; }
| while expression colon { $$ = ast::while_block{std::make_unique<ast::expression>($2), {}, @$}; } | while expression colon { $$ = ast::while_block{std::make_unique<ast::expression>($2), {}, @$}; }
| func name lparen function_declaration_argument_list rparen function_return_type colon { $$ = ast::function_definition{{$2, $4, $6, @$}, {}}; } | func name lparen function_declaration_argument_list rparen function_return_type colon { $$ = ast::function_definition{{$2, $4, $6, @$}, {}}; }
| foreign func name lparen function_declaration_argument_list rparen function_return_type { $$ = ast::foreign_function_declaration{{$3, $5, $7, @$}}; }
| return expression { $$ = ast::return_statement{std::make_unique<ast::expression>($2), @$}; } | return expression { $$ = ast::return_statement{std::make_unique<ast::expression>($2), @$}; }
| return { $$ = ast::return_statement{nullptr, @$}; } | return { $$ = ast::return_statement{nullptr, @$}; }
| struct name colon { $$ = ast::struct_definition{$2, {}, @$}; } | struct name colon { $$ = ast::struct_definition{$2, {}, @$}; }

View file

@ -1,8 +1,6 @@
Future plans: Future plans:
* Fix interpreter identifier resolution for functions & variables
* Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter
* Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value * Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value
* C FFI: `foreign func sin(x : f32) -> f32`
* Const propagation: annotate expression AST nodes that are computable in compile-time * Const propagation: annotate expression AST nodes that are computable in compile-time
* Generic parameters: can be either values or `t : type`, but always compile-time * Generic parameters: can be either values or `t : type`, but always compile-time
* Generic structs: `struct <t : type, n : u64> array:`, require explicitly specifying parameters when instantiated (used as a type or creating a value) * Generic structs: `struct <t : type, n : u64> array:`, require explicitly specifying parameters when instantiated (used as a type or creating a value)
@ -11,10 +9,15 @@ Future plans:
* Extension functions: operator overloading, destructors, iterators & for loop, move assignment (replaces built-in copy) * Extension functions: operator overloading, destructors, iterators & for loop, move assignment (replaces built-in copy)
* Metaprogramming: assigning types to variables (of type `type`), functions can take `type` as regular arguments and return `type`, all type computations are compile-time only * Metaprogramming: assigning types to variables (of type `type`), functions can take `type` as regular arguments and return `type`, all type computations are compile-time only
Aarch64 compiler: Interpreter backlog:
* Fix identifier resolution for functions & variables
* C FFI (foreign functions)
Aarch64 compiler backlog:
* Inner functions
* Struct values * Struct values
Backlog: General backlog:
* Mutually recursive structs (relevant only with pointers) * Mutually recursive structs (relevant only with pointers)
* Empty array expression * Empty array expression
* Calling functions as methods * Calling functions as methods