From Harshit, 4 Months ago, written in Plain Text.
Embed
  1. #pragma once
  2.  
  3. #include <ATen/cuda/detail/CUDAGuardImpl.h>
  4. #include <c10/DeviceType.h>
  5. #include <c10/impl/InlineDeviceGuard.h>
  6. #include <c10/impl/InlineStreamGuard.h>
  7.  
  8. #include <cstddef>
  9.  
  10. namespace at { namespace cuda {
  11.  
  12. // This code is kind of boilerplatey.  See Note [Whither the DeviceGuard boilerplate]
  13.  
  14. /// A variant of DeviceGuard that is specialized for CUDA.  It accepts
  15. /// integer indices (interpreting them as CUDA devices) and is a little
  16. /// more efficient than DeviceGuard (it compiles to straight line
  17. /// cudaSetDevice/cudaGetDevice calls); however, it can only be used
  18. /// from code that links against CUDA directly.
  19. struct CUDAGuard {
  20.   /// No default constructor; see Note [Omitted default constructor from RAII]
  21.   explicit CUDAGuard() = delete;
  22.  
  23.   /// Set the current CUDA device to the passed device index.
  24.   explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {}
  25.  
  26.   /// Sets the current CUDA device to the passed device.  Errors if the passed
  27.   /// device is not a CUDA device.
  28.   explicit CUDAGuard(Device device) : guard_(device) {}
  29.  
  30.   // Copy is not allowed
  31.   CUDAGuard(const CUDAGuard&) = delete;
  32.   CUDAGuard& operator=(const CUDAGuard&) = delete;
  33.  
  34.   // Move is not allowed (there is no uninitialized state)
  35.   CUDAGuard(CUDAGuard&& other) = delete;
  36.   CUDAGuard& operator=(CUDAGuard&& other) = delete;
  37.  
  38.   /// Sets the CUDA device to the given device.  Errors if the given device
  39.   /// is not a CUDA device.
  40.   void set_device(Device device) { guard_.set_device(device); }
  41.  
  42.   /// Sets the CUDA device to the given device.  Errors if the given device
  43.   /// is not a CUDA device.  (This method is provided for uniformity with
  44.   /// DeviceGuard).
  45.   void reset_device(Device device) { guard_.reset_device(device); }
  46.  
  47.   /// Sets the CUDA device to the given device index.
  48.   void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  49.  
  50.   /// Returns the device that was set upon construction of the guard
  51.   Device original_device() const { return guard_.original_device(); }
  52.  
  53.   /// Returns the last device that was set via `set_device`, if any, otherwise the
  54.   /// device passed during construction.
  55.   Device current_device() const { return guard_.current_device(); }
  56.  
  57.  private:
  58.   /// The guard for the current device.
  59.   c10::impl::InlineDeviceGuard<at::cuda::impl::CUDAGuardImpl> guard_;
  60. };
  61.  
  62. /// A variant of OptionalDeviceGuard that is specialized for CUDA.  See
  63. /// CUDAGuard for when you can use this.
  64. struct OptionalCUDAGuard {
  65.   /// Create an uninitialized OptionalCUDAGuard.
  66.   explicit OptionalCUDAGuard() : guard_() {}
  67.  
  68.   /// Set the current CUDA device to the passed Device, if it is not nullopt.
  69.   explicit OptionalCUDAGuard(optional<Device> device_opt) : guard_(device_opt) {}
  70.  
  71.   /// Set the current CUDA device to the passed device index, if it is not
  72.   /// nullopt
  73.   explicit OptionalCUDAGuard(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
  74.  
  75.   // Copy is not allowed
  76.   OptionalCUDAGuard(const OptionalCUDAGuard&) = delete;
  77.   OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete;
  78.  
  79.   // See Note [Move construction for RAII guards is tricky]
  80.   OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete;
  81.  
  82.   // See Note [Move assignment for RAII guards is tricky]
  83.   OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete;
  84.  
  85.   /// Sets the CUDA device to the given device, initializing the guard if it
  86.   /// is not already initialized.  Errors if the given device is not a CUDA device.
  87.   void set_device(Device device) { guard_.set_device(device); }
  88.  
  89.   /// Sets the CUDA device to the given device, initializing the guard if it is
  90.   /// not already initialized.  Errors if the given device is not a CUDA device.
  91.   /// (This method is provided for uniformity with OptionalDeviceGuard).
  92.   void reset_device(Device device) { guard_.reset_device(device); }
  93.  
  94.   /// Sets the CUDA device to the given device index, initializing the guard if
  95.   /// it is not already initialized.
  96.   void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
  97.  
  98.   /// Returns the device that was set immediately prior to initialization of the
  99.   /// guard, or nullopt if the guard is uninitialized.
  100.   optional<Device> original_device() const { return guard_.original_device(); }
  101.  
  102.   /// Returns the most recent device that was set using this device guard,
  103.   /// either from construction, or via set_device, if the guard is initialized,
  104.   /// or nullopt if the guard is uninitialized.
  105.   optional<Device> current_device() const { return guard_.current_device(); }
  106.  
  107.   /// Restore the original CUDA device, resetting this guard to uninitialized state.
  108.   void reset() { guard_.reset(); }
  109.  
  110. private:
  111.   c10::impl::InlineOptionalDeviceGuard<at::cuda::impl::CUDAGuardImpl> guard_;
  112. };
  113.  
  114. /// A variant of StreamGuard that is specialized for CUDA.  See CUDAGuard
  115. /// for when you can use this.
  116. struct CUDAStreamGuard {
  117.   /// No default constructor, see Note [Omitted default constructor from RAII]
  118.   explicit CUDAStreamGuard() = delete;
  119.  
  120.   /// Set the current CUDA device to the device associated with the passed stream,
  121.   /// and set the current CUDA stream on that device to the passed stream.
  122.   /// Errors if the Stream is not a CUDA stream.
  123.   explicit CUDAStreamGuard(Stream stream) : guard_(stream) {}
  124.  
  125.   /// Copy is disallowed
  126.   CUDAStreamGuard(const CUDAStreamGuard&) = delete;
  127.   CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;
  128.  
  129.   /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized state,
  130.   /// which is required for moves on types with nontrivial destructors.
  131.   CUDAStreamGuard(CUDAStreamGuard&& other) = delete;
  132.   CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete;
  133.  
  134.   /// Resets the currently set stream to the original stream and
  135.   /// the currently set device to the original device.  Then,
  136.   /// set the current device to the device associated with the passed stream,
  137.   /// and set the current stream on that device to the passed stream.
  138.   /// Errors if the stream passed is not a CUDA stream.
  139.   ///
  140.   /// NOTE: this implementation may skip some stream/device setting if
  141.   /// it can prove that it is unnecessary.
  142.   ///
  143.   /// WARNING: reset_stream does NOT preserve previously set streams on
  144.   /// different devices.  If you need to set streams on multiple devices
  145.   /// on CUDA, use CUDAMultiStreamGuard instead.
  146.   void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  147.  
  148.   /// Returns the CUDA stream that was set at the time the guard was constructed.
  149.   CUDAStream original_stream() const {
  150.     return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream());
  151.   }
  152.  
  153.   /// Returns the most recent CUDA stream that was set using this device guard,
  154.   /// either from construction, or via set_stream.
  155.   CUDAStream current_stream() const {
  156.     return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream());
  157.   }
  158.  
  159.   /// Returns the most recent CUDA device that was set using this device guard,
  160.   /// either from construction, or via set_device/reset_device/set_index.
  161.   Device current_device() const { return guard_.current_device(); }
  162.  
  163.   /// Returns the CUDA device that was set at the most recent reset_stream(),
  164.   /// or otherwise the device at construction time.
  165.   Device original_device() const { return guard_.original_device(); }
  166.  
  167. private:
  168.   c10::impl::InlineStreamGuard<at::cuda::impl::CUDAGuardImpl> guard_;
  169. };
  170.  
  171. /// A variant of OptionalStreamGuard that is specialized for CUDA.  See CUDAGuard
  172. /// for when you can use this.
  173. struct OptionalCUDAStreamGuard {
  174.   /// Create an uninitialized guard.
  175.   explicit OptionalCUDAStreamGuard() : guard_() {}
  176.  
  177.   /// Set the current CUDA device to the device associated with the passed stream,
  178.   /// and set the current CUDA stream on that device to the passed stream.
  179.   /// Errors if the Stream is not a CUDA stream.
  180.   explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {}
  181.  
  182.   /// Set the current device to the device associated with the passed stream,
  183.   /// and set the current stream on that device to the passed stream,
  184.   /// if the passed stream is not nullopt.
  185.   explicit OptionalCUDAStreamGuard(optional<Stream> stream_opt) : guard_(stream_opt) {}
  186.  
  187.   /// Copy is disallowed
  188.   OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete;
  189.   OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete;
  190.  
  191.   // See Note [Move construction for RAII guards is tricky]
  192.   OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete;
  193.  
  194.   // See Note [Move assignment for RAII guards is tricky]
  195.   OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete;
  196.  
  197.   /// Resets the currently set CUDA stream to the original stream and
  198.   /// the currently set device to the original device.  Then,
  199.   /// set the current device to the device associated with the passed stream,
  200.   /// and set the current stream on that device to the passed stream.
  201.   /// Initializes the guard if it was not previously initialized.
  202.   void reset_stream(Stream stream) { guard_.reset_stream(stream); }
  203.  
  204.   /// Returns the CUDA stream that was set at the time the guard was most recently
  205.   /// initialized, or nullopt if the guard is uninitialized.
  206.   optional<CUDAStream> original_stream() const {
  207.     auto r = guard_.original_stream();
  208.     if (r.has_value()) {
  209.       return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
  210.     } else {
  211.       return nullopt;
  212.     }
  213.   }
  214.  
  215.   /// Returns the most recent CUDA stream that was set using this stream guard,
  216.   /// either from construction, or via reset_stream, if the guard is initialized,
  217.   /// or nullopt if the guard is uninitialized.
  218.   optional<CUDAStream> current_stream() const {
  219.     auto r = guard_.current_stream();
  220.     if (r.has_value()) {
  221.       return make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value()));
  222.     } else {
  223.       return nullopt;
  224.     }
  225.   }
  226.  
  227.   /// Restore the original CUDA device and stream, resetting this guard to uninitialized state.
  228.   void reset() { guard_.reset(); }
  229.  
  230. private:
  231.   c10::impl::InlineOptionalStreamGuard<at::cuda::impl::CUDAGuardImpl> guard_;
  232. };
  233.  
  234. } // namespace cuda
  235. } // namespace at