評價此頁

控制流 - Cond#

創建於: 2023年10月03日 | 最後更新於: 2025年06月13日

torch.cond 是一個結構化控制流運算元。它可以用於指定 if-else 類的控制流,並且在邏輯上可以看作是如下實現的。

def cond(
    pred: Union[bool, torch.Tensor],
    true_fn: Callable,
    false_fn: Callable,
    operands: Tuple[torch.Tensor]
):
    if pred:
        return true_fn(*operands)
    else:
        return false_fn(*operands)

它獨特的力量在於其表達 **資料依賴型控制流** 的能力:它會降低為一個條件運算元(torch.ops.higher_order.cond),該運算元保留了謂詞、真函式和假函式。這極大地提高了編寫和部署那些根據張量操作的輸入或中間輸出的 **值** 或 **形狀** 來改變模型架構的模型所帶來的靈活性。

警告

torch.cond 是 PyTorch 中的一個原型功能。它對輸入和輸出型別支援有限,並且目前不支援訓練。請期待 PyTorch 未來版本中更穩定的實現。有關功能分類的更多資訊,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype

示例#

下面是一個使用 cond 根據輸入形狀進行分支的示例

    import torch

    def true_fn(x: torch.Tensor):
        return x.cos() + x.sin()

    def false_fn(x: torch.Tensor):
        return x.sin()

    class DynamicShapeCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on dynamic shape predicate.
        """

        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            def true_fn(x: torch.Tensor):
                return x.cos()

            def false_fn(x: torch.Tensor):
                return x.sin()

            return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

    dyn_shape_mod = DynamicShapeCondPredicate()

我們可以立即執行模型,並期望結果根據輸入形狀而變化

    inp = torch.randn(3)
    inp2 = torch.randn(5)
    assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
    assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))

我們可以匯出模型以進行進一步的轉換和部署

    inp = torch.randn(4, 3)
    dim_batch = torch.export.Dim("batch", min=2)
    ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
    print(ep)

這將為我們提供一個匯出的程式,如下所示

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: Sym(s0 > 4) = sym_size > 4;  sym_size = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

請注意,torch.cond 被降低為 torch.ops.higher_order.cond,其謂詞成為輸入形狀上的符號表達式,分支函式成為頂級圖模組的兩個子圖屬性。

這是另一個示例,展示瞭如何表達資料依賴型控制流

    class DataDependentCondPredicate(torch.nn.Module):
        """
        A basic usage of cond based on data dependent predicate.
        """
        def __init__(self):
            super().__init__()

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))

匯出後我們得到的匯出程式

    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 3]):
            sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
            gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0);  sum_1 = None

            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]);  gt = true_graph_0 = false_graph_0 = arg0_1 = None
            return (conditional,)

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
                return add

        class <lambda>(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 3]):
                sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
                return sin

torch.ops.higher_order.cond 的不變式#

對於 torch.ops.higher_order.cond 有幾個有用的不變式

  • 對於謂詞

    • 謂詞的動態性被保留(例如,上面示例中所示的 gt

    • 如果使用者程式中的謂詞是常量(例如,一個 Python 的 bool 常量),則運算元的 pred 將是一個常量。

  • 對於分支

    • 輸入和輸出簽名將是一個展平的元組。

    • 它們是 torch.fx.GraphModule

    • 原始函式中的閉包成為顯式輸入。沒有閉包。

    • 不允許對輸入或全域性變數進行修改。

  • 對於運算元

    • 它也將是一個扁平的元組。

  • 使用者程式中 torch.cond 的巢狀將成為巢狀的圖模組。

API 參考#

torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands=())[source]#

有條件地應用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一個原型功能。它對輸入和輸出型別支援有限,並且目前不支援訓練。請期待 PyTorch 未來版本中更穩定的實現。有關功能分類的更多資訊,請參閱:https://pytorch.com.tw/blog/pytorch-feature-classification-changes/#prototype

cond 是結構化控制流運算元。也就是說,它類似於 Python 的 if 語句,但對 true_fnfalse_fnoperands 有限制,這些限制使其能夠被 torch.compile 和 torch.export 捕獲。

假設滿足 cond 引數的約束條件,cond 等價於以下內容

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
引數
  • pred (Union[bool, torch.Tensor]) – 一個布林表示式或一個只有一個元素的張量,指示應用哪個分支函式。

  • true_fn (Callable) – 一個可呼叫函式(a -> b),它在被跟蹤的範圍內。

  • false_fn (Callable) – 一個可呼叫函式(a -> b),它在被跟蹤的範圍內。真分支和假分支必須具有一致的輸入和輸出,這意味著輸入必須相同,輸出必須是相同的型別和形狀。也允許 int 輸出。我們將透過將其轉換為 symint 來使輸出動態化。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 一個輸入真/假函式的元組。如果 true_fn/false_fn 不需要輸入,則可以為空。預設為 ()。

返回型別

任何

示例

def true_fn(x: torch.Tensor):
    return x.cos()


def false_fn(x: torch.Tensor):
    return x.sin()


return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制
  • 條件語句(又名 pred)必須滿足以下任一約束:

    • 它是一個只有一個元素的 torch.Tensor,且 dtype 為 torch.bool

    • 它是一個布林表示式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函式(又名 true_fn/false_fn)必須滿足以下所有約束:

    • 函式簽名必須與運算元匹配。

    • 函式必須返回具有相同元資料(例如,形狀、dtype 等)的張量。

    • 函式不能對輸入或全域性變數進行原地修改。(注意:像 add_ 這樣的原地張量操作用於中間結果是被允許在分支中使用的)