快捷方式

LLMMaskedCategorical

class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[原始碼]

LLM 最佳化的掩碼分類分佈。

此類透過以下方式為 LLM 訓練提供了更節省記憶體的方法:1. 在 log_prob 計算中使用 ignore_index=-100(無掩碼開銷)2. 在取樣操作中使用傳統掩碼

這對於詞彙量大的情況特別有益,因為掩蓋所有 logits 可能會佔用大量記憶體。

引數:
  • logits (torch.Tensor) – 事件對數機率(未歸一化),形狀為 [B, T, C]。 - B:批次大小(可選) - T:序列長度 - C:詞彙量大小(類別數)

  • mask (torch.Tensor) –

    布林掩碼,指示有效位置/標記。 - 如果形狀為 [*B, T]:位置級掩碼。True 表示該位置有效(所有標記都允許)。 - 如果形狀為 [*B, T, C]:標記級掩碼。True 表示該標記在該位置有效。

    警告

    標記級掩碼比位置級掩碼佔用更多記憶體。僅在需要掩蓋標記時使用。

  • ignore_index (int, optional) – log_prob 計算中要忽略的索引。預設為 -100。

輸入形狀
  • logits: [*B, T, C](必需)

  • mask: [*B, T](位置級)或 [*B, T, C](標記級)

  • tokens (for log_prob): [*B, T](標記索引,帶 ignore_index 用於被掩蓋的位置)

使用場景
  1. 位置級掩碼
    >>> logits = torch.randn(2, 10, 50000)  # [B=2, T=10, C=50000]
    >>> mask = torch.ones(2, 10, dtype=torch.bool)  # [B, T]
    >>> mask[0, :5] = False  # mask first 5 positions of first sequence
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))  # [B, T]
    >>> tokens[0, :5] = -100  # set masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    
  2. 標記級掩碼
    >>> logits = torch.randn(2, 10, 50000)
    >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool)  # [B, T, C]
    >>> mask[0, :5, :1000] = False  # mask first 1000 tokens for first 5 positions
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))
    >>> # Optionally, set tokens at fully-masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    

注意事項

  • 對於 log_prob,tokens 的形狀必須為 [B, T],幷包含有效的標記索引(0 <= token < C),或者對於被掩蓋/忽略的位置使用 ignore_index。

  • 對於標記級掩碼,如果某個位置的標記被掩蓋,log_prob 將為該條目返回 -inf。

  • 對於位置級掩碼,如果某個位置被掩蓋(ignore_index),log_prob 將為該條目返回 0.0(對於交叉熵損失是正確的)。

  • 取樣始終遵守掩碼(被掩蓋的標記/位置永遠不會被取樣)。

所有文件化的使用場景均包含在 test_distributions.py 中的測試中。

clear_cache()[原始碼]

清除快取的掩碼張量以釋放記憶體。

entropy() Tensor[原始碼]

使用掩碼 logits 計算熵。

log_prob(value: Tensor) Tensor[原始碼]

使用 ignore_index 方法計算對數機率。

這很節省記憶體,因為它不需要掩蓋 logits。value 張量應為被掩蓋的位置使用 ignore_index。

property logits: Tensor

獲取原始 logits。

property mask: Tensor

獲取掩碼。

property masked_dist: Categorical

獲取用於取樣操作的掩碼分佈。

property masked_logits: Tensor

獲取用於取樣操作的掩碼 logits。

property mode: Tensor

使用掩碼 logits 獲取模式。

property position_level_masking: bool

掩碼是位置級的(True)還是標記級的(False)。

property probs: Tensor

從原始 logits 獲取機率。

rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[原始碼]

使用掩碼 logits 進行重引數化取樣。

sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[原始碼]

使用掩碼 logits 從分佈中取樣。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源