LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/c10/util - Half.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 23 0.0 %
Date: 2025-11-25 13:55:50 Functions: 0 3 0.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : /// Defines the Half type (half-precision floating-point) including conversions
       4             : /// to standard C types and basic arithmetic operations. Note that arithmetic
       5             : /// operations are implemented by converting to floating point and
       6             : /// performing the operation in float32, instead of using CUDA half intrinsics.
       7             : /// Most uses of this type within ATen are memory bound, including the
       8             : /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
       9             : /// If you are writing a compute bound kernel, you can use the CUDA half
      10             : /// intrinsics directly on the Half type from device code.
      11             : 
      12             : #include <c10/macros/Export.h>
      13             : #include <c10/macros/Macros.h>
      14             : #include <c10/util/bit_cast.h>
      15             : #include <c10/util/floating_point_utils.h>
      16             : #include <type_traits>
      17             : 
      18             : #if defined(__cplusplus)
      19             : #include <cmath>
      20             : #elif !defined(__OPENCL_VERSION__)
      21             : #include <math.h>
      22             : #endif
      23             : 
      24             : #ifdef _MSC_VER
      25             : #include <intrin.h>
      26             : #endif
      27             : 
      28             : #include <cstdint>
      29             : #include <cstring>
      30             : #include <iosfwd>
      31             : #include <limits>
      32             : #include <ostream>
      33             : 
      34             : #ifdef __CUDACC__
      35             : #include <cuda_fp16.h>
      36             : #endif
      37             : 
      38             : #ifdef __HIPCC__
      39             : #include <hip/hip_fp16.h>
      40             : #endif
      41             : 
      42             : #if defined(CL_SYCL_LANGUAGE_VERSION)
      43             : #include <CL/sycl.hpp> // for SYCL 1.2.1
      44             : #elif defined(SYCL_LANGUAGE_VERSION)
      45             : #include <sycl/sycl.hpp> // for SYCL 2020
      46             : #endif
      47             : 
      48             : #if defined(__aarch64__) && !defined(__CUDACC__)
      49             : #include <arm_neon.h>
      50             : #endif
      51             : 
      52             : #if defined(__GNUC__) || defined(__clang__)
      53             : #if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
      54             :     defined(_M_IX86)
      55             : #if defined(__F16C__) &&                               \
      56             :     !(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \
      57             :       defined(__HIP_DEVICE_COMPILE__))
      58             : #define C10_X86_F16 1
      59             : #include <immintrin.h> // import conversion ops from f16cintrin.h
      60             : #endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__)
      61             :        // || defined(__HIP_DEVICE_COMPILE__))
      62             : #endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
      63             : #endif // __GNUC__ || __clang__
      64             : 
      65             : namespace c10 {
      66             : 
      67             : namespace detail {
      68             : 
      69             : /*
      70             :  * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
      71             :  * representation, to a 32-bit floating-point number in IEEE single-precision
      72             :  * format, in bit representation.
      73             :  *
      74             :  * @note The implementation doesn't use any floating-point operations.
      75             :  */
      76             : inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
      77             :   /*
      78             :    * Extend the half-precision floating-point number to 32 bits and shift to the
      79             :    * upper part of the 32-bit word:
      80             :    *      +---+-----+------------+-------------------+
      81             :    *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
      82             :    *      +---+-----+------------+-------------------+
      83             :    * Bits  31  26-30    16-25            0-15
      84             :    *
      85             :    * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
      86             :    * - zero bits.
      87             :    */
      88             :   const uint32_t w = (uint32_t)h << 16;
      89             :   /*
      90             :    * Extract the sign of the input number into the high bit of the 32-bit word:
      91             :    *
      92             :    *      +---+----------------------------------+
      93             :    *      | S |0000000 00000000 00000000 00000000|
      94             :    *      +---+----------------------------------+
      95             :    * Bits  31                 0-31
      96             :    */
      97             :   const uint32_t sign = w & UINT32_C(0x80000000);
      98             :   /*
      99             :    * Extract mantissa and biased exponent of the input number into the bits 0-30
     100             :    * of the 32-bit word:
     101             :    *
     102             :    *      +---+-----+------------+-------------------+
     103             :    *      | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
     104             :    *      +---+-----+------------+-------------------+
     105             :    * Bits  30  27-31     17-26            0-16
     106             :    */
     107             :   const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
     108             :   /*
     109             :    * Renorm shift is the number of bits to shift mantissa left to make the
     110             :    * half-precision number normalized. If the initial number is normalized, some
     111             :    * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
     112             :    * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
     113             :    * that if we shift denormalized nonsign by renorm_shift, the unit bit of
     114             :    * mantissa will shift into exponent, turning the biased exponent into 1, and
     115             :    * making mantissa normalized (i.e. without leading 1).
     116             :    */
     117             : #ifdef _MSC_VER
     118             :   unsigned long nonsign_bsr;
     119             :   _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
     120             :   uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
     121             : #else
     122             :   uint32_t renorm_shift = __builtin_clz(nonsign);
     123             : #endif
     124             :   renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
     125             :   /*
     126             :    * Iff half-precision number has exponent of 15, the addition overflows
     127             :    * it into bit 31, and the subsequent shift turns the high 9 bits
     128             :    * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
     129             :    * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
     130             :    */
     131             :   const int32_t inf_nan_mask =
     132             :       ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
     133             :   /*
     134             :    * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
     135             :    * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
     136             :    * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
     137             :    * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
     138             :    * 0x00000000 otherwise
     139             :    */
     140             :   const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
     141             :   /*
     142             :    * 1. Shift nonsign left by renorm_shift to normalize it (if the input
     143             :    * was denormal)
     144             :    * 2. Shift nonsign right by 3 so the exponent (5 bits originally)
     145             :    * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
     146             :    * bits of the 23-bit mantissa of IEEE single-precision number.
     147             :    * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
     148             :    * different in exponent bias (0x7F for single-precision number less 0xF
     149             :    * for half-precision number).
     150             :    * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
     151             :    * account for renormalization. As renorm_shift is less than 0x70, this
     152             :    * can be combined with step 3.
     153             :    * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
     154             :    * input was NaN or infinity.
     155             :    * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
     156             :    * into zero if the input was zero.
     157             :    * 7. Combine with the sign of the input number.
     158             :    */
     159             :   return sign |
     160             :       ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
     161             :         inf_nan_mask) &
     162             :        ~zero_mask);
     163             : }
     164             : 
     165             : /*
     166             :  * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
     167             :  * representation, to a 32-bit floating-point number in IEEE single-precision
     168             :  * format.
     169             :  *
     170             :  * @note The implementation relies on IEEE-like (no assumption about rounding
     171             :  * mode and no operations on denormals) floating-point operations and bitcasts
     172             :  * between integer and floating-point variables.
     173             :  */
     174           0 : C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
     175             : #ifdef C10_X86_F16
     176             :   return _cvtsh_ss(h);
     177             : #else
     178             :   /*
     179             :    * Extend the half-precision floating-point number to 32 bits and shift to the
     180             :    * upper part of the 32-bit word:
     181             :    *      +---+-----+------------+-------------------+
     182             :    *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
     183             :    *      +---+-----+------------+-------------------+
     184             :    * Bits  31  26-30    16-25            0-15
     185             :    *
     186             :    * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
     187             :    * - zero bits.
     188             :    */
     189           0 :   const uint32_t w = (uint32_t)h << 16;
     190             :   /*
     191             :    * Extract the sign of the input number into the high bit of the 32-bit word:
     192             :    *
     193             :    *      +---+----------------------------------+
     194             :    *      | S |0000000 00000000 00000000 00000000|
     195             :    *      +---+----------------------------------+
     196             :    * Bits  31                 0-31
     197             :    */
     198           0 :   const uint32_t sign = w & UINT32_C(0x80000000);
     199             :   /*
     200             :    * Extract mantissa and biased exponent of the input number into the high bits
     201             :    * of the 32-bit word:
     202             :    *
     203             :    *      +-----+------------+---------------------+
     204             :    *      |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
     205             :    *      +-----+------------+---------------------+
     206             :    * Bits  27-31    17-26            0-16
     207             :    */
     208           0 :   const uint32_t two_w = w + w;
     209             : 
     210             :   /*
     211             :    * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
     212             :    * mantissa and exponent of a single-precision floating-point number:
     213             :    *
     214             :    *       S|Exponent |          Mantissa
     215             :    *      +-+---+-----+------------+----------------+
     216             :    *      |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
     217             :    *      +-+---+-----+------------+----------------+
     218             :    * Bits   | 23-31   |           0-22
     219             :    *
     220             :    * Next, there are some adjustments to the exponent:
     221             :    * - The exponent needs to be corrected by the difference in exponent bias
     222             :    * between single-precision and half-precision formats (0x7F - 0xF = 0x70)
     223             :    * - Inf and NaN values in the inputs should become Inf and NaN values after
     224             :    * conversion to the single-precision number. Therefore, if the biased
     225             :    * exponent of the half-precision input was 0x1F (max possible value), the
     226             :    * biased exponent of the single-precision output must be 0xFF (max possible
     227             :    * value). We do this correction in two steps:
     228             :    *   - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
     229             :    * below) rather than by 0x70 suggested by the difference in the exponent bias
     230             :    * (see above).
     231             :    *   - Then we multiply the single-precision result of exponent adjustment by
     232             :    * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
     233             :    * necessary exponent adjustment by 0x70 due to difference in exponent bias.
     234             :    *     The floating-point multiplication hardware would ensure than Inf and
     235             :    * NaN would retain their value on at least partially IEEE754-compliant
     236             :    * implementations.
     237             :    *
     238             :    * Note that the above operations do not handle denormal inputs (where biased
     239             :    * exponent == 0). However, they also do not operate on denormal inputs, and
     240             :    * do not produce denormal results.
     241             :    */
     242             :   constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
     243             :   // const float exp_scale = 0x1.0p-112f;
     244             :   constexpr uint32_t scale_bits = (uint32_t)15 << 23;
     245             :   float exp_scale_val = 0;
     246             : #if defined(_MSC_VER) && defined(__clang__)
     247             :   __builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
     248             : #else
     249             :   std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
     250             : #endif
     251             : 
     252             :   const float exp_scale = exp_scale_val;
     253             :   const float normalized_value =
     254           0 :       fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
     255             : 
     256             :   /*
     257             :    * Convert denormalized half-precision inputs into single-precision results
     258             :    * (always normalized). Zero inputs are also handled here.
     259             :    *
     260             :    * In a denormalized number the biased exponent is zero, and mantissa has
     261             :    * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
     262             :    *
     263             :    *                  zeros           |  mantissa
     264             :    *      +---------------------------+------------+
     265             :    *      |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
     266             :    *      +---------------------------+------------+
     267             :    * Bits             10-31                0-9
     268             :    *
     269             :    * Now, remember that denormalized half-precision numbers are represented as:
     270             :    *    FP16 = mantissa * 2**(-24).
     271             :    * The trick is to construct a normalized single-precision number with the
     272             :    * same mantissa and thehalf-precision input and with an exponent which would
     273             :    * scale the corresponding mantissa bits to 2**(-24). A normalized
     274             :    * single-precision floating-point number is represented as: FP32 = (1 +
     275             :    * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
     276             :    * exponent is 126, a unit change in the mantissa of the input denormalized
     277             :    * half-precision number causes a change of the constructed single-precision
     278             :    * number by 2**(-24), i.e. the same amount.
     279             :    *
     280             :    * The last step is to adjust the bias of the constructed single-precision
     281             :    * number. When the input half-precision number is zero, the constructed
     282             :    * single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
     283             :    * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
     284             :    * single-precision number to get the numerical equivalent of the input
     285             :    * half-precision number.
     286             :    */
     287             :   constexpr uint32_t magic_mask = UINT32_C(126) << 23;
     288             :   constexpr float magic_bias = 0.5f;
     289             :   const float denormalized_value =
     290           0 :       fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
     291             : 
     292             :   /*
     293             :    * - Choose either results of conversion of input as a normalized number, or
     294             :    * as a denormalized number, depending on the input exponent. The variable
     295             :    * two_w contains input exponent in bits 27-31, therefore if its smaller than
     296             :    * 2**27, the input is either a denormal number, or zero.
     297             :    * - Combine the result of conversion of exponent and mantissa with the sign
     298             :    * of the input number.
     299             :    */
     300             :   constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
     301             :   const uint32_t result = sign |
     302           0 :       (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
     303           0 :                                    : fp32_to_bits(normalized_value));
     304           0 :   return fp32_from_bits(result);
     305             : #endif // C10_X86_F16
     306             : }
     307             : 
     308             : /*
     309             :  * Convert a 32-bit floating-point number in IEEE single-precision format to a
     310             :  * 16-bit floating-point number in IEEE half-precision format, in bit
     311             :  * representation.
     312             :  *
     313             :  * @note The implementation relies on IEEE-like (no assumption about rounding
     314             :  * mode and no operations on denormals) floating-point operations and bitcasts
     315             :  * between integer and floating-point variables.
     316             :  */
     317           0 : inline uint16_t fp16_ieee_from_fp32_value(float f) {
     318             : #ifdef C10_X86_F16
     319             :   return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
     320             : #else
     321             :   // const float scale_to_inf = 0x1.0p+112f;
     322             :   // const float scale_to_zero = 0x1.0p-110f;
     323             :   constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
     324             :   constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
     325             :   float scale_to_inf_val = 0, scale_to_zero_val = 0;
     326             :   std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
     327             :   std::memcpy(
     328             :       &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
     329             :   const float scale_to_inf = scale_to_inf_val;
     330             :   const float scale_to_zero = scale_to_zero_val;
     331             : 
     332             : #if defined(_MSC_VER) && _MSC_VER == 1916
     333             :   float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
     334             : #else
     335           0 :   float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
     336             : #endif
     337             : 
     338           0 :   const uint32_t w = fp32_to_bits(f);
     339           0 :   const uint32_t shl1_w = w + w;
     340             :   const uint32_t sign = w & UINT32_C(0x80000000);
     341           0 :   uint32_t bias = shl1_w & UINT32_C(0xFF000000);
     342             :   if (bias < UINT32_C(0x71000000)) {
     343             :     bias = UINT32_C(0x71000000);
     344             :   }
     345             : 
     346           0 :   base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
     347           0 :   const uint32_t bits = fp32_to_bits(base);
     348           0 :   const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
     349           0 :   const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
     350           0 :   const uint32_t nonsign = exp_bits + mantissa_bits;
     351             :   return static_cast<uint16_t>(
     352           0 :       (sign >> 16) |
     353           0 :       (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
     354             : #endif // C10_X86_F16
     355             : }
     356             : 
     357             : #ifdef C10_X86_F16
     358             : #undef C10_X86_F16
     359             : #endif // C10_X86_F16
     360             : 
     361             : #if defined(__aarch64__) && !defined(__CUDACC__)
     362             : inline float16_t fp16_from_bits(uint16_t h) {
     363             :   return c10::bit_cast<float16_t>(h);
     364             : }
     365             : 
     366             : inline uint16_t fp16_to_bits(float16_t f) {
     367             :   return c10::bit_cast<uint16_t>(f);
     368             : }
     369             : 
     370             : // According to https://godbolt.org/z/frExdbsWG it would translate to single
     371             : // fcvt s0, h0
     372             : inline float native_fp16_to_fp32_value(uint16_t h) {
     373             :   return static_cast<float>(fp16_from_bits(h));
     374             : }
     375             : 
     376             : inline uint16_t native_fp16_from_fp32_value(float f) {
     377             :   return fp16_to_bits(static_cast<float16_t>(f));
     378             : }
     379             : #endif
     380             : 
     381             : } // namespace detail
     382             : 
     383             : struct alignas(2) Half {
     384             :   unsigned short x;
     385             : 
     386             :   struct from_bits_t {};
     387             :   C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
     388             :     return from_bits_t();
     389             :   }
     390             : 
     391             :   // HIP wants __host__ __device__ tag, CUDA does not
     392             : #if defined(USE_ROCM)
     393             :   C10_HOST_DEVICE Half() = default;
     394             : #else
     395             :   Half() = default;
     396             : #endif
     397             : 
     398             :   constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {}
     399             : #if defined(__aarch64__) && !defined(__CUDACC__)
     400             :   inline Half(float16_t value);
     401             :   inline operator float16_t() const;
     402             : #else
     403             :   inline C10_HOST_DEVICE Half(float value);
     404             :   inline C10_HOST_DEVICE operator float() const;
     405             : #endif
     406             : 
     407             : #if defined(__CUDACC__) || defined(__HIPCC__)
     408             :   inline C10_HOST_DEVICE Half(const __half& value);
     409             :   inline C10_HOST_DEVICE operator __half() const;
     410             : #endif
     411             : #ifdef SYCL_LANGUAGE_VERSION
     412             :   inline C10_HOST_DEVICE Half(const sycl::half& value);
     413             :   inline C10_HOST_DEVICE operator sycl::half() const;
     414             : #endif
     415             : };
     416             : 
     417           0 : C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
     418             :   out << (float)value;
     419           0 :   return out;
     420             : }
     421             : 
     422             : } // namespace c10
     423             : 
     424             : #include <c10/util/Half-inl.h> // IWYU pragma: keep

Generated by: LCOV version 1.16