Line data Source code
1 : #pragma once
2 :
3 : #include <ATen/ExpandUtils.h>
4 : #include <ATen/ScalarOps.h>
5 : #include <ATen/core/Tensor.h>
6 : #include <ATen/core/TensorBody.h>
7 : #include <c10/core/SymInt.h>
8 : #include <c10/util/irange.h>
9 : #include <optional>
10 :
11 : #ifndef AT_PER_OPERATOR_HEADERS
12 : #include <ATen/Functions.h>
13 : #include <ATen/NativeFunctions.h>
14 : #else
15 : #include <ATen/ops/alias.h>
16 : #include <ATen/ops/empty.h>
17 : #include <ATen/ops/scalar_tensor.h>
18 : #include <ATen/ops/zeros.h>
19 : #endif
20 :
21 : #include <ATen/core/List.h>
22 :
23 : #include <utility>
24 :
25 : namespace at::indexing {
26 :
27 : constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
28 : constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
29 :
30 : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
31 :
32 : constexpr std::nullopt_t None = std::nullopt;
33 :
34 : struct TORCH_API EllipsisIndexType final {
35 : EllipsisIndexType() = default;
36 : };
37 : TORCH_API extern const EllipsisIndexType Ellipsis;
38 :
39 35 : struct TORCH_API Slice final {
40 : public:
41 110 : Slice(
42 : std::optional<c10::SymInt> start_index = std::nullopt,
43 : std::optional<c10::SymInt> stop_index = std::nullopt,
44 110 : std::optional<c10::SymInt> step_index = std::nullopt) {
45 110 : if (!step_index.has_value()) {
46 220 : step_ = c10::SymInt(1);
47 : } else {
48 0 : step_ = std::move(step_index).value();
49 : }
50 :
51 110 : TORCH_CHECK_VALUE(
52 : step_.sym_ne(0).expect_true(__FILE__, __LINE__),
53 : "slice step cannot be zero");
54 :
55 110 : if (!start_index.has_value()) {
56 330 : start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
57 : } else {
58 0 : start_ = std::move(start_index).value();
59 : }
60 :
61 110 : if (!stop_index.has_value()) {
62 330 : stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
63 : } else {
64 0 : stop_ = std::move(stop_index).value();
65 : }
66 110 : }
67 :
68 : inline c10::SymInt start() const {
69 : return start_;
70 : }
71 :
72 : inline c10::SymInt stop() const {
73 : return stop_;
74 : }
75 :
76 : inline c10::SymInt step() const {
77 : return step_;
78 : }
79 :
80 : private:
81 : c10::SymInt start_;
82 : c10::SymInt stop_;
83 : c10::SymInt step_;
84 : };
85 :
86 : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
87 :
88 : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
89 : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
90 : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
91 : // indexing operations can be performed using the supplied indices.
92 : //
93 : // There is one-to-one correspondence between Python and C++ tensor index types:
94 : // Python | C++
95 : // -----------------------------------------------------
96 : // `None` | `at::indexing::None`
97 : // `Ellipsis` | `at::indexing::Ellipsis`
98 : // `...` | `"..."`
99 : // `123` | `123`
100 : // `True` / `False` | `true` / `false`
101 : // `:` | `Slice()` / `Slice(None, None)`
102 : // `::` | `Slice()` / `Slice(None, None, None)`
103 : // `1:` | `Slice(1, None)`
104 : // `1::` | `Slice(1, None, None)`
105 : // `:3` | `Slice(None, 3)`
106 : // `:3:` | `Slice(None, 3, None)`
107 : // `::2` | `Slice(None, None, 2)`
108 : // `1:3` | `Slice(1, 3)`
109 : // `1::2` | `Slice(1, None, 2)`
110 : // `:3:2` | `Slice(None, 3, 2)`
111 : // `1:3:2` | `Slice(1, 3, 2)`
112 : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
113 : struct TORCH_API TensorIndex final {
114 : // Case 1: `at::indexing::None`
115 : TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
116 :
117 : // Case 2: "..." / `at::indexing::Ellipsis`
118 : TensorIndex(at::indexing::EllipsisIndexType)
119 : : type_(TensorIndexType::Ellipsis) {}
120 : TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
121 : TORCH_CHECK_VALUE(
122 : strcmp(str, "...") == 0,
123 : "Expected \"...\" to represent an ellipsis index, but got \"",
124 : str,
125 : "\"");
126 : }
127 :
128 : // Case 3: (Sym) Integer value
129 35 : TensorIndex(SymInt integer)
130 70 : : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
131 22 : TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
132 48 : TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
133 :
134 : // Case 4: Boolean value
135 : template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
136 : TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
137 :
138 : // Case 5: Slice represented in `at::indexing::Slice` form
139 35 : TensorIndex(Slice slice)
140 35 : : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
141 :
142 : // Case 6: Tensor value
143 : TensorIndex(Tensor tensor)
144 : : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
145 :
146 : inline bool is_none() const {
147 : return type_ == TensorIndexType::None;
148 : }
149 :
150 : inline bool is_ellipsis() const {
151 : return type_ == TensorIndexType::Ellipsis;
152 : }
153 :
154 : inline bool is_integer() const {
155 : return type_ == TensorIndexType::SymInt;
156 : }
157 :
158 : inline SymInt integer() const {
159 : return integer_;
160 : }
161 :
162 : inline bool is_boolean() const {
163 : return type_ == TensorIndexType::Boolean;
164 : }
165 :
166 : inline bool boolean() const {
167 : return boolean_;
168 : }
169 :
170 : inline bool is_slice() const {
171 : return type_ == TensorIndexType::Slice;
172 : }
173 :
174 : inline const Slice& slice() const {
175 : return slice_;
176 : }
177 :
178 : inline bool is_tensor() const {
179 : return type_ == TensorIndexType::Tensor;
180 : }
181 :
182 : inline const Tensor& tensor() const {
183 : return tensor_;
184 : }
185 :
186 : private:
187 : SymInt integer_ = 0;
188 : bool boolean_ = false;
189 : Slice slice_;
190 : Tensor tensor_;
191 : TensorIndexType type_;
192 : };
193 :
194 : TORCH_API std::ostream& operator<<(
195 : std::ostream& stream,
196 : const TensorIndex& tensor_index);
197 : TORCH_API std::ostream& operator<<(
198 : std::ostream& stream,
199 : const std::vector<TensorIndex>& tensor_indices);
200 :
201 : namespace impl {
202 : inline Tensor applySlice(
203 : const Tensor& self,
204 : int64_t dim,
205 : c10::SymInt start,
206 : c10::SymInt stop,
207 : c10::SymInt step,
208 : bool disable_slice_optimization,
209 : const at::Device& self_device,
210 : const std::optional<SymIntArrayRef>& self_sizes) {
211 : // TODO: implement negative step
212 : TORCH_CHECK_VALUE(
213 : step.sym_gt(0).expect_true(__FILE__, __LINE__),
214 : "step must be greater than zero");
215 :
216 : // See NOTE [nested tensor size for indexing]
217 : if (self_sizes.has_value()) {
218 : // Skip this optimization if we are tracing, as the trace may be polymorphic
219 : // over the shape of the `self` tensor, and we still want to record
220 : // the slice.
221 : SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
222 : ? (*self_sizes)[dim]
223 : : self.sym_size(dim);
224 : if (!disable_slice_optimization &&
225 : TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
226 : TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
227 : return self;
228 : }
229 : }
230 : return self.slice_symint(
231 : dim, std::move(start), std::move(stop), std::move(step));
232 : }
233 :
234 : inline Tensor applySelect(
235 : const Tensor& self,
236 : int64_t dim,
237 : SymInt index,
238 : int64_t real_dim,
239 : const at::Device& /*self_device*/,
240 : const std::optional<SymIntArrayRef>& self_sizes) {
241 : // See NOTE [nested tensor size for indexing]
242 : if (self_sizes.has_value()) {
243 : auto maybe_index = index.maybe_as_int();
244 : if (maybe_index.has_value()) {
245 : TORCH_CHECK_INDEX(
246 : !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
247 : "invalid index of a 0-dim tensor. ",
248 : "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
249 : }
250 :
251 : auto size = (*self_sizes)[dim];
252 : // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
253 : // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
254 : // minus is undefined by the standard but in practice is equal to self. On
255 : // the other hand, indexing wraping is valid for all negative int64_t
256 : // values, as x[INT64_MIN] is the same as x[INT64_MAX]
257 : TORCH_CHECK_INDEX(
258 : size.sym_gt(-1 - index)
259 : .sym_and(size.sym_gt(index))
260 : .expect_true(__FILE__, __LINE__),
261 : "index ",
262 : index,
263 : " is out of bounds for dimension ",
264 : real_dim,
265 : " with size ",
266 : size);
267 : }
268 :
269 : // if the index is negative, do not normalize it because that would fix the
270 : // index on the current tensor size in the tracer. aten::select also works on
271 : // negative indices
272 : return self.select_symint(dim, std::move(index));
273 : }
274 :
275 : inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
276 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
277 : // false as empty.
278 : if (value) {
279 : return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
280 : } else {
281 : return at::empty({0}, self.options().dtype(kLong));
282 : }
283 : }
284 :
285 : inline Tensor boolToIndexingTensorNonNativeDeviceType(
286 : const Tensor& self,
287 : bool value) {
288 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
289 : // false as empty.
290 : if (value) {
291 : return at::zeros({1}, self.options().dtype(kLong));
292 : } else {
293 : return at::empty({0}, self.options().dtype(kLong));
294 : }
295 : }
296 :
297 : inline Tensor boolToIndexingTensor(
298 : const Tensor& self,
299 : bool value,
300 : const at::Device& self_device) {
301 : if (self_device == at::kCPU || self_device == at::kCUDA) {
302 : return boolToIndexingTensorCPUOrCUDA(self, value);
303 : } else {
304 : return boolToIndexingTensorNonNativeDeviceType(self, value);
305 : }
306 : }
307 :
308 : inline Tensor scalarToTensorNonNativeDeviceType(
309 : const Scalar& v,
310 : const TensorOptions& options) {
311 : return at::scalar_tensor(v, options);
312 : }
313 :
314 : inline void recordTensorIndex(
315 : const Tensor& tensor,
316 : std::vector<Tensor>& outIndices,
317 : int64_t* dim_ptr) {
318 : // TODO: check scalarType
319 : outIndices.resize(*dim_ptr + 1);
320 : outIndices[*dim_ptr] = tensor;
321 : (*dim_ptr)++;
322 : }
323 :
324 : inline c10::List<::std::optional<Tensor>> typeConvertIndices(
325 : const Tensor& /*self*/,
326 : std::vector<Tensor>&& indices) {
327 : c10::List<::std::optional<Tensor>> converted_inds;
328 : converted_inds.reserve(indices.size());
329 : for (auto&& i : std::move(indices)) {
330 : converted_inds.push_back(std::move(i));
331 : }
332 : return converted_inds;
333 : }
334 :
335 : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
336 : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
337 : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
338 : // indexing (i.e. it's called by `applySlicing` which is called by
339 : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
340 : // than one dimension). If we were to merge the Python/C++
341 : // `count_specified_dimensions` function, on the Python side we would have to
342 : // construct a `std::vector` container to be consumed by the C++
343 : // `count_specified_dimensions` function, which adds 100s of nanoseconds
344 : // overhead and is undesirable.
345 : inline int64_t count_specified_dimensions(
346 : const ArrayRef<TensorIndex>& indices) {
347 : // Count the number of indexed dimensions (everything but ellipsis and None)
348 : int64_t count = 0;
349 : for (auto& obj : indices) {
350 : if (obj.is_tensor()) {
351 : auto& tensor = obj.tensor();
352 : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
353 : count += tensor.dim();
354 : } else {
355 : count++;
356 : }
357 : } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
358 : count++;
359 : }
360 : }
361 : return count;
362 : }
363 : } // namespace impl
364 :
365 : // NOTE: Many functions below are only for consumption from Python indexing
366 : // implementation, they include:
367 : //
368 : // - `Tensor scalarToTensor(...)`
369 : // - `IntArrayRef slicePrefix1sSize(...)`
370 : // - `void copy_to(...)`
371 : // - `Tensor handleDimInMultiDimIndexing(...)`
372 : // - `Tensor dispatch_index(...)`
373 : // - `Tensor dispatch_index_put_(...)`
374 : // - `Tensor get_item(...)`
375 : // - `void set_item(...)`
376 : //
377 : // The rest of the functions are in `at::indexing::impl` namespace, signifying
378 : // that they shouldn't be used from Python indexing implementation.
379 : inline Tensor scalarToTensor(
380 : const Scalar& v,
381 : const TensorOptions& options,
382 : const at::Device& self_device) {
383 : if (self_device == at::kCPU && !v.isSymbolic()) {
384 : return at::detail::scalar_tensor_static(
385 : v,
386 : // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
387 : options.dtype_opt()->toScalarType(),
388 : self_device);
389 : } else {
390 : return impl::scalarToTensorNonNativeDeviceType(v, options);
391 : }
392 : }
393 :
394 : // To match numpy semantics:
395 : // As a special case for backwards compatibility,
396 : // strip away unit dimensions from the left of 'src'
397 : inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
398 : size_t first_non1_src = sizes.size();
399 : for (const auto i : c10::irange(sizes.size())) {
400 : // Unbacked SymInt has different behavior, but this is sound because
401 : // failing to slice will only ever cause an error, not divergent
402 : // behavior
403 : if (!sizes[i].has_hint() || sizes[i] != 1) {
404 : first_non1_src = i;
405 : break;
406 : }
407 : }
408 :
409 : return sizes.slice(first_non1_src);
410 : }
411 :
412 : inline void copy_to(const Tensor& dst, const Tensor& src) {
413 : if (dst.sym_sizes().equals(src.sym_sizes())) {
414 : // A shortcut to avoid generating hard-coded constant sizes during tracing.
415 : // This is not a perfect solution: when src & dst have different shapes,
416 : // constants will still appear. Users can workaround that case by
417 : // dst[index..] = src.reshape(..)
418 : dst.copy_(src);
419 : return;
420 : } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
421 : dst.fill_(src);
422 : return;
423 : }
424 : auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
425 : c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
426 : dst.copy_(*b_src);
427 : }
428 :
429 : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
430 : // indexing functions from Python ]
431 : inline Tensor handleDimInMultiDimIndexing(
432 : const Tensor& prev_dim_result,
433 : const Tensor& original_tensor,
434 : const TensorIndex& index,
435 : int64_t* dim_ptr,
436 : int64_t* specified_dims_ptr,
437 : int64_t real_dim,
438 : std::vector<Tensor>& outIndices,
439 : bool disable_slice_optimization,
440 : const at::Device& original_tensor_device,
441 : const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
442 : if (index.is_integer()) {
443 : return impl::applySelect(
444 : prev_dim_result,
445 : *dim_ptr,
446 : index.integer(),
447 : real_dim,
448 : original_tensor_device,
449 : prev_dim_result_sizes);
450 : } else if (index.is_slice()) {
451 : Tensor result = impl::applySlice(
452 : prev_dim_result,
453 : *dim_ptr,
454 : index.slice().start(),
455 : index.slice().stop(),
456 : index.slice().step(),
457 : /*disable_slice_optimization=*/disable_slice_optimization,
458 : original_tensor_device,
459 : prev_dim_result_sizes);
460 : (*dim_ptr)++;
461 : return result;
462 : } else if (index.is_ellipsis()) {
463 : (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
464 : return prev_dim_result;
465 : } else if (index.is_none()) {
466 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
467 : (*dim_ptr)++;
468 : return result;
469 : } else if (index.is_boolean()) {
470 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
471 : impl::recordTensorIndex(
472 : impl::boolToIndexingTensor(
473 : result, index.boolean(), original_tensor_device),
474 : outIndices,
475 : dim_ptr);
476 : return result;
477 : } else if (index.is_tensor()) {
478 : Tensor result = prev_dim_result;
479 : const Tensor& tensor = index.tensor();
480 : auto scalar_type = tensor.scalar_type();
481 : if (tensor.dim() == 0 &&
482 : at::isIntegralType(scalar_type, /*includeBool=*/true)) {
483 : if (scalar_type != at::kByte && scalar_type != at::kBool) {
484 : result = impl::applySelect(
485 : result,
486 : *dim_ptr,
487 : tensor.item<int64_t>(),
488 : real_dim,
489 : original_tensor_device,
490 : prev_dim_result_sizes);
491 : } else {
492 : result = result.unsqueeze(*dim_ptr);
493 : if (scalar_type == at::kBool) {
494 : impl::recordTensorIndex(
495 : impl::boolToIndexingTensor(
496 : result, tensor.item<bool>() != 0, original_tensor_device),
497 : outIndices,
498 : dim_ptr);
499 : } else {
500 : impl::recordTensorIndex(
501 : impl::boolToIndexingTensor(
502 : result, tensor.item<uint8_t>() != 0, original_tensor_device),
503 : outIndices,
504 : dim_ptr);
505 : }
506 : }
507 : } else {
508 : impl::recordTensorIndex(tensor, outIndices, dim_ptr);
509 : }
510 : return result;
511 : } else {
512 : TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
513 : }
514 : }
515 :
516 : namespace impl {
517 : // This mirrors `applySlicing` in
518 : // torch/csrc/autograd/python_variable_indexing.cpp
519 : inline Tensor applySlicing(
520 : const Tensor& self,
521 : const ArrayRef<TensorIndex>& indices,
522 : std::vector<Tensor>& outIndices,
523 : bool disable_slice_optimization,
524 : const at::Device& self_device,
525 : const std::optional<SymIntArrayRef>& self_sizes) {
526 : int64_t dim = 0;
527 : int64_t specified_dims = impl::count_specified_dimensions(indices);
528 :
529 : // See NOTE [nested tensor size for indexing]
530 : if (self_sizes.has_value()) {
531 : TORCH_CHECK_INDEX(
532 : specified_dims <= (int64_t)self_sizes->size(),
533 : "too many indices for tensor of dimension ",
534 : (int)self_sizes->size());
535 : }
536 :
537 : Tensor result = self;
538 : for (const auto i : c10::irange(indices.size())) {
539 : auto& obj = indices[i];
540 : // See NOTE [nested tensor size for indexing]
541 : std::optional<SymIntArrayRef> result_sizes = result.is_nested()
542 : ? std::optional<SymIntArrayRef>(std::nullopt)
543 : : std::optional<SymIntArrayRef>(result.sym_sizes());
544 : result = handleDimInMultiDimIndexing(
545 : /*prev_dim_result=*/result,
546 : /*original_tensor=*/self,
547 : /*index=*/obj,
548 : /*dim_ptr=*/&dim,
549 : /*specified_dims_ptr=*/&specified_dims,
550 : /*real_dim=*/static_cast<int64_t>(i),
551 : /*outIndices=*/outIndices,
552 : /*disable_slice_optimization=*/disable_slice_optimization,
553 : /*original_tensor_device=*/self_device,
554 : /*prev_dim_result_sizes=*/result_sizes);
555 : }
556 : return result;
557 : }
558 : } // namespace impl
559 :
560 : inline Tensor dispatch_index(
561 : const Tensor& self,
562 : std::vector<Tensor>&& indices) {
563 : return self.index(impl::typeConvertIndices(self, std::move(indices)));
564 : }
565 :
566 : inline Tensor dispatch_index_put_(
567 : Tensor& self,
568 : std::vector<Tensor>&& indices,
569 : const Tensor& value) {
570 : return self.index_put_(
571 : impl::typeConvertIndices(self, std::move(indices)), value);
572 : }
573 :
574 : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
575 : // functions from Python ]
576 : //
577 : // Question: When should we set `disable_slice_optimization` to `true` when
578 : // calling C++ tensor indexing functions from Python indexing code?
579 : //
580 : // Answer: What "slice optimization" means: when we have a slicing expression
581 : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
582 : // would skip dispatching the actual slice call as an optimization. However,
583 : // here are the cases where we DON'T want this optimization:
584 : //
585 : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
586 : // Reason: we always return a shallow copy for expressions such as
587 : // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
588 : // :]`, we return an alias of `tensor` by doing the following:
589 : // ```
590 : // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
591 : // disable_slice_optimization, self_device, self_sizes); if
592 : // (tensorIndices.empty()) {
593 : // if (sliced.is_same(self)) {
594 : // // ensure we return a shallow copy for things like x[...]
595 : // sliced = at::alias(sliced);
596 : // }
597 : // return sliced;
598 : // }
599 : // ```)
600 : // 2. When we are doing JIT tracing.
601 : // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
602 : // slice operation.
603 :
604 : // This mirrors `THPVariable_getitem` in
605 : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
606 : // `disable_slice_optimization` when calling C++ tensor indexing functions from
607 : // Python ]
608 : inline Tensor get_item(
609 : const Tensor& self,
610 : const ArrayRef<TensorIndex>& indices,
611 : bool disable_slice_optimization = false) {
612 : at::Device self_device = self.device();
613 : // NOTE [nested tensor size for indexing]
614 : // nested tensor does not have a size (yet) so for now we represent its size
615 : // as null may need to be changed after we reach a better solution for nested
616 : // tensor size
617 : std::optional<SymIntArrayRef> self_sizes = self.is_nested()
618 : ? std::optional<SymIntArrayRef>(std::nullopt)
619 : : std::optional<SymIntArrayRef>(self.sym_sizes());
620 :
621 : // handle simple types: integers, slices, none, ellipsis, bool
622 : if (indices.size() == 1) {
623 : const TensorIndex& index = indices[0];
624 : if (index.is_integer()) {
625 : return impl::applySelect(
626 : self, 0, index.integer(), 0, self_device, self_sizes);
627 : } else if (index.is_slice()) {
628 : return impl::applySlice(
629 : self,
630 : 0,
631 : index.slice().start(),
632 : index.slice().stop(),
633 : index.slice().step(),
634 : /*disable_slice_optimization=*/true,
635 : self_device,
636 : self_sizes);
637 : } else if (index.is_none()) {
638 : return self.unsqueeze(0);
639 : } else if (index.is_ellipsis()) {
640 : return at::alias(self);
641 : } else if (index.is_boolean()) {
642 : Tensor result = self.unsqueeze(0);
643 : return dispatch_index(
644 : result,
645 : std::vector<Tensor>{impl::boolToIndexingTensor(
646 : result, index.boolean(), self_device)});
647 : }
648 : }
649 :
650 : std::vector<Tensor> tensorIndices;
651 : Tensor sliced = impl::applySlicing(
652 : self,
653 : indices,
654 : tensorIndices,
655 : disable_slice_optimization,
656 : self_device,
657 : self_sizes);
658 : if (tensorIndices.empty()) {
659 : if (sliced.is_same(self)) {
660 : // ensure we return a shallow copy for things like x[...]
661 : sliced = at::alias(sliced);
662 : }
663 : return sliced;
664 : }
665 :
666 : // indexing by tensors ("advanced" indexing)
667 : return dispatch_index(sliced, std::move(tensorIndices));
668 : }
669 :
670 : // This mirrors `THPVariable_setitem` in
671 : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
672 : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
673 : // tensor indexing functions from Python ]
674 : inline void set_item(
675 : const Tensor& self,
676 : const ArrayRef<TensorIndex>& indices,
677 : const Tensor& value,
678 : bool disable_slice_optimization = false) {
679 : at::Device self_device = self.device();
680 : SymIntArrayRef self_sizes = self.sym_sizes();
681 :
682 : // handle simple types: integers, slices, ellipsis, bool
683 : if (indices.size() == 1) {
684 : const TensorIndex& index = indices[0];
685 : if (index.is_boolean() && !index.boolean()) {
686 : // do nothing for false (technically we should check the size, but we
687 : // don't have real 0-sized shapes.
688 : return;
689 : } else if (index.is_ellipsis()) {
690 : copy_to(self, value);
691 : return;
692 : } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
693 : copy_to(self.unsqueeze(0), value);
694 : return;
695 : } else if (index.is_integer()) {
696 : copy_to(
697 : impl::applySelect(
698 : self, 0, index.integer(), 0, self_device, self_sizes),
699 : value);
700 : return;
701 : } else if (index.is_slice()) {
702 : copy_to(
703 : impl::applySlice(
704 : self,
705 : 0,
706 : index.slice().start(),
707 : index.slice().stop(),
708 : index.slice().step(),
709 : /*disable_slice_optimization=*/disable_slice_optimization,
710 : self_device,
711 : self_sizes),
712 : value);
713 : return;
714 : }
715 : }
716 :
717 : std::vector<Tensor> tensorIndices;
718 : Tensor sliced = impl::applySlicing(
719 : self,
720 : indices,
721 : tensorIndices,
722 : disable_slice_optimization,
723 : self_device,
724 : self_sizes);
725 : if (tensorIndices.empty()) {
726 : copy_to(sliced, value);
727 : return;
728 : }
729 :
730 : SymIntArrayRef valueSizes = value.sym_sizes();
731 : SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
732 : Tensor valuesSliced;
733 : if (!valueSizes.equals(slicedValueSizes)) {
734 : valuesSliced = value.view_symint(slicedValueSizes);
735 : } else {
736 : valuesSliced = value;
737 : }
738 : dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
739 : return;
740 : }
741 :
742 : } // namespace at::indexing
|