Add foreign functions (stub in interpreter, implemented in aarch64 compiler)

This commit is contained in:
Nikita Lisitsa 2026-03-13 11:01:53 +03:00
parent cb433d87bf
commit c8cd86dd0b
24 changed files with 269 additions and 27 deletions

View file

@ -7,6 +7,7 @@
#include <pslang/ast/print.hpp>
#include <pslang/jit/jit.hpp>
#include <pslang/jit/executable.hpp>
#include <pslang/jit/foreign.hpp>
#include <filesystem>
#include <iostream>
@ -74,6 +75,7 @@ int main(int argc, char ** argv)
using namespace pslang;
auto context = interpreter::empty_context();
context.foreign_resolver = &jit::load_foreign;
bool dump = false;
bool dump_ast = false;
@ -169,13 +171,19 @@ int main(int argc, char ** argv)
for (auto const & ast : parsed)
jit::compile(pcontext, ast);
for (auto const & resolve : pcontext.foreign_resolve)
{
auto fptr = jit::load_foreign(resolve.name);
std::copy_n((std::uint8_t const *)(&fptr), 8, pcontext.code.data() + resolve.offset);
}
auto executable = jit::make_host_executable(pcontext.code);
{
// TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("test");
auto fptr = (std::uint32_t(*)())(executable.data.get() + offset);
auto x = fptr();
auto fptr = (double(*)(double))(executable.data.get() + offset);
auto x = fptr(3.14159265358979323846 / 6.0);
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
}

7
examples/foreign.psl Normal file
View file

@ -0,0 +1,7 @@
foreign func sin(x: f64) -> f64
foreign func cos(x: f64) -> f64
func test(x: f64) -> f64:
let s = sin(x)
let c = cos(x)
return s * s + c * c

View file

