Aarch64 JIT compiler wip: partial struct support

This commit is contained in:
Nikita Lisitsa 2026-03-17 18:08:23 +03:00
parent 53be7b92df
commit 1a0dd2a48f
8 changed files with 570 additions and 79 deletions

View file

@ -15,8 +15,6 @@ foreign func putchar(c: i32) -> i32
func print(c: u8): func print(c: u8):
putchar(c as i32) putchar(c as i32)
let x = 0
func test1(): func test1():
print('H') print('H')
print('e') print('e')
@ -44,20 +42,28 @@ func test() -> i32:
else: else:
return b() return b()
struct interval: struct vec2:
min: f32 x : f32
max: f32 y : f32
struct box2f: struct vecX:
x: interval a: f32
y: interval b: f32
c: f64
struct weird: func test2():
a: i32 //let v = vec2(1.0, 2.0)
b: u8 vec2(1.0, 2.0)
c: f32 -> f32
d: f64 //test2()
e: f16
mut v = vecX(6.0, 9.0, 4.0l)
v = vecX(1.0, 2.0, 5.0l)
v.a = 1.0
//v.b = 2.0
v.c = 5.0l
print('0' + (v.b as u8))
print('\n')
//func test1(): //func test1():
// let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'] // let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n']

View file

@ -63,12 +63,12 @@ namespace pslang::ast
size_and_alignment apply(types::unit_type const &) size_and_alignment apply(types::unit_type const &)
{ {
return {.size = 1, .alignment = 1}; return {.size = 0, .alignment = 1};
} }
size_and_alignment apply(types::primitive_type const & type) size_and_alignment apply(types::primitive_type const & type)
{ {
auto size = types::builtin_type_size(type); auto size = types::type_size(type);
return {.size = size, .alignment = size}; return {.size = size, .alignment = size};
} }
@ -561,7 +561,7 @@ namespace pslang::ast
os << "Cannot create built-in type "; os << "Cannot create built-in type ";
types::print(os, *type); types::print(os, *type);
os << ": expected 0 arguments, but got " << node.arguments.size(); os << ": expected 0 arguments, but got " << node.arguments.size();
throw std::runtime_error(os.str()); throw type_error(os.str(), node.location);
} }
else if (auto named_type = std::get_if<types::named_type>(type.get())) else if (auto named_type = std::get_if<types::named_type>(type.get()))
{ {
@ -702,6 +702,17 @@ namespace pslang::ast
{ {
apply(node.lhs); apply(node.lhs);
apply(node.rhs); apply(node.rhs);
if (auto lvalue = classify_lvalue(node.lhs))
{
if (*lvalue != ast::value_category::_mutable)
throw type_error("Cannot assign to a non-mutable object", node.location);
}
else
{
throw type_error("Cannot assign to a non-lvalue", node.location);
}
auto ltype = get_type(*node.lhs); auto ltype = get_type(*node.lhs);
auto rtype = get_type(*node.rhs); auto rtype = get_type(*node.rhs);
// TODO: check lvalue // TODO: check lvalue
@ -863,6 +874,30 @@ namespace pslang::ast
for (auto & struct_data : scopes.back().structs) for (auto & struct_data : scopes.back().structs)
compute_layout(struct_data.second, scopes); compute_layout(struct_data.second, scopes);
} }
private:
std::optional<ast::value_category> classify_lvalue(expression_ptr const & node)
{
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto const & scope = scopes[identifier->level];
if (scope.functions.contains(identifier->name))
return ast::value_category::constant;
if (auto it = scope.variables.find(identifier->name); it != scope.variables.end())
return it->second.category;
return std::nullopt;
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
return classify_lvalue(field_access->object);
}
else if (auto array_access = std::get_if<ast::array_access>(node.get()))
{
return classify_lvalue(array_access->array);
}
return std::nullopt;
}
}; };
} }

View file

