Add f16 type & literals support in parser, type checker & interpreter

This commit is contained in:
Nikita Lisitsa 2026-01-05 00:36:11 +03:00
parent 0d87d35c47
commit 7ddc8ba25d
10 changed files with 87 additions and 3 deletions

View file

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <pslang/types/half_float.hpp>
#include <pslang/ast/location.hpp> #include <pslang/ast/location.hpp>
#include <variant> #include <variant>
@ -26,6 +27,7 @@ namespace pslang::ast
using i64_literal = primitive_literal_base<std::int64_t>; using i64_literal = primitive_literal_base<std::int64_t>;
using u64_literal = primitive_literal_base<std::uint64_t>; using u64_literal = primitive_literal_base<std::uint64_t>;
using f16_literal = primitive_literal_base<types::half_float>;
using f32_literal = primitive_literal_base<float>; using f32_literal = primitive_literal_base<float>;
using f64_literal = primitive_literal_base<double>; using f64_literal = primitive_literal_base<double>;
@ -39,6 +41,7 @@ namespace pslang::ast
u32_literal, u32_literal,
i64_literal, i64_literal,
u64_literal, u64_literal,
f16_literal,
f32_literal, f32_literal,
f64_literal f64_literal
>; >;

View file

@ -148,6 +148,12 @@ namespace pslang::ast
out << "u64 literal { value = " << node.value << " }\n"; out << "u64 literal { value = " << node.value << " }\n";
} }
void apply(f16_literal const & node)
{
put_indent(out, options);
out << "f32 literal { value = " << std::setprecision(3) << node.value.repr << " }\n";
}
void apply(f32_literal const & node) void apply(f32_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);

View file

@ -34,6 +34,7 @@ namespace pslang::interpreter
using i64_value = primitive_value_base<std::int64_t>; using i64_value = primitive_value_base<std::int64_t>;
using u64_value = primitive_value_base<std::uint64_t>; using u64_value = primitive_value_base<std::uint64_t>;
using f16_value = primitive_value_base<types::half_float>;
using f32_value = primitive_value_base<float>; using f32_value = primitive_value_base<float>;
using f64_value = primitive_value_base<double>; using f64_value = primitive_value_base<double>;
@ -47,6 +48,7 @@ namespace pslang::interpreter
u32_value, u32_value,
i64_value, i64_value,
u64_value, u64_value,
f16_value,
f32_value, f32_value,
f64_value f64_value
>; >;

View file

@ -336,7 +336,18 @@ namespace pslang::interpreter
} }
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>) else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
{ {
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)}); if constexpr (std::is_same_v<T, types::half_float>)
{
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value.repr)});
}
else if constexpr (std::is_same_v<H, types::half_float>)
{
return primitive_value(primitive_value_base<H>{{static_cast<float>(value.value)}});
}
else
{
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
}
} }
std::ostringstream os; std::ostringstream os;

View file

@ -46,6 +46,7 @@ i32 { return bp::make_i32(ctx.location); }
u32 { return bp::make_u32(ctx.location); } u32 { return bp::make_u32(ctx.location); }
i64 { return bp::make_i64(ctx.location); } i64 { return bp::make_i64(ctx.location); }
u64 { return bp::make_u64(ctx.location); } u64 { return bp::make_u64(ctx.location); }
f16 { return bp::make_f16(ctx.location); }
f32 { return bp::make_f32(ctx.location); } f32 { return bp::make_f32(ctx.location); }
f64 { return bp::make_f64(ctx.location); } f64 { return bp::make_f64(ctx.location); }
@ -88,6 +89,7 @@ f64 { return bp::make_f64(ctx.location); }
[0-9]+ul { return bp::make_lit_u64(yytext, ctx.location); } [0-9]+ul { return bp::make_lit_u64(yytext, ctx.location); }
[0-9]+\.[0-9]+ { return bp::make_lit_f32(yytext, ctx.location); } [0-9]+\.[0-9]+ { return bp::make_lit_f32(yytext, ctx.location); }
[0-9]+\.[0-9]+h { return bp::make_lit_f16(yytext, ctx.location); }
[0-9]+\.[0-9]+l { return bp::make_lit_f64(yytext, ctx.location); } [0-9]+\.[0-9]+l { return bp::make_lit_f64(yytext, ctx.location); }
'.' { return bp::make_lit_char8(yytext[1], ctx.location); } '.' { return bp::make_lit_char8(yytext[1], ctx.location); }

View file

