LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.10/site-packages/torch/include/ATen - TracerMode.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 1 1 100.0 %
Date: 2026-06-05 17:04:24 Functions: 0 0 -

          Line data    Source code
       1             : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
       2             : #pragma once
       3             : 
       4             : #include <c10/core/impl/LocalDispatchKeySet.h>
       5             : #include <c10/macros/Export.h>
       6             : #include <c10/macros/Macros.h>
       7             : 
       8             : // NOTE [Tracing Mode Switches]
       9             : //
      10             : // Historically, tracing function was controlled by two switches:
      11             : //
      12             : // - `AutoDispatchBelowADInplaceOrView` guard
      13             : //
      14             : //    Tracing function used to be script-generated inside `VariableType_*.cpp`
      15             : //    kernels, sharing the same `Autograd` dispatch key with autograd function.
      16             : //    Therefore, before tracing function was moved out of VariableType,
      17             : //    `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
      18             : //    side effect of disabling `Autograd` dispatching.
      19             : //
      20             : // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
      21             : //
      22             : //    It stores tracing data in a `TracingState` object in TLS. If the
      23             : //    `TracingState` object in TLS is `null`, then tracing is paused.
      24             : //
      25             : //    The `TracingState` object is created in `tracer::trace()` - the main
      26             : //    entrance of tracing function. It's temporarily set to `null` inside
      27             : //    generated VariableType (now TraceType) to bypass tracing for intermediate
      28             : //    ops (ops being called by other ops). After the intermediate op call
      29             : //    finishes it's set back to the original `TracingState` object.
      30             : //
      31             : //    The `TracingState` object in TLS can also be read/written via its Python
      32             : //    binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
      33             : //    which are also exposed as `TORCH_API`.
      34             : //
      35             : // Two new switches were introduced since tracing function was moved out of
      36             : // VariableType:
      37             : //
      38             : // - `tracer::impl::set_dispatch_enabled()` API
      39             : //
      40             : //    Unlike the special `Autograd` dispatch key which is included in dispatch
      41             : //    key set by default, `Tracer` dispatch key is off by default. The
      42             : //    dispatching switch can be toggled via this new API.
      43             : //
      44             : // - `tracer::impl::NoTracerDispatchMode` guard
      45             : //
      46             : //    It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
      47             : //    after tracing was moved out of VariableType.
      48             : //
      49             : // Before tracing function was moved out of VariableType, tracing was enabled
      50             : // when the following conditions are satisfied:
      51             : //
      52             : //    1) `TracingState` object in TLS != null;
      53             : //       - Either inside the execution scope of `tracer::trace()`, or
      54             : //       - Eagerly called `setTracingState()` with non-null object.
      55             : //    2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
      56             : //
      57             : // After:
      58             : //
      59             : //    1) `TracingState` object in TLS != null;
      60             : //    2) Has called `tracer::impl::set_dispatch_enabled(true)`;
      61             : //    3) Not inside `tracer::impl::NonDispatchGuard` scope;
      62             : //
      63             : // [TODOs]
      64             : //
      65             : // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
      66             : //
      67             : //   Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
      68             : //   to keep the semantics exactly the same as before - it's confusing to keep
      69             : //   both switches, though. We should consider simplifying/limiting the exposed
      70             : //   `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
      71             : //   these two can be unified.
      72             : //
      73             : // - `AutoDispatchBelowADInplaceOrView` v.s.
      74             : // `tracer::impl::NoTracerDispatchMode`
      75             : //
      76             : //   We don't need to always set both guards together to keep semantics
      77             : //   unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
      78             : //   we don't need set the new tracer guard:
      79             : //
      80             : //   * Script-generated VariableType kernels. The guard is not necessary as
      81             : //     tracing is already disabled explicitly by `setTracingState(null)` in
      82             : //     generated TraceType kernels - we could keep it as is or use the new guard
      83             : //     instead.
      84             : //
      85             : //   * Custom ops. Will be handled by fallback kernel for `Tracer`.
      86             : //
      87             : //   * Functions that are not likely to be called in tracing context (no python
      88             : //     binding / not an operator), e.g.: all mobile forward() wrappers, test
      89             : //     binaries, and etc.
      90             : //
      91             : //   * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
      92             : //     It's not necessary as tracing is off by default.
      93             : //
      94             : //   For the rest of cases we might need have both:
      95             : //
      96             : //   * Functions that might be reachable from eager mode python (especially
      97             : //     factory methods), e.g.:
      98             : //     `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
      99             : //     Without the new guard it will add `aten::empty` to the traced graph.
     100             : //
     101             : //   * Some manually maintained functions, e.g.:
     102             : //     `torch/csrc/autograd/VariableTypeManual.cpp`.
     103             : //     Set the new guard if it's not obvious whether `setTracingState(null)`
     104             : //     has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
     105             : //     guard.
     106             : //
     107             : //   We might need tweak the usage of the new guard to optimize/fix things.
     108             : //   It should only affect the correctness of tracing function, because the
     109             : //   guard is essentially no-op when the master `setTracingState()` switch is
     110             : //   off.
     111             : 
     112             : // TODO: move this from `at::` to `jit::torch::` after
     113             : // `aten/src/ATen/cpp_custom_type_hack.h` is removed.
     114             : 
     115             : namespace at::tracer::impl {
     116             : 
     117             : inline bool is_dispatch_enabled() {
     118             :   return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
     119             :       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
     120             : }
     121             : 
     122             : inline void set_dispatch_enabled(bool enabled) {
     123             :   TORCH_INTERNAL_ASSERT(
     124             :       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
     125             :       "Cannot enable tracing within the scope of NoTracerDispatchMode!");
     126             :   c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
     127             : }
     128             : 
     129          52 : struct NoTracerDispatchMode {
     130             :   c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
     131             : };
     132             : 
     133             : } // namespace at::tracer::impl
     134             : 
     135             : #else
     136             : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
     137             : #endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)

Generated by: LCOV version 1.16