Saltar a contenido

Week 9 - PyTorch Internals: Tensor, Dispatcher, ATen

9.1 Conceptual Core

  • PyTorch is a layered system:
  • Python frontend-torch.* namespace, what users write.
  • Dispatcher-routes ops to backend implementations based on device, dtype, layout, autograd state, and other "keys."
  • ATen-the C++ tensor library. Each op (add, matmul, softmax) has device-specific implementations (CPU, CUDA, MPS, XPU).
  • Backends-cuBLAS, cuDNN, OneDNN, custom kernels.
  • Every Python tensor op is, fundamentally, a dispatcher call. a + btorch.add(a, b)aten::add → CPU/CUDA add kernel. Understanding this is the foundation for the rest of the month.

9.2 Mechanical Detail

  • Read aten/src/ATen/core/dispatch/Dispatcher.h and DispatchKey.h. The DispatchKey enum names every backend, every layer (autograd, autocast, named tensors, vmap, ...).
  • Dispatch keys stack: a tensor's "key set" determines which dispatcher entries fire and in what order. AutogradCUDA → AutocastCUDA → CUDA, for example.
  • torch::Library macro registers ops:
    TORCH_LIBRARY_IMPL(aten, CUDA, m) {
        m.impl("add.Tensor", &my_add_cuda);
    }
    
  • The Python tensor object is a thin wrapper around at::Tensor, which is a thin wrapper around c10::TensorImpl, which holds a c10::Storage and view metadata (sizes, strides, offset, dtype, device).
  • Strides are critical. A "tensor view" (transpose, slice, narrow) shares storage but rewrites strides. The dispatcher and most ops handle strided tensors transparently; some kernels require contiguous (tensor.contiguous()).

9.3 Lab-"Trace an Op"

  1. From Python, run a + b for two CUDA tensors. Use TORCH_SHOW_DISPATCH_TRACE=1 (or torch._C._dispatch_print_registrations()) to see the dispatcher's path.
  2. Read `aten/src/ATen/native/cuda/BinaryOps.cu - find the actual CUDA kernel for add.
  3. Trace torch.matmul(a, b) similarly. Note that for BF16 it routes to cuBLAS.
  4. Document the call chain in TRACE.md.

9.4 Idiomatic & Diagnostic Drill

  • torch.profiler.profile(activities=[CPU, CUDA]) with record_shapes=True and with_stack=True. Read the table; identify any op spending more than 5% of total time.

9.5 Production Slice

  • Add torch.cuda.synchronize() discipline: every benchmark must sync before timing. CUDA is asynchronous; without sync, you'll measure queue insertion, not execution.

Comments