CudaGraphModule¶
- class tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None, device: Optional[device] = None)¶
PyTorch 可呼叫物件的 cudagraph 包裝器。
CudaGraphModule是一個包裝器類,它為 PyTorch 可呼叫物件提供使用者友好的 CUDA 圖介面。警告
CudaGraphModule是一個原型功能,其 API 限制在未來可能會發生變化。此類為 CUDA 圖提供了使用者友好的介面,允許 GPU 上操作的快速、無 CPU 開銷執行。它執行對函式輸入的必要檢查,並提供類似 nn.Module 的 API 來執行
警告
此模組要求包裝的函式滿足一些要求。使用者有責任確保所有這些要求都已滿足。
函式不能有動態控制流。例如,以下程式碼片段將在 CudaGraphModule 中包裝失敗
>>> def func(x): ... if x.norm() > 1: ... return x + 1 ... else: ... return x - 1
幸運的是,PyTorch 在大多數情況下都提供瞭解決方案
>>> def func(x): ... return torch.where(x.norm() > 1, x + 1, x - 1)
函式必須執行一個可以使用相同緩衝區精確重放的程式碼。這意味著不支援動態形狀(輸入中或程式碼執行期間形狀的變化)。換句話說,輸入必須具有恆定的形狀。
函式的輸出必須是分離的。如果需要呼叫最佳化器,請將其放在輸入函式中。例如,以下函式是一個有效的運算子
>>> def func(x, y): ... optim.zero_grad() ... loss_val = loss_fn(x, y) ... loss_val.backward() ... optim.step() ... return loss_val.detach()
輸入不應可微分。如果你需要使用 nn.Parameters(或一般可微分張量),只需編寫一個將它們用作全域性值而不是將它們作為輸入傳遞的函式
>>> x = nn.Parameter(torch.randn(())) >>> optim = Adam([x], lr=1) >>> def func(): # right ... optim.zero_grad() ... (x+1).backward() ... optim.step() >>> def func(x): # wrong ... optim.zero_grad() ... (x+1).backward() ... optim.step()
作為張量或 tensordict 的 args 和 kwargs 可能會改變(前提是裝置和形狀匹配),但非張量 args 和 kwargs 不應改變。例如,如果函式接收一個字串輸入並且該輸入在任何時候都被更改,模組將靜默地使用捕獲 cudagraph 時使用的字串執行程式碼。唯一支援的關鍵字引數是 tensordict_out,以防輸入是 tensordict。
如果模組是
TensorDictModuleBase例項,並且輸出 ID 與輸入 ID 匹配,那麼在呼叫CudaGraphModule時將保留此身份。在所有其他情況下,輸出將被克隆,無論其元素是否匹配輸入中的一個或多個。
警告
CudaGraphModule不是Module,其設計目的是為了阻止收集輸入模組的引數並將其傳遞給最佳化器。- 引數:
module (Callable) – 接收張量(或 tensordict)作為輸入並輸出 PyTreeable 張量集合的函式。如果提供 tensordict,則模組也可以使用關鍵字引數執行(請參見下面的示例)。
warmup (int, optional) – 如果模組已編譯(編譯後的模組應在被 cudagraphs 捕獲之前執行幾次),則進行預熱的次數。預設為所有模組的
2。in_keys (list of NestedKeys) –
輸入鍵,如果模組以 TensorDict 作為輸入。如果此值存在,則預設為
module.in_keys,否則為None。注意
如果提供了
in_keys但為空,則假定模組接收 tensordict 作為輸入。這足以讓CudaGraphModule意識到該函式應被視為 TensorDictModule,但關鍵字引數不會被分派。請參見下面的示例。out_keys (list of NestedKeys) – 輸出鍵,如果模組以 TensorDict 作為輸出。如果此值存在,則預設為
module.out_keys,否則為None。device (torch.device, optional) – 要使用的流的裝置。
示例
>>> # Wrap a simple function >>> def func(x): ... return x + 1 >>> func = CudaGraphModule(func) >>> x = torch.rand((), device='cuda') >>> out = func(x) >>> assert isinstance(out, torch.Tensor) >>> assert out == x+1 >>> # Wrap a tensordict module >>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]) >>> func = CudaGraphModule(func) >>> # This can be called either with a TensorDict or regular keyword arguments alike >>> y = func(x=x) >>> td = TensorDict(x=x) >>> td = func(td)
注意
關於除錯 CudaGraphModule 錯誤的提示
- 諸如operation would make the legacy stream depend on a capturing blocking stream(操作將使舊流依賴於捕獲阻塞流)之類的錯誤
應首先使用非編譯版本進行除錯(編譯程式碼將隱藏負責跨流依賴的程式碼部分)。這可能是因為您正在進行跨裝置操作,導致捕獲流依賴於其他流。
- 諸如Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture(在 CUDA 圖捕獲期間無法呼叫 CUDAGeneratorImpl::current_seed)之類的錯誤,或其他起源於
編譯器(而非編譯的程式碼!)的錯誤,可能指向在重新編譯時發生的圖捕獲。使用 TORCH_LOGS=”+recompiles” python myscrip.py 捕獲重新編譯,並嘗試修復它們。通常,請確保您使用了足夠多的預熱步驟。如果您在此類問題上遇到困難,請提交一個 issue。