@ -48,7 +48,9 @@ template <typename T>
{ {
T value; T value;
#ifdef __clang__ #ifdef __clang__
if constexpr (std::is_floating_point_v<T>) if constexpr (std::is_same_v<T, ::pslang::types::half_float>)
value = {float(std::atof(str.data()))};
else if constexpr (std::is_floating_point_v<T>)
value = std::atof(str.data()); value = std::atof(str.data());
else if constexpr (std::is_signed_v<T>) else if constexpr (std::is_signed_v<T>)
value = std::strtoll(str.data(), nullptr, 10); value = std::strtoll(str.data(), nullptr, 10);
@ -137,6 +139,7 @@ template <typename T>
%token u32 %token u32
%token i64 %token i64
%token u64 %token u64
%token f16
%token f32 %token f32
%token f64 %token f64
@ -259,6 +262,7 @@ primitive_type
| u32 { $$ = types::u32_type{}; } | u32 { $$ = types::u32_type{}; }
| i64 { $$ = types::i64_type{}; } | i64 { $$ = types::i64_type{}; }
| u64 { $$ = types::u64_type{}; } | u64 { $$ = types::u64_type{}; }
| f16 { $$ = types::f16_type{}; }
| f32 { $$ = types::f32_type{}; } | f32 { $$ = types::f32_type{}; }
| f64 { $$ = types::f64_type{}; } | f64 { $$ = types::f64_type{}; }
; ;
@ -309,6 +313,7 @@ postfix_expression
| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; } | u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; }
| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; } | i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; }
| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; } | u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; }
| f16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f16_type{}), $3, @$}; }
| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; } | f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; }
| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; } | f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; }
; ;
@ -341,6 +346,7 @@ literal
| lit_u32 { $$ = parse_primitive_literal<std::uint32_t>($1, ctx.location); } | lit_u32 { $$ = parse_primitive_literal<std::uint32_t>($1, ctx.location); }
| lit_i64 { $$ = parse_primitive_literal<std::int64_t>($1, ctx.location); } | lit_i64 { $$ = parse_primitive_literal<std::int64_t>($1, ctx.location); }
| lit_u64 { $$ = parse_primitive_literal<std::uint64_t>($1, ctx.location); } | lit_u64 { $$ = parse_primitive_literal<std::uint64_t>($1, ctx.location); }
| lit_f16 { $$ = parse_primitive_literal<types::half_float>($1, ctx.location); }
| lit_f32 { $$ = parse_primitive_literal<float>($1, ctx.location); } | lit_f32 { $$ = parse_primitive_literal<float>($1, ctx.location); }
| lit_f64 { $$ = parse_primitive_literal<double>($1, ctx.location); } | lit_f64 { $$ = parse_primitive_literal<double>($1, ctx.location); }
| lit_char8 { $$ = ast::u8_literal{static_cast<std::uint8_t>($1), ctx.location}; } | lit_char8 { $$ = ast::u8_literal{static_cast<std::uint8_t>($1), ctx.location}; }

View file

@ -0,0 +1,45 @@
#pragma once
#include <pslang/types/half_float.hpp>
#include <compare>
namespace pslang::types
{
// TODO: actual half-float operations? Maybe C++23 std::float16_t?
struct half_float
{
float repr = 0.f;
friend bool operator == (half_float const &, half_float const &) = default;
friend auto operator <=> (half_float const &, half_float const &) = default;
};
inline half_float operator - (half_float f)
{
return {-f.repr};
}
inline half_float operator + (half_float f1, half_float f2)
{
return {f1.repr + f2.repr};
}
inline half_float operator - (half_float f1, half_float f2)
{
return {f1.repr - f2.repr};
}
inline half_float operator * (half_float f1, half_float f2)
{
return {f1.repr * f2.repr};
}
inline half_float operator / (half_float f1, half_float f2)
{
return {f1.repr / f2.repr};
}
}

View file

@ -1,5 +1,7 @@
#pragma once #pragma once
#include <pslang/types/half_float.hpp>
#include <variant> #include <variant>
#include <cstdint> #include <cstdint>
@ -29,6 +31,7 @@ namespace pslang::types
using i64_type = primitive_type_base<std::int64_t>; using i64_type = primitive_type_base<std::int64_t>;
using u64_type = primitive_type_base<std::uint64_t>; using u64_type = primitive_type_base<std::uint64_t>;
using f16_type = primitive_type_base<half_float>;
using f32_type = primitive_type_base<float>; using f32_type = primitive_type_base<float>;
using f64_type = primitive_type_base<double>; using f64_type = primitive_type_base<double>;
@ -42,6 +45,7 @@ namespace pslang::types
u32_type, u32_type,
i64_type, i64_type,
u64_type, u64_type,
f16_type,
f32_type, f32_type,
f64_type f64_type
>; >;

View file

@ -64,6 +64,11 @@ namespace pslang::types
out << "u64"; out << "u64";
} }
void apply(f16_type const &)
{
out << "f16";
}
void apply(f32_type const &) void apply(f32_type const &)
{ {
out << "f32"; out << "f32";

View file

@ -93,7 +93,7 @@ namespace pslang::types
{ {
return std::visit([]<typename T>(primitive_type_base<T> const &) return std::visit([]<typename T>(primitive_type_base<T> const &)
{ {
return std::is_floating_point_v<T>; return std::is_floating_point_v<T> || std::is_same_v<T, types::half_float>;
}, *ptype); }, *ptype);
} }