Line data Source code
1 : #pragma once
2 :
3 : #include <c10/macros/Macros.h>
4 : #include <c10/util/bit_cast.h>
5 :
6 : #include <limits>
7 :
8 : C10_CLANG_DIAGNOSTIC_PUSH()
9 : #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 : C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 : #endif
12 :
13 : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
14 : #if defined(CL_SYCL_LANGUAGE_VERSION)
15 : #include <CL/sycl.hpp> // for SYCL 1.2.1
16 : #else
17 : #include <sycl/sycl.hpp> // for SYCL 2020
18 : #endif
19 : #include <ext/oneapi/bfloat16.hpp>
20 : #endif
21 :
22 : namespace c10 {
23 :
24 : /// Constructors
25 : inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
26 : :
27 : #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
28 : __CUDA_ARCH__ >= 800
29 : x(__bfloat16_as_ushort(__float2bfloat16(value)))
30 : #elif defined(__SYCL_DEVICE_ONLY__) && \
31 : defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
32 : x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
33 : #else
34 : // RNE by default
35 : x(detail::round_to_nearest_even(value))
36 : #endif
37 : {
38 : }
39 :
40 : /// Implicit conversions
41 : inline C10_HOST_DEVICE BFloat16::operator float() const {
42 : #if defined(__CUDACC__) && !defined(USE_ROCM)
43 : return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
44 : #elif defined(__SYCL_DEVICE_ONLY__) && \
45 : defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
46 : return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
47 : #else
48 0 : return detail::f32_from_bits(x);
49 : #endif
50 : }
51 :
52 : #if defined(__CUDACC__) && !defined(USE_ROCM)
53 : inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
54 : x = *reinterpret_cast<const unsigned short*>(&value);
55 : }
56 : inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
57 : return *reinterpret_cast<const __nv_bfloat16*>(&x);
58 : }
59 : #endif
60 :
61 : #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
62 : inline C10_HOST_DEVICE BFloat16::BFloat16(
63 : const sycl::ext::oneapi::bfloat16& value) {
64 : x = *reinterpret_cast<const unsigned short*>(&value);
65 : }
66 : inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
67 : return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
68 : }
69 : #endif
70 :
71 : // CUDA intrinsics
72 :
73 : #if defined(__CUDACC__) || defined(__HIPCC__)
74 : inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
75 : #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
76 : return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
77 : #else
78 : return *ptr;
79 : #endif
80 : }
81 : #endif
82 :
83 : /// Arithmetic
84 :
85 : inline C10_HOST_DEVICE BFloat16
86 : operator+(const BFloat16& a, const BFloat16& b) {
87 : return static_cast<float>(a) + static_cast<float>(b);
88 : }
89 :
90 : inline C10_HOST_DEVICE BFloat16
91 : operator-(const BFloat16& a, const BFloat16& b) {
92 : return static_cast<float>(a) - static_cast<float>(b);
93 : }
94 :
95 : inline C10_HOST_DEVICE BFloat16
96 : operator*(const BFloat16& a, const BFloat16& b) {
97 : return static_cast<float>(a) * static_cast<float>(b);
98 : }
99 :
100 : inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
101 : __ubsan_ignore_float_divide_by_zero__ {
102 : return static_cast<float>(a) / static_cast<float>(b);
103 : }
104 :
105 : inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
106 : return -static_cast<float>(a);
107 : }
108 :
109 : inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
110 : a = a + b;
111 : return a;
112 : }
113 :
114 : inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
115 : a = a - b;
116 : return a;
117 : }
118 :
119 : inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
120 : a = a * b;
121 : return a;
122 : }
123 :
124 : inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
125 : a = a / b;
126 : return a;
127 : }
128 :
129 : inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
130 : a.x = a.x | b.x;
131 : return a;
132 : }
133 :
134 : inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
135 : a.x = a.x ^ b.x;
136 : return a;
137 : }
138 :
139 : inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
140 : a.x = a.x & b.x;
141 : return a;
142 : }
143 :
144 : /// Arithmetic with floats
145 :
146 : inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
147 : return static_cast<float>(a) + b;
148 : }
149 : inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
150 : return static_cast<float>(a) - b;
151 : }
152 : inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
153 : return static_cast<float>(a) * b;
154 : }
155 : inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
156 : return static_cast<float>(a) / b;
157 : }
158 :
159 : inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
160 : return a + static_cast<float>(b);
161 : }
162 : inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
163 : return a - static_cast<float>(b);
164 : }
165 : inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
166 : return a * static_cast<float>(b);
167 : }
168 : inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
169 : return a / static_cast<float>(b);
170 : }
171 :
172 : inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
173 : return a += static_cast<float>(b);
174 : }
175 : inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
176 : return a -= static_cast<float>(b);
177 : }
178 : inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
179 : return a *= static_cast<float>(b);
180 : }
181 : inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
182 : return a /= static_cast<float>(b);
183 : }
184 :
185 : /// Arithmetic with doubles
186 :
187 : inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
188 : return static_cast<double>(a) + b;
189 : }
190 : inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
191 : return static_cast<double>(a) - b;
192 : }
193 : inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
194 : return static_cast<double>(a) * b;
195 : }
196 : inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
197 : return static_cast<double>(a) / b;
198 : }
199 :
200 : inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
201 : return a + static_cast<double>(b);
202 : }
203 : inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
204 : return a - static_cast<double>(b);
205 : }
206 : inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
207 : return a * static_cast<double>(b);
208 : }
209 : inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
210 : return a / static_cast<double>(b);
211 : }
212 :
213 : /// Arithmetic with ints
214 :
215 : inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
216 : return a + static_cast<BFloat16>(b);
217 : }
218 : inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
219 : return a - static_cast<BFloat16>(b);
220 : }
221 : inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
222 : return a * static_cast<BFloat16>(b);
223 : }
224 : inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
225 : return a / static_cast<BFloat16>(b);
226 : }
227 :
228 : inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
229 : return static_cast<BFloat16>(a) + b;
230 : }
231 : inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
232 : return static_cast<BFloat16>(a) - b;
233 : }
234 : inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
235 : return static_cast<BFloat16>(a) * b;
236 : }
237 : inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
238 : return static_cast<BFloat16>(a) / b;
239 : }
240 :
241 : //// Arithmetic with int64_t
242 :
243 : inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
244 : return a + static_cast<BFloat16>(b);
245 : }
246 : inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
247 : return a - static_cast<BFloat16>(b);
248 : }
249 : inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
250 : return a * static_cast<BFloat16>(b);
251 : }
252 : inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
253 : return a / static_cast<BFloat16>(b);
254 : }
255 :
256 : inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
257 : return static_cast<BFloat16>(a) + b;
258 : }
259 : inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
260 : return static_cast<BFloat16>(a) - b;
261 : }
262 : inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
263 : return static_cast<BFloat16>(a) * b;
264 : }
265 : inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
266 : return static_cast<BFloat16>(a) / b;
267 : }
268 :
269 : // Overloading < and > operators, because std::max and std::min use them.
270 :
271 : inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
272 : return float(lhs) > float(rhs);
273 : }
274 :
275 : inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
276 : return float(lhs) < float(rhs);
277 : }
278 :
279 : } // namespace c10
280 :
281 : namespace std {
282 :
283 : template <>
284 : class numeric_limits<c10::BFloat16> {
285 : public:
286 : static constexpr bool is_signed = true;
287 : static constexpr bool is_specialized = true;
288 : static constexpr bool is_integer = false;
289 : static constexpr bool is_exact = false;
290 : static constexpr bool has_infinity = true;
291 : static constexpr bool has_quiet_NaN = true;
292 : static constexpr bool has_signaling_NaN = true;
293 : static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
294 : static constexpr auto has_denorm_loss =
295 : numeric_limits<float>::has_denorm_loss;
296 : static constexpr auto round_style = numeric_limits<float>::round_style;
297 : static constexpr bool is_iec559 = false;
298 : static constexpr bool is_bounded = true;
299 : static constexpr bool is_modulo = false;
300 : static constexpr int digits = 8;
301 : static constexpr int digits10 = 2;
302 : static constexpr int max_digits10 = 4;
303 : static constexpr int radix = 2;
304 : static constexpr int min_exponent = -125;
305 : static constexpr int min_exponent10 = -37;
306 : static constexpr int max_exponent = 128;
307 : static constexpr int max_exponent10 = 38;
308 : static constexpr auto traps = numeric_limits<float>::traps;
309 : static constexpr auto tinyness_before =
310 : numeric_limits<float>::tinyness_before;
311 :
312 : static constexpr c10::BFloat16 min() {
313 : return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
314 : }
315 : static constexpr c10::BFloat16 lowest() {
316 : return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
317 : }
318 : static constexpr c10::BFloat16 max() {
319 : return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
320 : }
321 : static constexpr c10::BFloat16 epsilon() {
322 : return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
323 : }
324 : static constexpr c10::BFloat16 round_error() {
325 : return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
326 : }
327 : static constexpr c10::BFloat16 infinity() {
328 : return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
329 : }
330 : static constexpr c10::BFloat16 quiet_NaN() {
331 : return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
332 : }
333 : static constexpr c10::BFloat16 signaling_NaN() {
334 : return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
335 : }
336 : static constexpr c10::BFloat16 denorm_min() {
337 : return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
338 : }
339 : };
340 :
341 : } // namespace std
342 :
343 : C10_CLANG_DIAGNOSTIC_POP()
|