注意
跳轉到末尾 以下載完整的示例程式碼。
切片、索引和掩碼¶
作者: Tom Begley
在本教程中,您將學習如何對 TensorDict 進行切片、索引和掩碼操作。
正如在教程 Manipulating the shape of a TensorDict 中討論的那樣,當我們建立一個 TensorDict 時,我們會指定一個 batch_size,它必須與 TensorDict 中的所有條目的前導維度一致。由於我們保證所有條目都共享這些公共維度,因此我們可以像索引 torch.Tensor 一樣來索引和掩碼這些批處理維度。這些索引沿著批處理維度應用於 TensorDict 中的所有條目。
例如,給定一個具有兩個批處理維度的 TensorDict,tensordict[0] 將返回一個結構相同的新的 TensorDict,其值對應於原始 TensorDict 中每個條目的第一個“行”。
import torch
from tensordict import TensorDict
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
print(tensordict[0])
TensorDict(
fields={
a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4]),
device=None,
is_shared=False)
與常規張量相同的語法適用。例如,如果我們想刪除每個條目的第一行,可以進行如下索引:
print(tensordict[1:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2, 4]),
device=None,
is_shared=False)
我們可以同時索引多個維度
print(tensordict[:, 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
我們還可以使用 Ellipsis 來表示任意數量的 :,以使選擇元組的長度與 tensordict.batch_dims 的長度相同。
print(tensordict[..., 2:])
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 2, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 2]),
device=None,
is_shared=False)
使用索引設定值¶
通常,只要批處理大小相容,tensordict[index] = new_tensordict 就會起作用。
tensordict = TensorDict(
{"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4]
)
td2 = TensorDict({"a": torch.ones(2, 4, 5), "b": torch.ones(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
print(tensordict["a"], tensordict["b"])
tensor([[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]) tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]])
掩碼¶
我們像掩碼張量一樣掩碼 TensorDict。
TensorDict(
fields={
a: Tensor(shape=torch.Size([6, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([6]),
device=None,
is_shared=False)
指令碼總執行時間: (0 分鐘 0.005 秒)