Line data Source code
1 : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2 : #pragma once
3 :
4 : #include <ATen/ExpandUtils.h>
5 : #include <ATen/ScalarOps.h>
6 : #include <ATen/core/Tensor.h>
7 : #include <ATen/core/TensorBody.h>
8 : #include <c10/core/SymInt.h>
9 : #include <c10/util/irange.h>
10 : #include <optional>
11 :
12 : #ifndef AT_PER_OPERATOR_HEADERS
13 : #include <ATen/Functions.h>
14 : #include <ATen/NativeFunctions.h>
15 : #else
16 : #include <ATen/ops/alias.h>
17 : #include <ATen/ops/empty.h>
18 : #include <ATen/ops/scalar_tensor.h>
19 : #include <ATen/ops/zeros.h>
20 : #endif
21 :
22 : #include <ATen/core/List.h>
23 :
24 : #include <utility>
25 :
26 : namespace at::indexing {
27 :
28 : constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
29 : constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
30 :
31 : enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
32 :
33 : constexpr std::nullopt_t None = std::nullopt;
34 :
35 : struct TORCH_API EllipsisIndexType final {
36 : EllipsisIndexType() = default;
37 : };
38 : TORCH_API extern const EllipsisIndexType Ellipsis;
39 :
40 38 : struct TORCH_API Slice final {
41 : public:
42 116 : Slice(
43 : std::optional<c10::SymInt> start_index = std::nullopt,
44 : std::optional<c10::SymInt> stop_index = std::nullopt,
45 116 : std::optional<c10::SymInt> step_index = std::nullopt) {
46 116 : if (!step_index.has_value()) {
47 232 : step_ = c10::SymInt(1);
48 : } else {
49 0 : step_ = std::move(step_index).value();
50 : }
51 :
52 116 : TORCH_CHECK_VALUE(
53 : step_.sym_ne(0).expect_true(__FILE__, __LINE__),
54 : "slice step cannot be zero");
55 :
56 116 : if (!start_index.has_value()) {
57 348 : start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
58 : } else {
59 0 : start_ = std::move(start_index).value();
60 : }
61 :
62 116 : if (!stop_index.has_value()) {
63 348 : stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
64 : } else {
65 0 : stop_ = std::move(stop_index).value();
66 : }
67 116 : }
68 :
69 : inline c10::SymInt start() const {
70 : return start_;
71 : }
72 :
73 : inline c10::SymInt stop() const {
74 : return stop_;
75 : }
76 :
77 : inline c10::SymInt step() const {
78 : return step_;
79 : }
80 :
81 : private:
82 : c10::SymInt start_;
83 : c10::SymInt stop_;
84 : c10::SymInt step_;
85 : };
86 :
87 : TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
88 :
89 : // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
90 : // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
91 : // into its equivalent `std::vector<TensorIndex>`, so that further tensor
92 : // indexing operations can be performed using the supplied indices.
93 : //
94 : // There is one-to-one correspondence between Python and C++ tensor index types:
95 : // Python | C++
96 : // -----------------------------------------------------
97 : // `None` | `at::indexing::None`
98 : // `Ellipsis` | `at::indexing::Ellipsis`
99 : // `...` | `"..."`
100 : // `123` | `123`
101 : // `True` / `False` | `true` / `false`
102 : // `:` | `Slice()` / `Slice(None, None)`
103 : // `::` | `Slice()` / `Slice(None, None, None)`
104 : // `1:` | `Slice(1, None)`
105 : // `1::` | `Slice(1, None, None)`
106 : // `:3` | `Slice(None, 3)`
107 : // `:3:` | `Slice(None, 3, None)`
108 : // `::2` | `Slice(None, None, 2)`
109 : // `1:3` | `Slice(1, 3)`
110 : // `1::2` | `Slice(1, None, 2)`
111 : // `:3:2` | `Slice(None, 3, 2)`
112 : // `1:3:2` | `Slice(1, 3, 2)`
113 : // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
114 : struct TORCH_API TensorIndex final {
115 : // Case 1: `at::indexing::None`
116 : TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {}
117 :
118 : // Case 2: "..." / `at::indexing::Ellipsis`
119 : TensorIndex(at::indexing::EllipsisIndexType /*unused*/)
120 : : type_(TensorIndexType::Ellipsis) {}
121 : TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
122 : TORCH_CHECK_VALUE(
123 : strcmp(str, "...") == 0,
124 : "Expected \"...\" to represent an ellipsis index, but got \"",
125 : str,
126 : "\"");
127 : }
128 :
129 : // Case 3: (Sym) Integer value
130 38 : TensorIndex(SymInt integer)
131 76 : : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
132 22 : TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
133 54 : TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
134 :
135 : // Case 4: Boolean value
136 : template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
137 : TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
138 :
139 : // Case 5: Slice represented in `at::indexing::Slice` form
140 38 : TensorIndex(Slice slice)
141 38 : : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
142 :
143 : // Case 6: Tensor value
144 : TensorIndex(Tensor tensor)
145 : : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
146 :
147 : inline bool is_none() const {
148 : return type_ == TensorIndexType::None;
149 : }
150 :
151 : inline bool is_ellipsis() const {
152 : return type_ == TensorIndexType::Ellipsis;
153 : }
154 :
155 : inline bool is_integer() const {
156 : return type_ == TensorIndexType::SymInt;
157 : }
158 :
159 : inline SymInt integer() const {
160 : return integer_;
161 : }
162 :
163 : inline bool is_boolean() const {
164 : return type_ == TensorIndexType::Boolean;
165 : }
166 :
167 : inline bool boolean() const {
168 : return boolean_;
169 : }
170 :
171 : inline bool is_slice() const {
172 : return type_ == TensorIndexType::Slice;
173 : }
174 :
175 : inline const Slice& slice() const {
176 : return slice_;
177 : }
178 :
179 : inline bool is_tensor() const {
180 : return type_ == TensorIndexType::Tensor;
181 : }
182 :
183 : inline const Tensor& tensor() const {
184 : return tensor_;
185 : }
186 :
187 : private:
188 : SymInt integer_ = 0;
189 : bool boolean_ = false;
190 : Slice slice_;
191 : Tensor tensor_;
192 : TensorIndexType type_;
193 : };
194 :
195 : TORCH_API std::ostream& operator<<(
196 : std::ostream& stream,
197 : const TensorIndex& tensor_index);
198 : TORCH_API std::ostream& operator<<(
199 : std::ostream& stream,
200 : const std::vector<TensorIndex>& tensor_indices);
201 :
202 : namespace impl {
203 : inline Tensor applySlice(
204 : const Tensor& self,
205 : int64_t dim,
206 : c10::SymInt start,
207 : c10::SymInt stop,
208 : c10::SymInt step,
209 : bool disable_slice_optimization,
210 : const at::Device& self_device,
211 : const std::optional<SymIntArrayRef>& self_sizes) {
212 : // TODO: implement negative step
213 : TORCH_CHECK_VALUE(
214 : step.sym_gt(0).expect_true(__FILE__, __LINE__),
215 : "step must be greater than zero");
216 :
217 : // See NOTE [nested tensor size for indexing]
218 : if (self_sizes.has_value() && !self_sizes.value().empty()) {
219 : // Skip this optimization if we are tracing, as the trace may be polymorphic
220 : // over the shape of the `self` tensor, and we still want to record
221 : // the slice.
222 : SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
223 : ? (*self_sizes)[dim]
224 : : self.sym_size(dim);
225 : if (!disable_slice_optimization &&
226 : TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
227 : TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) {
228 : return self;
229 : }
230 : }
231 : return self.slice_symint(
232 : dim, std::move(start), std::move(stop), std::move(step));
233 : }
234 :
235 : inline Tensor applySelect(
236 : const Tensor& self,
237 : int64_t dim,
238 : SymInt index,
239 : int64_t real_dim,
240 : const at::Device& /*self_device*/,
241 : const std::optional<SymIntArrayRef>& self_sizes) {
242 : // See NOTE [nested tensor size for indexing]
243 : if (self_sizes.has_value()) {
244 : auto maybe_index = index.maybe_as_int();
245 : if (maybe_index.has_value()) {
246 : TORCH_CHECK_INDEX(
247 : !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
248 : "invalid index of a 0-dim tensor. ",
249 : "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
250 : }
251 :
252 : auto size = (*self_sizes)[dim];
253 : // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
254 : // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
255 : // minus is undefined by the standard but in practice is equal to self. On
256 : // the other hand, indexing wrapping is valid for all negative int64_t
257 : // values, as x[INT64_MIN] is the same as x[INT64_MAX]
258 : TORCH_CHECK_INDEX(
259 : size.sym_gt(-1 - index)
260 : .sym_and(size.sym_gt(index))
261 : .expect_true(__FILE__, __LINE__),
262 : "index ",
263 : index,
264 : " is out of bounds for dimension ",
265 : real_dim,
266 : " with size ",
267 : size);
268 : }
269 :
270 : // if the index is negative, do not normalize it because that would fix the
271 : // index on the current tensor size in the tracer. aten::select also works on
272 : // negative indices
273 : return self.select_symint(dim, std::move(index));
274 : }
275 :
276 : inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
277 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
278 : // false as empty.
279 : if (value) {
280 : return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
281 : } else {
282 : return at::empty({0}, self.options().dtype(kLong));
283 : }
284 : }
285 :
286 : inline Tensor boolToIndexingTensorNonNativeDeviceType(
287 : const Tensor& self,
288 : bool value) {
289 : // booleans add a dimension of size 1. true indexes this dimension as if 0:,
290 : // false as empty.
291 : if (value) {
292 : return at::zeros({1}, self.options().dtype(kLong));
293 : } else {
294 : return at::empty({0}, self.options().dtype(kLong));
295 : }
296 : }
297 :
298 : inline Tensor boolToIndexingTensor(
299 : const Tensor& self,
300 : bool value,
301 : const at::Device& self_device) {
302 : if (self_device == at::kCPU || self_device == at::kCUDA) {
303 : return boolToIndexingTensorCPUOrCUDA(self, value);
304 : } else {
305 : return boolToIndexingTensorNonNativeDeviceType(self, value);
306 : }
307 : }
308 :
309 : inline Tensor scalarToTensorNonNativeDeviceType(
310 : const Scalar& v,
311 : const TensorOptions& options) {
312 : return at::scalar_tensor(v, options);
313 : }
314 :
315 : inline void recordTensorIndex(
316 : const Tensor& tensor,
317 : std::vector<Tensor>& outIndices,
318 : int64_t* dim_ptr) {
319 : if (outIndices.empty()) {
320 : outIndices.resize(*dim_ptr + 1);
321 : outIndices[*dim_ptr] = tensor;
322 : } else {
323 : outIndices.push_back(tensor);
324 : }
325 : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
326 : *dim_ptr += tensor.dim();
327 : } else {
328 : *dim_ptr += 1;
329 : }
330 : }
331 :
332 : inline c10::List<::std::optional<Tensor>> typeConvertIndices(
333 : const Tensor& /*self*/,
334 : std::vector<Tensor>&& indices) {
335 : c10::List<::std::optional<Tensor>> converted_inds;
336 : converted_inds.reserve(indices.size());
337 : for (auto&& i : std::move(indices)) {
338 : converted_inds.push_back(std::move(i));
339 : }
340 : return converted_inds;
341 : }
342 :
343 : // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
344 : // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
345 : // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
346 : // indexing (i.e. it's called by `applySlicing` which is called by
347 : // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
348 : // than one dimension). If we were to merge the Python/C++
349 : // `count_specified_dimensions` function, on the Python side we would have to
350 : // construct a `std::vector` container to be consumed by the C++
351 : // `count_specified_dimensions` function, which adds 100s of nanoseconds
352 : // overhead and is undesirable.
353 : inline int64_t count_specified_dimensions(
354 : const ArrayRef<TensorIndex>& indices) {
355 : // Count the number of indexed dimensions (everything but ellipsis and None)
356 : int64_t count = 0;
357 : for (auto& obj : indices) {
358 : if (obj.is_tensor()) {
359 : auto& tensor = obj.tensor();
360 : if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
361 : count += tensor.dim();
362 : } else {
363 : count++;
364 : }
365 : } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
366 : count++;
367 : }
368 : }
369 : return count;
370 : }
371 : } // namespace impl
372 :
373 : // NOTE: Many functions below are only for consumption from Python indexing
374 : // implementation, they include:
375 : //
376 : // - `Tensor scalarToTensor(...)`
377 : // - `IntArrayRef slicePrefix1sSize(...)`
378 : // - `void copy_to(...)`
379 : // - `Tensor handleDimInMultiDimIndexing(...)`
380 : // - `Tensor dispatch_index(...)`
381 : // - `Tensor dispatch_index_put_(...)`
382 : // - `Tensor get_item(...)`
383 : // - `void set_item(...)`
384 : //
385 : // The rest of the functions are in `at::indexing::impl` namespace, signifying
386 : // that they shouldn't be used from Python indexing implementation.
387 : inline Tensor scalarToTensor(
388 : const Scalar& v,
389 : const TensorOptions& options,
390 : const at::Device& self_device) {
391 : if (self_device == at::kCPU && !v.isSymbolic()) {
392 : return at::detail::scalar_tensor_static(
393 : v,
394 : // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
395 : options.dtype_opt()->toScalarType(),
396 : self_device);
397 : } else {
398 : return impl::scalarToTensorNonNativeDeviceType(v, options);
399 : }
400 : }
401 :
402 : // To match numpy semantics:
403 : // As a special case for backwards compatibility,
404 : // strip away unit dimensions from the left of 'src'
405 : inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
406 : size_t first_non1_src = sizes.size();
407 : for (const auto i : c10::irange(sizes.size())) {
408 : // Unbacked SymInt has different behavior, but this is sound because
409 : // failing to slice will only ever cause an error, not divergent
410 : // behavior
411 : if (!sizes[i].has_hint() || sizes[i] != 1) {
412 : first_non1_src = i;
413 : break;
414 : }
415 : }
416 :
417 : return sizes.slice(first_non1_src);
418 : }
419 :
420 : inline void copy_to(const Tensor& dst, const Tensor& src) {
421 : if (dst.sym_sizes().equals(src.sym_sizes())) {
422 : // A shortcut to avoid generating hard-coded constant sizes during tracing.
423 : // This is not a perfect solution: when src & dst have different shapes,
424 : // constants will still appear. Users can workaround that case by
425 : // dst[index..] = src.reshape(..)
426 : dst.copy_(src);
427 : return;
428 : } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
429 : dst.fill_(src);
430 : return;
431 : }
432 : auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
433 : c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
434 : dst.copy_(*b_src);
435 : }
436 :
437 : // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
438 : // indexing functions from Python ]
439 : inline Tensor handleDimInMultiDimIndexing(
440 : const Tensor& prev_dim_result,
441 : const Tensor& original_tensor,
442 : const TensorIndex& index,
443 : int64_t* dim_ptr,
444 : int64_t* specified_dims_ptr,
445 : int64_t real_dim,
446 : std::vector<Tensor>& outIndices,
447 : bool disable_slice_optimization,
448 : const at::Device& original_tensor_device,
449 : const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
450 : if (index.is_integer()) {
451 : return impl::applySelect(
452 : prev_dim_result,
453 : *dim_ptr,
454 : index.integer(),
455 : real_dim,
456 : original_tensor_device,
457 : prev_dim_result_sizes);
458 : } else if (index.is_slice()) {
459 : Tensor result = impl::applySlice(
460 : prev_dim_result,
461 : *dim_ptr,
462 : index.slice().start(),
463 : index.slice().stop(),
464 : index.slice().step(),
465 : /*disable_slice_optimization=*/disable_slice_optimization,
466 : original_tensor_device,
467 : prev_dim_result_sizes);
468 : (*dim_ptr)++;
469 : if (!outIndices.empty()) {
470 : outIndices.resize(outIndices.size() + 1);
471 : }
472 : return result;
473 : } else if (index.is_ellipsis()) {
474 : auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr;
475 : (*dim_ptr) += ellipsis_ndims;
476 : if (!outIndices.empty()) {
477 : outIndices.resize(outIndices.size() + ellipsis_ndims);
478 : }
479 : return prev_dim_result;
480 : } else if (index.is_none()) {
481 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
482 : (*dim_ptr)++;
483 : if (!outIndices.empty()) {
484 : outIndices.resize(outIndices.size() + 1);
485 : }
486 : return result;
487 : } else if (index.is_boolean()) {
488 : Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
489 : impl::recordTensorIndex(
490 : impl::boolToIndexingTensor(
491 : result, index.boolean(), original_tensor_device),
492 : outIndices,
493 : dim_ptr);
494 : return result;
495 : } else if (index.is_tensor()) {
496 : Tensor result = prev_dim_result;
497 : const Tensor& tensor = index.tensor();
498 : auto scalar_type = tensor.scalar_type();
499 : bool is_batched = tensor.key_set().has_any(c10::DispatchKeySet(
500 : {c10::DispatchKey::FuncTorchBatched,
501 : c10::DispatchKey::BatchedNestedTensor}));
502 : if (tensor.dim() == 0 && !is_batched &&
503 : at::isIntegralType(scalar_type, /*includeBool=*/true)) {
504 : if (scalar_type != at::kByte && scalar_type != at::kBool) {
505 : result = impl::applySelect(
506 : result,
507 : *dim_ptr,
508 : tensor.item<int64_t>(),
509 : real_dim,
510 : original_tensor_device,
511 : prev_dim_result_sizes);
512 : } else {
513 : result = result.unsqueeze(*dim_ptr);
514 : if (scalar_type == at::kBool) {
515 : impl::recordTensorIndex(
516 : impl::boolToIndexingTensor(
517 : result, tensor.item<bool>() != 0, original_tensor_device),
518 : outIndices,
519 : dim_ptr);
520 : } else {
521 : impl::recordTensorIndex(
522 : impl::boolToIndexingTensor(
523 : result, tensor.item<uint8_t>() != 0, original_tensor_device),
524 : outIndices,
525 : dim_ptr);
526 : }
527 : }
528 : } else {
529 : impl::recordTensorIndex(tensor, outIndices, dim_ptr);
530 : }
531 : return result;
532 : } else {
533 : TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
534 : }
535 : }
536 :
537 : namespace impl {
538 : // This mirrors `applySlicing` in
539 : // torch/csrc/autograd/python_variable_indexing.cpp
540 : inline Tensor applySlicing(
541 : const Tensor& self,
542 : const ArrayRef<TensorIndex>& indices,
543 : std::vector<Tensor>& outIndices,
544 : bool disable_slice_optimization,
545 : const at::Device& self_device,
546 : const std::optional<SymIntArrayRef>& self_sizes) {
547 : int64_t dim = 0;
548 : int64_t specified_dims = impl::count_specified_dimensions(indices);
549 :
550 : // See NOTE [nested tensor size for indexing]
551 : if (self_sizes.has_value()) {
552 : TORCH_CHECK_INDEX(
553 : specified_dims <= (int64_t)self_sizes->size(),
554 : "too many indices for tensor of dimension ",
555 : (int)self_sizes->size());
556 : }
557 :
558 : Tensor result = self;
559 : for (const auto i : c10::irange(indices.size())) {
560 : auto& obj = indices[i];
561 : // See NOTE [nested tensor size for indexing]
562 : std::optional<SymIntArrayRef> result_sizes = result.is_nested()
563 : ? std::optional<SymIntArrayRef>(std::nullopt)
564 : : std::optional<SymIntArrayRef>(result.sym_sizes());
565 : result = handleDimInMultiDimIndexing(
566 : /*prev_dim_result=*/result,
567 : /*original_tensor=*/self,
568 : /*index=*/obj,
569 : /*dim_ptr=*/&dim,
570 : /*specified_dims_ptr=*/&specified_dims,
571 : /*real_dim=*/static_cast<int64_t>(i),
572 : /*outIndices=*/outIndices,
573 : /*disable_slice_optimization=*/disable_slice_optimization,
574 : /*original_tensor_device=*/self_device,
575 : /*prev_dim_result_sizes=*/result_sizes);
576 : }
577 : return result;
578 : }
579 : } // namespace impl
580 :
581 : inline Tensor dispatch_index(
582 : const Tensor& self,
583 : std::vector<Tensor>&& indices) {
584 : // Remove trailing null elements from indices
585 : while (!indices.empty() && !indices.back().defined()) {
586 : indices.pop_back();
587 : }
588 : return self.index(impl::typeConvertIndices(self, std::move(indices)));
589 : }
590 :
591 : inline Tensor dispatch_index_put_(
592 : Tensor& self,
593 : std::vector<Tensor>&& indices,
594 : const Tensor& value) {
595 : // Remove trailing null elements from indices
596 : while (!indices.empty() && !indices.back().defined()) {
597 : indices.pop_back();
598 : }
599 : return self.index_put_(
600 : impl::typeConvertIndices(self, std::move(indices)), value);
601 : }
602 :
603 : // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
604 : // functions from Python ]
605 : //
606 : // Question: When should we set `disable_slice_optimization` to `true` when
607 : // calling C++ tensor indexing functions from Python indexing code?
608 : //
609 : // Answer: What "slice optimization" means: when we have a slicing expression
610 : // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
611 : // would skip dispatching the actual slice call as an optimization. However,
612 : // here are the cases where we DON'T want this optimization:
613 : //
614 : // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
615 : // Reason: we always return a shallow copy for expressions such as
616 : // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
617 : // :]`, we return an alias of `tensor` by doing the following:
618 : // ```
619 : // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
620 : // disable_slice_optimization, self_device, self_sizes); if
621 : // (tensorIndices.empty()) {
622 : // if (sliced.is_same(self)) {
623 : // // ensure we return a shallow copy for things like x[...]
624 : // sliced = at::alias(sliced);
625 : // }
626 : // return sliced;
627 : // }
628 : // ```)
629 : // 2. When we are doing JIT tracing.
630 : // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
631 : // slice operation.
632 :
633 : // This mirrors `THPVariable_getitem` in
634 : // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
635 : // `disable_slice_optimization` when calling C++ tensor indexing functions from
636 : // Python ]
637 : inline Tensor get_item(
638 : const Tensor& self,
639 : const ArrayRef<TensorIndex>& indices,
640 : bool disable_slice_optimization = false) {
641 : at::Device self_device = self.device();
642 : // NOTE [nested tensor size for indexing]
643 : // nested tensor does not have a size (yet) so for now we represent its size
644 : // as null may need to be changed after we reach a better solution for nested
645 : // tensor size
646 : std::optional<SymIntArrayRef> self_sizes = self.is_nested()
647 : ? std::optional<SymIntArrayRef>(std::nullopt)
648 : : std::optional<SymIntArrayRef>(self.sym_sizes());
649 :
650 : // handle simple types: integers, slices, none, ellipsis, bool
651 : if (indices.size() == 1) {
652 : const TensorIndex& index = indices[0];
653 : if (index.is_integer()) {
654 : return impl::applySelect(
655 : self, 0, index.integer(), 0, self_device, self_sizes);
656 : } else if (index.is_slice()) {
657 : return impl::applySlice(
658 : self,
659 : 0,
660 : index.slice().start(),
661 : index.slice().stop(),
662 : index.slice().step(),
663 : /*disable_slice_optimization=*/true,
664 : self_device,
665 : self_sizes);
666 : } else if (index.is_none()) {
667 : return self.unsqueeze(0);
668 : } else if (index.is_ellipsis()) {
669 : return at::alias(self);
670 : } else if (index.is_boolean()) {
671 : Tensor result = self.unsqueeze(0);
672 : return dispatch_index(
673 : result,
674 : std::vector<Tensor>{impl::boolToIndexingTensor(
675 : result, index.boolean(), self_device)});
676 : }
677 : }
678 :
679 : std::vector<Tensor> tensorIndices;
680 : Tensor sliced = impl::applySlicing(
681 : self,
682 : indices,
683 : tensorIndices,
684 : disable_slice_optimization,
685 : self_device,
686 : self_sizes);
687 : if (tensorIndices.empty()) {
688 : if (sliced.is_same(self)) {
689 : // ensure we return a shallow copy for things like x[...]
690 : sliced = at::alias(sliced);
691 : }
692 : return sliced;
693 : }
694 :
695 : // indexing by tensors ("advanced" indexing)
696 : return dispatch_index(sliced, std::move(tensorIndices));
697 : }
698 :
699 : // This mirrors `THPVariable_setitem` in
700 : // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
701 : // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
702 : // tensor indexing functions from Python ]
703 : inline void set_item(
704 : const Tensor& self,
705 : const ArrayRef<TensorIndex>& indices,
706 : const Tensor& value,
707 : bool disable_slice_optimization = false) {
708 : at::Device self_device = self.device();
709 : SymIntArrayRef self_sizes = self.sym_sizes();
710 :
711 : // handle simple types: integers, slices, ellipsis, bool
712 : if (indices.size() == 1) {
713 : const TensorIndex& index = indices[0];
714 : if (index.is_boolean() && !index.boolean()) {
715 : // do nothing for false (technically we should check the size, but we
716 : // don't have real 0-sized shapes.
717 : return;
718 : } else if (index.is_ellipsis()) {
719 : copy_to(self, value);
720 : return;
721 : } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
722 : copy_to(self.unsqueeze(0), value);
723 : return;
724 : } else if (index.is_integer()) {
725 : copy_to(
726 : impl::applySelect(
727 : self, 0, index.integer(), 0, self_device, self_sizes),
728 : value);
729 : return;
730 : } else if (index.is_slice()) {
731 : copy_to(
732 : impl::applySlice(
733 : self,
734 : 0,
735 : index.slice().start(),
736 : index.slice().stop(),
737 : index.slice().step(),
738 : /*disable_slice_optimization=*/disable_slice_optimization,
739 : self_device,
740 : self_sizes),
741 : value);
742 : return;
743 : }
744 : }
745 :
746 : std::vector<Tensor> tensorIndices;
747 : Tensor sliced = impl::applySlicing(
748 : self,
749 : indices,
750 : tensorIndices,
751 : disable_slice_optimization,
752 : self_device,
753 : self_sizes);
754 : if (tensorIndices.empty()) {
755 : copy_to(sliced, value);
756 : return;
757 : }
758 :
759 : SymIntArrayRef valueSizes = value.sym_sizes();
760 : SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
761 : Tensor valuesSliced;
762 : if (!valueSizes.equals(slicedValueSizes)) {
763 : valuesSliced = value.view_symint(slicedValueSizes);
764 : } else {
765 : valuesSliced = value;
766 : }
767 : dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
768 : return;
769 : }
770 :
771 : } // namespace at::indexing
772 :
773 : #else
774 : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
775 : #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|