Line data Source code
1 : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) 2 : #pragma once 3 : 4 : #include <c10/core/impl/LocalDispatchKeySet.h> 5 : #include <c10/macros/Export.h> 6 : #include <c10/macros/Macros.h> 7 : 8 : // NOTE [Tracing Mode Switches] 9 : // 10 : // Historically, tracing function was controlled by two switches: 11 : // 12 : // - `AutoDispatchBelowADInplaceOrView` guard 13 : // 14 : // Tracing function used to be script-generated inside `VariableType_*.cpp` 15 : // kernels, sharing the same `Autograd` dispatch key with autograd function. 16 : // Therefore, before tracing function was moved out of VariableType, 17 : // `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a 18 : // side effect of disabling `Autograd` dispatching. 19 : // 20 : // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h` 21 : // 22 : // It stores tracing data in a `TracingState` object in TLS. If the 23 : // `TracingState` object in TLS is `null`, then tracing is paused. 24 : // 25 : // The `TracingState` object is created in `tracer::trace()` - the main 26 : // entrance of tracing function. It's temporarily set to `null` inside 27 : // generated VariableType (now TraceType) to bypass tracing for intermediate 28 : // ops (ops being called by other ops). After the intermediate op call 29 : // finishes it's set back to the original `TracingState` object. 30 : // 31 : // The `TracingState` object in TLS can also be read/written via its Python 32 : // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, 33 : // which are also exposed as `TORCH_API`. 34 : // 35 : // Two new switches were introduced since tracing function was moved out of 36 : // VariableType: 37 : // 38 : // - `tracer::impl::set_dispatch_enabled()` API 39 : // 40 : // Unlike the special `Autograd` dispatch key which is included in dispatch 41 : // key set by default, `Tracer` dispatch key is off by default. The 42 : // dispatching switch can be toggled via this new API. 43 : // 44 : // - `tracer::impl::NoTracerDispatchMode` guard 45 : // 46 : // It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView` 47 : // after tracing was moved out of VariableType. 48 : // 49 : // Before tracing function was moved out of VariableType, tracing was enabled 50 : // when the following conditions are satisfied: 51 : // 52 : // 1) `TracingState` object in TLS != null; 53 : // - Either inside the execution scope of `tracer::trace()`, or 54 : // - Eagerly called `setTracingState()` with non-null object. 55 : // 2) Not inside `AutoDispatchBelowADInplaceOrView` scope; 56 : // 57 : // After: 58 : // 59 : // 1) `TracingState` object in TLS != null; 60 : // 2) Has called `tracer::impl::set_dispatch_enabled(true)`; 61 : // 3) Not inside `tracer::impl::NonDispatchGuard` scope; 62 : // 63 : // [TODOs] 64 : // 65 : // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()` 66 : // 67 : // Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()` 68 : // to keep the semantics exactly the same as before - it's confusing to keep 69 : // both switches, though. We should consider simplifying/limiting the exposed 70 : // `setTracingState()` Python/C++ APIs (and other APIs calling it) so that 71 : // these two can be unified. 72 : // 73 : // - `AutoDispatchBelowADInplaceOrView` v.s. 74 : // `tracer::impl::NoTracerDispatchMode` 75 : // 76 : // We don't need to always set both guards together to keep semantics 77 : // unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView` 78 : // we don't need set the new tracer guard: 79 : // 80 : // * Script-generated VariableType kernels. The guard is not necessary as 81 : // tracing is already disabled explicitly by `setTracingState(null)` in 82 : // generated TraceType kernels - we could keep it as is or use the new guard 83 : // instead. 84 : // 85 : // * Custom ops. Will be handled by fallback kernel for `Tracer`. 86 : // 87 : // * Functions that are not likely to be called in tracing context (no python 88 : // binding / not an operator), e.g.: all mobile forward() wrappers, test 89 : // binaries, and etc. 90 : // 91 : // * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp. 92 : // It's not necessary as tracing is off by default. 93 : // 94 : // For the rest of cases we might need have both: 95 : // 96 : // * Functions that might be reachable from eager mode python (especially 97 : // factory methods), e.g.: 98 : // `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`. 99 : // Without the new guard it will add `aten::empty` to the traced graph. 100 : // 101 : // * Some manually maintained functions, e.g.: 102 : // `torch/csrc/autograd/VariableTypeManual.cpp`. 103 : // Set the new guard if it's not obvious whether `setTracingState(null)` 104 : // has been called before it reaches the `AutoDispatchBelowADInplaceOrView` 105 : // guard. 106 : // 107 : // We might need tweak the usage of the new guard to optimize/fix things. 108 : // It should only affect the correctness of tracing function, because the 109 : // guard is essentially no-op when the master `setTracingState()` switch is 110 : // off. 111 : 112 : // TODO: move this from `at::` to `jit::torch::` after 113 : // `aten/src/ATen/cpp_custom_type_hack.h` is removed. 114 : 115 : namespace at::tracer::impl { 116 : 117 : inline bool is_dispatch_enabled() { 118 : return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) && 119 : !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer); 120 : } 121 : 122 : inline void set_dispatch_enabled(bool enabled) { 123 : TORCH_INTERNAL_ASSERT( 124 : !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer), 125 : "Cannot enable tracing within the scope of NoTracerDispatchMode!"); 126 : c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled); 127 : } 128 : 129 52 : struct NoTracerDispatchMode { 130 : c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer}; 131 : }; 132 : 133 : } // namespace at::tracer::impl 134 : 135 : #else 136 : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." 137 : #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)