@ -33,6 +33,10 @@ namespace pslang::ast
statement_list_ptr statements;
};
struct foreign_function_declaration
: function_declaration
{};
struct return_statement
{
// can be null, which means "return unit"

View file

@ -40,6 +40,7 @@ namespace pslang::ast
if_chain,
while_block,
function_definition,
foreign_function_declaration,
return_statement,
field_definition,
struct_definition

View file

@ -352,6 +352,23 @@ namespace pslang::ast
--options.indent_level;
}
void apply(foreign_function_declaration const & node)
{
put_indent(out, options);
out << "foreign function { name = \"" << node.name << "\", return type = ";
print(out, *node.return_type);
out << " }\n";
++options.indent_level;
for (auto const & arg : node.arguments)
{
put_indent(out, options);
out << "argument { name = \"" << arg.name << "\", type = ";
print(out, *arg.type);
out << " }\n";
}
--options.indent_level;
}
void apply(return_statement const & node)
{
put_indent(out, options);

View file

@ -95,6 +95,14 @@ namespace pslang::ast
scopes.back().functions.insert(function_definition.name);
}
void apply(foreign_function_declaration const foreign_function_declaration)
{
if (scopes.back().contains(foreign_function_declaration.name))
throw parse_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined at this scope", foreign_function_declaration.location);
scopes.back().functions.insert(foreign_function_declaration.name);
}
void apply(return_statement const &)
{}
@ -322,6 +330,22 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(foreign_function_declaration const & foreign_function_declaration)
{
// Already added to scope by populate_globals_visitor
std::unordered_set<std::string> argument_names;
for (auto const & argument : foreign_function_declaration.arguments)
{
if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + foreign_function_declaration.name + "\"", argument.location);
argument_names.insert(argument.name);
apply(*argument.type);
}
apply(*foreign_function_declaration.return_type);
}
void apply(return_statement const & return_statement)
{
if (return_statement.value)

View file

@ -57,6 +57,11 @@ namespace pslang::ast
return node.location;
}
location apply(foreign_function_declaration const & node)
{
return node.location;
}
location apply(return_statement const & node)
{
return node.location;

View file

@ -159,6 +159,18 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(foreign_function_declaration const & node)
{
for (auto const & argument : node.arguments)
resolve_types(scopes, *argument.type);
resolve_types(scopes, *node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
}
void apply(return_statement const &)
{}
@ -731,6 +743,9 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(foreign_function_declaration const &)
{}
void apply(return_statement const & node)
{
types::type_ptr actual_type;

View file

@ -9,6 +9,7 @@
#include <string>
#include <vector>
#include <iostream>
#include <functional>
namespace pslang::interpreter
{
@ -44,9 +45,12 @@ namespace pslang::interpreter
value return_value = unit_value{};
};
using foreign_resolver = std::function<void*(std::string const &)>;
struct context
{
bool trace = false;
foreign_resolver foreign_resolver;
std::vector<frame> frame_stack;
};

View file

@ -0,0 +1,11 @@
#pragma once
#include <pslang/interpreter/value.hpp>
#include <pslang/interpreter/context.hpp>
namespace pslang::interpreter
{
value exec_foreign(context & context, void * pointer, types::type const & return_type, std::vector<value> args);
}

View file

@ -72,7 +72,7 @@ namespace pslang::interpreter
std::unordered_map<std::string, value_ptr> fields;
};
struct function_value
struct function_common
{
struct argument
{
@ -82,15 +82,27 @@ namespace pslang::interpreter
std::vector<argument> arguments;
types::type_ptr return_type;
};
struct function_value
: function_common
{
ast::statement_list_ptr statements;
};
struct foreign_function_value
: function_common
{
void * pointer;
};
using value_impl = std::variant<
unit_value,
primitive_value,
array_value,
struct_value,
function_value
function_value,
foreign_function_value
>;
struct value

View file

@ -2,6 +2,7 @@
#include <pslang/interpreter/exec.hpp>
#include <pslang/interpreter/value.hpp>
#include <pslang/interpreter/error.hpp>
#include <pslang/interpreter/foreign.hpp>
#include <pslang/ast/expression.hpp>
#include <pslang/types/print.hpp>
@ -498,6 +499,11 @@ namespace pslang::interpreter
throw internal_error("Cannot cast function type to anything", location);
}
value cast_impl(foreign_function_value const &, types::type const &, ast::location const & location)
{
throw internal_error("Cannot cast function type to anything", location);
}
value eval_impl(context & context, ast::cast_operation const & cast_operation)
{
auto arg = eval(context, cast_operation.expression);
@ -511,12 +517,14 @@ namespace pslang::interpreter
{
auto lvalue = eval(context, function_call.function);
auto fvalue = std::get_if<function_value>(&lvalue);
if (fvalue)
auto ffvalue = std::get_if<foreign_function_value>(&lvalue);
auto fcommon = fvalue ? static_cast<function_common *>(fvalue) : ffvalue;
if (fcommon)
{
if (fvalue->arguments.size() != function_call.arguments.size())
if (fcommon->arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size();
os << "Cannot call function: expected " << fcommon->arguments.size() << " arguments, got " << function_call.arguments.size();
throw internal_error(os.str(), function_call.location);
}
@ -527,29 +535,36 @@ namespace pslang::interpreter
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!types::equal(actual_type, *fvalue->arguments[i].type))
if (!types::equal(actual_type, *fcommon->arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type ";
types::print(os, *fvalue->arguments[i].type);
types::print(os, *fcommon->arguments[i].type);
os << " but actual type is ";
types::print(os, actual_type);
throw internal_error(os.str(), ast::get_location(*function_call.arguments[i]));
}
}
auto & function_scope = context.frame_stack.emplace_back();
if (fvalue)
{
auto & function_scope = context.frame_stack.emplace_back();
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
function_scope.expected_return_type = fvalue->return_type;
function_scope.expected_return_type = fvalue->return_type;
exec(context, fvalue->statements);
exec(context, fvalue->statements);
auto result = std::move(context.frame_stack.back().return_value);
context.frame_stack.pop_back();
return result;
auto result = std::move(context.frame_stack.back().return_value);
context.frame_stack.pop_back();
return result;
}
else if (ffvalue)
{
return exec_foreign(context, ffvalue->pointer, *fcommon->return_type, std::move(args));
}
}
std::ostringstream os;

View file

@ -170,6 +170,23 @@ namespace pslang::interpreter
frame.variables[function_definition.name] = {.category = ast::value_category::constant, .value = std::move(value)};
}
void exec_impl(context & context, ast::foreign_function_declaration const & foreign_function_declaration)
{
auto & frame = context.frame_stack.back();
if (frame.contains(foreign_function_declaration.name))
throw std::runtime_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined in this scope");
foreign_function_value value;
for (auto const & argument : foreign_function_declaration.arguments)
value.arguments.push_back({.name = argument.name, .type = get_type(*argument.type)});
value.return_type = get_type(*foreign_function_declaration.return_type);
value.pointer = context.foreign_resolver(foreign_function_declaration.name);
frame.variables[foreign_function_declaration.name] = {.category = ast::value_category::constant, .value = std::move(value)};
}
void exec_impl(context & context, ast::return_statement const & return_statement)
{
// NB: cannot use return_statement.level here because lexical scope stack

View file

@ -0,0 +1,12 @@
#include <pslang/interpreter/foreign.hpp>
#include <pslang/interpreter/error.hpp>
namespace pslang::interpreter
{
value exec_foreign(context & context, void * pointer, types::type const & return_type, std::vector<value> args)
{
throw std::runtime_error("Not implemented");
}
}

View file

@ -41,6 +41,15 @@ namespace pslang::interpreter
result.result = value.return_type;
return result;
}
types::type operator()(foreign_function_value const & value)
{
types::function_type result;
for (auto const & argument : value.arguments)
result.arguments.push_back(argument.type);
result.result = value.return_type;
return result;
}
};
struct print_visitor
@ -112,6 +121,11 @@ namespace pslang::interpreter
{
out << "func";
}
void operator()(foreign_function_value const & value)
{
out << "foreign func";
}
};
}

