graph#
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source]#
上下文管理器,它將 CUDA 工作捕獲到
torch.cuda.CUDAGraph物件中以供以後重放。有關通用介紹、詳細用法和限制,請參閱 CUDA Graphs。
- 引數
cuda_graph (torch.cuda.CUDAGraph) – 用於捕獲的 Graph 物件。
pool (optional) – 不透明令牌(透過呼叫
graph_pool_handle()或other_Graph_instance.pool()返回),指示此 graph 的捕獲可能共享指定池的記憶體。請參閱 Graph 記憶體管理。stream (torch.cuda.Stream, optional) – 如果提供,將在上下文中設定為當前流。如果未提供,則
graph會將其自身的內部輔助流設定為上下文中的當前流。capture_error_mode (str, optional) – 指定 graph 捕獲流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda graph 捕獲期間,某些操作(如 cudaMalloc)可能不安全。“global”會因其他執行緒中的操作而報錯,“thread_local”只會因當前執行緒中的操作而報錯,“relaxed”則不會因操作而報錯。除非您熟悉 cudaStreamCaptureMode,否則請勿更改此設定。
注意
為了有效地共享記憶體,如果您傳遞一個由先前捕獲使用的
pool,並且先前的捕獲使用了顯式的stream引數,那麼您應該將相同的stream引數傳遞給此次捕獲。警告
此 API 處於 Beta 版,未來版本中可能會更改。