LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.10/site-packages/torch/include/ATen - TensorIndexing.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 17 20 85.0 %
Date: 2026-06-05 17:04:24 Functions: 5 5 100.0 %

          Line data    Source code
       1             : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
       2             : #pragma once
       3             : 
       4             : #include <ATen/ExpandUtils.h>
       5             : #include <ATen/ScalarOps.h>
       6             : #include <ATen/core/Tensor.h>
       7             : #include <ATen/core/TensorBody.h>
       8             : #include <c10/core/SymInt.h>
       9             : #include <c10/util/irange.h>
      10             : #include <optional>
      11             : 
      12             : #ifndef AT_PER_OPERATOR_HEADERS
      13             : #include <ATen/Functions.h>
      14             : #include <ATen/NativeFunctions.h>
      15             : #else
      16             : #include <ATen/ops/alias.h>
      17             : #include <ATen/ops/empty.h>
      18             : #include <ATen/ops/scalar_tensor.h>
      19             : #include <ATen/ops/zeros.h>
      20             : #endif
      21             : 
      22             : #include <ATen/core/List.h>
      23             : 
      24             : #include <utility>
      25             : 
      26             : namespace at::indexing {
      27             : 
      28             : constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
      29             : constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
      30             : 
      31             : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
      32             : 
      33             : constexpr std::nullopt_t None = std::nullopt;
      34             : 
      35             : struct TORCH_API EllipsisIndexType final {
      36             :   EllipsisIndexType() = default;
      37             : };
      38             : TORCH_API extern const EllipsisIndexType Ellipsis;
      39             : 
      40          38 : struct TORCH_API Slice final {
      41             :  public:
      42         116 :   Slice(
      43             :       std::optional<c10::SymInt> start_index = std::nullopt,
      44             :       std::optional<c10::SymInt> stop_index = std::nullopt,
      45         116 :       std::optional<c10::SymInt> step_index = std::nullopt) {
      46         116 :     if (!step_index.has_value()) {
      47         232 :       step_ = c10::SymInt(1);
      48             :     } else {
      49           0 :       step_ = std::move(step_index).value();
      50             :     }
      51             : 
      52         116 :     TORCH_CHECK_VALUE(
      53             :         step_.sym_ne(0).expect_true(__FILE__, __LINE__),
      54             :         "slice step cannot be zero");
      55             : 
      56         116 :     if (!start_index.has_value()) {
      57         348 :       start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
      58             :     } else {
      59           0 :       start_ = std::move(start_index).value();
      60             :     }
      61             : 
      62         116 :     if (!stop_index.has_value()) {
      63         348 :       stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
      64             :     } else {
      65           0 :       stop_ = std::move(stop_index).value();
      66             :     }
      67         116 :   }
      68             : 
      69             :   inline c10::SymInt start() const {
      70             :     return start_;
      71             :   }
      72             : 
      73             :   inline c10::SymInt stop() const {
      74             :     return stop_;
      75             :   }
      76             : 
      77             :   inline c10::SymInt step() const {
      78             :     return step_;
      79             :   }
      80             : 
      81             :  private:
      82             :   c10::SymInt start_;
      83             :   c10::SymInt stop_;
      84             :   c10::SymInt step_;
      85             : };
      86             : 
      87             : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
      88             : 
      89             : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
      90             : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
      91             : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
      92             : // indexing operations can be performed using the supplied indices.
      93             : //
      94             : // There is one-to-one correspondence between Python and C++ tensor index types:
      95             : // Python                  | C++
      96             : // -----------------------------------------------------
      97             : // `None`                  | `at::indexing::None`
      98             : // `Ellipsis`              | `at::indexing::Ellipsis`
      99             : // `...`                   | `"..."`
     100             : // `123`                   | `123`
     101             : // `True` / `False`        | `true` / `false`
     102             : // `:`                     | `Slice()` / `Slice(None, None)`
     103             : // `::`                    | `Slice()` / `Slice(None, None, None)`
     104             : // `1:`                    | `Slice(1, None)`
     105             : // `1::`                   | `Slice(1, None, None)`
     106             : // `:3`                    | `Slice(None, 3)`
     107             : // `:3:`                   | `Slice(None, 3, None)`
     108             : // `::2`                   | `Slice(None, None, 2)`
     109             : // `1:3`                   | `Slice(1, 3)`
     110             : // `1::2`                  | `Slice(1, None, 2)`
     111             : // `:3:2`                  | `Slice(None, 3, 2)`
     112             : // `1:3:2`                 | `Slice(1, 3, 2)`
     113             : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
     114             : struct TORCH_API TensorIndex final {
     115             :   // Case 1: `at::indexing::None`
     116             :   TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {}
     117             : 
     118             :   // Case 2: "..." / `at::indexing::Ellipsis`
     119             :   TensorIndex(at::indexing::EllipsisIndexType /*unused*/)
     120             :       : type_(TensorIndexType::Ellipsis) {}
     121             :   TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
     122             :     TORCH_CHECK_VALUE(
     123             :         strcmp(str, "...") == 0,
     124             :         "Expected \"...\" to represent an ellipsis index, but got \"",
     125             :         str,
     126             :         "\"");
     127             :   }
     128             : 
     129             :   // Case 3: (Sym) Integer value
     130          38 :   TensorIndex(SymInt integer)
     131          76 :       : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
     132          22 :   TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
     133          54 :   TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
     134             : 
     135             :   // Case 4: Boolean value
     136             :   template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
     137             :   TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
     138             : 
     139             :   // Case 5: Slice represented in `at::indexing::Slice` form
     140          38 :   TensorIndex(Slice slice)
     141          38 :       : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
     142             : 
     143             :   // Case 6: Tensor value
     144             :   TensorIndex(Tensor tensor)
     145             :       : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
     146             : 
     147             :   inline bool is_none() const {
     148             :     return type_ == TensorIndexType::None;
     149             :   }
     150             : 
     151             :   inline bool is_ellipsis() const {
     152             :     return type_ == TensorIndexType::Ellipsis;
     153             :   }
     154             : 
     155             :   inline bool is_integer() const {
     156             :     return type_ == TensorIndexType::SymInt;
     157             :   }
     158             : 
     159             :   inline SymInt integer() const {
     160             :     return integer_;
     161             :   }
     162             : 
     163             :   inline bool is_boolean() const {
     164             :     return type_ == TensorIndexType::Boolean;
     165             :   }
     166             : 
     167             :   inline bool boolean() const {
     168             :     return boolean_;
     169             :   }
     170             : 
     171             :   inline bool is_slice() const {
     172             :     return type_ == TensorIndexType::Slice;
     173             :   }
     174             : 
     175             :   inline const Slice& slice() const {
     176             :     return slice_;
     177             :   }
     178             : 
     179             :   inline bool is_tensor() const {
     180             :     return type_ == TensorIndexType::Tensor;
     181             :   }
     182             : 
     183             :   inline const Tensor& tensor() const {
     184             :     return tensor_;
     185             :   }
     186             : 
     187             :  private:
     188             :   SymInt integer_ = 0;
     189             :   bool boolean_ = false;
     190             :   Slice slice_;
     191             :   Tensor tensor_;
     192             :   TensorIndexType type_;
     193             : };
     194             : 
     195             : TORCH_API std::ostream& operator<<(
     196             :     std::ostream& stream,
     197             :     const TensorIndex& tensor_index);
     198             : TORCH_API std::ostream& operator<<(
     199             :     std::ostream& stream,
     200             :     const std::vector<TensorIndex>& tensor_indices);
     201             : 
     202             : namespace impl {
     203             : inline Tensor applySlice(
     204             :     const Tensor& self,
     205             :     int64_t dim,
     206             :     c10::SymInt start,
     207             :     c10::SymInt stop,
     208             :     c10::SymInt step,
     209             :     bool disable_slice_optimization,
     210             :     const at::Device& self_device,
     211             :     const std::optional<SymIntArrayRef>& self_sizes) {
     212             :   // TODO: implement negative step
     213             :   TORCH_CHECK_VALUE(
     214             :       step.sym_gt(0).expect_true(__FILE__, __LINE__),
     215             :       "step must be greater than zero");
     216             : 
     217             :   // See NOTE [nested tensor size for indexing]
     218             :   if (self_sizes.has_value() && !self_sizes.value().empty()) {
     219             :     // Skip this optimization if we are tracing, as the trace may be polymorphic
     220             :     // over the shape of the `self` tensor, and we still want to record
     221             :     // the slice.
     222             :     SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
     223             :         ? (*self_sizes)[dim]
     224             :         : self.sym_size(dim);
     225             :     if (!disable_slice_optimization &&
     226             :         TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
     227             :         TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) {
     228             :       return self;
     229             :     }
     230             :   }
     231             :   return self.slice_symint(
     232             :       dim, std::move(start), std::move(stop), std::move(step));
     233             : }
     234             : 
     235             : inline Tensor applySelect(
     236             :     const Tensor& self,
     237             :     int64_t dim,
     238             :     SymInt index,
     239             :     int64_t real_dim,
     240             :     const at::Device& /*self_device*/,
     241             :     const std::optional<SymIntArrayRef>& self_sizes) {
     242             :   // See NOTE [nested tensor size for indexing]
     243             :   if (self_sizes.has_value()) {
     244             :     auto maybe_index = index.maybe_as_int();
     245             :     if (maybe_index.has_value()) {
     246             :       TORCH_CHECK_INDEX(
     247             :           !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
     248             :           "invalid index of a 0-dim tensor. ",
     249             :           "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
     250             :     }
     251             : 
     252             :     auto size = (*self_sizes)[dim];
     253             :     // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
     254             :     // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
     255             :     // minus is undefined by the standard but in practice is equal to self. On
     256             :     // the other hand, indexing wrapping is valid for all negative int64_t
     257             :     // values, as x[INT64_MIN] is the same as x[INT64_MAX]
     258             :     TORCH_CHECK_INDEX(
     259             :         size.sym_gt(-1 - index)
     260             :             .sym_and(size.sym_gt(index))
     261             :             .expect_true(__FILE__, __LINE__),
     262             :         "index ",
     263             :         index,
     264             :         " is out of bounds for dimension ",
     265             :         real_dim,
     266             :         " with size ",
     267             :         size);
     268             :   }
     269             : 
     270             :   // if the index is negative, do not normalize it because that would fix the
     271             :   // index on the current tensor size in the tracer. aten::select also works on
     272             :   // negative indices
     273             :   return self.select_symint(dim, std::move(index));
     274             : }
     275             : 
     276             : inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
     277             :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     278             :   // false as empty.
     279             :   if (value) {
     280             :     return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
     281             :   } else {
     282             :     return at::empty({0}, self.options().dtype(kLong));
     283             :   }
     284             : }
     285             : 
     286             : inline Tensor boolToIndexingTensorNonNativeDeviceType(
     287             :     const Tensor& self,
     288             :     bool value) {
     289             :   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
     290             :   // false as empty.
     291             :   if (value) {
     292             :     return at::zeros({1}, self.options().dtype(kLong));
     293             :   } else {
     294             :     return at::empty({0}, self.options().dtype(kLong));
     295             :   }
     296             : }
     297             : 
     298             : inline Tensor boolToIndexingTensor(
     299             :     const Tensor& self,
     300             :     bool value,
     301             :     const at::Device& self_device) {
     302             :   if (self_device == at::kCPU || self_device == at::kCUDA) {
     303             :     return boolToIndexingTensorCPUOrCUDA(self, value);
     304             :   } else {
     305             :     return boolToIndexingTensorNonNativeDeviceType(self, value);
     306             :   }
     307             : }
     308             : 
     309             : inline Tensor scalarToTensorNonNativeDeviceType(
     310             :     const Scalar& v,
     311             :     const TensorOptions& options) {
     312             :   return at::scalar_tensor(v, options);
     313             : }
     314             : 
     315             : inline void recordTensorIndex(
     316             :     const Tensor& tensor,
     317             :     std::vector<Tensor>& outIndices,
     318             :     int64_t* dim_ptr) {
     319             :   if (outIndices.empty()) {
     320             :     outIndices.resize(*dim_ptr + 1);
     321             :     outIndices[*dim_ptr] = tensor;
     322             :   } else {
     323             :     outIndices.push_back(tensor);
     324             :   }
     325             :   if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
     326             :     *dim_ptr += tensor.dim();
     327             :   } else {
     328             :     *dim_ptr += 1;
     329             :   }
     330             : }
     331             : 
     332             : inline c10::List<::std::optional<Tensor>> typeConvertIndices(
     333             :     const Tensor& /*self*/,
     334             :     std::vector<Tensor>&& indices) {
     335             :   c10::List<::std::optional<Tensor>> converted_inds;
     336             :   converted_inds.reserve(indices.size());
     337             :   for (auto&& i : std::move(indices)) {
     338             :     converted_inds.push_back(std::move(i));
     339             :   }
     340             :   return converted_inds;
     341             : }
     342             : 
     343             : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
     344             : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
     345             : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
     346             : // indexing (i.e. it's called by `applySlicing` which is called by
     347             : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
     348             : // than one dimension). If we were to merge the Python/C++
     349             : // `count_specified_dimensions` function, on the Python side we would have to
     350             : // construct a `std::vector` container to be consumed by the C++
     351             : // `count_specified_dimensions` function, which adds 100s of nanoseconds
     352             : // overhead and is undesirable.
     353             : inline int64_t count_specified_dimensions(
     354             :     const ArrayRef<TensorIndex>& indices) {
     355             :   // Count the number of indexed dimensions (everything but ellipsis and None)
     356             :   int64_t count = 0;
     357             :   for (auto& obj : indices) {
     358             :     if (obj.is_tensor()) {
     359             :       auto& tensor = obj.tensor();
     360             :       if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
     361             :         count += tensor.dim();
     362             :       } else {
     363             :         count++;
     364             :       }
     365             :     } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
     366             :       count++;
     367             :     }
     368             :   }
     369             :   return count;
     370             : }
     371             : } // namespace impl
     372             : 
     373             : // NOTE: Many functions below are only for consumption from Python indexing
     374             : // implementation, they include:
     375             : //
     376             : // - `Tensor scalarToTensor(...)`
     377             : // - `IntArrayRef slicePrefix1sSize(...)`
     378             : // - `void copy_to(...)`
     379             : // - `Tensor handleDimInMultiDimIndexing(...)`
     380             : // - `Tensor dispatch_index(...)`
     381             : // - `Tensor dispatch_index_put_(...)`
     382             : // - `Tensor get_item(...)`
     383             : // - `void set_item(...)`
     384             : //
     385             : // The rest of the functions are in `at::indexing::impl` namespace, signifying
     386             : // that they shouldn't be used from Python indexing implementation.
     387             : inline Tensor scalarToTensor(
     388             :     const Scalar& v,
     389             :     const TensorOptions& options,
     390             :     const at::Device& self_device) {
     391             :   if (self_device == at::kCPU && !v.isSymbolic()) {
     392             :     return at::detail::scalar_tensor_static(
     393             :         v,
     394             :         // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
     395             :         options.dtype_opt()->toScalarType(),
     396             :         self_device);
     397             :   } else {
     398             :     return impl::scalarToTensorNonNativeDeviceType(v, options);
     399             :   }
     400             : }
     401             : 
     402             : // To match numpy semantics:
     403             : // As a special case for backwards compatibility,
     404             : // strip away unit dimensions from the left of 'src'
     405             : inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
     406             :   size_t first_non1_src = sizes.size();
     407             :   for (const auto i : c10::irange(sizes.size())) {
     408             :     // Unbacked SymInt has different behavior, but this is sound because
     409             :     // failing to slice will only ever cause an error, not divergent
     410             :     // behavior
     411             :     if (!sizes[i].has_hint() || sizes[i] != 1) {
     412             :       first_non1_src = i;
     413             :       break;
     414             :     }
     415             :   }
     416             : 
     417             :   return sizes.slice(first_non1_src);
     418             : }
     419             : 
     420             : inline void copy_to(const Tensor& dst, const Tensor& src) {
     421             :   if (dst.sym_sizes().equals(src.sym_sizes())) {
     422             :     // A shortcut to avoid generating hard-coded constant sizes during tracing.
     423             :     // This is not a perfect solution: when src & dst have different shapes,
     424             :     // constants will still appear. Users can workaround that case by
     425             :     // dst[index..] = src.reshape(..)
     426             :     dst.copy_(src);
     427             :     return;
     428             :   } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
     429             :     dst.fill_(src);
     430             :     return;
     431             :   }
     432             :   auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
     433             :   c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
     434             :   dst.copy_(*b_src);
     435             : }
     436             : 
     437             : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
     438             : // indexing functions from Python ]
     439             : inline Tensor handleDimInMultiDimIndexing(
     440             :     const Tensor& prev_dim_result,
     441             :     const Tensor& original_tensor,
     442             :     const TensorIndex& index,
     443             :     int64_t* dim_ptr,
     444             :     int64_t* specified_dims_ptr,
     445             :     int64_t real_dim,
     446             :     std::vector<Tensor>& outIndices,
     447             :     bool disable_slice_optimization,
     448             :     const at::Device& original_tensor_device,
     449             :     const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
     450             :   if (index.is_integer()) {
     451             :     return impl::applySelect(
     452             :         prev_dim_result,
     453             :         *dim_ptr,
     454             :         index.integer(),
     455             :         real_dim,
     456             :         original_tensor_device,
     457             :         prev_dim_result_sizes);
     458             :   } else if (index.is_slice()) {
     459             :     Tensor result = impl::applySlice(
     460             :         prev_dim_result,
     461             :         *dim_ptr,
     462             :         index.slice().start(),
     463             :         index.slice().stop(),
     464             :         index.slice().step(),
     465             :         /*disable_slice_optimization=*/disable_slice_optimization,
     466             :         original_tensor_device,
     467             :         prev_dim_result_sizes);
     468             :     (*dim_ptr)++;
     469             :     if (!outIndices.empty()) {
     470             :       outIndices.resize(outIndices.size() + 1);
     471             :     }
     472             :     return result;
     473             :   } else if (index.is_ellipsis()) {
     474             :     auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr;
     475             :     (*dim_ptr) += ellipsis_ndims;
     476             :     if (!outIndices.empty()) {
     477             :       outIndices.resize(outIndices.size() + ellipsis_ndims);
     478             :     }
     479             :     return prev_dim_result;
     480             :   } else if (index.is_none()) {
     481             :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     482             :     (*dim_ptr)++;
     483             :     if (!outIndices.empty()) {
     484             :       outIndices.resize(outIndices.size() + 1);
     485             :     }
     486             :     return result;
     487             :   } else if (index.is_boolean()) {
     488             :     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
     489             :     impl::recordTensorIndex(
     490             :         impl::boolToIndexingTensor(
     491             :             result, index.boolean(), original_tensor_device),
     492             :         outIndices,
     493             :         dim_ptr);
     494             :     return result;
     495             :   } else if (index.is_tensor()) {
     496             :     Tensor result = prev_dim_result;
     497             :     const Tensor& tensor = index.tensor();
     498             :     auto scalar_type = tensor.scalar_type();
     499             :     bool is_batched = tensor.key_set().has_any(c10::DispatchKeySet(
     500             :         {c10::DispatchKey::FuncTorchBatched,
     501             :          c10::DispatchKey::BatchedNestedTensor}));
     502             :     if (tensor.dim() == 0 && !is_batched &&
     503             :         at::isIntegralType(scalar_type, /*includeBool=*/true)) {
     504             :       if (scalar_type != at::kByte && scalar_type != at::kBool) {
     505             :         result = impl::applySelect(
     506             :             result,
     507             :             *dim_ptr,
     508             :             tensor.item<int64_t>(),
     509             :             real_dim,
     510             :             original_tensor_device,
     511             :             prev_dim_result_sizes);
     512             :       } else {
     513             :         result = result.unsqueeze(*dim_ptr);
     514             :         if (scalar_type == at::kBool) {
     515             :           impl::recordTensorIndex(
     516             :               impl::boolToIndexingTensor(
     517             :                   result, tensor.item<bool>() != 0, original_tensor_device),
     518             :               outIndices,
     519             :               dim_ptr);
     520             :         } else {
     521             :           impl::recordTensorIndex(
     522             :               impl::boolToIndexingTensor(
     523             :                   result, tensor.item<uint8_t>() != 0, original_tensor_device),
     524             :               outIndices,
     525             :               dim_ptr);
     526             :         }
     527             :       }
     528             :     } else {
     529             :       impl::recordTensorIndex(tensor, outIndices, dim_ptr);
     530             :     }
     531             :     return result;
     532             :   } else {
     533             :     TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
     534             :   }
     535             : }
     536             : 
     537             : namespace impl {
     538             : // This mirrors `applySlicing` in
     539             : // torch/csrc/autograd/python_variable_indexing.cpp
     540             : inline Tensor applySlicing(
     541             :     const Tensor& self,
     542             :     const ArrayRef<TensorIndex>& indices,
     543             :     std::vector<Tensor>& outIndices,
     544             :     bool disable_slice_optimization,
     545             :     const at::Device& self_device,
     546             :     const std::optional<SymIntArrayRef>& self_sizes) {
     547             :   int64_t dim = 0;
     548             :   int64_t specified_dims = impl::count_specified_dimensions(indices);
     549             : 
     550             :   // See NOTE [nested tensor size for indexing]
     551             :   if (self_sizes.has_value()) {
     552             :     TORCH_CHECK_INDEX(
     553             :         specified_dims <= (int64_t)self_sizes->size(),
     554             :         "too many indices for tensor of dimension ",
     555             :         (int)self_sizes->size());
     556             :   }
     557             : 
     558             :   Tensor result = self;
     559             :   for (const auto i : c10::irange(indices.size())) {
     560             :     auto& obj = indices[i];
     561             :     // See NOTE [nested tensor size for indexing]
     562             :     std::optional<SymIntArrayRef> result_sizes = result.is_nested()
     563             :         ? std::optional<SymIntArrayRef>(std::nullopt)
     564             :         : std::optional<SymIntArrayRef>(result.sym_sizes());
     565             :     result = handleDimInMultiDimIndexing(
     566             :         /*prev_dim_result=*/result,
     567             :         /*original_tensor=*/self,
     568             :         /*index=*/obj,
     569             :         /*dim_ptr=*/&dim,
     570             :         /*specified_dims_ptr=*/&specified_dims,
     571             :         /*real_dim=*/static_cast<int64_t>(i),
     572             :         /*outIndices=*/outIndices,
     573             :         /*disable_slice_optimization=*/disable_slice_optimization,
     574             :         /*original_tensor_device=*/self_device,
     575             :         /*prev_dim_result_sizes=*/result_sizes);
     576             :   }
     577             :   return result;
     578             : }
     579             : } // namespace impl
     580             : 
     581             : inline Tensor dispatch_index(
     582             :     const Tensor& self,
     583             :     std::vector<Tensor>&& indices) {
     584             :   // Remove trailing null elements from indices
     585             :   while (!indices.empty() && !indices.back().defined()) {
     586             :     indices.pop_back();
     587             :   }
     588             :   return self.index(impl::typeConvertIndices(self, std::move(indices)));
     589             : }
     590             : 
     591             : inline Tensor dispatch_index_put_(
     592             :     Tensor& self,
     593             :     std::vector<Tensor>&& indices,
     594             :     const Tensor& value) {
     595             :   // Remove trailing null elements from indices
     596             :   while (!indices.empty() && !indices.back().defined()) {
     597             :     indices.pop_back();
     598             :   }
     599             :   return self.index_put_(
     600             :       impl::typeConvertIndices(self, std::move(indices)), value);
     601             : }
     602             : 
     603             : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
     604             : // functions from Python ]
     605             : //
     606             : // Question: When should we set `disable_slice_optimization` to `true` when
     607             : // calling C++ tensor indexing functions from Python indexing code?
     608             : //
     609             : // Answer: What "slice optimization" means: when we have a slicing expression
     610             : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
     611             : // would skip dispatching the actual slice call as an optimization. However,
     612             : // here are the cases where we DON'T want this optimization:
     613             : //
     614             : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
     615             : //    Reason: we always return a shallow copy for expressions such as
     616             : //    `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
     617             : //    :]`, we return an alias of `tensor` by doing the following:
     618             : //    ```
     619             : //    Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
     620             : //    disable_slice_optimization, self_device, self_sizes); if
     621             : //    (tensorIndices.empty()) {
     622             : //      if (sliced.is_same(self)) {
     623             : //        // ensure we return a shallow copy for things like x[...]
     624             : //        sliced = at::alias(sliced);
     625             : //      }
     626             : //      return sliced;
     627             : //    }
     628             : //    ```)
     629             : // 2. When we are doing JIT tracing.
     630             : //    Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
     631             : //    slice operation.
     632             : 
     633             : // This mirrors `THPVariable_getitem` in
     634             : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
     635             : // `disable_slice_optimization` when calling C++ tensor indexing functions from
     636             : // Python ]
     637             : inline Tensor get_item(
     638             :     const Tensor& self,
     639             :     const ArrayRef<TensorIndex>& indices,
     640             :     bool disable_slice_optimization = false) {
     641             :   at::Device self_device = self.device();
     642             :   // NOTE [nested tensor size for indexing]
     643             :   // nested tensor does not have a size (yet) so for now we represent its size
     644             :   // as null may need to be changed after we reach a better solution for nested
     645             :   // tensor size
     646             :   std::optional<SymIntArrayRef> self_sizes = self.is_nested()
     647             :       ? std::optional<SymIntArrayRef>(std::nullopt)
     648             :       : std::optional<SymIntArrayRef>(self.sym_sizes());
     649             : 
     650             :   // handle simple types: integers, slices, none, ellipsis, bool
     651             :   if (indices.size() == 1) {
     652             :     const TensorIndex& index = indices[0];
     653             :     if (index.is_integer()) {
     654             :       return impl::applySelect(
     655             :           self, 0, index.integer(), 0, self_device, self_sizes);
     656             :     } else if (index.is_slice()) {
     657             :       return impl::applySlice(
     658             :           self,
     659             :           0,
     660             :           index.slice().start(),
     661             :           index.slice().stop(),
     662             :           index.slice().step(),
     663             :           /*disable_slice_optimization=*/true,
     664             :           self_device,
     665             :           self_sizes);
     666             :     } else if (index.is_none()) {
     667             :       return self.unsqueeze(0);
     668             :     } else if (index.is_ellipsis()) {
     669             :       return at::alias(self);
     670             :     } else if (index.is_boolean()) {
     671             :       Tensor result = self.unsqueeze(0);
     672             :       return dispatch_index(
     673             :           result,
     674             :           std::vector<Tensor>{impl::boolToIndexingTensor(
     675             :               result, index.boolean(), self_device)});
     676             :     }
     677             :   }
     678             : 
     679             :   std::vector<Tensor> tensorIndices;
     680             :   Tensor sliced = impl::applySlicing(
     681             :       self,
     682             :       indices,
     683             :       tensorIndices,
     684             :       disable_slice_optimization,
     685             :       self_device,
     686             :       self_sizes);
     687             :   if (tensorIndices.empty()) {
     688             :     if (sliced.is_same(self)) {
     689             :       // ensure we return a shallow copy for things like x[...]
     690             :       sliced = at::alias(sliced);
     691             :     }
     692             :     return sliced;
     693             :   }
     694             : 
     695             :   // indexing by tensors ("advanced" indexing)
     696             :   return dispatch_index(sliced, std::move(tensorIndices));
     697             : }
     698             : 
     699             : // This mirrors `THPVariable_setitem` in
     700             : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
     701             : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
     702             : // tensor indexing functions from Python ]
     703             : inline void set_item(
     704             :     const Tensor& self,
     705             :     const ArrayRef<TensorIndex>& indices,
     706             :     const Tensor& value,
     707             :     bool disable_slice_optimization = false) {
     708             :   at::Device self_device = self.device();
     709             :   SymIntArrayRef self_sizes = self.sym_sizes();
     710             : 
     711             :   // handle simple types: integers, slices, ellipsis, bool
     712             :   if (indices.size() == 1) {
     713             :     const TensorIndex& index = indices[0];
     714             :     if (index.is_boolean() && !index.boolean()) {
     715             :       // do nothing for false (technically we should check the size, but we
     716             :       // don't have real 0-sized shapes.
     717             :       return;
     718             :     } else if (index.is_ellipsis()) {
     719             :       copy_to(self, value);
     720             :       return;
     721             :     } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
     722             :       copy_to(self.unsqueeze(0), value);
     723             :       return;
     724             :     } else if (index.is_integer()) {
     725             :       copy_to(
     726             :           impl::applySelect(
     727             :               self, 0, index.integer(), 0, self_device, self_sizes),
     728             :           value);
     729             :       return;
     730             :     } else if (index.is_slice()) {
     731             :       copy_to(
     732             :           impl::applySlice(
     733             :               self,
     734             :               0,
     735             :               index.slice().start(),
     736             :               index.slice().stop(),
     737             :               index.slice().step(),
     738             :               /*disable_slice_optimization=*/disable_slice_optimization,
     739             :               self_device,
     740             :               self_sizes),
     741             :           value);
     742             :       return;
     743             :     }
     744             :   }
     745             : 
     746             :   std::vector<Tensor> tensorIndices;
     747             :   Tensor sliced = impl::applySlicing(
     748             :       self,
     749             :       indices,
     750             :       tensorIndices,
     751             :       disable_slice_optimization,
     752             :       self_device,
     753             :       self_sizes);
     754             :   if (tensorIndices.empty()) {
     755             :     copy_to(sliced, value);
     756             :     return;
     757             :   }
     758             : 
     759             :   SymIntArrayRef valueSizes = value.sym_sizes();
     760             :   SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
     761             :   Tensor valuesSliced;
     762             :   if (!valueSizes.equals(slicedValueSizes)) {
     763             :     valuesSliced = value.view_symint(slicedValueSizes);
     764             :   } else {
     765             :     valuesSliced = value;
     766             :   }
     767             :   dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
     768             :   return;
     769             : }
     770             : 
     771             : } // namespace at::indexing
     772             : 
     773             : #else
     774             : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
     775             : #endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)

Generated by: LCOV version 1.16