@ -20,7 +20,7 @@ namespace pslang::jit::aarch64
// @shift must be 0, 1, 2 or 3 // @shift must be 0, 1, 2 or 3
void movk(std::uint8_t reg, std::uint16_t val, std::uint8_t shift = 0); void movk(std::uint8_t reg, std::uint16_t val, std::uint8_t shift = 0);
// Load the value of the register @reg_src at the address specified by the value of // Store the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an unsigned 12-bit offset multiplied by 8. // register @reg_addr plus an unsigned 12-bit offset multiplied by 8.
void str(std::uint8_t reg_src, std::uint8_t reg_addr, std::uint16_t offset); void str(std::uint8_t reg_src, std::uint8_t reg_addr, std::uint16_t offset);
@ -28,6 +28,22 @@ namespace pslang::jit::aarch64
// plus a signed 9-bit offset. Store the new address value in @reg_addr // plus a signed 9-bit offset. Store the new address value in @reg_addr
void str_pre(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); void str_pre(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Store the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an signed 9-bit offset.
void stur(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Store the lowest 32 bits of the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an signed 9-bit offset.
void sturw(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Store the lowest 16 bits of the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an signed 9-bit offset.
void sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Store the lowest 9 bits of the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an signed 9-bit offset.
void sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Load the value at address specified by the value of register @reg_addr // Load the value at address specified by the value of register @reg_addr
// plus an unsigned 12-bit offset multiplied by 8 and store it in the register @reg_dst. // plus an unsigned 12-bit offset multiplied by 8 and store it in the register @reg_dst.
void ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset); void ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset);
@ -42,6 +58,22 @@ namespace pslang::jit::aarch64
// Store the address plus a signed 9-bit offset in @reg_addr // Store the address plus a signed 9-bit offset in @reg_addr
void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the value at address specified by the value of register @reg_addr
// plus an signed 9-bit offset and store it in the register @reg_dst.
void ldur(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the lowest 32 bits of the value at address specified by the value of register @reg_addr
// plus an signed 9-bit offset and store it in the register @reg_dst.
void ldurw(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the lowest 16 bits of the value at address specified by the value of register @reg_addr
// plus an signed 9-bit offset and store it in the register @reg_dst.
void ldurh(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the lowest 8 bits of the value at address specified by the value of register @reg_addr
// plus an signed 9-bit offset and store it in the register @reg_dst.
void ldurb(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset);
// Load the value at address specified by the value of the program counter (PC) // Load the value at address specified by the value of the program counter (PC)
// plus a signed 19-bit @offset multiplied by 4, and store it into register @reg_dst // plus a signed 19-bit @offset multiplied by 4, and store it into register @reg_dst
void ldr_pc(std::uint8_t reg_dst, std::int32_t offset); void ldr_pc(std::uint8_t reg_dst, std::int32_t offset);
@ -135,15 +167,25 @@ namespace pslang::jit::aarch64
void ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset); void ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset);
// Load a floating-point value from the address stored in @reg_addr plus a // Load a floating-point value from the address stored in @reg_addr plus a
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point // 12-bit unsigned @offset multiplied by sizeof the type (2, 4 or 8), and store it in floating-point
// register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit // register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Load a floating-point value from the address stored in @reg_addr plus a
// 9-bit signed @offset, and store it in floating-point
// register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void ldur_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Store a floating-point value from @reg_src into the address stored in // Store a floating-point value from @reg_src into the address stored in
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4. // @reg_addr plus a 12-bit unsigned @offset multiplied by sizeof the type (2, 4 or 8).
// @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Store a floating-point value from @reg_src into the address stored in
// @reg_addr plus a 9-bit signed @offset.
// @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void stur_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::int16_t offset);
// Convert a floating-point value from @reg_src stored in format @mode_src // Convert a floating-point value from @reg_src stored in format @mode_src
// into a floating-point value in @reg_dst in format @mode_dst // into a floating-point value in @reg_dst in format @mode_dst
// Modes are the same as for ldr_fp/str_fp // Modes are the same as for ldr_fp/str_fp
@ -161,7 +203,6 @@ namespace pslang::jit::aarch64
// If @op is 0, move the value of a floating-point register @reg_src in format @mode into a general register @reg_dst // If @op is 0, move the value of a floating-point register @reg_src in format @mode into a general register @reg_dst
// If @op is 1, move the value of a general register @reg_src into a floating-point register @reg_dst in format @mode // If @op is 1, move the value of a general register @reg_src into a floating-point register @reg_dst in format @mode
// NB: mode = 2 (single precision) isn't supported
void fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op); void fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op);
// Convert signed integer in floating-point register @reg_src into floating-point format @mode and // Convert signed integer in floating-point register @reg_src into floating-point format @mode and

