評價此頁

torch.cuda.make_graphed_callables#

torch.cuda.make_graphed_callables(callables: Union[Module, Callable[[...], object]], sample_args: tuple[torch.Tensor, ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) Union[Module, Callable[[...], object]][source]#
torch.cuda.make_graphed_callables(callables: tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...], sample_args: tuple[tuple[torch.Tensor, ...], ...], num_warmup_iters: int = 3, allow_unused_input: bool = False, pool: Optional[_POOL_HANDLE] = None) tuple[Union[torch.nn.modules.module.Module, Callable[..., object]], ...]

接受可呼叫物件(函式或nn.Module)並返回圖化版本。

每個圖化可呼叫物件的正向傳播將源可呼叫物件的正向 CUDA 工作作為單個 autograd 節點中的 CUDA 圖執行。

圖化可呼叫物件的前向傳播還將一個反向節點附加到 autograd 圖中。在反向傳播期間,此節點將可呼叫物件的反向工作作為 CUDA 圖執行。

因此,每個圖化可呼叫物件都應成為 autograd 啟用的訓練迴圈中其源可呼叫物件的即插即用替代品。

有關詳細用法和限制,請參閱部分網路捕獲

如果傳遞多個可呼叫物件的元組,則它們的捕獲將使用相同的記憶體池。請參閱圖記憶體管理瞭解何時適用。

引數
  • callablestorch.nn.ModulePython 函式,或 元組 中的 這些)– 要圖化的可呼叫物件或可呼叫物件。有關傳遞可呼叫物件元組何時適用的資訊,請參閱圖記憶體管理。如果傳遞可呼叫物件元組,則元組中的順序必須與即時工作負載中的執行順序相同。

  • sample_args元組 中的 Tensor,或 元組 中的 元組 中的 Tensor)– 為每個可呼叫物件取樣引數。如果傳遞單個可呼叫物件,則 sample_args 必須是引數 Tensor 的單個元組。如果傳遞了可呼叫物件元組,則 sample_args 必須是引數 Tensor 的元組元組。

  • num_warmup_itersint)– 預熱迭代次數。目前,DataDistributedParallel 需要 11 次預熱迭代。預設值:3

  • allow_unused_inputbool)– 如果為 False,則指定未在計算輸出時使用的輸入(因此其 grad 始終為零)將引發錯誤。預設為 False。

  • pool可選)– Token(由 graph_pool_handle()other_Graph_instance.pool() 返回)提示此圖可能與指示的池共享記憶體。請參閱圖記憶體管理

注意

sample_args 中每個 Tensor 的 requires_grad 狀態必須與訓練迴圈中相應真實輸入的預期狀態匹配。

警告

此 API 處於 Beta 版,未來版本中可能會更改。

警告

每個可呼叫物件的 sample_args 必須只包含 Tensor。不允許其他型別。

警告

返回的可呼叫物件不支援高階微分(例如,二次反向傳播)。

警告

在傳遞給 make_graphed_callables() 的任何 Module 中,只有引數可以是可訓練的。緩衝區必須具有 requires_grad=False

警告

在透過 make_graphed_callables() 傳遞 torch.nn.Module 後,您不能新增或刪除該 Module 的任何引數或緩衝區。

警告

傳遞給 make_graphed_callables()torch.nn.Module 在傳遞時不得在其上註冊模組掛鉤。但是,在傳遞給 make_graphed_callables() 之後在模組上註冊掛鉤是允許的。

警告

執行圖化可呼叫物件時,必須以其 sample_args 中出現的相同順序和格式傳遞其引數。

警告

make_graphed_callables() 中的自動混合精度僅在停用快取時受支援。上下文管理器 torch.cuda.amp.autocast() 必須將 cache_enabled=False