View file

@ -42,6 +42,10 @@ namespace pslang::jit::aarch64
// Store the address plus a signed 9-bit offset in @reg_addr
void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the value at address specified by the value of the program counter (PC)
// plus a signed 19-bit @offset multiplied by 4, and store it into register @reg_dst
void ldr_pc(std::uint8_t reg_dst, std::int32_t offset);
// Add a 12-bit @value to the register @reg_src and store the result in @reg_dst
void add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value);

View file

@ -0,0 +1,10 @@
#pragma once
#include <string>
namespace pslang::jit
{
void * load_foreign(std::string const & name);
}

View file

@ -12,9 +12,18 @@ namespace pslang::jit
struct program_context
{
struct foreign_resolve_info
{
std::string name;
// Offset in bytes to the place in code blob
// containing the 64-bit address of the foreign symbol
std::int32_t offset;
};
jit::abi abi;
std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {};
std::vector<foreign_resolve_info> foreign_resolve = {};
};
}

View file

@ -10,8 +10,6 @@
#include <unordered_map>
#include <vector>
#include <iostream>
namespace pslang::jit::aarch64
{
@ -32,6 +30,8 @@ namespace pslang::jit::aarch64
};
std::vector<resolve_info> resolve;
std::unordered_map<std::string, std::int32_t> foreign_address;
};
std::uint8_t fp_mode_for(types::type const & type)
@ -186,6 +186,15 @@ namespace pslang::jit::aarch64
apply(*node.statements);
}
void apply(ast::foreign_function_declaration const & foreign_function_declaration)
{
if (!lcontext.foreign_address.contains(foreign_function_declaration.name))
{
lcontext.foreign_address[foreign_function_declaration.name] = pcontext.code.size();
push_bytes<void *>(nullptr);
}
}
void apply(ast::return_statement const & node)
{
apply(*node.value);
@ -365,8 +374,15 @@ namespace pslang::jit::aarch64
}
// Not a variable - must be a function!
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
if (lcontext.foreign_address.contains(node.name))
{
builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4);
}
else
{
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
}
}
void apply(ast::unary_operation const & node)
@ -789,9 +805,14 @@ namespace pslang::jit::aarch64
builder.ret();
}
void apply(ast::function_definition const & node)
void apply(ast::function_definition const &)
{
// Don't handle internal functions
// Must be handled prior to that
}
void apply(ast::foreign_function_declaration const &)
{
// Must be handled prior to that
}
void apply(ast::statement_list const & node)
@ -909,6 +930,11 @@ namespace pslang::jit::aarch64
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node);
}
void apply(ast::foreign_function_declaration const &)
{
// Already handled by populate_globals
}
};
}
@ -924,6 +950,9 @@ namespace pslang::jit::aarch64
{
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset);
}
for (auto const & foreign : lcontext.foreign_address)
pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
}
}