View file

@ -18,6 +18,14 @@ namespace pslang::jit::aarch64
namespace namespace
{ {
// Homogeneous floating-point aggregate: up to 4 floating-point members
// of the same type (after struct flattening)
struct hfa_data
{
types::type_ptr type;
std::size_t count;
};
struct local_context struct local_context
{ {
std::unordered_map<float, std::int32_t> f16_constants; std::unordered_map<float, std::int32_t> f16_constants;
@ -28,10 +36,18 @@ namespace pslang::jit::aarch64
std::unordered_map<ast::function_definition const *, std::int32_t> functions; std::unordered_map<ast::function_definition const *, std::int32_t> functions;
struct struct_data
{
std::optional<hfa_data> hfa = {};
};
std::unordered_map<ast::struct_definition const *, struct_data> structs;
struct scope struct scope
{ {
std::unordered_set<std::string> foreign_functions; std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions; std::unordered_map<std::string, ast::function_definition const *> functions;
std::unordered_map<std::string, ast::struct_definition const *> structs;
}; };
std::vector<scope> scopes; std::vector<scope> scopes;
@ -54,6 +70,9 @@ namespace pslang::jit::aarch64
if (it->functions.contains(name)) if (it->functions.contains(name))
return false; return false;
if (it->structs.contains(name))
return false;
} }
return false; return false;
} }
@ -62,11 +81,30 @@ namespace pslang::jit::aarch64
{ {
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{ {
if (auto jt = it->functions.find(name); jt != it->functions.end())
return jt->second;
if (it->foreign_functions.contains(name)) if (it->foreign_functions.contains(name))
return nullptr; return nullptr;
if (auto jt = it->functions.find(name); jt != it->functions.end()) if (it->structs.contains(name))
return nullptr;
}
return nullptr;
}
ast::struct_definition const * is_struct(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->structs.find(name); jt != it->structs.end())
return jt->second; return jt->second;
if (it->foreign_functions.contains(name))
return nullptr;
if (it->functions.contains(name))
return nullptr;
} }
return nullptr; return nullptr;
} }
@ -158,7 +196,8 @@ namespace pslang::jit::aarch64
void apply(ast::function_call const & node) void apply(ast::function_call const & node)
{ {
apply(*node.function); if (node.function)
apply(*node.function);
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
apply(*argument); apply(*argument);
} }
@ -258,6 +297,56 @@ namespace pslang::jit::aarch64
} }
}; };
std::optional<hfa_data> get_hfa_data(ast::struct_definition const & node, local_context & lcontext)
{
if (auto it = lcontext.structs.find(&node); it != lcontext.structs.end())
return it->second.hfa;
types::type_ptr type = nullptr;
std::size_t count = 0;
for (auto const & field : node.fields)
{
if (types::is_builtin_type(*field.inferred_type))
{
if (!types::is_floating_point_type(*field.inferred_type))
return std::nullopt;
if (type && !types::equal(*type, *field.inferred_type))
return std::nullopt;
type = field.inferred_type;
++count;
}
else if (auto named_type = std::get_if<types::named_type>(field.inferred_type.get()))
{
// NB: recursion must be impossible due to prior checks in type checker
if (auto struct_node = lcontext.is_struct(named_type->name))
{
if (auto subdata = get_hfa_data(*struct_node, lcontext))
{
if (type && !types::equal(*type, *subdata->type))
return std::nullopt;
type = subdata->type;
count += subdata->count;
}
else
return std::nullopt;
}
else
throw std::runtime_error("Unknown named type: \"" + named_type->name + "\"");
}
else
return std::nullopt;
}
if (count <= 4)
return hfa_data{type, count};
return std::nullopt;
}
// Iterate over a single scope (i.e. not visiting subscopes recursively) // Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope // and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor struct populate_symbols_visitor
@ -297,7 +386,16 @@ namespace pslang::jit::aarch64
void apply(ast::field_definition const &) {} void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {} void apply(ast::struct_definition const & node)
{
lcontext.scopes.back().structs[node.name] = &node;
if (!lcontext.structs.contains(&node))
{
// NB: make sure not to add struct to lcontext.structs before computing hfa data
auto hfa = get_hfa_data(node, lcontext);
lcontext.structs[&node].hfa = hfa;
}
}
}; };
struct reg_extend_visitor struct reg_extend_visitor
@ -363,9 +461,9 @@ namespace pslang::jit::aarch64
struct variable_data struct variable_data
{ {
// Difference between initial stack pointer at scope enter // Difference between initial stack pointer at function enter
// and the variable address // and the variable address
// Must be a multiple of 8 // Must be a multiple of 16
std::uint32_t frame_offset; std::uint32_t frame_offset;
}; };
@ -450,8 +548,30 @@ namespace pslang::jit::aarch64
{ {
if (auto jt = it->variables.find(node.name); jt != it->variables.end()) if (auto jt = it->variables.find(node.name); jt != it->variables.end())
{ {
if (types::is_floating_point_type(*node.inferred_type)) if (auto named_type = std::get_if<types::named_type>(node.inferred_type.get()))
builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*node.inferred_type)); {
if (auto struct_node = lcontext.is_struct(named_type->name))
{
std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16;
builder.sub_imm(31, 31, stack_size);
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
std::size_t variable_offset = stack_offset - jt->second.frame_offset;
for (std::size_t offset = 0; offset < stack_size; offset += 16)
{
builder.ldr(0, 31, (variable_offset + offset) / 8);
builder.ldr(1, 31, (variable_offset + offset) / 8 + 1);
builder.str(0, 31, offset / 8);
builder.str(1, 31, offset / 8 + 1);
}
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
else if (types::is_unit_type(*node.inferred_type))
{}
else if (types::is_floating_point_type(*node.inferred_type))
builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / type_size(*node.inferred_type));
else else
builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8);
return; return;
@ -767,71 +887,228 @@ namespace pslang::jit::aarch64
void apply(ast::function_call const & node) void apply(ast::function_call const & node)
{ {
apply(*node.function); if (node.function)
push(0);
for (std::size_t i = node.arguments.size(); i --> 0;)
{ {
auto const & arg = node.arguments[i]; apply(*node.function);
apply(*arg); push(0);
auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type)) for (std::size_t i = node.arguments.size(); i --> 0;)
{ {
push(0); auto const & arg = node.arguments[i];
apply(*arg);
auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type))
{
push(0);
}
else if (types::is_floating_point_type(*type))
{
push_fp(0, fp_mode_for(*type));
}
} }
else if (types::is_floating_point_type(*type))
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & arg : node.arguments)
{ {
push_fp(0, fp_mode_for(*type)); auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type))
{
pop(reg);
++reg;
}
else if (types::is_floating_point_type(*type))
{
pop_fp(fp_reg, fp_mode_for(*type));
++fp_reg;
}
}
pop(reg);
push(30);
builder.b_reg(reg);
pop(30);
}
else // if (node.type)
{
if (types::is_unit_type(*node.inferred_type))
{
// Do nothing
}
else if (types::is_bool_type(*node.inferred_type) || types::is_integer_type(*node.inferred_type) || types::is_function_type(*node.inferred_type))
{
builder.xor_reg(0, 0, 0);
}
else if (types::is_floating_point_type(*node.inferred_type))
{
builder.xor_reg(0, 0, 0);
builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1);
}
else if (auto named_type = std::get_if<types::named_type>(node.inferred_type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{
// Allocate stack space for the struct
std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16;
auto offset = stack_offset;
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
builder.sub_imm(31, 31, stack_size);
// Evaluate each field of the struct (i.e. each constructor argument)
// and copy it to the corresponding place in the struct
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto type = ast::get_type(*node.arguments[i]);
apply(*node.arguments[i]);
if (std::get_if<types::named_type>(type.get()))
{
// TODO: struct field
throw std::runtime_error("Not implemented");
}
else if (types::is_floating_point_type(*type))
{
builder.stur_fp(0, fp_mode_for(*type), 31, struct_node->fields[i].layout.offset);
}
else
{
auto size = types::type_size(*type);
if (size == 1)
builder.sturb(0, 31, struct_node->fields[i].layout.offset);
else if (size == 2)
builder.sturh(0, 31, struct_node->fields[i].layout.offset);
else if (size == 4)
builder.sturw(0, 31, struct_node->fields[i].layout.offset);
else if (size == 8)
builder.stur(0, 31, struct_node->fields[i].layout.offset);
}
}
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
}
}
void apply(ast::field_access const & node)
{
auto struct_type = get_type(*node.object);
if (auto named_type = std::get_if<types::named_type>(struct_type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{
std::size_t field_id = -1;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{
if (struct_node->fields[i].name == node.field_name)
{
field_id = i;
break;
}
}
if (field_id == -1)
throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + named_type->name + "\"");
apply(*node.object);
auto stack_size = ((struct_node->layout.size + 15) / 16) * 16;
auto const & field = struct_node->fields[field_id];
if (types::is_unit_type(*field.inferred_type))
{}
else if (types::is_floating_point_type(*field.inferred_type))
{
builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type))
{
auto size = types::type_size(*field.inferred_type);
if (size == 1)
builder.ldurb(0, 31, field.layout.offset);
else if (size == 2)
builder.ldurh(0, 31, field.layout.offset);
else if (size == 4)
builder.ldurw(0, 31, field.layout.offset);
else if (size == 8)
builder.ldur(0, 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (auto named_type = std::get_if<types::named_type>(field.inferred_type.get()))
{
if (auto field_struct_node = lcontext.is_struct(named_type->name))
{
// TODO: copy the struct-typed field on stack, overriding
// the struct itself, and update the stack offset
throw std::runtime_error("Not implemented");
}
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
}
return;
} }
} }
std::uint8_t reg = 0; throw std::runtime_error("Unknown object in field access");
std::uint8_t fp_reg = 0;
for (auto const & arg : node.arguments)
{
auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type))
{
pop(reg);
++reg;
}
else if (types::is_floating_point_type(*type))
{
pop_fp(fp_reg, fp_mode_for(*type));
++fp_reg;
}
}
pop(reg);
push(30);
builder.b_reg(reg);
pop(30);
} }
void apply(ast::expression_ptr const & node) void apply(ast::expression_ptr const & node)
{ {
auto stack_offset_before = stack_offset;
apply(*node); apply(*node);
// Restore stack offset in case the expression evaluated to a struct
// (in which case the struct would be placed on the stack)
auto stack_delta = stack_offset - stack_offset_before;
if (stack_delta > 0)
{
builder.add_imm(31, 31, stack_delta);
stack_offset -= stack_delta;
scopes.back().stack_offset -= stack_delta;
}
} }
void apply(ast::assignment const & node) void apply(ast::assignment const & node)
{ {
auto identifier = std::get_if<ast::identifier>(node.lhs.get()); auto frame_offset = lvalue_offset(node.lhs);
if (!identifier)
throw std::runtime_error("assignment for non-identifier lhs is not implemented");
apply(*node.rhs); apply(*node.rhs);
auto type = ast::get_type(*node.rhs);
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) if (types::is_unit_type(*type))
{}
else if (types::is_floating_point_type(*type))
builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - frame_offset) / type_size(*type));
else if (types::is_bool_type(*type) || types::is_integer_type(*type) || types::is_function_type(*type))
{ {
if (auto jt = it->variables.find(identifier->name); jt != it->variables.end()) auto size = types::type_size(*type);
if (size == 1)
builder.sturb(0, 31, stack_offset - frame_offset);
else if (size == 2)
builder.sturh(0, 31, stack_offset - frame_offset);
else if (size == 4)
builder.sturw(0, 31, stack_offset - frame_offset);
else if (size == 8)
builder.stur(0, 31, stack_offset - frame_offset);
}
else if (auto named_type = std::get_if<types::named_type>(type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{ {
auto type = ast::get_type(*node.rhs); // TODO: whole-struct assignment
if (types::is_floating_point_type(*type)) throw std::runtime_error("Not implemented");
builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*type));
else
builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8);
break;
} }
else
throw std::runtime_error("Unknown type \"" + named_type->name + "\"");
} }
} }
@ -839,8 +1116,18 @@ namespace pslang::jit::aarch64
{ {
apply(*node.initializer); apply(*node.initializer);
auto type = ast::get_type(*node.initializer); auto type = ast::get_type(*node.initializer);
if (types::is_floating_point_type(*type)) if (std::get_if<types::named_type>(type.get()))
{
// Nothing to be done: the struct is already on the stack
// Just record the stack offset as variable location
}
else if (types::is_floating_point_type(*type))
push_fp(0, fp_mode_for(*type)); push_fp(0, fp_mode_for(*type));
else if (types::is_unit_type(*type))
{
// Nothing to be done: unit type has zero size
// Its stack position is recorded but de facto unused
}
else else
push(0); push(0);
scopes.back().variables[node.name] = {.frame_offset = stack_offset}; scopes.back().variables[node.name] = {.frame_offset = stack_offset};
@ -909,6 +1196,7 @@ namespace pslang::jit::aarch64
void apply(ast::return_statement const & node) void apply(ast::return_statement const & node)
{ {
// TODO: struct return value
if (node.value) if (node.value)
apply(*node.value); apply(*node.value);
do_return(); do_return();
@ -924,6 +1212,11 @@ namespace pslang::jit::aarch64
// Must be handled prior to that in populate_symbols_visitor // Must be handled prior to that in populate_symbols_visitor
} }
void apply(ast::struct_definition const &)
{
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::statement_list const & node) void apply(ast::statement_list const & node)
{ {
lcontext.scopes.emplace_back(); lcontext.scopes.emplace_back();
@ -1036,6 +1329,51 @@ namespace pslang::jit::aarch64
stack_offset -= scopes.back().stack_offset; stack_offset -= scopes.back().stack_offset;
} }
} }
// Returns offset from function entry stack frame
std::size_t lvalue_offset(ast::expression_ptr const & node)
{
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto const & scope = scopes[identifier->level];
if (auto it = scope.variables.find(identifier->name); it != scope.variables.end())
return it->second.frame_offset;
throw std::runtime_error("Non-lvalue identifier: \"" + identifier->name + "\"");
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
auto base_offset = lvalue_offset(field_access->object);
auto type = ast::get_type(*field_access->object);
if (auto named_type = std::get_if<types::named_type>(type.get()))
{
if (auto struct_node = lcontext.is_struct(named_type->name))
{
std::size_t field_id = -1;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
if (struct_node->fields[i].name == field_access->field_name)
{
field_id = i;
break;
}
if (field_id == -1)
throw std::runtime_error("Invalid field \"" + field_access->field_name + "\"");
return base_offset - struct_node->fields[field_id].layout.offset;
}
else
throw std::runtime_error("Invalid struct \"" + named_type->name + "\"");
}
else
throw std::runtime_error("Invalid field access node");
}
else if (auto array_access = std::get_if<ast::array_access>(node.get()))
{
throw std::runtime_error("Not implemented");
}
throw std::runtime_error("Unknown lvalue node");
}
}; };
// Main compilation visitor // Main compilation visitor

