Add f16 type & literals support in parser, type checker & interpreter
This commit is contained in:
parent
0d87d35c47
commit
7ddc8ba25d
10 changed files with 87 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <pslang/types/half_float.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
|
||||
#include <variant>
|
||||
|
|
@ -26,6 +27,7 @@ namespace pslang::ast
|
|||
using i64_literal = primitive_literal_base<std::int64_t>;
|
||||
using u64_literal = primitive_literal_base<std::uint64_t>;
|
||||
|
||||
using f16_literal = primitive_literal_base<types::half_float>;
|
||||
using f32_literal = primitive_literal_base<float>;
|
||||
using f64_literal = primitive_literal_base<double>;
|
||||
|
||||
|
|
@ -39,6 +41,7 @@ namespace pslang::ast
|
|||
u32_literal,
|
||||
i64_literal,
|
||||
u64_literal,
|
||||
f16_literal,
|
||||
f32_literal,
|
||||
f64_literal
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -148,6 +148,12 @@ namespace pslang::ast
|
|||
out << "u64 literal { value = " << node.value << " }\n";
|
||||
}
|
||||
|
||||
void apply(f16_literal const & node)
|
||||
{
|
||||
put_indent(out, options);
|
||||
out << "f32 literal { value = " << std::setprecision(3) << node.value.repr << " }\n";
|
||||
}
|
||||
|
||||
void apply(f32_literal const & node)
|
||||
{
|
||||
put_indent(out, options);
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ namespace pslang::interpreter
|
|||
using i64_value = primitive_value_base<std::int64_t>;
|
||||
using u64_value = primitive_value_base<std::uint64_t>;
|
||||
|
||||
using f16_value = primitive_value_base<types::half_float>;
|
||||
using f32_value = primitive_value_base<float>;
|
||||
using f64_value = primitive_value_base<double>;
|
||||
|
||||
|
|
@ -47,6 +48,7 @@ namespace pslang::interpreter
|
|||
u32_value,
|
||||
i64_value,
|
||||
u64_value,
|
||||
f16_value,
|
||||
f32_value,
|
||||
f64_value
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -336,7 +336,18 @@ namespace pslang::interpreter
|
|||
}
|
||||
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
|
||||
{
|
||||
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
|
||||
if constexpr (std::is_same_v<T, types::half_float>)
|
||||
{
|
||||
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value.repr)});
|
||||
}
|
||||
else if constexpr (std::is_same_v<H, types::half_float>)
|
||||
{
|
||||
return primitive_value(primitive_value_base<H>{{static_cast<float>(value.value)}});
|
||||
}
|
||||
else
|
||||
{
|
||||
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream os;
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ i32 { return bp::make_i32(ctx.location); }
|
|||
u32 { return bp::make_u32(ctx.location); }
|
||||
i64 { return bp::make_i64(ctx.location); }
|
||||
u64 { return bp::make_u64(ctx.location); }
|
||||
f16 { return bp::make_f16(ctx.location); }
|
||||
f32 { return bp::make_f32(ctx.location); }
|
||||
f64 { return bp::make_f64(ctx.location); }
|
||||
|
||||
|
|
@ -88,6 +89,7 @@ f64 { return bp::make_f64(ctx.location); }
|
|||
[0-9]+ul { return bp::make_lit_u64(yytext, ctx.location); }
|
||||
|
||||
[0-9]+\.[0-9]+ { return bp::make_lit_f32(yytext, ctx.location); }
|
||||
[0-9]+\.[0-9]+h { return bp::make_lit_f16(yytext, ctx.location); }
|
||||
[0-9]+\.[0-9]+l { return bp::make_lit_f64(yytext, ctx.location); }
|
||||
|
||||
'.' { return bp::make_lit_char8(yytext[1], ctx.location); }
|
||||
|
|
|
|||
|
|
@ -48,7 +48,9 @@ template <typename T>
|
|||
{
|
||||
T value;
|
||||
#ifdef __clang__
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
if constexpr (std::is_same_v<T, ::pslang::types::half_float>)
|
||||
value = {float(std::atof(str.data()))};
|
||||
else if constexpr (std::is_floating_point_v<T>)
|
||||
value = std::atof(str.data());
|
||||
else if constexpr (std::is_signed_v<T>)
|
||||
value = std::strtoll(str.data(), nullptr, 10);
|
||||
|
|
@ -137,6 +139,7 @@ template <typename T>
|
|||
%token u32
|
||||
%token i64
|
||||
%token u64
|
||||
%token f16
|
||||
%token f32
|
||||
%token f64
|
||||
|
||||
|
|
@ -259,6 +262,7 @@ primitive_type
|
|||
| u32 { $$ = types::u32_type{}; }
|
||||
| i64 { $$ = types::i64_type{}; }
|
||||
| u64 { $$ = types::u64_type{}; }
|
||||
| f16 { $$ = types::f16_type{}; }
|
||||
| f32 { $$ = types::f32_type{}; }
|
||||
| f64 { $$ = types::f64_type{}; }
|
||||
;
|
||||
|
|
@ -309,6 +313,7 @@ postfix_expression
|
|||
| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; }
|
||||
| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; }
|
||||
| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; }
|
||||
| f16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f16_type{}), $3, @$}; }
|
||||
| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; }
|
||||
| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; }
|
||||
;
|
||||
|
|
@ -341,6 +346,7 @@ literal
|
|||
| lit_u32 { $$ = parse_primitive_literal<std::uint32_t>($1, ctx.location); }
|
||||
| lit_i64 { $$ = parse_primitive_literal<std::int64_t>($1, ctx.location); }
|
||||
| lit_u64 { $$ = parse_primitive_literal<std::uint64_t>($1, ctx.location); }
|
||||
| lit_f16 { $$ = parse_primitive_literal<types::half_float>($1, ctx.location); }
|
||||
| lit_f32 { $$ = parse_primitive_literal<float>($1, ctx.location); }
|
||||
| lit_f64 { $$ = parse_primitive_literal<double>($1, ctx.location); }
|
||||
| lit_char8 { $$ = ast::u8_literal{static_cast<std::uint8_t>($1), ctx.location}; }
|
||||
|
|
|
|||
45
libs/types/include/pslang/types/half_float.hpp
Normal file
45
libs/types/include/pslang/types/half_float.hpp
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
|
||||
#include <pslang/types/half_float.hpp>
|
||||
|
||||
#include <compare>
|
||||
|
||||
namespace pslang::types
|
||||
{
|
||||
|
||||
// TODO: actual half-float operations? Maybe C++23 std::float16_t?
|
||||
|
||||
struct half_float
|
||||
{
|
||||
float repr = 0.f;
|
||||
|
||||
friend bool operator == (half_float const &, half_float const &) = default;
|
||||
friend auto operator <=> (half_float const &, half_float const &) = default;
|
||||
};
|
||||
|
||||
inline half_float operator - (half_float f)
|
||||
{
|
||||
return {-f.repr};
|
||||
}
|
||||
|
||||
inline half_float operator + (half_float f1, half_float f2)
|
||||
{
|
||||
return {f1.repr + f2.repr};
|
||||
}
|
||||
|
||||
inline half_float operator - (half_float f1, half_float f2)
|
||||
{
|
||||
return {f1.repr - f2.repr};
|
||||
}
|
||||
|
||||
inline half_float operator * (half_float f1, half_float f2)
|
||||
{
|
||||
return {f1.repr * f2.repr};
|
||||
}
|
||||
|
||||
inline half_float operator / (half_float f1, half_float f2)
|
||||
{
|
||||
return {f1.repr / f2.repr};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <pslang/types/half_float.hpp>
|
||||
|
||||
#include <variant>
|
||||
#include <cstdint>
|
||||
|
||||
|
|
@ -29,6 +31,7 @@ namespace pslang::types
|
|||
using i64_type = primitive_type_base<std::int64_t>;
|
||||
using u64_type = primitive_type_base<std::uint64_t>;
|
||||
|
||||
using f16_type = primitive_type_base<half_float>;
|
||||
using f32_type = primitive_type_base<float>;
|
||||
using f64_type = primitive_type_base<double>;
|
||||
|
||||
|
|
@ -42,6 +45,7 @@ namespace pslang::types
|
|||
u32_type,
|
||||
i64_type,
|
||||
u64_type,
|
||||
f16_type,
|
||||
f32_type,
|
||||
f64_type
|
||||
>;
|
||||
|
|
|
|||
|
|
@ -64,6 +64,11 @@ namespace pslang::types
|
|||
out << "u64";
|
||||
}
|
||||
|
||||
void apply(f16_type const &)
|
||||
{
|
||||
out << "f16";
|
||||
}
|
||||
|
||||
void apply(f32_type const &)
|
||||
{
|
||||
out << "f32";
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ namespace pslang::types
|
|||
{
|
||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||
{
|
||||
return std::is_floating_point_v<T>;
|
||||
return std::is_floating_point_v<T> || std::is_same_v<T, types::half_float>;
|
||||
}, *ptype);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue