快捷方式

pad

class tensordict.pad(tensordict: T, pad_size: Sequence[int], value: float = 0.0)

使用常量值填充 tensordict 中的所有張量(沿批處理維度),並返回一個新的 tensordict。

引數:
  • tensordict (TensorDict) – 需要填充的 tensordict

  • pad_size (Sequence[int]) – 用於填充 tensordict 的某些批處理維度的填充大小,從第一個維度開始向前移動。[pad_size 的長度 / 2] 個批處理大小維度將被填充。例如,要僅填充第一個維度,pad 的形式為(左填充,右填充)。要填充兩個維度,則為(上左填充,上右填充,下左填充,下右填充)等等。pad_size 必須是偶數,並且小於或等於批處理維度的兩倍。

  • value (float, optional) – 用於填充的值,預設為 0.0

返回:

沿批處理維度填充後的新 TensorDict

示例

>>> from tensordict import TensorDict, pad
>>> import torch
>>> td = TensorDict({'a': torch.ones(3, 4, 1),
...     'b': torch.ones(3, 4, 1, 1)}, batch_size=[3, 4])
>>> dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
>>> padded_td = pad(td, [dim0_left, dim0_right, dim1_left, dim1_right], value=0.0)
>>> print(padded_td.batch_size)
torch.Size([4, 6])
>>> print(padded_td.get("a").shape)
torch.Size([4, 6, 1])
>>> print(padded_td.get("b").shape)
torch.Size([4, 6, 1, 1])

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源