torch.cuda.make_graphed_callables#
- torch.cuda.make_graphed_callables(callables: Union[Module, Callable[[...], object]], sample_args: tuple[torch.Tensor, ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) Union[Module, Callable[[...], object]][source]#
- torch.cuda.make_graphed_callables(callables: tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...], sample_args: tuple[tuple[torch.Tensor, ...], ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...]
接受可呼叫物件(函式或
nn.Module)並返回圖化版本。每個圖化可呼叫物件的正向傳播將源可呼叫物件的正向 CUDA 工作作為單個 autograd 節點中的 CUDA 圖執行。
圖化可呼叫物件的前向傳播還將一個反向節點附加到 autograd 圖中。在反向傳播期間,此節點將可呼叫物件的反向工作作為 CUDA 圖執行。
因此,每個圖化可呼叫物件都應成為 autograd 啟用的訓練迴圈中其源可呼叫物件的即插即用替代品。
有關詳細用法和限制,請參閱部分網路捕獲。
如果傳遞多個可呼叫物件的元組,則它們的捕獲將使用相同的記憶體池。請參閱圖記憶體管理瞭解何時適用。
- 引數
callables (torch.nn.Module 或 Python 函式,或 元組 中的 這些)– 要圖化的可呼叫物件或可呼叫物件。有關傳遞可呼叫物件元組何時適用的資訊,請參閱圖記憶體管理。如果傳遞可呼叫物件元組,則元組中的順序必須與即時工作負載中的執行順序相同。
sample_args (元組 中的 Tensor,或 元組 中的 元組 中的 Tensor)– 為每個可呼叫物件取樣引數。如果傳遞單個可呼叫物件,則
sample_args必須是引數 Tensor 的單個元組。如果傳遞了可呼叫物件元組,則sample_args必須是引數 Tensor 的元組元組。num_warmup_iters (int)– 預熱迭代次數。目前,
DataDistributedParallel需要 11 次預熱迭代。預設值:3。allow_unused_input (bool)– 如果為 False,則指定未在計算輸出時使用的輸入(因此其 grad 始終為零)將引發錯誤。預設為 False。
pool (可選)– Token(由
graph_pool_handle()或other_Graph_instance.pool()返回)提示此圖可能與指示的池共享記憶體。請參閱圖記憶體管理。
注意
sample_args中每個 Tensor 的requires_grad狀態必須與訓練迴圈中相應真實輸入的預期狀態匹配。警告
此 API 處於 Beta 版,未來版本中可能會更改。
警告
每個可呼叫物件的
sample_args必須只包含 Tensor。不允許其他型別。警告
返回的可呼叫物件不支援高階微分(例如,二次反向傳播)。
警告
在傳遞給
make_graphed_callables()的任何Module中,只有引數可以是可訓練的。緩衝區必須具有requires_grad=False。警告
在透過
make_graphed_callables()傳遞torch.nn.Module後,您不能新增或刪除該 Module 的任何引數或緩衝區。警告
傳遞給
make_graphed_callables()的torch.nn.Module在傳遞時不得在其上註冊模組掛鉤。但是,在傳遞給make_graphed_callables()之後在模組上註冊掛鉤是允許的。警告
執行圖化可呼叫物件時,必須以其
sample_args中出現的相同順序和格式傳遞其引數。警告
make_graphed_callables()中的自動混合精度僅在停用快取時受支援。上下文管理器 torch.cuda.amp.autocast() 必須將 cache_enabled=False。