LLMMaskedCategorical¶
- class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[原始碼]¶
LLM 最佳化的掩碼分類分佈。
此類透過以下方式為 LLM 訓練提供了更節省記憶體的方法:1. 在 log_prob 計算中使用 ignore_index=-100(無掩碼開銷)2. 在取樣操作中使用傳統掩碼
這對於詞彙量大的情況特別有益,因為掩蓋所有 logits 可能會佔用大量記憶體。
- 引數:
logits (torch.Tensor) – 事件對數機率(未歸一化),形狀為 [B, T, C]。 - B:批次大小(可選) - T:序列長度 - C:詞彙量大小(類別數)
mask (torch.Tensor) –
布林掩碼,指示有效位置/標記。 - 如果形狀為 [*B, T]:位置級掩碼。True 表示該位置有效(所有標記都允許)。 - 如果形狀為 [*B, T, C]:標記級掩碼。True 表示該標記在該位置有效。
警告
標記級掩碼比位置級掩碼佔用更多記憶體。僅在需要掩蓋標記時使用。
ignore_index (int, optional) – log_prob 計算中要忽略的索引。預設為 -100。
- 輸入形狀
- 使用場景
- 位置級掩碼
>>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000] >>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T] >>> mask[0, :5] = False # mask first 5 positions of first sequence >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T] >>> tokens[0, :5] = -100 # set masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
- 標記級掩碼
>>> logits = torch.randn(2, 10, 50000) >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C] >>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) >>> # Optionally, set tokens at fully-masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
注意事項
對於 log_prob,tokens 的形狀必須為 [B, T],幷包含有效的標記索引(0 <= token < C),或者對於被掩蓋/忽略的位置使用 ignore_index。
對於標記級掩碼,如果某個位置的標記被掩蓋,log_prob 將為該條目返回 -inf。
對於位置級掩碼,如果某個位置被掩蓋(ignore_index),log_prob 將為該條目返回 0.0(對於交叉熵損失是正確的)。
取樣始終遵守掩碼(被掩蓋的標記/位置永遠不會被取樣)。
所有文件化的使用場景均包含在 test_distributions.py 中的測試中。
- log_prob(value: Tensor) Tensor[原始碼]¶
使用 ignore_index 方法計算對數機率。
這很節省記憶體,因為它不需要掩蓋 logits。value 張量應為被掩蓋的位置使用 ignore_index。
- property masked_dist: Categorical¶
獲取用於取樣操作的掩碼分佈。
- property position_level_masking: bool¶
掩碼是位置級的(True)還是標記級的(False)。
- rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[原始碼]¶
使用掩碼 logits 進行重引數化取樣。
- sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[原始碼]¶
使用掩碼 logits 從分佈中取樣。