評價此頁

UX 限制#

建立日期:2025 年 6 月 12 日 | 最後更新日期:2025 年 6 月 12 日

torch.func 像 JAX 一樣,在可轉換內容方面存在限制。總的來說,JAX 的限制是轉換僅適用於純函式:即,輸出完全由輸入決定且不包含副作用(如變異)的函式。

我們有類似的保證:我們的轉換對純函式效果很好。但是,我們也支援某些原地操作。一方面,編寫與函式轉換相容的程式碼可能需要更改您編寫 PyTorch 程式碼的方式;另一方面,您可能會發現我們的轉換能夠表達以前在 PyTorch 中難以表達的內容。

通用限制#

所有 torch.func 轉換都共享一個限制,即函式不應分配給全域性變數。相反,函式的所有輸出都必須從函式返回。此限制源於 torch.func 的實現方式:每個轉換都將 Tensor 輸入包裝在特殊的 torch.func Tensor 子類中,以促進轉換。

所以,不要這樣做

import torch
from torch.func import grad

# Don't do this
intermediate = None

def f(x):
  global intermediate
  intermediate = x.sin()
  z = intermediate.sin()
  return z

x = torch.randn([])
grad_x = grad(f)(x)

請重寫 f 以返回 intermediate

def f(x):
  intermediate = x.sin()
  z = intermediate.sin()
  return z, intermediate

grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd API#

如果您嘗試在由 vmap() 或 torch.func 的 AD 轉換(vjp()jvp()jacrev()jacfwd())轉換的函式內部使用 torch.autograd API(如 torch.autograd.gradtorch.autograd.backward)時,轉換可能無法對其進行轉換。如果無法這樣做,您將收到錯誤訊息。

這是 PyTorch AD 支援實現方式上的根本設計限制,也是我們設計 torch.func 庫的原因。請改用 torch.autograd API 的 torch.func 等效項

  • torch.autograd.grad, Tensor.backward -> torch.func.vjptorch.func.grad

  • torch.autograd.functional.jvp -> torch.func.jvp

  • torch.autograd.functional.jacobian -> torch.func.jacrevtorch.func.jacfwd

  • torch.autograd.functional.hessian -> torch.func.hessian

vmap 限制#

注意

vmap() 是我們限制最多的轉換。與 grad 相關的轉換(grad()vjp()jvp())沒有這些限制。jacfwd()(以及 hessian(),它使用 jacfwd() 實現)是 vmap()jvp() 的組合,因此也具有這些限制。

vmap(func) 是一個返回函式的轉換,該函式在每個輸入 Tensor 的某個新維度上對映 func。vmap 的思維模型是,它就像執行一個 for 迴圈:對於純函式(即,在沒有副作用的情況下),vmap(f)(x) 等同於

torch.stack([f(x_i) for x_i in x.unbind(0)])

變異:任意變異 Python 資料結構#

在存在副作用的情況下,vmap() 不再像執行 for 迴圈一樣。例如,以下函式

def f(x, list):
  list.pop()
  print("hello!")
  return x.sum(0)

x = torch.randn(3, 1)
lst = [0, 1, 2, 3]

result = vmap(f, in_dims=(0, None))(x, lst)

將列印“hello!”一次,並從 lst 中彈出(pop)一個元素。

vmap() 只執行 f 一次,所以所有副作用只發生一次。

這是 vmap 實現方式的結果。torch.func 有一個特殊的內部 BatchedTensor 類。vmap(f)(*inputs) 獲取所有 Tensor 輸入,將它們轉換為 BatchedTensor,並呼叫 f(*batched_tensor_inputs)。BatchedTensor 重寫了 PyTorch API,為每個 PyTorch 運算子生成批處理(即向量化)行為。

變異:原地 PyTorch 操作#

您可能因為收到關於 vmap 不相容的原地操作的錯誤而在此處。當 vmap() 遇到不支援的 PyTorch 原地操作時會引發錯誤,否則會成功。不支援的操作是那些會將元素更多的 Tensor 寫入到元素更少的 Tensor 中的操作。以下是如何發生的示例

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(1)
y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

# Raises an error because `x` has fewer elements than `y`.
vmap(f, in_dims=(None, 0))(x, y)

x 是一個只有一個元素的 Tensor,y 是一個有三個元素的 Tensor。x + y 有三個元素(由於廣播),但嘗試將三個元素寫回 x(它只有一個元素)會引發錯誤,因為試圖將三個元素寫入只有一個元素的 Tensor。

如果正在寫入的 Tensor 在 vmap() 下被批處理(即,它被 vmap 了),則沒有問題。

def f(x, y):
  x.add_(y)
  return x

x = torch.randn(3, 1)
y = torch.randn(3, 1)
expected = x + y

# Does not raise an error because x is being vmapped over.
vmap(f, in_dims=(0, 0))(x, y)
assert torch.allclose(x, expected)

一個常見的修復方法是將工廠函式的呼叫替換為它們的“new_*”等效項。例如

要了解為什麼這有幫助,請看以下示例。

def diag_embed(vec):
  assert vec.dim() == 1
  result = torch.zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
vmap(diag_embed)(vecs)

vmap() 內部,result 是一個形狀為 [3, 3] 的 Tensor。但是,儘管 vec 的形狀看起來是 [3],但 vec 的底層形狀實際上是 [2, 3]。無法將 vec 複製到形狀為 [3] 的 result.diagonal() 中,因為它包含的元素太多。

def diag_embed(vec):
  assert vec.dim() == 1
  result = vec.new_zeros(vec.shape[0], vec.shape[0])
  result.diagonal().copy_(vec)
  return result

vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
vmap(diag_embed)(vecs)

torch.zeros() 替換為 Tensor.new_zeros() 會使 result 的底層 Tensor 形狀為 [2, 3, 3],因此現在可以將底層形狀為 [2, 3] 的 vec 複製到 result.diagonal() 中。

變異:out= PyTorch 操作#

vmap() 不支援 PyTorch 操作中的 out= 關鍵字引數。如果它在您的程式碼中遇到該引數,它將優雅地報錯。

這不是一個根本性的限制;理論上我們將來可以支援它,但目前我們選擇不這樣做。

資料相關 Python 控制流#

我們尚未支援在資料相關控制流上進行 vmap。資料相關控制流是指 if 語句、while 迴圈或 for 迴圈的條件是正在被 vmap 的 Tensor。例如,以下程式碼會引發錯誤訊息

def relu(x):
  if x > 0:
    return x
  return 0

x = torch.randn(3)
vmap(relu)(x)

但是,任何不依賴於 vmap 的 Tensor 中的值的控制流都可以正常工作

def custom_dot(x):
  if x.dim() == 1:
    return torch.dot(x, x)
  return (x * x).sum()

x = torch.randn(3)
vmap(custom_dot)(x)

JAX 支援使用特殊控制流運算子(例如 jax.lax.condjax.lax.while_loop)在資料相關控制流上進行轉換。我們正在研究為 PyTorch 新增這些的等效項。

資料相關操作(.item())#

我們不支援(也不會支援)對呼叫 Tensor 的 .item() 的使用者定義函式進行 vmap。例如,以下程式碼會引發錯誤訊息

def f(x):
  return x.item()

x = torch.randn(3)
vmap(f)(x)

請嘗試重寫您的程式碼以避免使用 .item() 呼叫。

您也可能遇到有關使用 .item() 的錯誤訊息,但您可能沒有使用它。在這種情況下,PyTorch 內部可能正在呼叫 .item() – 請在 GitHub 上提交一個 issue,我們將修復 PyTorch 內部。

動態形狀操作(nonzero 及類似操作)#

vmap(f) 要求 f 應用於輸入中的每個“示例”時返回的 Tensor 形狀相同。諸如 torch.nonzerotorch.is_nonzero 等操作不受支援,並將因此報錯。

要了解原因,請看以下示例

xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
vmap(torch.nonzero)(xs)

torch.nonzero(xs[0]) 返回一個形狀為 2 的 Tensor;但 torch.nonzero(xs[1]) 返回一個形狀為 1 的 Tensor。我們無法構建一個單一的 Tensor 作為輸出;輸出需要是一個不規則 Tensor(而 PyTorch 還沒有不規則 Tensor 的概念)。

隨機性#

呼叫隨機操作時的使用者意圖可能不明確。具體來說,一些使用者可能希望隨機行為在批次之間保持一致,而另一些使用者可能希望它在批次之間有所不同。為了解決這個問題,vmap 接受一個隨機性標誌。

該標誌只能傳遞給 vmap,並且可以取三個值:“error”、“different”或“same”,預設為“error”。在“error”模式下,任何對隨機函式的呼叫都會產生一個錯誤,要求使用者根據其用例使用另外兩個標誌之一。

在“different”隨機性下,批次中的元素會產生不同的隨機值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be different across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

在“same”隨機性下,批次中的元素會產生相同的隨機值。例如,

def add_noise(x):
  y = torch.randn(())  # y will be the same across the batch
  return x + y

x = torch.ones(3)
result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times

警告

我們的系統只能確定 PyTorch 運算子的隨機行為,而不能控制 numpy 等其他庫的行為。這與 JAX 解決方案的限制類似。

注意

使用任一型別受支援的隨機性的多個 vmap 呼叫將不會產生相同的結果。與標準 PyTorch 一樣,使用者可以透過在 vmap 外部使用 torch.manual_seed() 或使用生成器來實現隨機性可重複性。

注意

最後,我們的隨機性與 JAX 不同,因為我們沒有使用無狀態 PRNG,部分原因是 PyTorch 沒有對無狀態 PRNG 的全面支援。相反,我們引入了一個標誌系統,以允許我們看到的最常見的隨機性形式。如果您的用例不適合這些隨機性形式,請提交一個 issue。