評價此頁

torch.cuda.jiterator._create_multi_output_jit_fn#

torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[source]#

建立一個 jiterator 生成的 CUDA 核函式,用於支援返回一個或多個輸出的元素級操作。

引數
  • code_string (str) – 由 jiterator 編譯的 CUDA 程式碼字串。入口函式必須透過引用返回值。

  • num_outputs (int) – 核函式返回的輸出數量

  • kwargs (Dict, optional) – 生成函式的關鍵字引數

返回型別

Callable

示例

code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device="cuda")
b = torch.rand(3, device="cuda")
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

警告

此 API 處於 Beta 版,未來版本中可能會更改。

警告

此 API 最多支援 8 個輸入和 8 個輸出。