diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 4932c36..45f425b 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -14,8 +14,6 @@ #include #include -#include - std::string extract_nth_line(std::filesystem::path const & path, std::size_t n) { std::ifstream file(path); @@ -171,9 +169,9 @@ int main(int argc, char ** argv) for (auto const & module : modules) { // TODO: remove, testing-only code; should execute entry point instead - auto offset = module.code.symbol_table.at("test"); - auto fptr = (float(*)(float))(module.code.memory.data.get() + offset); - auto x = fptr(6.f); + auto offset = module.code.symbol_table.at("pow"); + auto fptr = (float(*)(float, unsigned))(module.code.memory.data.get() + offset); + auto x = fptr(1.5f, 12); std::cout << "Result: " << std::boolalpha << x << std::endl; } } diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 6d87861..2f9beda 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,2 +1,10 @@ -func test(x : f32) -> f32: - return x * 1.5 +func pow(x : f32, n : u32) -> f32: + mut r = 1.0 + mut f = x // x^k + mut k = 1u + while k <= n: + if (n & k) != 0u: + r = r * f + f = f * f + k = k + k + return r diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 5c9188d..2dcbde7 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -333,7 +333,10 @@ namespace pslang::jit::aarch64 { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) { - builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); + if (types::is_floating_point_type(*node.inferred_type)) + builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / 4); + else + builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } @@ -537,7 +540,11 @@ namespace pslang::jit::aarch64 { if (auto jt = it->variables.find(identifier->name); jt != it->variables.end()) { - builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8); + auto type = ast::get_type(*node.rhs); + if (types::is_floating_point_type(*type)) + builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - jt->second.frame_offset) / 4); + else + builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } @@ -546,7 +553,11 @@ namespace pslang::jit::aarch64 void apply(ast::variable_declaration const & node) { apply(*node.initializer); - push(0); + auto type = ast::get_type(*node.initializer); + if (types::is_floating_point_type(*type)) + push_fp(0, fp_mode_for(*type)); + else + push(0); scopes.back().variables[node.name] = {.frame_offset = stack_offset}; }