評價此頁

(beta) 使用 FX 構建簡單的 CPU 效能分析器#

創建於:2021 年 3 月 4 日 | 最後更新:2025 年 7 月 14 日 | 最後驗證:未驗證

作者James Reed

在本教程中,我們將使用 FX 來完成以下操作:

  1. 以一種我們可以檢查並收集程式碼結構和執行統計資訊的方式捕獲 PyTorch Python 程式碼。

  2. 構建一個小型類,作為簡單的效能“分析器”,收集模型各部分在實際執行中的執行時統計資訊。

在本教程中,我們將使用 torchvision 的 ResNet18 模型進行演示。

import torch
import torch.fx
import torchvision.models as models

rn18 = models.resnet18()
rn18.eval()
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

現在我們有了模型,我們想更深入地檢查它的效能。也就是說,對於以下呼叫,模型的哪些部分花費的時間最長?

input = torch.randn(5, 3, 224, 224)
output = rn18(input)

回答這個問題的常見方法是遍歷程式原始碼,新增在程式各個點收集時間戳的程式碼,然後比較這些時間戳之間的差異,以檢視時間戳之間的區域花費了多長時間。

這種技術當然適用於 PyTorch 程式碼,但如果我們不必複製程式碼並對其進行編輯,那就更好了,特別是對於我們尚未編寫的程式碼(例如此 torchvision 模型)。相反,我們將使用 FX 來自動化這個“插樁”過程,而無需修改任何原始碼。

首先,讓我們處理一些匯入(我們稍後將在程式碼中全部使用它們)。

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

注意

tabulate 是一個外部庫,不是 PyTorch 的依賴項。我們將使用它來更輕鬆地視覺化效能資料。請確保您已從您喜歡的 Python 包源安裝了它。

使用符號跟蹤捕獲模型#