View file

@ -25,6 +25,26 @@ namespace pslang::jit::aarch64
do_push(0xf8000c00u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); do_push(0xf8000c00u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
} }
void instruction_builder::stur(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0xf8000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::sturw(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0xb8000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0x78000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0x38000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset) void instruction_builder::ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset)
{ {
do_push(0xf9400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0xfffu) << 10)); do_push(0xf9400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0xfffu) << 10));
@ -40,6 +60,26 @@ namespace pslang::jit::aarch64
do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
} }
void instruction_builder::ldur(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0xf8400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldurw(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0xb8400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldurh(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0x78400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldurb(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0x38400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12));
}
void instruction_builder::ldr_pc(std::uint8_t reg_dst, std::int32_t offset) void instruction_builder::ldr_pc(std::uint8_t reg_dst, std::int32_t offset)
{ {
do_push(0x58000000u | (reg_dst & REG_MASK) | ((((std::uint32_t)offset) & 0x7ffffu) << 5)); do_push(0x58000000u | (reg_dst & REG_MASK) | ((((std::uint32_t)offset) & 0x7ffffu) << 5));
@ -194,11 +234,21 @@ namespace pslang::jit::aarch64
do_push(0x3d400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); do_push(0x3d400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
} }
void instruction_builder::ldur_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{
do_push(0x3c400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12) | ((mode & 0x3u) << 30));
}
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{ {
do_push(0x3d000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); do_push(0x3d000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
} }
void instruction_builder::stur_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::int16_t offset)
{
do_push(0x3c000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12) | ((mode & 0x3u) << 30));
}
void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst) void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst)
{ {
do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22)); do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22));
@ -236,7 +286,7 @@ namespace pslang::jit::aarch64
void instruction_builder::fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op) void instruction_builder::fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op)
{ {
do_push(0x9e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16)); do_push(0x1e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (mode == 2 ? 0u : 0x8000000u) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16));
} }
void instruction_builder::scvtf(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode) void instruction_builder::scvtf(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode)

