LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/c10/util - BFloat16.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 7 0.0 %
Date: 2025-12-04 11:19:34 Functions: 0 0 -

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : // Defines the bloat16 type (brain floating-point). This representation uses
       4             : // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
       5             : 
       6             : #include <c10/macros/Macros.h>
       7             : #include <cmath>
       8             : #include <cstdint>
       9             : #include <cstring>
      10             : #include <iosfwd>
      11             : #include <ostream>
      12             : 
      13             : #if defined(__CUDACC__) && !defined(USE_ROCM)
      14             : #include <cuda_bf16.h>
      15             : #endif
      16             : 
      17             : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
      18             : #if defined(CL_SYCL_LANGUAGE_VERSION)
      19             : #include <CL/sycl.hpp> // for SYCL 1.2.1
      20             : #else
      21             : #include <sycl/sycl.hpp> // for SYCL 2020
      22             : #endif
      23             : #include <ext/oneapi/bfloat16.hpp>
      24             : #endif
      25             : 
      26             : namespace c10 {
      27             : 
      28             : namespace detail {
      29             : inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
      30             :   float res = 0;
      31           0 :   uint32_t tmp = src;
      32           0 :   tmp <<= 16;
      33             : 
      34             : #if defined(USE_ROCM)
      35             :   float* tempRes;
      36             : 
      37             :   // We should be using memcpy in order to respect the strict aliasing rule
      38             :   // but it fails in the HIP environment.
      39             :   tempRes = reinterpret_cast<float*>(&tmp);
      40             :   res = *tempRes;
      41             : #else
      42             :   std::memcpy(&res, &tmp, sizeof(tmp));
      43             : #endif
      44             : 
      45             :   return res;
      46             : }
      47             : 
      48             : inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
      49             :   uint32_t res = 0;
      50             : 
      51             : #if defined(USE_ROCM)
      52             :   // We should be using memcpy in order to respect the strict aliasing rule
      53             :   // but it fails in the HIP environment.
      54             :   uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
      55             :   res = *tempRes;
      56             : #else
      57             :   std::memcpy(&res, &src, sizeof(res));
      58             : #endif
      59             : 
      60             :   return res >> 16;
      61             : }
      62             : 
      63             : inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
      64             : #if defined(USE_ROCM)
      65             :   if (src != src) {
      66             : #elif defined(_MSC_VER)
      67             :   if (isnan(src)) {
      68             : #else
      69           0 :   if (std::isnan(src)) {
      70             : #endif
      71             :     return UINT16_C(0x7FC0);
      72             :   } else {
      73             :     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
      74             :     union {
      75             :       uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
      76             :       float F32; // NOLINT(facebook-hte-BadMemberName)
      77             :     };
      78             : 
      79           0 :     F32 = src;
      80           0 :     uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
      81           0 :     return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
      82             :   }
      83             : }
      84             : } // namespace detail
      85             : 
      86             : struct alignas(2) BFloat16 {
      87             :   uint16_t x;
      88             : 
      89             :   // HIP wants __host__ __device__ tag, CUDA does not
      90             : #if defined(USE_ROCM)
      91             :   C10_HOST_DEVICE BFloat16() = default;
      92             : #else
      93             :   BFloat16() = default;
      94             : #endif
      95             : 
      96             :   struct from_bits_t {};
      97             :   static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
      98             :     return from_bits_t();
      99             :   }
     100             : 
     101             :   constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
     102             :       : x(bits) {}
     103             :   /* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
     104             :   inline C10_HOST_DEVICE operator float() const;
     105             : 
     106             : #if defined(__CUDACC__) && !defined(USE_ROCM)
     107             :   inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
     108             :   explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
     109             : #endif
     110             : 
     111             : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
     112             :   inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
     113             :   explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
     114             : #endif
     115             : };
     116             : 
     117             : C10_API inline std::ostream& operator<<(
     118             :     std::ostream& out,
     119             :     const BFloat16& value) {
     120             :   out << (float)value;
     121           0 :   return out;
     122             : }
     123             : 
     124             : } // namespace c10
     125             : 
     126             : #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep

Generated by: LCOV version 1.16