接下來,我們將使用 FX 的符號跟蹤機制將模型的定義捕獲到我們可以操作和檢查的資料結構中。

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
    %x : torch.Tensor [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
    %maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
    %layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
    %layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
    %layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
    %layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
    %layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
    %layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
    %layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
    %layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
    %layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
    %layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
    %layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
    %layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
    %layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
    %layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
    %layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
    %layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
    %layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
    %layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
    %add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
    %layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
    %layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
    %layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
    %layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
    %layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
    %layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
    %add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
    %layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
    %layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
    %layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
    %layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
    %layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
    %layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
    %layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
    %add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
    %layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
    %layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
    %layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
    %layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
    %layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
    %layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
    %add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
    %layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
    %layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
    %layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
    %layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
    %layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
    %layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
    %layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
    %add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
    %layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
    %layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
    %layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
    %layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
    %layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
    %layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
    %add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
    %layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
    %avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
    %flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
    return fc

這為我們提供了 ResNet18 模型的 Graph 表示。Graph 由一系列相互連線的 Node 組成。每個 Node 代表 Python 程式碼中的一個呼叫點(無論是對函式、模組還是方法的呼叫),並且邊(在每個節點上表示為 argskwargs)表示在這些呼叫點之間傳遞的值。有關 Graph 表示和 FX 其他 API 的更多資訊,請參閱 FX 文件 https://pytorch.com.tw/docs/stable/fx.html

建立效能分析直譯器#

接下來,我們將建立一個繼承自 torch.fx.Interpreter 的類。雖然 symbolic_trace 生成的 GraphModule 會編譯在呼叫 GraphModule 時執行的 Python 程式碼,但執行 GraphModule 的另一種方法是逐個執行 Graph 中的每個 Node。這就是 Interpreter 提供的功能:它逐個節點地解釋 Graph。

透過繼承 Interpreter,我們可以覆蓋各種功能並安裝我們想要的效能分析行為。目標是建立一個物件,我們可以將模型傳遞給它,呼叫模型一次或多次,然後獲取有關模型及其各個部分在這些執行中花費時間的統計資訊。

讓我們定義我們的 ProfilingInterpreter 類。

class ProfilingInterpreter(Interpreter):
    def __init__(self, mod : torch.nn.Module):
        # Rather than have the user symbolically trace their model,
        # we're going to do it in the constructor. As a result, the
        # user can pass in any ``Module`` without having to worry about
        # symbolic tracing APIs
        gm = torch.fx.symbolic_trace(mod)
        super().__init__(gm)

        # We are going to store away two things here:
        #
        # 1. A list of total runtimes for ``mod``. In other words, we are
        #    storing away the time ``mod(...)`` took each time this
        #    interpreter is called.
        self.total_runtime_sec : List[float] = []
        # 2. A map from ``Node`` to a list of times (in seconds) that
        #    node took to run. This can be seen as similar to (1) but
        #    for specific sub-parts of the model.
        self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

    ######################################################################
    # Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
    # method is the top-level entry point for execution of the model. We will
    # want to intercept this so that we can record the total runtime of the
    # model.

    def run(self, *args) -> Any:
        # Record the time we started running the model
        t_start = time.time()
        # Run the model by delegating back into Interpreter.run()
        return_val = super().run(*args)
        # Record the time we finished running the model
        t_end = time.time()
        # Store the total elapsed time this model execution took in the
        # ``ProfilingInterpreter``
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    ######################################################################
    # Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
    # time it executes a single node. We will intercept this so that we
    # can measure and record the time taken for each individual call in
    # the model.

    def run_node(self, n : torch.fx.Node) -> Any:
        # Record the time we started running the op
        t_start = time.time()
        # Run the op by delegating back into Interpreter.run_node()
        return_val = super().run_node(n)
        # Record the time we finished running the op
        t_end = time.time()
        # If we don't have an entry for this node in our runtimes_sec
        # data structure, add one with an empty list value.
        self.runtimes_sec.setdefault(n, [])
        # Record the total elapsed time for this single invocation
        # in the runtimes_sec data structure
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    ######################################################################
    # Finally, we are going to define a method (one which doesn't override
    # any ``Interpreter`` method) that provides us a nice, organized view of
    # the data we have collected.

    def summary(self, should_sort : bool = False) -> str:
        # Build up a list of summary information for each node
        node_summaries : List[List[Any]] = []
        # Calculate the mean runtime for the whole network. Because the
        # network may have been called multiple times during profiling,
        # we need to summarize the runtimes. We choose to use the
        # arithmetic mean for this.
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # For each node, record summary statistics
        for node, runtimes in self.runtimes_sec.items():
            # Similarly, compute the mean runtime for ``node``
            mean_runtime = statistics.mean(runtimes)
            # For easier understanding, we also compute the percentage
            # time each node took with respect to the whole network.
            pct_total = mean_runtime / mean_total_runtime * 100
            # Record the node's type, name of the node, mean runtime, and
            # percent runtime.
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total])

        # One of the most important questions to answer when doing performance
        # profiling is "Which op(s) took the longest?". We can make this easy
        # to see by providing sorting functionality in our summary view
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # Use the ``tabulate`` library to create a well-formatted table
        # presenting our summary information
        headers : List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)

注意

我們使用 Python 的 time.time 函式來獲取掛鐘時間戳並進行比較。這不是衡量效能的最準確方法,只能提供一階近似。在本教程中,我們僅出於演示目的使用這種簡單技術。

調查 ResNet18 的效能#

