評價此頁

torch.cond#

torch.cond(pred, true_fn, false_fn, operands=())[原始碼]#

有條件地應用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一個原型(prototype)功能。它對輸入和輸出型別有有限的支援,目前不支援訓練。請期待 PyTorch 未來版本中更穩定的實現。有關功能分類的更多資訊,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype

cond 是一個結構化控制流運算元。也就是說,它類似於 Python 的 if 語句,但在 true_fnfalse_fnoperands 上有限制,這使得它可以使用 torch.compile 和 torch.export 進行捕獲。

假設 cond 引數的約束條件已滿足,cond 等價於以下程式碼:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
引數
  • pred (Union[bool, torch.Tensor]) – 一個布林表示式或一個只有一個元素的張量,指示應用哪個分支函式。

  • true_fn (Callable) – 一個在正在追蹤的範圍內可呼叫的函式(a -> b)。

  • false_fn (Callable) – 一個在正在追蹤的範圍內可呼叫的函式(a -> b)。真分支和假分支必須具有一致的輸入和輸出,這意味著輸入必須相同,輸出必須是相同的型別和形狀。也允許 int 輸出。我們將透過將其轉換為 symint 來使輸出動態化。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 輸入到 true/false 函式的元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設為 ()。

返回型別

任何

示例

def true_fn(x: torch.Tensor):
    return x.cos()


def false_fn(x: torch.Tensor):
    return x.sin()


return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
約束條件
  • 條件語句(也稱為 pred)必須滿足以下任一約束條件:

    • 它是一個只有一個元素的 torch.Tensor,且 dtype 為 torch.bool。

    • 它是一個布林表示式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函式(也稱為 true_fn/false_fn)必須滿足以下所有約束條件:

    • 函式簽名必須與 operands 匹配。

    • 函式必須返回一個具有相同元資料(例如,形狀、dtype 等)的張量。

    • 函式不能對輸入或全域性變數進行就地(in-place)修改。(注意:對於中間結果的就地張量操作,如 add_,在分支中是允許的)。