View file

@ -24,6 +24,6 @@ namespace pslang::types
bool is_builtin_type(type const & type); bool is_builtin_type(type const & type);
bool is_function_type(type const & type); bool is_function_type(type const & type);
std::size_t builtin_type_size(type const & type); std::size_t type_size(type const & type);
} }

View file

@ -130,10 +130,13 @@ namespace pslang::types
return false; return false;
} }
std::size_t builtin_type_size(type const & type) std::size_t type_size(type const & type)
{ {
if (std::get_if<unit_type>(&type)) if (std::get_if<unit_type>(&type))
return 1; return 0;
if (std::get_if<function_type>(&type))
return 8;
if (auto ptype = std::get_if<primitive_type>(&type)) if (auto ptype = std::get_if<primitive_type>(&type))
{ {
@ -148,7 +151,7 @@ namespace pslang::types
}, *ptype); }, *ptype);
} }
return 0; throw std::runtime_error("Unknown type");
} }
} }

View file

@ -1,4 +1,5 @@
Future plans: Future plans:
* Split compiler into IR generation, IR optimization, and target-specific bytecode emitter
* Globals (requires a separate mmaped segment in JIT compiler) * Globals (requires a separate mmaped segment in JIT compiler)
* Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter
* Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value * Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value
@ -15,9 +16,26 @@ Interpreter backlog:
* C FFI (foreign functions) * C FFI (foreign functions)
Aarch64 compiler backlog: Aarch64 compiler backlog:
* Structs * Struct fields in structs (initialization & field access)
* Struct function arguments & return values
* Arrays * Arrays
IR outline:
* Doubly-linked list (std::list will do) of instruction nodes
* Nodes reference other nodes for input data or branching target
* No temporary variables, nodes themselves represent data flow
* Assignment expressions turn into nodes that assign to previously-defined nodes
* Jump targets (first `while` node, node after an `if`, etc) are generated as nop nodes, to prevent them from being removed by the optimizer
* Structs & arrays can be split into several independent nodes per-field (before pointers are introduced)
IR optimizations:
* Inlining: track function IR size (number of nodes will do), substitute its code instead of calling if it is small enough (pay attention to recursion)
* Constant folding & propagation: if all node arguments are `const` nodes, replace the current node with the computed value
* Arithmetic simplification: replace a+0 with a, etc
* Branch removal: `if` nodes with condition nodes being `const` are replaced with unconditional jumps
* Jump removal: `jump` that jumps to the immediate successor node is removed
* Dead code elimination: DFS from function `return` nodes & impure function calls, remove all nodes not visited (make sure to not remove function arguments)
General backlog: General backlog:
* Mutually recursive structs (relevant only with pointers) * Mutually recursive structs (relevant only with pointers)
* Empty array expression * Empty array expression