torch.cond#
- torch.cond(pred, true_fn, false_fn, operands=())[原始碼]#
有條件地應用 true_fn 或 false_fn。
警告
torch.cond 是 PyTorch 中的一個原型(prototype)功能。它對輸入和輸出型別有有限的支援,目前不支援訓練。請期待 PyTorch 未來版本中更穩定的實現。有關功能分類的更多資訊,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype
cond 是一個結構化控制流運算元。也就是說,它類似於 Python 的 if 語句,但在 true_fn、false_fn 和 operands 上有限制,這使得它可以使用 torch.compile 和 torch.export 進行捕獲。
假設 cond 引數的約束條件已滿足,cond 等價於以下程式碼:
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 引數
pred (Union[bool, torch.Tensor]) – 一個布林表示式或一個只有一個元素的張量,指示應用哪個分支函式。
true_fn (Callable) – 一個在正在追蹤的範圍內可呼叫的函式(a -> b)。
false_fn (Callable) – 一個在正在追蹤的範圍內可呼叫的函式(a -> b)。真分支和假分支必須具有一致的輸入和輸出,這意味著輸入必須相同,輸出必須是相同的型別和形狀。也允許 int 輸出。我們將透過將其轉換為 symint 來使輸出動態化。
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 輸入到 true/false 函式的元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設為 ()。
- 返回型別
示例
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 約束條件
條件語句(也稱為 pred)必須滿足以下任一約束條件:
它是一個只有一個元素的 torch.Tensor,且 dtype 為 torch.bool。
它是一個布林表示式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10。
分支函式(也稱為 true_fn/false_fn)必須滿足以下所有約束條件:
函式簽名必須與 operands 匹配。
函式必須返回一個具有相同元資料(例如,形狀、dtype 等)的張量。
函式不能對輸入或全域性變數進行就地(in-place)修改。(注意:對於中間結果的就地張量操作,如 add_,在分支中是允許的)。