Line data Source code
1 : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) 2 : // Copyright 2004-present Facebook. All Rights Reserved. 3 : 4 : #pragma once 5 : 6 : #include <c10/util/TypeSafeSignMath.h> 7 : 8 : #include <algorithm> 9 : #include <cstddef> 10 : #include <iterator> 11 : #include <type_traits> 12 : 13 : namespace c10 { 14 : 15 : namespace detail { 16 : 17 : template < 18 : typename I, 19 : bool one_sided = false, 20 : std::enable_if_t<std::is_integral_v<I>, int> = 0> 21 : struct integer_iterator { 22 : using iterator_category = std::input_iterator_tag; 23 : using value_type = I; 24 : using difference_type = std::ptrdiff_t; 25 : using pointer = I*; 26 : using reference = I&; 27 : 28 : explicit constexpr integer_iterator(I val) : value(val) {} 29 : 30 : constexpr I operator*() const { 31 : return value; 32 : } 33 : 34 : constexpr I const* operator->() const { 35 : return &value; 36 : } 37 : 38 : constexpr integer_iterator& operator++() { 39 0 : ++value; 40 0 : return *this; 41 : } 42 : 43 : constexpr integer_iterator operator++(int) { 44 : const auto copy = *this; 45 : ++*this; 46 : return copy; 47 : } 48 : 49 : constexpr bool operator==(const integer_iterator& other) const { 50 : if constexpr (one_sided) { 51 : // Range-for loops' end test is `begin != end`, not `begin < 52 : // end`. To handle `c10::irange(n)` where n < 0 (which should be 53 : // empty), we just make `begin != end` fail whenever `end` is 54 : // negative. 55 0 : return is_negative(other.value) || value == other.value; 56 : } else { 57 : return value == other.value; 58 : } 59 : // Suppress "warning: missing return statement at end of non-void function" 60 : // which Nvidia's Robert Crovella confirms is an NVCC compiler error 61 : // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 62 : // `__builtin_unreachable();` would be best here, but it's not 63 : // available with all compilers. So we instead return an arbitrary 64 : // value trusting that this line will, in fact, never be reached. 65 : return false; // Horrible hack 66 : } 67 : 68 : constexpr bool operator!=(const integer_iterator& other) const { 69 : return !(*this == other); 70 : } 71 : 72 : protected: 73 : I value; 74 : }; 75 : 76 : } // namespace detail 77 : 78 : template < 79 : typename I, 80 : bool one_sided = false, 81 : std::enable_if_t<std::is_integral_v<I>, bool> = true> 82 : struct integer_range { 83 : public: 84 : constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} 85 : using iterator = detail::integer_iterator<I, one_sided>; 86 : constexpr iterator begin() const { 87 : return begin_; 88 : } 89 : constexpr iterator end() const { 90 : return end_; 91 : } 92 : 93 : private: 94 : iterator begin_; 95 : iterator end_; 96 : }; 97 : 98 : /// Creates an integer range for the half-open interval [begin, end) 99 : /// If end<=begin, then the range is empty. 100 : /// The range has the type of the `end` integer; `begin` integer is 101 : /// cast to this type. 102 : template < 103 : typename Integer1, 104 : typename Integer2, 105 : std::enable_if_t<std::is_integral_v<Integer1>, bool> = true, 106 : std::enable_if_t<std::is_integral_v<Integer2>, bool> = true> 107 : constexpr integer_range<Integer2> irange(Integer1 begin, Integer2 end) { 108 : // If end<=begin then the range is empty; we can achieve this effect by 109 : // choosing the larger of {begin, end} as the loop terminator 110 : return { 111 : static_cast<Integer2>(begin), 112 : std::max(static_cast<Integer2>(begin), end)}; 113 : } 114 : 115 : /// Creates an integer range for the half-open interval [0, end) 116 : /// If end<=begin, then the range is empty 117 : template < 118 : typename Integer, 119 : std::enable_if_t<std::is_integral_v<Integer>, bool> = true> 120 : constexpr integer_range<Integer, true> irange(Integer end) { 121 : return {Integer(), end}; 122 : } 123 : 124 : } // namespace c10 125 : 126 : #else 127 : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." 128 : #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)