Line data Source code
1 : #pragma once 2 : 3 : // Defines the bloat16 type (brain floating-point). This representation uses 4 : // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. 5 : 6 : #include <c10/macros/Macros.h> 7 : #include <cmath> 8 : #include <cstdint> 9 : #include <cstring> 10 : #include <iosfwd> 11 : #include <ostream> 12 : 13 : #if defined(__CUDACC__) && !defined(USE_ROCM) 14 : #include <cuda_bf16.h> 15 : #endif 16 : 17 : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 18 : #if defined(CL_SYCL_LANGUAGE_VERSION) 19 : #include <CL/sycl.hpp> // for SYCL 1.2.1 20 : #else 21 : #include <sycl/sycl.hpp> // for SYCL 2020 22 : #endif 23 : #include <ext/oneapi/bfloat16.hpp> 24 : #endif 25 : 26 : namespace c10 { 27 : 28 : namespace detail { 29 : inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { 30 : float res = 0; 31 0 : uint32_t tmp = src; 32 0 : tmp <<= 16; 33 : 34 : #if defined(USE_ROCM) 35 : float* tempRes; 36 : 37 : // We should be using memcpy in order to respect the strict aliasing rule 38 : // but it fails in the HIP environment. 39 : tempRes = reinterpret_cast<float*>(&tmp); 40 : res = *tempRes; 41 : #else 42 : std::memcpy(&res, &tmp, sizeof(tmp)); 43 : #endif 44 : 45 : return res; 46 : } 47 : 48 : inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { 49 : uint32_t res = 0; 50 : 51 : #if defined(USE_ROCM) 52 : // We should be using memcpy in order to respect the strict aliasing rule 53 : // but it fails in the HIP environment. 54 : uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); 55 : res = *tempRes; 56 : #else 57 : std::memcpy(&res, &src, sizeof(res)); 58 : #endif 59 : 60 : return res >> 16; 61 : } 62 : 63 : inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { 64 : #if defined(USE_ROCM) 65 : if (src != src) { 66 : #elif defined(_MSC_VER) 67 : if (isnan(src)) { 68 : #else 69 0 : if (std::isnan(src)) { 70 : #endif 71 : return UINT16_C(0x7FC0); 72 : } else { 73 : // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 74 : union { 75 : uint32_t U32; // NOLINT(facebook-hte-BadMemberName) 76 : float F32; // NOLINT(facebook-hte-BadMemberName) 77 : }; 78 : 79 0 : F32 = src; 80 0 : uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); 81 0 : return static_cast<uint16_t>((U32 + rounding_bias) >> 16); 82 : } 83 : } 84 : } // namespace detail 85 : 86 : struct alignas(2) BFloat16 { 87 : uint16_t x; 88 : 89 : // HIP wants __host__ __device__ tag, CUDA does not 90 : #if defined(USE_ROCM) 91 : C10_HOST_DEVICE BFloat16() = default; 92 : #else 93 : BFloat16() = default; 94 : #endif 95 : 96 : struct from_bits_t {}; 97 : static constexpr C10_HOST_DEVICE from_bits_t from_bits() { 98 : return from_bits_t(); 99 : } 100 : 101 : constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) 102 : : x(bits) {} 103 : /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); 104 : inline C10_HOST_DEVICE operator float() const; 105 : 106 : #if defined(__CUDACC__) && !defined(USE_ROCM) 107 : inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); 108 : explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; 109 : #endif 110 : 111 : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 112 : inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); 113 : explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; 114 : #endif 115 : }; 116 : 117 : C10_API inline std::ostream& operator<<( 118 : std::ostream& out, 119 : const BFloat16& value) { 120 : out << (float)value; 121 0 : return out; 122 : } 123 : 124 : } // namespace c10 125 : 126 : #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep