Line data Source code
1 : // Copyright 2004-present Facebook. All Rights Reserved. 2 : 3 : #pragma once 4 : 5 : #include <c10/util/TypeSafeSignMath.h> 6 : 7 : #include <algorithm> 8 : #include <cstddef> 9 : #include <iterator> 10 : #include <type_traits> 11 : 12 : namespace c10 { 13 : 14 : namespace detail { 15 : 16 : template < 17 : typename I, 18 : bool one_sided = false, 19 : std::enable_if_t<std::is_integral_v<I>, int> = 0> 20 : struct integer_iterator { 21 : using iterator_category = std::input_iterator_tag; 22 : using value_type = I; 23 : using difference_type = std::ptrdiff_t; 24 : using pointer = I*; 25 : using reference = I&; 26 : 27 : explicit constexpr integer_iterator(I value) : value(value) {} 28 : 29 : constexpr I operator*() const { 30 : return value; 31 : } 32 : 33 : constexpr I const* operator->() const { 34 : return &value; 35 : } 36 : 37 : constexpr integer_iterator& operator++() { 38 0 : ++value; 39 0 : return *this; 40 : } 41 : 42 : constexpr integer_iterator operator++(int) { 43 : const auto copy = *this; 44 : ++*this; 45 : return copy; 46 : } 47 : 48 : constexpr bool operator==(const integer_iterator& other) const { 49 : if constexpr (one_sided) { 50 : // Range-for loops' end test is `begin != end`, not `begin < 51 : // end`. To handle `c10::irange(n)` where n < 0 (which should be 52 : // empty), we just make `begin != end` fail whenever `end` is 53 : // negative. 54 0 : return is_negative(other.value) || value == other.value; 55 : } else { 56 : return value == other.value; 57 : } 58 : // Suppress "warning: missing return statement at end of non-void function" 59 : // which Nvidia's Robert Crovella confirms is an NVCC compiler error 60 : // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27 61 : // `__builtin_unreachable();` would be best here, but it's not 62 : // available with all compilers. So we instead return an arbitrary 63 : // value trusting that this line will, in fact, never be reached. 64 : return false; // Horrible hack 65 : } 66 : 67 : constexpr bool operator!=(const integer_iterator& other) const { 68 : return !(*this == other); 69 : } 70 : 71 : protected: 72 : I value; 73 : }; 74 : 75 : } // namespace detail 76 : 77 : template < 78 : typename I, 79 : bool one_sided = false, 80 : std::enable_if_t<std::is_integral_v<I>, bool> = true> 81 : struct integer_range { 82 : public: 83 : constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {} 84 : using iterator = detail::integer_iterator<I, one_sided>; 85 : constexpr iterator begin() const { 86 : return begin_; 87 : } 88 : constexpr iterator end() const { 89 : return end_; 90 : } 91 : 92 : private: 93 : iterator begin_; 94 : iterator end_; 95 : }; 96 : 97 : /// Creates an integer range for the half-open interval [begin, end) 98 : /// If end<=begin, then the range is empty. 99 : /// The range has the type of the `end` integer; `begin` integer is 100 : /// cast to this type. 101 : template < 102 : typename Integer1, 103 : typename Integer2, 104 : std::enable_if_t<std::is_integral_v<Integer1>, bool> = true, 105 : std::enable_if_t<std::is_integral_v<Integer2>, bool> = true> 106 : constexpr integer_range<Integer2> irange(Integer1 begin, Integer2 end) { 107 : // If end<=begin then the range is empty; we can achieve this effect by 108 : // choosing the larger of {begin, end} as the loop terminator 109 : return { 110 : static_cast<Integer2>(begin), 111 : std::max(static_cast<Integer2>(begin), end)}; 112 : } 113 : 114 : /// Creates an integer range for the half-open interval [0, end) 115 : /// If end<=begin, then the range is empty 116 : template < 117 : typename Integer, 118 : std::enable_if_t<std::is_integral_v<Integer>, bool> = true> 119 : constexpr integer_range<Integer, true> irange(Integer end) { 120 : return {Integer(), end}; 121 : } 122 : 123 : } // namespace c10