LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/torch/include/ATen - TensorIndexing.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 17 20 85.0 %
Date: 2025-12-04 11:19:34 Functions: 5 5 100.0 %

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <ATen/ExpandUtils.h>
       4             : #include <ATen/ScalarOps.h>
       5             : #include <ATen/core/Tensor.h>
       6             : #include <ATen/core/TensorBody.h>
       7             : #include <c10/core/SymInt.h>
       8             : #include <c10/util/irange.h>
       9             : #include <optional>
      10             : 
      11             : #ifndef AT_PER_OPERATOR_HEADERS
      12             : #include <ATen/Functions.h>
      13             : #include <ATen/NativeFunctions.h>
      14             : #else
      15             : #include <ATen/ops/alias.h>
      16             : #include <ATen/ops/empty.h>
      17             : #include <ATen/ops/scalar_tensor.h>
      18             : #include <ATen/ops/zeros.h>
      19             : #endif
      20             : 
      21             : #include <ATen/core/List.h>
      22             : 
      23             : #include <utility>
      24             : 
      25             : namespace at::indexing {
      26             : 
      27             : constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
      28             : constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
      29             : 
      30             : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
      31             : 
      32             : constexpr std::nullopt_t None = std::nullopt;
      33             : 
      34             : struct TORCH_API EllipsisIndexType final {
      35             :   EllipsisIndexType() = default;
      36             : };
      37             : TORCH_API extern const EllipsisIndexType Ellipsis;
      38             : 
      39          35 : struct TORCH_API Slice final {
      40             :  public:
      41         110 :   Slice(
      42             :       std::optional<c10::SymInt> start_index = std::nullopt,
      43             :       std::optional<c10::SymInt> stop_index = std::nullopt,
      44         110 :       std::optional<c10::SymInt> step_index = std::nullopt) {
      45         110 :     if (!step_index.has_value()) {
      46         220 :       step_ = c10::SymInt(1);
      47             :     } else {
      48           0 :       step_ = std::move(step_index).value();
      49             :     }
      50             : 
      51         110 :     TORCH_CHECK_VALUE(
      52             :         step_.sym_ne(0).expect_true(__FILE__, __LINE__),
      53             :         "slice step cannot be zero");
      54             : 
      55         110 :     if (!start_index.has_value()) {
      56         330 :       start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
      57             :     } else {
      58           0 :       start_ = std::move(start_index).value();
      59             :     }
      60             : 
      61         110 :     if (!stop_index.has_value()) {
      62         330 :       stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
      63             :     } else {
      64           0 :       stop_ = std::move(stop_index).value();
      65             :     }
      66         110 :   }
      67             : 
      68             :   inline c10::SymInt start() const {
      69             :     return start_;
      70             :   }
      71             : 
      72             :   inline c10::SymInt stop() const {
      73             :     return stop_;
      74             :   }
      75             : 
      76             :   inline c10::SymInt step() const {
      77             :     return step_;
      78             :   }
      79             : 
      80             :  private:
      81             :   c10::SymInt start_;
      82             :   c10::SymInt stop_;
      83             :   c10::SymInt step_;
      84             : };
      85             : 
      86             : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
      87             : 
      88             : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
      89             : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
      90             : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
      91             : // indexing operations can be performed using the supplied indices.
      92             : //
      93             : // There is one-to-one correspondence between Python and C++ tensor index types:
      94             : // Python                  | C++
      95             : // -----------------------------------------------------
      96             : // `None`                  | `at::indexing::None`
      97             : // `Ellipsis`              | `at::indexing::Ellipsis`
      98             : // `...`                   | `"..."`
      99             : // `123`                   | `123`
     100             : // `True` / `False`        | `true` / `false`
     101             : // `:`                     | `Slice()` / `Slice(None, None)`
     102             : // `::`                    | `Slice()` / `Slice(None, None, None)`
     103             : // `1:`                    | `Slice(1, None)`
     104             : // `1::`                   | `Slice(1, None, None)`
     105             : // `:3`                    | `Slice(None, 3)`
     106             : // `:3:`                   | `Slice(None, 3, None)`
     107             : // `::2`                   | `Slice(None, None, 2)`
     108             : // `1:3`                   | `Slice(1, 3)`
     109             : // `1::2`                  | `Slice(1, None, 2)`
     110             : // `:3:2`                  | `Slice(None, 3, 2)`
     111             : // `1:3:2`                 | `Slice(1, 3, 2)`
     112             : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
     113             : struct TORCH_API TensorIndex final {
     114             :   // Case 1: `at::indexing::None`
     115             :   TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
     116             : 
     117             :   // Case 2: "..." / `at::indexing::Ellipsis`
     118             :   TensorIndex(at::indexing::EllipsisIndexType)
     119             :       : type_(TensorIndexType::Ellipsis) {}
     120             :   TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
     121             :     TORCH_CHECK_VALUE(
     122             :         strcmp(str, "...") == 0,
     123             :         "Expected \"...\" to represent an ellipsis index, but got \"",
     124             :         str,
     125             :         "\"");
     126             :   }
     127             : 
     128             :   // Case 3: (Sym) Integer value
     129          35 :   TensorIndex(SymInt integer)
     130          70 :       : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
     131          22 :   TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
     132          48 :   TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
     133             : 
     134             :   // Case 4: Boolean value
     135             :   template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
     136             :   TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
     137             : 
     138             :   // Case 5: Slice represented in `at::indexing::Slice` form
     139          35 :   TensorIndex(Slice slice)
     140          35 :       : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
     141             : 
     142             :   // Case 6: Tensor value
     143             :   TensorIndex(Tensor tensor)
     144             :       : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
     145             : 
     146             :   inline bool is_none() const {
     147             :     return type_ == TensorIndexType::None;
     148             :   }
     149             : 
     150             :   inline bool is_ellipsis() const {
     151             :     return type_ == TensorIndexType::Ellipsis;
     152             :   }
     153             : 
     154             :   inline bool is_integer() const {
     155             :     return type_ == TensorIndexType::SymInt;
     156             :   }
     157             : 
     158             :   inline SymInt integer() const {
     159             :     return integer_;
     160             :   }
     161             : 
     162             :   inline bool is_boolean() const {
     163             :     return type_ == TensorIndexType::Boolean;
     164             :   }
     165             : 
     166             :   inline bool boolean() const {
     167             :     return boolean_;
     168             :   }
     169             : 
     170             :   inline bool is_slice() const {
     171             :     return type_ == TensorIndexType::Slice;
     172             :   }
     173             : 
     174             :   inline const Slice& slice() const {
     175             :     return slice_;
     176             :   }
     177             : 
     178             :   inline bool is_tensor() const {
     179             :     return type_ == TensorIndexType::Tensor;
     180             :   }
     181             : 
     182             :   inline const Tensor& tensor() const {
     183             :     return tensor_;
     184             :   }
     185             : 
     186             :  private:
     187             :   SymInt integer_ = 0;
     188             :   bool boolean_ = false;
     189             :   Slice slice_;
     190             :   Tensor tensor_;
     191             :   TensorIndexType type_;
     192             : };
     193             : 
     194             : TORCH_API std::ostream& operator<<(
     195             :     std::ostream& stream,
     196             :     const TensorIndex& tensor_index);
     197             : TORCH_API std::ostream& operator<<(
     198             :     std::ostream& stream,
     199             :     const std::vector<TensorIndex>& tensor_indices);
     200             : 
     201             : namespace impl {
     202             : inline Tensor applySlice(
     203             :     const Tensor& self,
     204             :     int64_t dim,
     205             :     c10::SymInt start,
     206             :     c10::SymInt stop,
     207             :     c10::SymInt step,
     208             :     bool disable_slice_optimization,
     209             :     const at::Device& self_device,
     210             :     const std::optional<SymIntArrayRef>& self_sizes) {
     211             :   // TODO: implement negative step
     212             :   TORCH_CHECK_VALUE(
     213             :       step.sym_gt(0).expect_true(__FILE__, __LINE__),
     214             :       "step must be greater than zero");
     215             : 
     216             :   // See NOTE [nested tensor size for indexing]
     217             :   if (self_sizes.has_value()) {
     218             :     // Skip this optimization if we are tracing, as the trace may be polymorphic
     219             :     // over the shape of the `self` tensor, and we still want to record
     220             :     // the slice.
     221             :     SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
     222             :         ? (*self_sizes)[dim]
     223             :         : self.sym_size(dim);
     224             :     if (!disable_slice_optimization &&
     225             :         TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
     226             :         TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
     227             :       return self;
     228             :     }
     229             :   }
     230             :   return self.slice_symint(
     231             :       dim, std::move(start), std::move(stop), std::move(step));
     232             : }
     233             : 
     234             : inline Tensor applySelect(
     235             :     const Tensor& self,
     236             :     int64_t dim,
     237             :     SymInt index,
     238             :     int64_t real_dim,
     239             :     const at::Device& /*self_device*/,
     240             :     const std::optional<SymIntArrayRef>& self_sizes) {
     241             :   // See NOTE [nested tensor size for indexing]
     242             :   if (self_sizes.has_value()) {
     243             :     auto maybe_index = index.maybe_as_int();
     244             :     if (maybe_index.has_value()) {
     245             :       TORCH_CHECK_INDEX(
     246             :           !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
     247             :           "invalid index of a 0-dim tensor. ",
     248             :           "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
     249             :     }
     250             : 
     251             :     auto size = (*self_sizes)[dim];
     252             :     // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
     253             :     // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
     254             :     // minus is undefined by the standard but in practice is equal to self. On
     255             :     // the other hand, indexing wraping is valid for all negative int64_t
     256             :     // values, as x[INT64_MIN] is the same as x[INT64_MAX]
     257             :     TORCH_CHECK_INDEX(
     258             :         size.sym_gt(-1 - index)
     259             :             .sym_and(size.sym_gt(index))
     260             :             .expect_true(__FILE__, __LINE__),
     261             :         "index ",
     262             :         index,
     263             :         " is out of bounds for dimension ",
     264             :         real_dim,
     265             :         " with size ",
     266             :         size);
     267             :   }
     268             : 
     269             :   // if the index is negative, do not normalize it because that would fix the
     270             :   // index on the current tensor size in the tracer. aten::select also works on
     271             :   // negative indices
     272             :   return self.select_symint(dim, std::move(index));
     273             : }
     274             : 
     275             : inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
     276             :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     277             :   // false as empty.
     278             :   if (value) {
     279             :     return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
     280             :   } else {
     281             :     return at::empty({0}, self.options().dtype(kLong));
     282             :   }
     283             : }
     284             : 
     285             : inline Tensor boolToIndexingTensorNonNativeDeviceType(
     286             :     const Tensor& self,
     287             :     bool value) {
     288             :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     289             :   // false as empty.
     290             :   if (value) {
     291             :     return at::zeros({1}, self.options().dtype(kLong));
     292             :   } else {
     293             :     return at::empty({0}, self.options().dtype(kLong));
     294             :   }
     295             : }
     296             : 
     297             : inline Tensor boolToIndexingTensor(
     298             :     const Tensor& self,
     299             :     bool value,
     300             :     const at::Device& self_device) {
     301             :   if (self_device == at::kCPU || self_device == at::kCUDA) {
     302             :     return boolToIndexingTensorCPUOrCUDA(self, value);
     303             :   } else {
     304             :     return boolToIndexingTensorNonNativeDeviceType(self, value);
     305             :   }
     306             : }
     307             : 
     308             : inline Tensor scalarToTensorNonNativeDeviceType(
     309             :     const Scalar& v,
     310             :     const TensorOptions& options) {
     311             :   return at::scalar_tensor(v, options);
     312             : }
     313             : 
     314             : inline void recordTensorIndex(
     315             :     const Tensor& tensor,
     316             :     std::vector<Tensor>& outIndices,
     317             :     int64_t* dim_ptr) {
     318             :   // TODO: check scalarType
     319             :   outIndices.resize(*dim_ptr + 1);
     320             :   outIndices[*dim_ptr] = tensor;
     321             :   (*dim_ptr)++;
     322             : }
     323             : 
     324             : inline c10::List<::std::optional<Tensor>> typeConvertIndices(
     325             :     const Tensor& /*self*/,
     326             :     std::vector<Tensor>&& indices) {
     327             :   c10::List<::std::optional<Tensor>> converted_inds;
     328             :   converted_inds.reserve(indices.size());
     329             :   for (auto&& i : std::move(indices)) {
     330             :     converted_inds.push_back(std::move(i));
     331             :   }
     332             :   return converted_inds;
     333             : }
     334             : 
     335             : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
     336             : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
     337             : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
     338             : // indexing (i.e. it's called by `applySlicing` which is called by
     339             : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
     340             : // than one dimension). If we were to merge the Python/C++
     341             : // `count_specified_dimensions` function, on the Python side we would have to
     342             : // construct a `std::vector` container to be consumed by the C++
     343             : // `count_specified_dimensions` function, which adds 100s of nanoseconds
     344             : // overhead and is undesirable.
     345             : inline int64_t count_specified_dimensions(
     346             :     const ArrayRef<TensorIndex>& indices) {
     347             :   // Count the number of indexed dimensions (everything but ellipsis and None)
     348             :   int64_t count = 0;
     349             :   for (auto& obj : indices) {
     350             :     if (obj.is_tensor()) {
     351             :       auto& tensor = obj.tensor();
     352             :       if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
     353             :         count += tensor.dim();
     354             :       } else {
     355             :         count++;
     356             :       }
     357             :     } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
     358             :       count++;
     359             :     }
     360             :   }
     361             :   return count;
     362             : }
     363             : } // namespace impl
     364             : 
     365             : // NOTE: Many functions below are only for consumption from Python indexing
     366             : // implementation, they include:
     367             : //
     368             : // - `Tensor scalarToTensor(...)`
     369             : // - `IntArrayRef slicePrefix1sSize(...)`
     370             : // - `void copy_to(...)`
     371             : // - `Tensor handleDimInMultiDimIndexing(...)`
     372             : // - `Tensor dispatch_index(...)`
     373             : // - `Tensor dispatch_index_put_(...)`
     374             : // - `Tensor get_item(...)`
     375             : // - `void set_item(...)`
     376             : //
     377             : // The rest of the functions are in `at::indexing::impl` namespace, signifying
     378             : // that they shouldn't be used from Python indexing implementation.
     379             : inline Tensor scalarToTensor(
     380             :     const Scalar& v,
     381             :     const TensorOptions& options,
     382             :     const at::Device& self_device) {
     383             :   if (self_device == at::kCPU && !v.isSymbolic()) {
     384             :     return at::detail::scalar_tensor_static(
     385             :         v,
     386             :         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
     387             :         options.dtype_opt()->toScalarType(),
     388             :         self_device);
     389             :   } else {
     390             :     return impl::scalarToTensorNonNativeDeviceType(v, options);
     391             :   }
     392             : }
     393             : 
     394             : // To match numpy semantics:
     395             : // As a special case for backwards compatibility,
     396             : // strip away unit dimensions from the left of 'src'
     397             : inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
     398             :   size_t first_non1_src = sizes.size();
     399             :   for (const auto i : c10::irange(sizes.size())) {
     400             :     // Unbacked SymInt has different behavior, but this is sound because
     401             :     // failing to slice will only ever cause an error, not divergent
     402             :     // behavior
     403             :     if (!sizes[i].has_hint() || sizes[i] != 1) {
     404             :       first_non1_src = i;
     405             :       break;
     406             :     }
     407             :   }
     408             : 
     409             :   return sizes.slice(first_non1_src);
     410             : }
     411             : 
     412             : inline void copy_to(const Tensor& dst, const Tensor& src) {
     413             :   if (dst.sym_sizes().equals(src.sym_sizes())) {
     414             :     // A shortcut to avoid generating hard-coded constant sizes during tracing.
     415             :     // This is not a perfect solution: when src & dst have different shapes,
     416             :     // constants will still appear. Users can workaround that case by
     417             :     // dst[index..] = src.reshape(..)
     418             :     dst.copy_(src);
     419             :     return;
     420             :   } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
     421             :     dst.fill_(src);
     422             :     return;
     423             :   }
     424             :   auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
     425             :   c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
     426             :   dst.copy_(*b_src);
     427             : }
     428             : 
     429             : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
     430             : // indexing functions from Python ]
     431             : inline Tensor handleDimInMultiDimIndexing(
     432             :     const Tensor& prev_dim_result,
     433             :     const Tensor& original_tensor,
     434             :     const TensorIndex& index,
     435             :     int64_t* dim_ptr,
     436             :     int64_t* specified_dims_ptr,
     437             :     int64_t real_dim,
     438             :     std::vector<Tensor>& outIndices,
     439             :     bool disable_slice_optimization,
     440             :     const at::Device& original_tensor_device,
     441             :     const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
     442             :   if (index.is_integer()) {
     443             :     return impl::applySelect(
     444             :         prev_dim_result,
     445             :         *dim_ptr,
     446             :         index.integer(),
     447             :         real_dim,
     448             :         original_tensor_device,
     449             :         prev_dim_result_sizes);
     450             :   } else if (index.is_slice()) {
     451             :     Tensor result = impl::applySlice(
     452             :         prev_dim_result,
     453             :         *dim_ptr,
     454             :         index.slice().start(),
     455             :         index.slice().stop(),
     456             :         index.slice().step(),
     457             :         /*disable_slice_optimization=*/disable_slice_optimization,
     458             :         original_tensor_device,
     459             :         prev_dim_result_sizes);
     460             :     (*dim_ptr)++;
     461             :     return result;
     462             :   } else if (index.is_ellipsis()) {
     463             :     (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
     464             :     return prev_dim_result;
     465             :   } else if (index.is_none()) {
     466             :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     467             :     (*dim_ptr)++;
     468             :     return result;
     469             :   } else if (index.is_boolean()) {
     470             :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     471             :     impl::recordTensorIndex(
     472             :         impl::boolToIndexingTensor(
     473             :             result, index.boolean(), original_tensor_device),
     474             :         outIndices,
     475             :         dim_ptr);
     476             :     return result;
     477             :   } else if (index.is_tensor()) {
     478             :     Tensor result = prev_dim_result;
     479             :     const Tensor& tensor = index.tensor();
     480             :     auto scalar_type = tensor.scalar_type();
     481             :     if (tensor.dim() == 0 &&
     482             :         at::isIntegralType(scalar_type, /*includeBool=*/true)) {
     483             :       if (scalar_type != at::kByte && scalar_type != at::kBool) {
     484             :         result = impl::applySelect(
     485             :             result,
     486             :             *dim_ptr,
     487             :             tensor.item<int64_t>(),
     488             :             real_dim,
     489             :             original_tensor_device,
     490             :             prev_dim_result_sizes);
     491             :       } else {
     492             :         result = result.unsqueeze(*dim_ptr);
     493             :         if (scalar_type == at::kBool) {
     494             :           impl::recordTensorIndex(
     495             :               impl::boolToIndexingTensor(
     496             :                   result, tensor.item<bool>() != 0, original_tensor_device),
     497             :               outIndices,
     498             :               dim_ptr);
     499             :         } else {
     500             :           impl::recordTensorIndex(
     501             :               impl::boolToIndexingTensor(
     502             :                   result, tensor.item<uint8_t>() != 0, original_tensor_device),
     503             :               outIndices,
     504             :               dim_ptr);
     505             :         }
     506             :       }
     507             :     } else {
     508             :       impl::recordTensorIndex(tensor, outIndices, dim_ptr);
     509             :     }
     510             :     return result;
     511             :   } else {
     512             :     TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
     513             :   }
     514             : }
     515             : 
     516             : namespace impl {
     517             : // This mirrors `applySlicing` in
     518             : // torch/csrc/autograd/python_variable_indexing.cpp
     519             : inline Tensor applySlicing(
     520             :     const Tensor& self,
     521             :     const ArrayRef<TensorIndex>& indices,
     522             :     std::vector<Tensor>& outIndices,
     523             :     bool disable_slice_optimization,
     524             :     const at::Device& self_device,
     525             :     const std::optional<SymIntArrayRef>& self_sizes) {
     526             :   int64_t dim = 0;
     527             :   int64_t specified_dims = impl::count_specified_dimensions(indices);
     528             : 
     529             :   // See NOTE [nested tensor size for indexing]
     530             :   if (self_sizes.has_value()) {
     531             :     TORCH_CHECK_INDEX(
     532             :         specified_dims <= (int64_t)self_sizes->size(),
     533             :         "too many indices for tensor of dimension ",
     534             :         (int)self_sizes->size());
     535             :   }
     536             : 
     537             :   Tensor result = self;
     538             :   for (const auto i : c10::irange(indices.size())) {
     539             :     auto& obj = indices[i];
     540             :     // See NOTE [nested tensor size for indexing]
     541             :     std::optional<SymIntArrayRef> result_sizes = result.is_nested()
     542             :         ? std::optional<SymIntArrayRef>(std::nullopt)
     543             :         : std::optional<SymIntArrayRef>(result.sym_sizes());
     544             :     result = handleDimInMultiDimIndexing(
     545             :         /*prev_dim_result=*/result,
     546             :         /*original_tensor=*/self,
     547             :         /*index=*/obj,
     548             :         /*dim_ptr=*/&dim,
     549             :         /*specified_dims_ptr=*/&specified_dims,
     550             :         /*real_dim=*/static_cast<int64_t>(i),
     551             :         /*outIndices=*/outIndices,
     552             :         /*disable_slice_optimization=*/disable_slice_optimization,
     553             :         /*original_tensor_device=*/self_device,
     554             :         /*prev_dim_result_sizes=*/result_sizes);
     555             :   }
     556             :   return result;
     557             : }
     558             : } // namespace impl
     559             : 
     560             : inline Tensor dispatch_index(
     561             :     const Tensor& self,
     562             :     std::vector<Tensor>&& indices) {
     563             :   return self.index(impl::typeConvertIndices(self, std::move(indices)));
     564             : }
     565             : 
     566             : inline Tensor dispatch_index_put_(
     567             :     Tensor& self,
     568             :     std::vector<Tensor>&& indices,
     569             :     const Tensor& value) {
     570             :   return self.index_put_(
     571             :       impl::typeConvertIndices(self, std::move(indices)), value);
     572             : }
     573             : 
     574             : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
     575             : // functions from Python ]
     576             : //
     577             : // Question: When should we set `disable_slice_optimization` to `true` when
     578             : // calling C++ tensor indexing functions from Python indexing code?
     579             : //
     580             : // Answer: What "slice optimization" means: when we have a slicing expression
     581             : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
     582             : // would skip dispatching the actual slice call as an optimization. However,
     583             : // here are the cases where we DON'T want this optimization:
     584             : //
     585             : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
     586             : //    Reason: we always return a shallow copy for expressions such as
     587             : //    `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
     588             : //    :]`, we return an alias of `tensor` by doing the following:
     589             : //    ```
     590             : //    Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
     591             : //    disable_slice_optimization, self_device, self_sizes); if
     592             : //    (tensorIndices.empty()) {
     593             : //      if (sliced.is_same(self)) {
     594             : //        // ensure we return a shallow copy for things like x[...]
     595             : //        sliced = at::alias(sliced);
     596             : //      }
     597             : //      return sliced;
     598             : //    }
     599             : //    ```)
     600             : // 2. When we are doing JIT tracing.
     601             : //    Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
     602             : //    slice operation.
     603             : 
     604             : // This mirrors `THPVariable_getitem` in
     605             : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
     606             : // `disable_slice_optimization` when calling C++ tensor indexing functions from
     607             : // Python ]
     608             : inline Tensor get_item(
     609             :     const Tensor& self,
     610             :     const ArrayRef<TensorIndex>& indices,
     611             :     bool disable_slice_optimization = false) {
     612             :   at::Device self_device = self.device();
     613             :   // NOTE [nested tensor size for indexing]
     614             :   // nested tensor does not have a size (yet) so for now we represent its size
     615             :   // as null may need to be changed after we reach a better solution for nested
     616             :   // tensor size
     617             :   std::optional<SymIntArrayRef> self_sizes = self.is_nested()
     618             :       ? std::optional<SymIntArrayRef>(std::nullopt)
     619             :       : std::optional<SymIntArrayRef>(self.sym_sizes());
     620             : 
     621             :   // handle simple types: integers, slices, none, ellipsis, bool
     622             :   if (indices.size() == 1) {
     623             :     const TensorIndex& index = indices[0];
     624             :     if (index.is_integer()) {
     625             :       return impl::applySelect(
     626             :           self, 0, index.integer(), 0, self_device, self_sizes);
     627             :     } else if (index.is_slice()) {
     628             :       return impl::applySlice(
     629             :           self,
     630             :           0,
     631             :           index.slice().start(),
     632             :           index.slice().stop(),
     633             :           index.slice().step(),
     634             :           /*disable_slice_optimization=*/true,
     635             :           self_device,
     636             :           self_sizes);
     637             :     } else if (index.is_none()) {
     638             :       return self.unsqueeze(0);
     639             :     } else if (index.is_ellipsis()) {
     640             :       return at::alias(self);
     641             :     } else if (index.is_boolean()) {
     642             :       Tensor result = self.unsqueeze(0);
     643             :       return dispatch_index(
     644             :           result,
     645             :           std::vector<Tensor>{impl::boolToIndexingTensor(
     646             :               result, index.boolean(), self_device)});
     647             :     }
     648             :   }
     649             : 
     650             :   std::vector<Tensor> tensorIndices;
     651             :   Tensor sliced = impl::applySlicing(
     652             :       self,
     653             :       indices,
     654             :       tensorIndices,
     655             :       disable_slice_optimization,
     656             :       self_device,
     657             :       self_sizes);
     658             :   if (tensorIndices.empty()) {
     659             :     if (sliced.is_same(self)) {
     660             :       // ensure we return a shallow copy for things like x[...]
     661             :       sliced = at::alias(sliced);
     662             :     }
     663             :     return sliced;
     664             :   }
     665             : 
     666             :   // indexing by tensors ("advanced" indexing)
     667             :   return dispatch_index(sliced, std::move(tensorIndices));
     668             : }
     669             : 
     670             : // This mirrors `THPVariable_setitem` in
     671             : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
     672             : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
     673             : // tensor indexing functions from Python ]
     674             : inline void set_item(
     675             :     const Tensor& self,
     676             :     const ArrayRef<TensorIndex>& indices,
     677             :     const Tensor& value,
     678             :     bool disable_slice_optimization = false) {
     679             :   at::Device self_device = self.device();
     680             :   SymIntArrayRef self_sizes = self.sym_sizes();
     681             : 
     682             :   // handle simple types: integers, slices, ellipsis, bool
     683             :   if (indices.size() == 1) {
     684             :     const TensorIndex& index = indices[0];
     685             :     if (index.is_boolean() && !index.boolean()) {
     686             :       // do nothing for false (technically we should check the size, but we
     687             :       // don't have real 0-sized shapes.
     688             :       return;
     689             :     } else if (index.is_ellipsis()) {
     690             :       copy_to(self, value);
     691             :       return;
     692             :     } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
     693             :       copy_to(self.unsqueeze(0), value);
     694             :       return;
     695             :     } else if (index.is_integer()) {
     696             :       copy_to(
     697             :           impl::applySelect(
     698             :               self, 0, index.integer(), 0, self_device, self_sizes),
     699             :           value);
     700             :       return;
     701             :     } else if (index.is_slice()) {
     702             :       copy_to(
     703             :           impl::applySlice(
     704             :               self,
     705             :               0,
     706             :               index.slice().start(),
     707             :               index.slice().stop(),
     708             :               index.slice().step(),
     709             :               /*disable_slice_optimization=*/disable_slice_optimization,
     710             :               self_device,
     711             :               self_sizes),
     712             :           value);
     713             :       return;
     714             :     }
     715             :   }
     716             : 
     717             :   std::vector<Tensor> tensorIndices;
     718             :   Tensor sliced = impl::applySlicing(
     719             :       self,
     720             :       indices,
     721             :       tensorIndices,
     722             :       disable_slice_optimization,
     723             :       self_device,
     724             :       self_sizes);
     725             :   if (tensorIndices.empty()) {
     726             :     copy_to(sliced, value);
     727             :     return;
     728             :   }
     729             : 
     730             :   SymIntArrayRef valueSizes = value.sym_sizes();
     731             :   SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
     732             :   Tensor valuesSliced;
     733             :   if (!valueSizes.equals(slicedValueSizes)) {
     734             :     valuesSliced = value.view_symint(slicedValueSizes);
     735             :   } else {
     736             :     valuesSliced = value;
     737             :   }
     738             :   dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
     739             :   return;
     740             : }
     741             : 
     742             : } // namespace at::indexing

Generated by: LCOV version 1.16