View file

@ -40,6 +40,11 @@ namespace pslang::jit::aarch64
do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldr_pc(std::uint8_t reg_dst, std::int32_t offset)
{
do_push(0x58000000u | (reg_dst & REG_MASK) | ((((std::uint32_t)offset) & 0x7ffffu) << 5));
}
void instruction_builder::add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value)
{
do_push(0x91000000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((value & 0xfffu) << 10));

View file

@ -0,0 +1,13 @@
#include <pslang/jit/foreign.hpp>
#include <dlfcn.h>
namespace pslang::jit
{
void * load_foreign(std::string const & name)
{
return dlsym(RTLD_DEFAULT, name.data());
}
}

View file

@ -31,6 +31,7 @@ else { return bp::make_else(ctx.location); }
while { return bp::make_while(ctx.location); }
as { return bp::make_as(ctx.location); }
func { return bp::make_func(ctx.location); }
foreign { return bp::make_foreign(ctx.location); }
return { return bp::make_return(ctx.location); }
struct { return bp::make_struct(ctx.location); }
true { return bp::make_true(ctx.location); }

View file

@ -130,6 +130,7 @@ template <typename T>
%token while
%token as
%token func
%token foreign
%token return
%token struct
%token true
@ -212,6 +213,7 @@ statement
| else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), @$}; }
| while expression colon { $$ = ast::while_block{std::make_unique<ast::expression>($2), {}, @$}; }
| func name lparen function_declaration_argument_list rparen function_return_type colon { $$ = ast::function_definition{{$2, $4, $6, @$}, {}}; }
| foreign func name lparen function_declaration_argument_list rparen function_return_type { $$ = ast::foreign_function_declaration{{$3, $5, $7, @$}}; }
| return expression { $$ = ast::return_statement{std::make_unique<ast::expression>($2), @$}; }
| return { $$ = ast::return_statement{nullptr, @$}; }
| struct name colon { $$ = ast::struct_definition{$2, {}, @$}; }

View file

@ -1,8 +1,6 @@
Future plans:
* Fix interpreter identifier resolution for functions & variables
* Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter
* Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value
* C FFI: `foreign func sin(x : f32) -> f32`
* Const propagation: annotate expression AST nodes that are computable in compile-time
* Generic parameters: can be either values or `t : type`, but always compile-time
* Generic structs: `struct <t : type, n : u64> array:`, require explicitly specifying parameters when instantiated (used as a type or creating a value)
@ -11,10 +9,15 @@ Future plans:
* Extension functions: operator overloading, destructors, iterators & for loop, move assignment (replaces built-in copy)
* Metaprogramming: assigning types to variables (of type `type`), functions can take `type` as regular arguments and return `type`, all type computations are compile-time only
Aarch64 compiler:
Interpreter backlog:
* Fix identifier resolution for functions & variables
* C FFI (foreign functions)
Aarch64 compiler backlog:
* Inner functions
* Struct values
Backlog:
General backlog:
* Mutually recursive structs (relevant only with pointers)
* Empty array expression
* Calling functions as methods