評價此頁

graph#

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source]#

上下文管理器,它將 CUDA 工作捕獲到 torch.cuda.CUDAGraph 物件中以供以後重放。

有關通用介紹、詳細用法和限制,請參閱 CUDA Graphs

引數
  • cuda_graph (torch.cuda.CUDAGraph) – 用於捕獲的 Graph 物件。

  • pool (optional) – 不透明令牌(透過呼叫 graph_pool_handle()other_Graph_instance.pool() 返回),指示此 graph 的捕獲可能共享指定池的記憶體。請參閱 Graph 記憶體管理

  • stream (torch.cuda.Stream, optional) – 如果提供,將在上下文中設定為當前流。如果未提供,則 graph 會將其自身的內部輔助流設定為上下文中的當前流。

  • capture_error_mode (str, optional) – 指定 graph 捕獲流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda graph 捕獲期間,某些操作(如 cudaMalloc)可能不安全。“global”會因其他執行緒中的操作而報錯,“thread_local”只會因當前執行緒中的操作而報錯,“relaxed”則不會因操作而報錯。除非您熟悉 cudaStreamCaptureMode,否則請勿更改此設定。

注意

為了有效地共享記憶體,如果您傳遞一個由先前捕獲使用的 pool,並且先前的捕獲使用了顯式的 stream 引數,那麼您應該將相同的 stream 引數傳遞給此次捕獲。

警告

此 API 處於 Beta 版,未來版本中可能會更改。