我們現在可以使用 ProfilingInterpreter 來檢查我們的 ResNet18 模型的效能特徵;

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type        Op                       Average runtime (s)    Pct total runtime
-------------  ---------------------  ---------------------  -------------------
call_module    maxpool                          0.00485635             8.16969
call_module    conv1                            0.00463676             7.80029
call_module    layer1_0_conv1                   0.00338912             5.70142
call_module    layer1_0_conv2                   0.00322509             5.42547
call_module    layer4_0_conv2                   0.00317121             5.33483
call_module    layer4_1_conv1                   0.00294328             4.95139
call_module    layer1_1_conv1                   0.00290251             4.8828
call_module    layer1_1_conv2                   0.00289297             4.86676
call_module    layer4_1_conv2                   0.00289154             4.86435
call_module    layer2_1_conv2                   0.00269008             4.52544
call_module    layer2_1_conv1                   0.00249147             4.19133
call_module    layer3_1_conv1                   0.00236368             3.97635
call_module    layer2_0_conv2                   0.00229502             3.86084
call_module    layer3_0_conv2                   0.00229478             3.86044
call_module    layer3_1_conv2                   0.00209951             3.53195
call_module    layer4_0_conv1                   0.00189781             3.19263
call_module    layer3_0_conv1                   0.00145459             2.44702
call_module    bn1                              0.00137854             2.31907
call_module    layer2_0_conv1                   0.00126791             2.13297
call_module    layer2_0_downsample_0            0.00077939             1.31115
call_module    layer4_0_downsample_0            0.00050211             0.844684
call_module    layer3_0_downsample_0            0.000460625            0.774895
call_function  add                              0.000433922            0.729974
call_function  add_1                            0.000392914            0.660987
call_module    layer1_0_bn1                     0.000322819            0.543068
call_module    layer1_1_bn2                     0.000308275            0.518602
call_module    layer1_0_bn2                     0.000286818            0.482505
call_module    relu                             0.000284195            0.478093
call_function  add_3                            0.000205278            0.345334
call_module    fc                               0.000194311            0.326884
call_module    layer2_1_bn2                     0.000166893            0.280759
call_module    layer1_1_bn1                     0.000156403            0.263111
call_module    layer1_0_relu_1                  0.000153542            0.258298
call_module    layer2_0_downsample_1            0.000129938            0.218591
call_module    avgpool                          0.000120878            0.20335
call_module    layer4_1_bn2                     0.000114679            0.192922
call_module    layer3_1_bn2                     0.000114202            0.192119
call_module    layer2_1_bn1                     0.000109196            0.183697
call_module    layer2_0_relu                    9.98974e-05            0.168054
call_module    layer1_0_relu                    9.77516e-05            0.164445
call_module    layer2_0_bn2                     9.39369e-05            0.158027
call_module    layer4_0_bn2                     9.39369e-05            0.158027
call_module    layer2_0_bn1                     9.10759e-05            0.153214
call_module    layer4_1_bn1                     8.34465e-05            0.14038
call_module    layer1_1_relu_1                  8.32081e-05            0.139979
call_module    layer3_0_bn2                     8.32081e-05            0.139979
call_module    layer1_1_relu                    7.98702e-05            0.134363
call_function  add_2                            7.96318e-05            0.133962
call_module    layer3_1_bn1                     7.82013e-05            0.131556
call_function  add_5                            7.67708e-05            0.129149
call_module    layer4_0_downsample_1            7.53403e-05            0.126743
output         output                           7.36713e-05            0.123935
call_module    layer4_0_bn1                     6.8903e-05             0.115913
call_module    layer3_0_downsample_1            6.84261e-05            0.115111
call_module    layer3_0_bn1                     6.4373e-05             0.108293
call_function  add_7                            6.41346e-05            0.107892
call_function  add_6                            6.03199e-05            0.101474
call_function  add_4                            5.53131e-05            0.0930516
call_module    layer4_1_relu                    5.24521e-05            0.0882386
call_module    layer4_0_relu                    4.88758e-05            0.0822223
call_module    layer2_0_relu_1                  4.69685e-05            0.0790137
call_module    layer2_1_relu_1                  4.64916e-05            0.0782115
call_module    layer4_0_relu_1                  4.50611e-05            0.075805
call_module    layer4_1_relu_1                  4.43459e-05            0.0746017
call_module    layer2_1_relu                    4.19617e-05            0.0705909
call_module    layer3_1_relu                    3.79086e-05            0.0637724
call_module    layer3_1_relu_1                  3.60012e-05            0.0605638
call_module    layer3_0_relu                    3.55244e-05            0.0597616
call_module    layer3_0_relu_1                  3.52859e-05            0.0593605
call_function  flatten                          2.52724e-05            0.042515
placeholder    x                                1.81198e-05            0.0304824

這裡有兩點值得注意:

結論#

正如我們所見,使用 FX,我們可以輕鬆地捕獲 PyTorch 程式(即使是我們沒有原始碼的程式!)到一個機器可解釋的格式,並將其用於分析,例如我們在此處進行的效能分析。FX 為處理 PyTorch 程式打開了一個令人興奮的可能性世界。

最後,由於 FX 仍處於 beta 階段,我們很樂意聽取您在使用它方面的任何反饋。請隨時使用 PyTorch 論壇(https://discuss.pytorch.org/)和問題跟蹤器(pytorch/pytorch#issues)提供您可能有的任何反饋。

指令碼總執行時間: (0 分鐘 0.327 秒)