快捷方式

OneHotDiscreteTensorSpec

class torchrl.data.OneHotDiscreteTensorSpec(*args, **kwargs)[原始碼]

Deprecated version of torchrl.data.OneHot.

assert_is_in(value: Tensor) None

斷言一個張量是否屬於該區域(box),否則丟擲異常。

引數:

value (torch.Tensor) – 要檢查的值。

cardinality() int

規格的基數。

這指的是規格中可能出現的結果的數量。假設複合規格的基數是所有可能結果的笛卡爾積。

clear_device_() T

對於所有葉子規格(必須有裝置),此方法無操作。

對於 Composite 規格,此方法將擦除裝置。

clone() OneHot

建立 TensorSpec 的副本。

contains(item: torch.Tensor | TensorDictBase) bool

如果值 val 可以由 TensorSpec 生成,則返回 True,否則返回 False

See is_in() for more information.

cpu()

將 TensorSpec 轉換為“cpu”裝置。

cuda(device=None)

將 TensorSpec 轉換為“cuda”裝置。

device: torch.device | None = None
encode(val: np.ndarray | list | torch.Tensor | TensorDictBase, *, ignore_device: bool = False) torch.Tensor | TensorDictBase

使用指定的規格對值進行編碼,並返回相應的張量。

此方法用於返回易於對映到 TorchRL 所需域的值(例如 numpy 陣列)的環境。如果值已經是張量,則規格不會更改其值,而是按原樣返回。

引數:

val (np.ndarraytorch.Tensor) – 要編碼為張量的值。

關鍵字引數:

ignore_device (bool, 可選) – 如果為 True,則將忽略規格裝置。這用於在呼叫 TensorDict(..., device="cuda") 時將張量轉換分組,這樣更快速。

返回:

符合所需張量規格的 torch.Tensor。

enumerate(use_mask: bool = False) Tensor

返回可以從 TensorSpec 獲得的所有樣本。

樣本將沿第一個維度堆疊。

此方法僅為離散規格實現。

引數:

use_mask (bool, 可選) – 如果為 True 且規格有掩碼,則排除被掩碼的樣本。預設為 False

erase_memoize_cache() None

清除用於快取 encode 執行的快取。

另請參閱

memoize_encode().

expand(*shape)

返回一個具有擴充套件形狀的新 Spec。

引數:

*shape (tupleiterable of int) – Spec 的新形狀。必須可廣播到當前形狀:其長度至少等於當前形狀的長度,並且其最後的值也必須相容;即,只有噹噹前維度是單例時,它們才能與當前維度不同。

flatten(start_dim: int, end_dim: int) T

展平一個 TensorSpec

有關此方法的更多資訊,請檢視 flatten()

classmethod implements_for_spec(torch_function: Callable) Callable

為 TensorSpec 註冊一個 torch 函式覆蓋。

index(index: Union[int, Tensor, ndarray, slice, list], tensor_to_index: Tensor) Tensor

索引輸入張量。

此方法用於索引那些編碼一個或多個分類變數的規格(例如,OneHotCategorical),以便在不關心索引的實際表示的情況下對張量進行索引。

引數:
  • index (int, torch.Tensor, slicelist) – 張量的索引

  • tensor_to_index – 要索引的張量

返回:

被索引的張量

示例
>>> from torchrl.data import OneHot
>>> import torch
>>>
>>> one_hot = OneHot(n=100)
>>> categ = one_hot.to_categorical_spec()
>>> idx_one_hot = torch.zeros((100,), dtype=torch.bool)
>>> idx_one_hot[50] = 1
>>> print(one_hot.index(idx_one_hot, torch.arange(100)))
tensor(50)
>>> idx_categ = one_hot.to_categorical(idx_one_hot)
>>> print(categ.index(idx_categ, torch.arange(100)))
tensor(50)
is_in(val: Tensor) bool

如果值 val 可以由 TensorSpec 生成,則返回 True,否則返回 False

更具體地說,is_in 方法檢查值 val 是否在 space 屬性(盒子)定義的限制內,並且 dtypedeviceshape 以及其他可能的元資料是否與規格匹配。如果任何檢查失敗,is_in 方法將返回 False

引數:

val (torch.Tensor) – 要檢查的值。

返回:

布林值,指示值是否屬於 TensorSpec 區域。

make_neg_dim(dim: int) T

將特定維度轉換為 -1

memoize_encode(mode: bool = True) None

建立 encode 方法的快取可呼叫序列,以加快其執行速度。

這應該只在輸入型別、形狀等在給定規格的呼叫之間預期一致時使用。

引數:

mode (bool, optional) – 是否使用快取。預設為 True

另請參閱

快取可以透過 erase_memoize_cache() 擦除。

property ndim: int

規格形狀的維數。

相當於 len(spec.shape)

ndimension() int

規格形狀的維數。

相當於 len(spec.shape)

one(shape: torch.Size = None) torch.Tensor | TensorDictBase

返回盒中的一個填充一的張量。

注意

儘管不能保證 1 屬於規格域,但當此條件被違反時,此方法不會引發異常。 one 的主要用例是生成空的(資料)緩衝區,而不是有意義的資料。

引數:

shape (torch.Size) – one-tensor 的形狀

返回:

