LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/c10/util - BFloat16-inl.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 1 0.0 %
Date: 2025-11-25 13:55:50 Functions: 0 0 -

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <c10/macros/Macros.h>
       4             : #include <c10/util/bit_cast.h>
       5             : 
       6             : #include <limits>
       7             : 
       8             : C10_CLANG_DIAGNOSTIC_PUSH()
       9             : #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
      10             : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
      11             : #endif
      12             : 
      13             : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
      14             : #if defined(CL_SYCL_LANGUAGE_VERSION)
      15             : #include <CL/sycl.hpp> // for SYCL 1.2.1
      16             : #else
      17             : #include <sycl/sycl.hpp> // for SYCL 2020
      18             : #endif
      19             : #include <ext/oneapi/bfloat16.hpp>
      20             : #endif
      21             : 
      22             : namespace c10 {
      23             : 
      24             : /// Constructors
      25             : inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
      26             :     :
      27             : #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
      28             :     __CUDA_ARCH__ >= 800
      29             :       x(__bfloat16_as_ushort(__float2bfloat16(value)))
      30             : #elif defined(__SYCL_DEVICE_ONLY__) && \
      31             :     defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
      32             :       x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
      33             : #else
      34             :       // RNE by default
      35             :       x(detail::round_to_nearest_even(value))
      36             : #endif
      37             : {
      38             : }
      39             : 
      40             : /// Implicit conversions
      41             : inline C10_HOST_DEVICE BFloat16::operator float() const {
      42             : #if defined(__CUDACC__) && !defined(USE_ROCM)
      43             :   return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
      44             : #elif defined(__SYCL_DEVICE_ONLY__) && \
      45             :     defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
      46             :   return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
      47             : #else
      48           0 :   return detail::f32_from_bits(x);
      49             : #endif
      50             : }
      51             : 
      52             : #if defined(__CUDACC__) && !defined(USE_ROCM)
      53             : inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
      54             :   x = *reinterpret_cast<const unsigned short*>(&value);
      55             : }
      56             : inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
      57             :   return *reinterpret_cast<const __nv_bfloat16*>(&x);
      58             : }
      59             : #endif
      60             : 
      61             : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
      62             : inline C10_HOST_DEVICE BFloat16::BFloat16(
      63             :     const sycl::ext::oneapi::bfloat16& value) {
      64             :   x = *reinterpret_cast<const unsigned short*>(&value);
      65             : }
      66             : inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
      67             :   return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
      68             : }
      69             : #endif
      70             : 
      71             : // CUDA intrinsics
      72             : 
      73             : #if defined(__CUDACC__) || defined(__HIPCC__)
      74             : inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
      75             : #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
      76             :   return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
      77             : #else
      78             :   return *ptr;
      79             : #endif
      80             : }
      81             : #endif
      82             : 
      83             : /// Arithmetic
      84             : 
      85             : inline C10_HOST_DEVICE BFloat16
      86             : operator+(const BFloat16& a, const BFloat16& b) {
      87             :   return static_cast<float>(a) + static_cast<float>(b);
      88             : }
      89             : 
      90             : inline C10_HOST_DEVICE BFloat16
      91             : operator-(const BFloat16& a, const BFloat16& b) {
      92             :   return static_cast<float>(a) - static_cast<float>(b);
      93             : }
      94             : 
      95             : inline C10_HOST_DEVICE BFloat16
      96             : operator*(const BFloat16& a, const BFloat16& b) {
      97             :   return static_cast<float>(a) * static_cast<float>(b);
      98             : }
      99             : 
     100             : inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
     101             :     __ubsan_ignore_float_divide_by_zero__ {
     102             :   return static_cast<float>(a) / static_cast<float>(b);
     103             : }
     104             : 
     105             : inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
     106             :   return -static_cast<float>(a);
     107             : }
     108             : 
     109             : inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
     110             :   a = a + b;
     111             :   return a;
     112             : }
     113             : 
     114             : inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
     115             :   a = a - b;
     116             :   return a;
     117             : }
     118             : 
     119             : inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
     120             :   a = a * b;
     121             :   return a;
     122             : }
     123             : 
     124             : inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
     125             :   a = a / b;
     126             :   return a;
     127             : }
     128             : 
     129             : inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
     130             :   a.x = a.x | b.x;
     131             :   return a;
     132             : }
     133             : 
     134             : inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
     135             :   a.x = a.x ^ b.x;
     136             :   return a;
     137             : }
     138             : 
     139             : inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
     140             :   a.x = a.x & b.x;
     141             :   return a;
     142             : }
     143             : 
     144             : /// Arithmetic with floats
     145             : 
     146             : inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
     147             :   return static_cast<float>(a) + b;
     148             : }
     149             : inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
     150             :   return static_cast<float>(a) - b;
     151             : }
     152             : inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
     153             :   return static_cast<float>(a) * b;
     154             : }
     155             : inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
     156             :   return static_cast<float>(a) / b;
     157             : }
     158             : 
     159             : inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
     160             :   return a + static_cast<float>(b);
     161             : }
     162             : inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
     163             :   return a - static_cast<float>(b);
     164             : }
     165             : inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
     166             :   return a * static_cast<float>(b);
     167             : }
     168             : inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
     169             :   return a / static_cast<float>(b);
     170             : }
     171             : 
     172             : inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
     173             :   return a += static_cast<float>(b);
     174             : }
     175             : inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
     176             :   return a -= static_cast<float>(b);
     177             : }
     178             : inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
     179             :   return a *= static_cast<float>(b);
     180             : }
     181             : inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
     182             :   return a /= static_cast<float>(b);
     183             : }
     184             : 
     185             : /// Arithmetic with doubles
     186             : 
     187             : inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
     188             :   return static_cast<double>(a) + b;
     189             : }
     190             : inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
     191             :   return static_cast<double>(a) - b;
     192             : }
     193             : inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
     194             :   return static_cast<double>(a) * b;
     195             : }
     196             : inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
     197             :   return static_cast<double>(a) / b;
     198             : }
     199             : 
     200             : inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
     201             :   return a + static_cast<double>(b);
     202             : }
     203             : inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
     204             :   return a - static_cast<double>(b);
     205             : }
     206             : inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
     207             :   return a * static_cast<double>(b);
     208             : }
     209             : inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
     210             :   return a / static_cast<double>(b);
     211             : }
     212             : 
     213             : /// Arithmetic with ints
     214             : 
     215             : inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
     216             :   return a + static_cast<BFloat16>(b);
     217             : }
     218             : inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
     219             :   return a - static_cast<BFloat16>(b);
     220             : }
     221             : inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
     222             :   return a * static_cast<BFloat16>(b);
     223             : }
     224             : inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
     225             :   return a / static_cast<BFloat16>(b);
     226             : }
     227             : 
     228             : inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
     229             :   return static_cast<BFloat16>(a) + b;
     230             : }
     231             : inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
     232             :   return static_cast<BFloat16>(a) - b;
     233             : }
     234             : inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
     235             :   return static_cast<BFloat16>(a) * b;
     236             : }
     237             : inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
     238             :   return static_cast<BFloat16>(a) / b;
     239             : }
     240             : 
     241             : //// Arithmetic with int64_t
     242             : 
     243             : inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
     244             :   return a + static_cast<BFloat16>(b);
     245             : }
     246             : inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
     247             :   return a - static_cast<BFloat16>(b);
     248             : }
     249             : inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
     250             :   return a * static_cast<BFloat16>(b);
     251             : }
     252             : inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
     253             :   return a / static_cast<BFloat16>(b);
     254             : }
     255             : 
     256             : inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
     257             :   return static_cast<BFloat16>(a) + b;
     258             : }
     259             : inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
     260             :   return static_cast<BFloat16>(a) - b;
     261             : }
     262             : inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
     263             :   return static_cast<BFloat16>(a) * b;
     264             : }
     265             : inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
     266             :   return static_cast<BFloat16>(a) / b;
     267             : }
     268             : 
     269             : // Overloading < and > operators, because std::max and std::min use them.
     270             : 
     271             : inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
     272             :   return float(lhs) > float(rhs);
     273             : }
     274             : 
     275             : inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
     276             :   return float(lhs) < float(rhs);
     277             : }
     278             : 
     279             : } // namespace c10
     280             : 
     281             : namespace std {
     282             : 
     283             : template <>
     284             : class numeric_limits<c10::BFloat16> {
     285             :  public:
     286             :   static constexpr bool is_signed = true;
     287             :   static constexpr bool is_specialized = true;
     288             :   static constexpr bool is_integer = false;
     289             :   static constexpr bool is_exact = false;
     290             :   static constexpr bool has_infinity = true;
     291             :   static constexpr bool has_quiet_NaN = true;
     292             :   static constexpr bool has_signaling_NaN = true;
     293             :   static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
     294             :   static constexpr auto has_denorm_loss =
     295             :       numeric_limits<float>::has_denorm_loss;
     296             :   static constexpr auto round_style = numeric_limits<float>::round_style;
     297             :   static constexpr bool is_iec559 = false;
     298             :   static constexpr bool is_bounded = true;
     299             :   static constexpr bool is_modulo = false;
     300             :   static constexpr int digits = 8;
     301             :   static constexpr int digits10 = 2;
     302             :   static constexpr int max_digits10 = 4;
     303             :   static constexpr int radix = 2;
     304             :   static constexpr int min_exponent = -125;
     305             :   static constexpr int min_exponent10 = -37;
     306             :   static constexpr int max_exponent = 128;
     307             :   static constexpr int max_exponent10 = 38;
     308             :   static constexpr auto traps = numeric_limits<float>::traps;
     309             :   static constexpr auto tinyness_before =
     310             :       numeric_limits<float>::tinyness_before;
     311             : 
     312             :   static constexpr c10::BFloat16 min() {
     313             :     return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
     314             :   }
     315             :   static constexpr c10::BFloat16 lowest() {
     316             :     return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
     317             :   }
     318             :   static constexpr c10::BFloat16 max() {
     319             :     return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
     320             :   }
     321             :   static constexpr c10::BFloat16 epsilon() {
     322             :     return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
     323             :   }
     324             :   static constexpr c10::BFloat16 round_error() {
     325             :     return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
     326             :   }
     327             :   static constexpr c10::BFloat16 infinity() {
     328             :     return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
     329             :   }
     330             :   static constexpr c10::BFloat16 quiet_NaN() {
     331             :     return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
     332             :   }
     333             :   static constexpr c10::BFloat16 signaling_NaN() {
     334             :     return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
     335             :   }
     336             :   static constexpr c10::BFloat16 denorm_min() {
     337             :     return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
     338             :   }
     339             : };
     340             : 
     341             : } // namespace std
     342             : 
     343             : C10_CLANG_DIAGNOSTIC_POP()

Generated by: LCOV version 1.16