LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/c10/util - complex.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 2 0.0 %
Date: 2025-11-25 13:55:50 Functions: 0 0 -

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <complex>
       4             : 
       5             : #include <c10/macros/Macros.h>
       6             : #include <c10/util/Half.h>
       7             : 
       8             : #if defined(__CUDACC__) || defined(__HIPCC__)
       9             : #include <thrust/complex.h>
      10             : #endif
      11             : 
      12             : C10_CLANG_DIAGNOSTIC_PUSH()
      13             : #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
      14             : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
      15             : #endif
      16             : #if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
      17             : C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
      18             : #endif
      19             : 
      20             : namespace c10 {
      21             : 
      22             : // c10::complex is an implementation of complex numbers that aims
      23             : // to work on all devices supported by PyTorch
      24             : //
      25             : // Most of the APIs duplicates std::complex
      26             : // Reference: https://en.cppreference.com/w/cpp/numeric/complex
      27             : //
      28             : // [NOTE: Complex Operator Unification]
      29             : // Operators currently use a mix of std::complex, thrust::complex, and
      30             : // c10::complex internally. The end state is that all operators will use
      31             : // c10::complex internally.  Until then, there may be some hacks to support all
      32             : // variants.
      33             : //
      34             : //
      35             : // [Note on Constructors]
      36             : //
      37             : // The APIs of constructors are mostly copied from C++ standard:
      38             : //   https://en.cppreference.com/w/cpp/numeric/complex/complex
      39             : //
      40             : // Since C++14, all constructors are constexpr in std::complex
      41             : //
      42             : // There are three types of constructors:
      43             : // - initializing from real and imag:
      44             : //     `constexpr complex( const T& re = T(), const T& im = T() );`
      45             : // - implicitly-declared copy constructor
      46             : // - converting constructors
      47             : //
      48             : // Converting constructors:
      49             : // - std::complex defines converting constructor between float/double/long
      50             : // double,
      51             : //   while we define converting constructor between float/double.
      52             : // - For these converting constructors, upcasting is implicit, downcasting is
      53             : //   explicit.
      54             : // - We also define explicit casting from std::complex/thrust::complex
      55             : //   - Note that the conversion from thrust is not constexpr, because
      56             : //     thrust does not define them as constexpr ????
      57             : //
      58             : //
      59             : // [Operator =]
      60             : //
      61             : // The APIs of operator = are mostly copied from C++ standard:
      62             : //   https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
      63             : //
      64             : // Since C++20, all operator= are constexpr. Although we are not building with
      65             : // C++20, we also obey this behavior.
      66             : //
      67             : // There are three types of assign operator:
      68             : // - Assign a real value from the same scalar type
      69             : //   - In std, this is templated as complex& operator=(const T& x)
      70             : //     with specialization `complex& operator=(T x)` for float/double/long
      71             : //     double Since we only support float and double, on will use `complex&
      72             : //     operator=(T x)`
      73             : // - Copy assignment operator and converting assignment operator
      74             : //   - There is no specialization of converting assignment operators, which type
      75             : //   is
      76             : //     convertible is solely dependent on whether the scalar type is convertible
      77             : //
      78             : // In addition to the standard assignment, we also provide assignment operators
      79             : // with std and thrust
      80             : //
      81             : //
      82             : // [Casting operators]
      83             : //
      84             : // std::complex does not have casting operators. We define casting operators
      85             : // casting to std::complex and thrust::complex
      86             : //
      87             : //
      88             : // [Operator ""]
      89             : //
      90             : // std::complex has custom literals `i`, `if` and `il` defined in namespace
      91             : // `std::literals::complex_literals`. We define our own custom literals in the
      92             : // namespace `c10::complex_literals`. Our custom literals does not follow the
      93             : // same behavior as in std::complex, instead, we define _if, _id to construct
      94             : // float/double complex literals.
      95             : //
      96             : //
      97             : // [real() and imag()]
      98             : //
      99             : // In C++20, there are two overload of these functions, one it to return the
     100             : // real/imag, another is to set real/imag, they are both constexpr. We follow
     101             : // this design.
     102             : //
     103             : //
     104             : // [Operator +=,-=,*=,/=]
     105             : //
     106             : // Since C++20, these operators become constexpr. In our implementation, they
     107             : // are also constexpr.
     108             : //
     109             : // There are two types of such operators: operating with a real number, or
     110             : // operating with another complex number. For the operating with a real number,
     111             : // the generic template form has argument type `const T &`, while the overload
     112             : // for float/double/long double has `T`. We will follow the same type as
     113             : // float/double/long double in std.
     114             : //
     115             : // [Unary operator +-]
     116             : //
     117             : // Since C++20, they are constexpr. We also make them expr
     118             : //
     119             : // [Binary operators +-*/]
     120             : //
     121             : // Each operator has three versions (taking + as example):
     122             : // - complex + complex
     123             : // - complex + real
     124             : // - real + complex
     125             : //
     126             : // [Operator ==, !=]
     127             : //
     128             : // Each operator has three versions (taking == as example):
     129             : // - complex == complex
     130             : // - complex == real
     131             : // - real == complex
     132             : //
     133             : // Some of them are removed on C++20, but we decide to keep them
     134             : //
     135             : // [Operator <<, >>]
     136             : //
     137             : // These are implemented by casting to std::complex
     138             : //
     139             : //
     140             : //
     141             : // TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
     142             : // because:
     143             : //  - lots of members and functions of c10::Half are not constexpr
     144             : //  - thrust::complex only support float and double
     145             : 
     146             : template <typename T>
     147             : struct alignas(sizeof(T) * 2) complex {
     148             :   using value_type = T;
     149             : 
     150             :   T real_ = T(0);
     151             :   T imag_ = T(0);
     152             : 
     153             :   constexpr complex() = default;
     154             :   C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
     155             :       : real_(re), imag_(im) {}
     156             :   template <typename U>
     157             :   explicit constexpr complex(const std::complex<U>& other)
     158             :       : complex(other.real(), other.imag()) {}
     159             : #if defined(__CUDACC__) || defined(__HIPCC__)
     160             :   template <typename U>
     161             :   explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
     162             :       : real_(other.real()), imag_(other.imag()) {}
     163             : // NOTE can not be implemented as follow due to ROCm bug:
     164             : //   explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
     165             : //   complex(other.real(), other.imag()) {}
     166             : #endif
     167             : 
     168             :   // Use SFINAE to specialize casting constructor for c10::complex<float> and
     169             :   // c10::complex<double>
     170             :   template <typename U = T>
     171             :   C10_HOST_DEVICE explicit constexpr complex(
     172             :       const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
     173             :       : real_(other.real_), imag_(other.imag_) {}
     174             :   template <typename U = T>
     175             :   C10_HOST_DEVICE constexpr complex(
     176             :       const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
     177             :       : real_(other.real_), imag_(other.imag_) {}
     178             : 
     179             :   constexpr complex<T>& operator=(T re) {
     180             :     real_ = re;
     181             :     imag_ = 0;
     182             :     return *this;
     183             :   }
     184             : 
     185             :   constexpr complex<T>& operator+=(T re) {
     186             :     real_ += re;
     187             :     return *this;
     188             :   }
     189             : 
     190             :   constexpr complex<T>& operator-=(T re) {
     191             :     real_ -= re;
     192             :     return *this;
     193             :   }
     194             : 
     195             :   constexpr complex<T>& operator*=(T re) {
     196             :     real_ *= re;
     197             :     imag_ *= re;
     198             :     return *this;
     199             :   }
     200             : 
     201             :   constexpr complex<T>& operator/=(T re) {
     202             :     real_ /= re;
     203             :     imag_ /= re;
     204             :     return *this;
     205             :   }
     206             : 
     207             :   template <typename U>
     208             :   constexpr complex<T>& operator=(const complex<U>& rhs) {
     209             :     real_ = rhs.real();
     210             :     imag_ = rhs.imag();
     211             :     return *this;
     212             :   }
     213             : 
     214             :   template <typename U>
     215             :   constexpr complex<T>& operator+=(const complex<U>& rhs) {
     216             :     real_ += rhs.real();
     217             :     imag_ += rhs.imag();
     218             :     return *this;
     219             :   }
     220             : 
     221             :   template <typename U>
     222             :   constexpr complex<T>& operator-=(const complex<U>& rhs) {
     223             :     real_ -= rhs.real();
     224             :     imag_ -= rhs.imag();
     225             :     return *this;
     226             :   }
     227             : 
     228             :   template <typename U>
     229             :   constexpr complex<T>& operator*=(const complex<U>& rhs) {
     230             :     // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
     231             :     T a = real_;
     232             :     T b = imag_;
     233             :     U c = rhs.real();
     234             :     U d = rhs.imag();
     235             :     real_ = a * c - b * d;
     236             :     imag_ = a * d + b * c;
     237             :     return *this;
     238             :   }
     239             : 
     240             : #ifdef __APPLE__
     241             : #define FORCE_INLINE_APPLE __attribute__((always_inline))
     242             : #else
     243             : #define FORCE_INLINE_APPLE
     244             : #endif
     245             :   template <typename U>
     246             :   constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
     247             :       __ubsan_ignore_float_divide_by_zero__ {
     248             :     // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
     249             :     // the calculation below follows numpy's complex division
     250             :     T a = real_;
     251             :     T b = imag_;
     252             :     U c = rhs.real();
     253             :     U d = rhs.imag();
     254             : 
     255             : #if defined(__GNUC__) && !defined(__clang__)
     256             :     // std::abs is already constexpr by gcc
     257             :     auto abs_c = std::abs(c);
     258             :     auto abs_d = std::abs(d);
     259             : #else
     260             :     auto abs_c = c < 0 ? -c : c;
     261             :     auto abs_d = d < 0 ? -d : d;
     262             : #endif
     263             : 
     264             :     if (abs_c >= abs_d) {
     265             :       if (abs_c == U(0) && abs_d == U(0)) {
     266             :         /* divide by zeros should yield a complex inf or nan */
     267             :         real_ = a / abs_c;
     268             :         imag_ = b / abs_d;
     269             :       } else {
     270             :         auto rat = d / c;
     271             :         auto scl = U(1.0) / (c + d * rat);
     272             :         real_ = (a + b * rat) * scl;
     273             :         imag_ = (b - a * rat) * scl;
     274             :       }
     275             :     } else {
     276             :       auto rat = c / d;
     277             :       auto scl = U(1.0) / (d + c * rat);
     278             :       real_ = (a * rat + b) * scl;
     279             :       imag_ = (b * rat - a) * scl;
     280             :     }
     281             :     return *this;
     282             :   }
     283             : #undef FORCE_INLINE_APPLE
     284             : 
     285             :   template <typename U>
     286             :   constexpr complex<T>& operator=(const std::complex<U>& rhs) {
     287             :     real_ = rhs.real();
     288             :     imag_ = rhs.imag();
     289             :     return *this;
     290             :   }
     291             : 
     292             : #if defined(__CUDACC__) || defined(__HIPCC__)
     293             :   template <typename U>
     294             :   C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
     295             :     real_ = rhs.real();
     296             :     imag_ = rhs.imag();
     297             :     return *this;
     298             :   }
     299             : #endif
     300             : 
     301             :   template <typename U>
     302             :   explicit constexpr operator std::complex<U>() const {
     303             :     return std::complex<U>(std::complex<T>(real(), imag()));
     304             :   }
     305             : 
     306             : #if defined(__CUDACC__) || defined(__HIPCC__)
     307             :   template <typename U>
     308             :   C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
     309             :     return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
     310             :   }
     311             : #endif
     312             : 
     313             :   // consistent with NumPy behavior
     314             :   explicit constexpr operator bool() const {
     315             :     return real() || imag();
     316             :   }
     317             : 
     318             :   C10_HOST_DEVICE constexpr T real() const {
     319           0 :     return real_;
     320             :   }
     321             :   constexpr void real(T value) {
     322             :     real_ = value;
     323             :   }
     324             :   C10_HOST_DEVICE constexpr T imag() const {
     325           0 :     return imag_;
     326             :   }
     327             :   constexpr void imag(T value) {
     328             :     imag_ = value;
     329             :   }
     330             : };
     331             : 
     332             : namespace complex_literals {
     333             : 
     334             : constexpr complex<float> operator""_if(long double imag) {
     335             :   return complex<float>(0.0f, static_cast<float>(imag));
     336             : }
     337             : 
     338             : constexpr complex<double> operator""_id(long double imag) {
     339             :   return complex<double>(0.0, static_cast<double>(imag));
     340             : }
     341             : 
     342             : constexpr complex<float> operator""_if(unsigned long long imag) {
     343             :   return complex<float>(0.0f, static_cast<float>(imag));
     344             : }
     345             : 
     346             : constexpr complex<double> operator""_id(unsigned long long imag) {
     347             :   return complex<double>(0.0, static_cast<double>(imag));
     348             : }
     349             : 
     350             : } // namespace complex_literals
     351             : 
     352             : template <typename T>
     353             : constexpr complex<T> operator+(const complex<T>& val) {
     354             :   return val;
     355             : }
     356             : 
     357             : template <typename T>
     358             : constexpr complex<T> operator-(const complex<T>& val) {
     359             :   return complex<T>(-val.real(), -val.imag());
     360             : }
     361             : 
     362             : template <typename T>
     363             : constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
     364             :   complex<T> result = lhs;
     365             :   return result += rhs;
     366             : }
     367             : 
     368             : template <typename T>
     369             : constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
     370             :   complex<T> result = lhs;
     371             :   return result += rhs;
     372             : }
     373             : 
     374             : template <typename T>
     375             : constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
     376             :   return complex<T>(lhs + rhs.real(), rhs.imag());
     377             : }
     378             : 
     379             : template <typename T>
     380             : constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
     381             :   complex<T> result = lhs;
     382             :   return result -= rhs;
     383             : }
     384             : 
     385             : template <typename T>
     386             : constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
     387             :   complex<T> result = lhs;
     388             :   return result -= rhs;
     389             : }
     390             : 
     391             : template <typename T>
     392             : constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
     393             :   complex<T> result = -rhs;
     394             :   return result += lhs;
     395             : }
     396             : 
     397             : template <typename T>
     398             : constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
     399             :   complex<T> result = lhs;
     400             :   return result *= rhs;
     401             : }
     402             : 
     403             : template <typename T>
     404             : constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
     405             :   complex<T> result = lhs;
     406             :   return result *= rhs;
     407             : }
     408             : 
     409             : template <typename T>
     410             : constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
     411             :   complex<T> result = rhs;
     412             :   return result *= lhs;
     413             : }
     414             : 
     415             : template <typename T>
     416             : constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
     417             :   complex<T> result = lhs;
     418             :   return result /= rhs;
     419             : }
     420             : 
     421             : template <typename T>
     422             : constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
     423             :   complex<T> result = lhs;
     424             :   return result /= rhs;
     425             : }
     426             : 
     427             : template <typename T>
     428             : constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
     429             :   complex<T> result(lhs, T());
     430             :   return result /= rhs;
     431             : }
     432             : 
     433             : // Define operators between integral scalars and c10::complex. std::complex does
     434             : // not support this when T is a floating-point number. This is useful because it
     435             : // saves a lot of "static_cast" when operate a complex and an integer. This
     436             : // makes the code both less verbose and potentially more efficient.
     437             : #define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION                 \
     438             :   typename std::enable_if_t<                                  \
     439             :       std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
     440             :       int> = 0
     441             : 
     442             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     443             : constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
     444             :   return a + static_cast<fT>(b);
     445             : }
     446             : 
     447             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     448             : constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
     449             :   return static_cast<fT>(a) + b;
     450             : }
     451             : 
     452             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     453             : constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
     454             :   return a - static_cast<fT>(b);
     455             : }
     456             : 
     457             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     458             : constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
     459             :   return static_cast<fT>(a) - b;
     460             : }
     461             : 
     462             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     463             : constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
     464             :   return a * static_cast<fT>(b);
     465             : }
     466             : 
     467             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     468             : constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
     469             :   return static_cast<fT>(a) * b;
     470             : }
     471             : 
     472             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     473             : constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
     474             :   return a / static_cast<fT>(b);
     475             : }
     476             : 
     477             : template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
     478             : constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
     479             :   return static_cast<fT>(a) / b;
     480             : }
     481             : 
     482             : #undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
     483             : 
     484             : template <typename T>
     485             : constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
     486             :   return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
     487             : }
     488             : 
     489             : template <typename T>
     490             : constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
     491             :   return (lhs.real() == rhs) && (lhs.imag() == T());
     492             : }
     493             : 
     494             : template <typename T>
     495             : constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
     496             :   return (lhs == rhs.real()) && (T() == rhs.imag());
     497             : }
     498             : 
     499             : template <typename T>
     500             : constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
     501             :   return !(lhs == rhs);
     502             : }
     503             : 
     504             : template <typename T>
     505             : constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
     506             :   return !(lhs == rhs);
     507             : }
     508             : 
     509             : template <typename T>
     510             : constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
     511             :   return !(lhs == rhs);
     512             : }
     513             : 
     514             : template <typename T, typename CharT, typename Traits>
     515             : std::basic_ostream<CharT, Traits>& operator<<(
     516             :     std::basic_ostream<CharT, Traits>& os,
     517             :     const complex<T>& x) {
     518             :   return (os << static_cast<std::complex<T>>(x));
     519             : }
     520             : 
     521             : template <typename T, typename CharT, typename Traits>
     522             : std::basic_istream<CharT, Traits>& operator>>(
     523             :     std::basic_istream<CharT, Traits>& is,
     524             :     complex<T>& x) {
     525             :   std::complex<T> tmp;
     526             :   is >> tmp;
     527             :   x = tmp;
     528             :   return is;
     529             : }
     530             : 
     531             : } // namespace c10
     532             : 
     533             : // std functions
     534             : //
     535             : // The implementation of these functions also follow the design of C++20
     536             : 
     537             : namespace std {
     538             : 
     539             : template <typename T>
     540             : constexpr T real(const c10::complex<T>& z) {
     541             :   return z.real();
     542             : }
     543             : 
     544             : template <typename T>
     545             : constexpr T imag(const c10::complex<T>& z) {
     546             :   return z.imag();
     547             : }
     548             : 
     549             : template <typename T>
     550             : C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
     551             : #if defined(__CUDACC__) || defined(__HIPCC__)
     552             :   return thrust::abs(static_cast<thrust::complex<T>>(z));
     553             : #else
     554             :   return std::abs(static_cast<std::complex<T>>(z));
     555             : #endif
     556             : }
     557             : 
     558             : #if defined(USE_ROCM)
     559             : #define ROCm_Bug(x)
     560             : #else
     561             : #define ROCm_Bug(x) x
     562             : #endif
     563             : 
     564             : template <typename T>
     565             : C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
     566             :   return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
     567             : }
     568             : 
     569             : #undef ROCm_Bug
     570             : 
     571             : template <typename T>
     572             : constexpr T norm(const c10::complex<T>& z) {
     573             :   return z.real() * z.real() + z.imag() * z.imag();
     574             : }
     575             : 
     576             : // For std::conj, there are other versions of it:
     577             : //   constexpr std::complex<float> conj( float z );
     578             : //   template< class DoubleOrInteger >
     579             : //   constexpr std::complex<double> conj( DoubleOrInteger z );
     580             : //   constexpr std::complex<long double> conj( long double z );
     581             : // These are not implemented
     582             : // TODO(@zasdfgbnm): implement them as c10::conj
     583             : template <typename T>
     584             : constexpr c10::complex<T> conj(const c10::complex<T>& z) {
     585             :   return c10::complex<T>(z.real(), -z.imag());
     586             : }
     587             : 
     588             : // Thrust does not have complex --> complex version of thrust::proj,
     589             : // so this function is not implemented at c10 right now.
     590             : // TODO(@zasdfgbnm): implement it by ourselves
     591             : 
     592             : // There is no c10 version of std::polar, because std::polar always
     593             : // returns std::complex. Use c10::polar instead;
     594             : 
     595             : } // namespace std
     596             : 
     597             : namespace c10 {
     598             : 
     599             : template <typename T>
     600             : C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
     601             : #if defined(__CUDACC__) || defined(__HIPCC__)
     602             :   return static_cast<complex<T>>(thrust::polar(r, theta));
     603             : #else
     604             :   // std::polar() requires r >= 0, so spell out the explicit implementation to
     605             :   // avoid a branch.
     606             :   return complex<T>(r * std::cos(theta), r * std::sin(theta));
     607             : #endif
     608             : }
     609             : 
     610             : template <>
     611             : struct alignas(4) complex<Half> {
     612             :   Half real_;
     613             :   Half imag_;
     614             : 
     615             :   // Constructors
     616             :   complex() = default;
     617             :   // Half constructor is not constexpr so the following constructor can't
     618             :   // be constexpr
     619             :   C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
     620             :       : real_(real), imag_(imag) {}
     621             :   C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
     622             :       : real_(value.real()), imag_(value.imag()) {}
     623             : 
     624             :   // Conversion operator
     625             :   inline C10_HOST_DEVICE operator c10::complex<float>() const {
     626             :     return {real_, imag_};
     627             :   }
     628             : 
     629             :   constexpr C10_HOST_DEVICE Half real() const {
     630             :     return real_;
     631             :   }
     632             :   constexpr C10_HOST_DEVICE Half imag() const {
     633             :     return imag_;
     634             :   }
     635             : 
     636             :   C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
     637             :     real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
     638             :     imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
     639             :     return *this;
     640             :   }
     641             : 
     642             :   C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
     643             :     real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
     644             :     imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
     645             :     return *this;
     646             :   }
     647             : 
     648             :   C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
     649             :     auto a = static_cast<float>(real_);
     650             :     auto b = static_cast<float>(imag_);
     651             :     auto c = static_cast<float>(other.real());
     652             :     auto d = static_cast<float>(other.imag());
     653             :     real_ = a * c - b * d;
     654             :     imag_ = a * d + b * c;
     655             :     return *this;
     656             :   }
     657             : };
     658             : 
     659             : } // namespace c10
     660             : 
     661             : C10_CLANG_DIAGNOSTIC_POP()
     662             : 
     663             : #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
     664             : // math functions are included in a separate file
     665             : #include <c10/util/complex_math.h> // IWYU pragma: keep
     666             : // utilities for complex types
     667             : #include <c10/util/complex_utils.h> // IWYU pragma: keep
     668             : #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H

Generated by: LCOV version 1.16