在 TensorSpec 區域中取樣的填充一的張量。

ones(shape: torch.Size = None) torch.Tensor | TensorDictBase

Proxy to one().

project(val: torch.Tensor | TensorDictBase) torch.Tensor | TensorDictBase

如果輸入張量不在 TensorSpec 區域內,則根據定義的啟發式方法將其映射回該區域。

引數:

val (torch.Tensor) – 要對映到盒子的張量。

返回:

屬於 TensorSpec 區域的 torch.Tensor。

rand(shape: Optional[Size] = None) Tensor

返回規格定義的區域中的隨機張量。

取樣將在區域內均勻進行,除非區域無界,在這種情況下將繪製正態值。

引數:

shape (torch.Size) – 隨機張量的形狀

返回:

在 TensorSpec 區域中取樣的隨機張量。

reshape(*shape) T

重塑一個 TensorSpec

有關此方法的更多資訊,請檢視 reshape()

sample(shape: torch.Size = None) torch.Tensor | TensorDictBase

返回規格定義的區域中的隨機張量。

See rand() for details.

squeeze(dim=None)

返回一個新 Spec,其中所有大小為 1 的維度都已刪除。

當給定 dim 時,僅在該維度上執行擠壓操作。

引數:

dim (intNone) – 應用擠壓操作的維度

to(dest: torch.dtype | DEVICE_TYPING) OneHot

將 TensorSpec 轉換為裝置或 dtype。

如果未進行更改,則返回相同的規格。

to_categorical(val: torch.Tensor, safe: bool | None = None) torch.Tensor

將給定的獨熱張量轉換為分類格式。

引數:
  • val (torch.Tensor, 可選) – 要轉換為分類格式的 one-hot 張量。

  • safe (bool) – 布林值,指示是否應檢查值與規格域的匹配程度。預設為 CHECK_SPEC_ENCODE 環境變數的值。

返回:

分類張量。

示例

>>> one_hot = OneHot(3, shape=(2, 3))
>>> one_hot_sample = one_hot.rand()
>>> one_hot_sample
tensor([[False,  True, False],
        [False,  True, False]])
>>> categ_sample = one_hot.to_categorical(one_hot_sample)
>>> categ_sample
tensor([1, 1])
to_categorical_spec() Categorical

將 spec 轉換為等效的分類 spec。

示例

>>> one_hot = OneHot(3, shape=(2, 3))
>>> one_hot.to_categorical_spec()
Categorical(
    shape=torch.Size([2]),
    space=CategoricalBox(n=3),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)
to_numpy(val: torch.Tensor, safe: bool | None = None) np.ndarray

返回輸入張量的 np.ndarray 對應項。

This is intended to be the inverse operation of encode().

引數:
  • val (torch.Tensor) – 要轉換為 numpy 的張量。

  • safe (bool) – 布林值,指示是否應檢查值與規格域的匹配程度。預設為 CHECK_SPEC_ENCODE 環境變數的值。

返回:

一個 np.ndarray。

to_one_hot(val: torch.Tensor, safe: bool | None = None) torch.Tensor

No-op for OneHot.

to_one_hot_spec() OneHot

No-op for OneHot.

type_check(value: Tensor, key: Optional[NestedKey] = None) None

檢查輸入值 dtype 是否與 TensorSpecdtype 匹配,如果不匹配則引發異常。

引數:
  • value (torch.Tensor) – 需要檢查 dtype 的張量。

  • key (str, optional) – 如果 TensorSpec 具有鍵,則將檢查值 dtype 是否與指定鍵指向的規格匹配。

unflatten(dim: int, sizes: tuple[int]) T

解展一個 TensorSpec

有關此方法的更多資訊,請檢視 unflatten()

unsqueeze(dim: int)

返回一個新 Spec,其中在 dim 指定的位置增加了一個單例維度。

引數:

dim (intNone) – 應用 unsqueeze 操作的維度。

update_mask(mask)

設定一個掩碼,以防止在取樣時出現某些可能的輸出。

掩碼也可以在 spec 初始化期間設定。

引數:

mask (torch.Tensor or None) – boolean mask. If None, the mask is disabled. Otherwise, the shape of the mask must be expandable to the shape of the spec. False masks an outcome and True leaves the outcome unmasked. If all the possible outcomes are masked, then an error is raised when a sample is taken.

示例

>>> mask = torch.tensor([True, False, False])
>>> ts = OneHot(3, (2, 3,), dtype=torch.int64, mask=mask)
>>> # All but one of the three possible outcomes are masked
>>> ts.rand()
tensor([[1, 0, 0],
        [1, 0, 0]])
view(*shape) T

重塑一個 TensorSpec

有關此方法的更多資訊,請檢視 reshape()

zero(shape: torch.Size = None) torch.Tensor | TensorDictBase

返回盒中的零填充張量。

注意

儘管不能保證 0 屬於規格域,但當此條件被違反時,此方法不會引發異常。 zero 的主要用例是生成空的(資料)緩衝區,而不是有意義的資料。

引數:

shape (torch.Size) – zero-tensor 的形狀

返回:

在 TensorSpec 框中取樣的零填充張量。

zeros(shape: torch.Size = None) torch.Tensor | TensorDictBase

Proxy to zero().

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源