LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/c10/util - Half-inl.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 2 0.0 %
Date: 2025-12-04 11:19:34 Functions: 0 0 -

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <c10/macros/Macros.h>
       4             : #include <c10/util/bit_cast.h>
       5             : 
       6             : #include <cstring>
       7             : #include <limits>
       8             : 
       9             : #ifdef __CUDACC__
      10             : #include <cuda_fp16.h>
      11             : #endif
      12             : 
      13             : #ifdef __HIPCC__
      14             : #include <hip/hip_fp16.h>
      15             : #endif
      16             : 
      17             : #if defined(CL_SYCL_LANGUAGE_VERSION)
      18             : #include <CL/sycl.hpp> // for SYCL 1.2.1
      19             : #elif defined(SYCL_LANGUAGE_VERSION)
      20             : #include <sycl/sycl.hpp> // for SYCL 2020
      21             : #endif
      22             : 
      23             : #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
      24             :     !defined(__APPLE__)
      25             : #include <ATen/cpu/vec/vec_half.h>
      26             : #endif
      27             : 
      28             : C10_CLANG_DIAGNOSTIC_PUSH()
      29             : #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
      30             : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
      31             : #endif
      32             : 
      33             : namespace c10 {
      34             : 
      35             : #if defined(__aarch64__) && !defined(__CUDACC__)
      36             : /// Constructors
      37             : inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {}
      38             : inline Half::operator float16_t() const {
      39             :   return detail::fp16_from_bits(x);
      40             : }
      41             : #else
      42             : 
      43             : inline C10_HOST_DEVICE Half::Half(float value)
      44             :     :
      45             : #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
      46             :       x(__half_as_short(__float2half(value)))
      47             : #elif defined(__SYCL_DEVICE_ONLY__)
      48             :       x(c10::bit_cast<uint16_t>(sycl::half(value)))
      49             : #elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
      50             :     !defined(__APPLE__)
      51             :       x(at::vec::float2half_scalar(value))
      52             : #else
      53           0 :       x(detail::fp16_ieee_from_fp32_value(value))
      54             : #endif
      55             : {
      56             : }
      57             : 
      58             : /// Implicit conversions
      59             : 
      60             : inline C10_HOST_DEVICE Half::operator float() const {
      61             : #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
      62             :   return __half2float(*reinterpret_cast<const __half*>(&x));
      63             : #elif defined(__SYCL_DEVICE_ONLY__)
      64             :   return float(c10::bit_cast<sycl::half>(x));
      65             : #elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
      66             :     !defined(__APPLE__)
      67             :   return at::vec::half2float_scalar(x);
      68             : #elif defined(__aarch64__) && !defined(__CUDACC__)
      69             :   return detail::native_fp16_to_fp32_value(x);
      70             : #else
      71           0 :   return detail::fp16_ieee_to_fp32_value(x);
      72             : #endif
      73             : }
      74             : 
      75             : #endif /* !defined(__aarch64__) || defined(__CUDACC__) \
      76             :         */
      77             : 
      78             : #if defined(__CUDACC__) || defined(__HIPCC__)
      79             : inline C10_HOST_DEVICE Half::Half(const __half& value) {
      80             :   x = *reinterpret_cast<const unsigned short*>(&value);
      81             : }
      82             : inline C10_HOST_DEVICE Half::operator __half() const {
      83             :   return *reinterpret_cast<const __half*>(&x);
      84             : }
      85             : #endif
      86             : 
      87             : #ifdef SYCL_LANGUAGE_VERSION
      88             : inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
      89             :   x = *reinterpret_cast<const unsigned short*>(&value);
      90             : }
      91             : inline C10_HOST_DEVICE Half::operator sycl::half() const {
      92             :   return *reinterpret_cast<const sycl::half*>(&x);
      93             : }
      94             : #endif
      95             : 
      96             : // CUDA intrinsics
      97             : 
      98             : #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
      99             :     (defined(__clang__) && defined(__CUDA__))
     100             : inline __device__ Half __ldg(const Half* ptr) {
     101             :   return __ldg(reinterpret_cast<const __half*>(ptr));
     102             : }
     103             : #endif
     104             : 
     105             : /// Arithmetic
     106             : 
     107             : inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
     108             :   return static_cast<float>(a) + static_cast<float>(b);
     109             : }
     110             : 
     111             : inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
     112             :   return static_cast<float>(a) - static_cast<float>(b);
     113             : }
     114             : 
     115             : inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
     116             :   return static_cast<float>(a) * static_cast<float>(b);
     117             : }
     118             : 
     119             : inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b)
     120             :     __ubsan_ignore_float_divide_by_zero__ {
     121             :   return static_cast<float>(a) / static_cast<float>(b);
     122             : }
     123             : 
     124             : inline C10_HOST_DEVICE Half operator-(const Half& a) {
     125             : #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
     126             :     defined(__HIP_DEVICE_COMPILE__)
     127             :   return __hneg(a);
     128             : #elif defined(__SYCL_DEVICE_ONLY__)
     129             :   return -c10::bit_cast<sycl::half>(a);
     130             : #else
     131             :   return -static_cast<float>(a);
     132             : #endif
     133             : }
     134             : 
     135             : inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
     136             :   a = a + b;
     137             :   return a;
     138             : }
     139             : 
     140             : inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
     141             :   a = a - b;
     142             :   return a;
     143             : }
     144             : 
     145             : inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
     146             :   a = a * b;
     147             :   return a;
     148             : }
     149             : 
     150             : inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
     151             :   a = a / b;
     152             :   return a;
     153             : }
     154             : 
     155             : /// Arithmetic with floats
     156             : 
     157             : inline C10_HOST_DEVICE float operator+(Half a, float b) {
     158             :   return static_cast<float>(a) + b;
     159             : }
     160             : inline C10_HOST_DEVICE float operator-(Half a, float b) {
     161             :   return static_cast<float>(a) - b;
     162             : }
     163             : inline C10_HOST_DEVICE float operator*(Half a, float b) {
     164             :   return static_cast<float>(a) * b;
     165             : }
     166             : inline C10_HOST_DEVICE float operator/(Half a, float b)
     167             :     __ubsan_ignore_float_divide_by_zero__ {
     168             :   return static_cast<float>(a) / b;
     169             : }
     170             : 
     171             : inline C10_HOST_DEVICE float operator+(float a, Half b) {
     172             :   return a + static_cast<float>(b);
     173             : }
     174             : inline C10_HOST_DEVICE float operator-(float a, Half b) {
     175             :   return a - static_cast<float>(b);
     176             : }
     177             : inline C10_HOST_DEVICE float operator*(float a, Half b) {
     178             :   return a * static_cast<float>(b);
     179             : }
     180             : inline C10_HOST_DEVICE float operator/(float a, Half b)
     181             :     __ubsan_ignore_float_divide_by_zero__ {
     182             :   return a / static_cast<float>(b);
     183             : }
     184             : 
     185             : inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
     186             :   return a += static_cast<float>(b);
     187             : }
     188             : inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
     189             :   return a -= static_cast<float>(b);
     190             : }
     191             : inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
     192             :   return a *= static_cast<float>(b);
     193             : }
     194             : inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
     195             :   return a /= static_cast<float>(b);
     196             : }
     197             : 
     198             : /// Arithmetic with doubles
     199             : 
     200             : inline C10_HOST_DEVICE double operator+(Half a, double b) {
     201             :   return static_cast<double>(a) + b;
     202             : }
     203             : inline C10_HOST_DEVICE double operator-(Half a, double b) {
     204             :   return static_cast<double>(a) - b;
     205             : }
     206             : inline C10_HOST_DEVICE double operator*(Half a, double b) {
     207             :   return static_cast<double>(a) * b;
     208             : }
     209             : inline C10_HOST_DEVICE double operator/(Half a, double b)
     210             :     __ubsan_ignore_float_divide_by_zero__ {
     211             :   return static_cast<double>(a) / b;
     212             : }
     213             : 
     214             : inline C10_HOST_DEVICE double operator+(double a, Half b) {
     215             :   return a + static_cast<double>(b);
     216             : }
     217             : inline C10_HOST_DEVICE double operator-(double a, Half b) {
     218             :   return a - static_cast<double>(b);
     219             : }
     220             : inline C10_HOST_DEVICE double operator*(double a, Half b) {
     221             :   return a * static_cast<double>(b);
     222             : }
     223             : inline C10_HOST_DEVICE double operator/(double a, Half b)
     224             :     __ubsan_ignore_float_divide_by_zero__ {
     225             :   return a / static_cast<double>(b);
     226             : }
     227             : 
     228             : /// Arithmetic with ints
     229             : 
     230             : inline C10_HOST_DEVICE Half operator+(Half a, int b) {
     231             :   return a + static_cast<Half>(b);
     232             : }
     233             : inline C10_HOST_DEVICE Half operator-(Half a, int b) {
     234             :   return a - static_cast<Half>(b);
     235             : }
     236             : inline C10_HOST_DEVICE Half operator*(Half a, int b) {
     237             :   return a * static_cast<Half>(b);
     238             : }
     239             : inline C10_HOST_DEVICE Half operator/(Half a, int b) {
     240             :   return a / static_cast<Half>(b);
     241             : }
     242             : 
     243             : inline C10_HOST_DEVICE Half operator+(int a, Half b) {
     244             :   return static_cast<Half>(a) + b;
     245             : }
     246             : inline C10_HOST_DEVICE Half operator-(int a, Half b) {
     247             :   return static_cast<Half>(a) - b;
     248             : }
     249             : inline C10_HOST_DEVICE Half operator*(int a, Half b) {
     250             :   return static_cast<Half>(a) * b;
     251             : }
     252             : inline C10_HOST_DEVICE Half operator/(int a, Half b) {
     253             :   return static_cast<Half>(a) / b;
     254             : }
     255             : 
     256             : //// Arithmetic with int64_t
     257             : 
     258             : inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
     259             :   return a + static_cast<Half>(b);
     260             : }
     261             : inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
     262             :   return a - static_cast<Half>(b);
     263             : }
     264             : inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
     265             :   return a * static_cast<Half>(b);
     266             : }
     267             : inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
     268             :   return a / static_cast<Half>(b);
     269             : }
     270             : 
     271             : inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
     272             :   return static_cast<Half>(a) + b;
     273             : }
     274             : inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
     275             :   return static_cast<Half>(a) - b;
     276             : }
     277             : inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
     278             :   return static_cast<Half>(a) * b;
     279             : }
     280             : inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
     281             :   return static_cast<Half>(a) / b;
     282             : }
     283             : 
     284             : /// NOTE: we do not define comparisons directly and instead rely on the implicit
     285             : /// conversion from c10::Half to float.
     286             : 
     287             : } // namespace c10
     288             : 
     289             : namespace std {
     290             : 
     291             : template <>
     292             : class numeric_limits<c10::Half> {
     293             :  public:
     294             :   static constexpr bool is_specialized = true;
     295             :   static constexpr bool is_signed = true;
     296             :   static constexpr bool is_integer = false;
     297             :   static constexpr bool is_exact = false;
     298             :   static constexpr bool has_infinity = true;
     299             :   static constexpr bool has_quiet_NaN = true;
     300             :   static constexpr bool has_signaling_NaN = true;
     301             :   static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
     302             :   static constexpr auto has_denorm_loss =
     303             :       numeric_limits<float>::has_denorm_loss;
     304             :   static constexpr auto round_style = numeric_limits<float>::round_style;
     305             :   static constexpr bool is_iec559 = true;
     306             :   static constexpr bool is_bounded = true;
     307             :   static constexpr bool is_modulo = false;
     308             :   static constexpr int digits = 11;
     309             :   static constexpr int digits10 = 3;
     310             :   static constexpr int max_digits10 = 5;
     311             :   static constexpr int radix = 2;
     312             :   static constexpr int min_exponent = -13;
     313             :   static constexpr int min_exponent10 = -4;
     314             :   static constexpr int max_exponent = 16;
     315             :   static constexpr int max_exponent10 = 4;
     316             :   static constexpr auto traps = numeric_limits<float>::traps;
     317             :   static constexpr auto tinyness_before =
     318             :       numeric_limits<float>::tinyness_before;
     319             :   static constexpr c10::Half min() {
     320             :     return c10::Half(0x0400, c10::Half::from_bits());
     321             :   }
     322             :   static constexpr c10::Half lowest() {
     323             :     return c10::Half(0xFBFF, c10::Half::from_bits());
     324             :   }
     325             :   static constexpr c10::Half max() {
     326             :     return c10::Half(0x7BFF, c10::Half::from_bits());
     327             :   }
     328             :   static constexpr c10::Half epsilon() {
     329             :     return c10::Half(0x1400, c10::Half::from_bits());
     330             :   }
     331             :   static constexpr c10::Half round_error() {
     332             :     return c10::Half(0x3800, c10::Half::from_bits());
     333             :   }
     334             :   static constexpr c10::Half infinity() {
     335             :     return c10::Half(0x7C00, c10::Half::from_bits());
     336             :   }
     337             :   static constexpr c10::Half quiet_NaN() {
     338             :     return c10::Half(0x7E00, c10::Half::from_bits());
     339             :   }
     340             :   static constexpr c10::Half signaling_NaN() {
     341             :     return c10::Half(0x7D00, c10::Half::from_bits());
     342             :   }
     343             :   static constexpr c10::Half denorm_min() {
     344             :     return c10::Half(0x0001, c10::Half::from_bits());
     345             :   }
     346             : };
     347             : 
     348             : } // namespace std
     349             : 
     350             : C10_CLANG_DIAGNOSTIC_POP()

Generated by: LCOV version 1.16