Add f16 type & literals support in parser, type checker & interpreter
This commit is contained in:
parent
0d87d35c47
commit
7ddc8ba25d
10 changed files with 87 additions and 3 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <pslang/types/half_float.hpp>
|
||||||
#include <pslang/ast/location.hpp>
|
#include <pslang/ast/location.hpp>
|
||||||
|
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
@ -26,6 +27,7 @@ namespace pslang::ast
|
||||||
using i64_literal = primitive_literal_base<std::int64_t>;
|
using i64_literal = primitive_literal_base<std::int64_t>;
|
||||||
using u64_literal = primitive_literal_base<std::uint64_t>;
|
using u64_literal = primitive_literal_base<std::uint64_t>;
|
||||||
|
|
||||||
|
using f16_literal = primitive_literal_base<types::half_float>;
|
||||||
using f32_literal = primitive_literal_base<float>;
|
using f32_literal = primitive_literal_base<float>;
|
||||||
using f64_literal = primitive_literal_base<double>;
|
using f64_literal = primitive_literal_base<double>;
|
||||||
|
|
||||||
|
|
@ -39,6 +41,7 @@ namespace pslang::ast
|
||||||
u32_literal,
|
u32_literal,
|
||||||
i64_literal,
|
i64_literal,
|
||||||
u64_literal,
|
u64_literal,
|
||||||
|
f16_literal,
|
||||||
f32_literal,
|
f32_literal,
|
||||||
f64_literal
|
f64_literal
|
||||||
>;
|
>;
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,12 @@ namespace pslang::ast
|
||||||
out << "u64 literal { value = " << node.value << " }\n";
|
out << "u64 literal { value = " << node.value << " }\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(f16_literal const & node)
|
||||||
|
{
|
||||||
|
put_indent(out, options);
|
||||||
|
out << "f32 literal { value = " << std::setprecision(3) << node.value.repr << " }\n";
|
||||||
|
}
|
||||||
|
|
||||||
void apply(f32_literal const & node)
|
void apply(f32_literal const & node)
|
||||||
{
|
{
|
||||||
put_indent(out, options);
|
put_indent(out, options);
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ namespace pslang::interpreter
|
||||||
using i64_value = primitive_value_base<std::int64_t>;
|
using i64_value = primitive_value_base<std::int64_t>;
|
||||||
using u64_value = primitive_value_base<std::uint64_t>;
|
using u64_value = primitive_value_base<std::uint64_t>;
|
||||||
|
|
||||||
|
using f16_value = primitive_value_base<types::half_float>;
|
||||||
using f32_value = primitive_value_base<float>;
|
using f32_value = primitive_value_base<float>;
|
||||||
using f64_value = primitive_value_base<double>;
|
using f64_value = primitive_value_base<double>;
|
||||||
|
|
||||||
|
|
@ -47,6 +48,7 @@ namespace pslang::interpreter
|
||||||
u32_value,
|
u32_value,
|
||||||
i64_value,
|
i64_value,
|
||||||
u64_value,
|
u64_value,
|
||||||
|
f16_value,
|
||||||
f32_value,
|
f32_value,
|
||||||
f64_value
|
f64_value
|
||||||
>;
|
>;
|
||||||
|
|
|
||||||
|
|
@ -336,7 +336,18 @@ namespace pslang::interpreter
|
||||||
}
|
}
|
||||||
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
|
else if constexpr (!std::is_same_v<T, bool> && !std::is_same_v<bool, H>)
|
||||||
{
|
{
|
||||||
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
|
if constexpr (std::is_same_v<T, types::half_float>)
|
||||||
|
{
|
||||||
|
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value.repr)});
|
||||||
|
}
|
||||||
|
else if constexpr (std::is_same_v<H, types::half_float>)
|
||||||
|
{
|
||||||
|
return primitive_value(primitive_value_base<H>{{static_cast<float>(value.value)}});
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return primitive_value(primitive_value_base<H>{static_cast<H>(value.value)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ i32 { return bp::make_i32(ctx.location); }
|
||||||
u32 { return bp::make_u32(ctx.location); }
|
u32 { return bp::make_u32(ctx.location); }
|
||||||
i64 { return bp::make_i64(ctx.location); }
|
i64 { return bp::make_i64(ctx.location); }
|
||||||
u64 { return bp::make_u64(ctx.location); }
|
u64 { return bp::make_u64(ctx.location); }
|
||||||
|
f16 { return bp::make_f16(ctx.location); }
|
||||||
f32 { return bp::make_f32(ctx.location); }
|
f32 { return bp::make_f32(ctx.location); }
|
||||||
f64 { return bp::make_f64(ctx.location); }
|
f64 { return bp::make_f64(ctx.location); }
|
||||||
|
|
||||||
|
|
@ -88,6 +89,7 @@ f64 { return bp::make_f64(ctx.location); }
|
||||||
[0-9]+ul { return bp::make_lit_u64(yytext, ctx.location); }
|
[0-9]+ul { return bp::make_lit_u64(yytext, ctx.location); }
|
||||||
|
|
||||||
[0-9]+\.[0-9]+ { return bp::make_lit_f32(yytext, ctx.location); }
|
[0-9]+\.[0-9]+ { return bp::make_lit_f32(yytext, ctx.location); }
|
||||||
|
[0-9]+\.[0-9]+h { return bp::make_lit_f16(yytext, ctx.location); }
|
||||||
[0-9]+\.[0-9]+l { return bp::make_lit_f64(yytext, ctx.location); }
|
[0-9]+\.[0-9]+l { return bp::make_lit_f64(yytext, ctx.location); }
|
||||||
|
|
||||||
'.' { return bp::make_lit_char8(yytext[1], ctx.location); }
|
'.' { return bp::make_lit_char8(yytext[1], ctx.location); }
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,9 @@ template <typename T>
|
||||||
{
|
{
|
||||||
T value;
|
T value;
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
if constexpr (std::is_floating_point_v<T>)
|
if constexpr (std::is_same_v<T, ::pslang::types::half_float>)
|
||||||
|
value = {float(std::atof(str.data()))};
|
||||||
|
else if constexpr (std::is_floating_point_v<T>)
|
||||||
value = std::atof(str.data());
|
value = std::atof(str.data());
|
||||||
else if constexpr (std::is_signed_v<T>)
|
else if constexpr (std::is_signed_v<T>)
|
||||||
value = std::strtoll(str.data(), nullptr, 10);
|
value = std::strtoll(str.data(), nullptr, 10);
|
||||||
|
|
@ -137,6 +139,7 @@ template <typename T>
|
||||||
%token u32
|
%token u32
|
||||||
%token i64
|
%token i64
|
||||||
%token u64
|
%token u64
|
||||||
|
%token f16
|
||||||
%token f32
|
%token f32
|
||||||
%token f64
|
%token f64
|
||||||
|
|
||||||
|
|
@ -259,6 +262,7 @@ primitive_type
|
||||||
| u32 { $$ = types::u32_type{}; }
|
| u32 { $$ = types::u32_type{}; }
|
||||||
| i64 { $$ = types::i64_type{}; }
|
| i64 { $$ = types::i64_type{}; }
|
||||||
| u64 { $$ = types::u64_type{}; }
|
| u64 { $$ = types::u64_type{}; }
|
||||||
|
| f16 { $$ = types::f16_type{}; }
|
||||||
| f32 { $$ = types::f32_type{}; }
|
| f32 { $$ = types::f32_type{}; }
|
||||||
| f64 { $$ = types::f64_type{}; }
|
| f64 { $$ = types::f64_type{}; }
|
||||||
;
|
;
|
||||||
|
|
@ -309,6 +313,7 @@ postfix_expression
|
||||||
| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; }
|
| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; }
|
||||||
| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; }
|
| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; }
|
||||||
| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; }
|
| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; }
|
||||||
|
| f16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f16_type{}), $3, @$}; }
|
||||||
| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; }
|
| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; }
|
||||||
| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; }
|
| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; }
|
||||||
;
|
;
|
||||||
|
|
@ -341,6 +346,7 @@ literal
|
||||||
| lit_u32 { $$ = parse_primitive_literal<std::uint32_t>($1, ctx.location); }
|
| lit_u32 { $$ = parse_primitive_literal<std::uint32_t>($1, ctx.location); }
|
||||||
| lit_i64 { $$ = parse_primitive_literal<std::int64_t>($1, ctx.location); }
|
| lit_i64 { $$ = parse_primitive_literal<std::int64_t>($1, ctx.location); }
|
||||||
| lit_u64 { $$ = parse_primitive_literal<std::uint64_t>($1, ctx.location); }
|
| lit_u64 { $$ = parse_primitive_literal<std::uint64_t>($1, ctx.location); }
|
||||||
|
| lit_f16 { $$ = parse_primitive_literal<types::half_float>($1, ctx.location); }
|
||||||
| lit_f32 { $$ = parse_primitive_literal<float>($1, ctx.location); }
|
| lit_f32 { $$ = parse_primitive_literal<float>($1, ctx.location); }
|
||||||
| lit_f64 { $$ = parse_primitive_literal<double>($1, ctx.location); }
|
| lit_f64 { $$ = parse_primitive_literal<double>($1, ctx.location); }
|
||||||
| lit_char8 { $$ = ast::u8_literal{static_cast<std::uint8_t>($1), ctx.location}; }
|
| lit_char8 { $$ = ast::u8_literal{static_cast<std::uint8_t>($1), ctx.location}; }
|
||||||
|
|
|
||||||
45
libs/types/include/pslang/types/half_float.hpp
Normal file
45
libs/types/include/pslang/types/half_float.hpp
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <pslang/types/half_float.hpp>
|
||||||
|
|
||||||
|
#include <compare>
|
||||||
|
|
||||||
|
namespace pslang::types
|
||||||
|
{
|
||||||
|
|
||||||
|
// TODO: actual half-float operations? Maybe C++23 std::float16_t?
|
||||||
|
|
||||||
|
struct half_float
|
||||||
|
{
|
||||||
|
float repr = 0.f;
|
||||||
|
|
||||||
|
friend bool operator == (half_float const &, half_float const &) = default;
|
||||||
|
friend auto operator <=> (half_float const &, half_float const &) = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline half_float operator - (half_float f)
|
||||||
|
{
|
||||||
|
return {-f.repr};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline half_float operator + (half_float f1, half_float f2)
|
||||||
|
{
|
||||||
|
return {f1.repr + f2.repr};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline half_float operator - (half_float f1, half_float f2)
|
||||||
|
{
|
||||||
|
return {f1.repr - f2.repr};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline half_float operator * (half_float f1, half_float f2)
|
||||||
|
{
|
||||||
|
return {f1.repr * f2.repr};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline half_float operator / (half_float f1, half_float f2)
|
||||||
|
{
|
||||||
|
return {f1.repr / f2.repr};
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <pslang/types/half_float.hpp>
|
||||||
|
|
||||||
#include <variant>
|
#include <variant>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
|
@ -29,6 +31,7 @@ namespace pslang::types
|
||||||
using i64_type = primitive_type_base<std::int64_t>;
|
using i64_type = primitive_type_base<std::int64_t>;
|
||||||
using u64_type = primitive_type_base<std::uint64_t>;
|
using u64_type = primitive_type_base<std::uint64_t>;
|
||||||
|
|
||||||
|
using f16_type = primitive_type_base<half_float>;
|
||||||
using f32_type = primitive_type_base<float>;
|
using f32_type = primitive_type_base<float>;
|
||||||
using f64_type = primitive_type_base<double>;
|
using f64_type = primitive_type_base<double>;
|
||||||
|
|
||||||
|
|
@ -42,6 +45,7 @@ namespace pslang::types
|
||||||
u32_type,
|
u32_type,
|
||||||
i64_type,
|
i64_type,
|
||||||
u64_type,
|
u64_type,
|
||||||
|
f16_type,
|
||||||
f32_type,
|
f32_type,
|
||||||
f64_type
|
f64_type
|
||||||
>;
|
>;
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,11 @@ namespace pslang::types
|
||||||
out << "u64";
|
out << "u64";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(f16_type const &)
|
||||||
|
{
|
||||||
|
out << "f16";
|
||||||
|
}
|
||||||
|
|
||||||
void apply(f32_type const &)
|
void apply(f32_type const &)
|
||||||
{
|
{
|
||||||
out << "f32";
|
out << "f32";
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ namespace pslang::types
|
||||||
{
|
{
|
||||||
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
return std::visit([]<typename T>(primitive_type_base<T> const &)
|
||||||
{
|
{
|
||||||
return std::is_floating_point_v<T>;
|
return std::is_floating_point_v<T> || std::is_same_v<T, types::half_float>;
|
||||||
}, *ptype);
|
}, *ptype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue