torch.func.vmap#
- torch.func.vmap(func, in_dims=0, out_dims=0, randomness='error', *, chunk_size=None)[原始碼]#
vmap 是向量化對映;
vmap(func)返回一個新函式,該函式將func對映到輸入的某些維度上。從語義上講,vmap 將 map 推入func呼叫的 PyTorch 操作中,從而有效地將這些操作向量化。vmap 對於處理批處理維度很有用:您可以編寫一個在單個示例上執行的函式
func,然後使用vmap(func)將其提升為一個可以接受示例批次的函式。vmap 還可以與 autograd 組合以計算批處理梯度。注意
torch.vmap()為了方便起見,別名為torch.func.vmap()。你可以隨意使用其中一個。- 引數
func (function) – 接受一個或多個引數的 Python 函式。必須返回一個或多個 Tensor。
in_dims (int 或 巢狀結構) – 指定輸入應該在哪一個維度上進行對映。
in_dims的結構應該與輸入相匹配。如果某個輸入的in_dim為 None,則表示沒有對映維度。預設為 0。out_dims (int 或 Tuple[int]) – 指定對映維度應該出現在輸出的哪個位置。如果
out_dims是一個 Tuple,那麼它應該為每個輸出包含一個元素。預設為 0。randomness (str) – 指定此 vmap 中的隨機性在批次之間是相同還是不同。如果為 ‘different’,則每個批次的隨機性將不同。如果為 ‘same’,則批次之間的隨機性將相同。如果為 ‘error’,則對隨機函式的任何呼叫都將報錯。預設為 ‘error’。警告:此標誌僅適用於 PyTorch 的隨機操作,不適用於 Python 的 random 模組或 numpy 隨機性。
chunk_size (None 或 int) – 如果為 None(預設),則在輸入上應用單個 vmap。如果非 None,則一次計算
chunk_size個樣本的 vmap。請注意,chunk_size=1等同於使用 for 迴圈計算 vmap。如果您在計算 vmap 時遇到記憶體問題,請嘗試使用非 None 的 chunk_size。
- 返回
返回一個新的“批處理”函式。它接受與
func相同的輸入,只是每個輸入的指定in_dims索引處會多出一個維度。它返回與func相同的輸出,只是每個輸出的指定out_dims索引處會多出一個維度。- 返回型別
使用
vmap()的一個例子是計算批次點積。PyTorch 不提供批次torch.dotAPI;與其徒勞地翻閱文件,不如使用vmap()來構造一個新函式。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot(x, y)
vmap()有助於隱藏批次維度,從而提供更簡單的模型編寫體驗。>>> batch_size, feature_size = 3, 5 >>> weights = torch.randn(feature_size, requires_grad=True) >>> >>> def model(feature_vec): >>> # Very simple linear model with activation >>> return feature_vec.dot(weights).relu() >>> >>> examples = torch.randn(batch_size, feature_size) >>> result = torch.vmap(model)(examples)
vmap()還可以幫助向量化以前難以或不可能進行批處理的計算。一個例子是高階梯度計算。PyTorch 的 autograd 引擎計算 vjps(向量-雅可比乘積)。通常需要 N 次呼叫autograd.grad(每呼叫一次對應一個雅可比矩陣的行)才能計算某個函式 f: R^N -> R^N 的完整雅可比矩陣。使用vmap(),我們可以向量化整個計算,在一次autograd.grad呼叫中計算雅可比矩陣。>>> # Setup >>> N = 5 >>> f = lambda x: x**2 >>> x = torch.randn(N, requires_grad=True) >>> y = f(x) >>> I_N = torch.eye(N) >>> >>> # Sequential approach >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] >>> for v in I_N.unbind()] >>> jacobian = torch.stack(jacobian_rows) >>> >>> # vectorized gradient computation >>> def get_vjp(v): >>> return torch.autograd.grad(y, x, v) >>> jacobian = torch.vmap(get_vjp)(I_N)
vmap()也可以巢狀使用,產生一個具有多個批處理維度的輸出。>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap( ... torch.vmap(torch.dot) ... ) # [N1, N0, D], [N1, N0, D] -> [N1, N0] >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) >>> batched_dot(x, y) # tensor of size [2, 3]
如果輸入不在第一個維度上進行批處理,
in_dims會指定每個輸入的批處理維度為>>> torch.dot # [N], [N] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] >>> x, y = torch.randn(2, 5), torch.randn(2, 5) >>> batched_dot( ... x, y ... ) # output is [5] instead of [2] if batched along the 0th dimension
如果存在多個輸入,且每個輸入在不同維度上進行批處理,
in_dims必須是一個元組,其中包含每個輸入的批處理維度為>>> torch.dot # [D], [D] -> [] >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] >>> x, y = torch.randn(2, 5), torch.randn(5) >>> batched_dot( ... x, y ... ) # second arg doesn't have a batch dim because in_dim[1] was None
如果輸入是 Python 結構,
in_dims必須是一個元組,其中包含一個與輸入形狀匹配的結構。>>> f = lambda dict: torch.dot(dict["x"], dict["y"]) >>> x, y = torch.randn(2, 5), torch.randn(5) >>> input = {"x": x, "y": y} >>> batched_dot = torch.vmap(f, in_dims=({"x": 0, "y": None},)) >>> batched_dot(input)
預設情況下,輸出在第一個維度上進行批處理。但是,可以使用
out_dims在任何維度上進行批處理。>>> f = lambda x: x**2 >>> x = torch.randn(2, 5) >>> batched_pow = torch.vmap(f, out_dims=1) >>> batched_pow(x) # [5, 2]
對於任何使用 kwargs 的函式,返回的函式不會批處理 kwargs,但會接受 kwargs。
>>> x = torch.randn([2, 5]) >>> def fn(x, scale=4.): >>> return x * scale >>> >>> batched_pow = torch.vmap(fn) >>> assert torch.allclose(batched_pow(x), x * 4) >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
注意
vmap 不提供開箱即用的通用自動批處理或處理可變長度序列。