評價此頁

PyTorch 基準測試#

建立日期:2020 年 12 月 02 日 | 最後更新:2025 年 09 月 23 日 | 最後驗證:2024 年 11 月 05 日

本示例提供了一個快速入門指南,介紹如何使用 PyTorch 的 benchmark 模組來測量和比較程式碼效能。

簡介#

基準測試是編寫程式碼的重要步驟。它有助於我們驗證程式碼是否滿足效能預期、比較解決同一問題的不同方法以及防止效能回退。

在對 PyTorch 程式碼進行基準測試時,有許多選項可供選擇,包括 Python 內建的 timeit 模組。然而,對 PyTorch 程式碼進行基準測試有許多容易被忽略的注意事項,例如管理執行緒數量和同步 CUDA 裝置。此外,為基準測試生成 Tensor 輸入可能會非常繁瑣。

本示例演示瞭如何使用 PyTorch 的 benchmark 模組來避免常見錯誤,同時更容易比較不同程式碼的效能、生成基準測試輸入等。

設定#

在開始之前,請安裝 torch(如果尚未安裝)。

pip install torch

步驟#

  1. 定義要進行基準測試的函式

  2. 使用 timeit.Timer 進行基準測試

  3. 使用 torch.utils.benchmark.Timer 進行基準測試

  4. 使用 Blocked Autorange 進行基準測試

  5. 比較基準測試結果

  6. 儲存/載入基準測試結果

  7. 使用 Fuzzed Parameters 生成輸入

  8. 使用 Callgrind 收集指令計數

1. 定義要進行基準測試的函式#

在撰寫本文時,torch.dot 不支援批處理模式,因此我們將比較使用現有 torch 運算子實現的兩種方法:一種方法使用 mulsum 的組合,另一種方法將問題簡化為 bmm

import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

2. 使用 timeit.Timer 進行基準測試#

首先,讓我們使用 Python 內建的 timeit 模組對程式碼進行基準測試。我們在這裡保持基準測試程式碼的簡潔,以便能夠比較 timeittorch.utils.benchmark 的預設設定。

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
輸出#
 mul_sum(x, x):  111.6 us
 bmm(x, x):       70.0 us

3. 使用 torch.utils.benchmark.Timer 進行基準測試#

PyTorch 的 benchmark 模組在設計時考慮了使用過 timeit 模組的人員,力求熟悉。然而,它的預設設定使其更容易、更安全地用於 PyTorch 程式碼的基準測試。讓我們首先比較與上面相同的基本 API。

import torch.utils.benchmark as benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))
輸出#
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
 batched_dot_mul_sum(x, x)
 setup: from __main__ import batched_dot_mul_sum
   379.29 us
   1 measurement, 100 runs , 1 thread
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d67048>
 batched_dot_bmm(x, x)
 setup: from __main__ import batched_dot_bmm
   716.42 us
   1 measurement, 100 runs , 1 thread

儘管基本功能 API 相同,但存在一些重要差異。 benchmark.Timer.timeit() 返回每次執行的時間,而不是像 timeit.Timer.timeit() 那樣返回總執行時間。PyTorch 的 benchmark 模組還提供了格式化的字串表示形式來列印結果。

另一個重要區別,也是結果出現差異的原因,是 PyTorch 的 benchmark 模組預設在單個執行緒中執行。我們可以使用 num_threads 引數更改執行緒數量。

torch.utils.benchmark.Timer 接受幾個附加引數,包括:labelsub_labeldescriptionenv,這些引數會改變返回的測量物件的 __repr__,並用於對結果進行分組(稍後將詳細介紹)。

num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))
輸出#
 Benchmarking on 40 threads
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb103d54080>
 Multithreaded batch dot: Implemented using mul and sum
 setup: from __main__ import batched_dot_mul_sum
   118.47 us
   1 measurement, 100 runs , 40 threads
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
 Multithreaded batch dot: Implemented using bmm
 setup: from __main__ import batched_dot_bmm
   68.21 us
   1 measurement, 100 runs , 40 threads

使用所有可用執行緒執行 benchmark 會產生與 timeit 模組類似的結果。更重要的是,哪個版本更快取決於我們使用多少執行緒執行程式碼。這就是為什麼在代表實際使用場景的執行緒設定下進行程式碼基準測試很重要。另一個需要記住的重要事項是,在 GPU 上進行基準測試時要同步 CPU 和 CUDA。讓我們在 CUDA 張量上再次執行上述基準測試,看看會發生什麼。

x = torch.randn(10000, 1024, device='cuda')

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warm-up
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
輸出#
 mul_sum(x, x):   27.6 us
 mul_sum(x, x):   25.3 us
 bmm(x, x):      2775.5 us
 bmm(x, x):       22.4 us
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Run only once since benchmark module does warm-up for us
print(t0.timeit(100))
print(t1.timeit(100))
輸出#
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>
 batched_dot_mul_sum(x, x)
 setup: from __main__ import batched_dot_mul_sum
   232.93 us
   1 measurement, 100 runs , 1 thread
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
 batched_dot_bmm(x, x)
 setup: from __main__ import batched_dot_bmm
   181.04 us
   1 measurement, 100 runs , 1 thread

結果揭示了一些有趣的地方。使用 timeit 模組的 bmm 版本的第一次執行比第二次執行花費的時間長得多。這是因為 bmm 呼叫 cuBLAS,它在第一次呼叫時需要載入,這需要一些時間。這就是為什麼在基準測試之前進行預熱執行很重要,幸運的是,PyTorch 的 benchmark 模組會處理這個問題。

timeitbenchmark 模組之間的結果差異是因為 timeit 模組沒有同步 CUDA,因此只對核心啟動時間進行計時。PyTorch 的 benchmark 模組為我們進行了同步。

4. 使用 Blocked Autorange 進行基準測試#

雖然 timeit.Timer.autorange 進行至少 0.2 秒的單次連續測量,但 torch.utils.benchmark.Timer.blocked_autorange 進行多次測量,這些測量的總時間至少為 0.2 秒(可以透過 min_run_time 引數更改),但受到測量開銷是整體測量一小部分的限制。這是透過首先以增加的每次迴圈執行次數進行執行,直到執行時間遠大於測量開銷(這也充當了預熱),然後進行測量直到達到目標時間來實現的。這具有浪費更少資料的有用特性,並允許我們計算統計資料來估計測量的可靠性。

m0 = t0.blocked_autorange()
m1 = t1.blocked_autorange()

print(m0)
print(m1)
輸出#
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d0f0>
 batched_dot_mul_sum(x, x)
 setup: from __main__ import batched_dot_mul_sum
   231.79 us
   1 measurement, 1000 runs , 1 thread
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb10400d080>
 batched_dot_bmm(x, x)
 setup: from __main__ import batched_dot_bmm
   Median: 162.08 us
   2 measurements, 1000 runs per measurement, 1 thread

我們還可以檢查返回的測量物件中的各個統計資料。

print(f"Mean:   {m0.mean * 1e6:6.2f} us")
print(f"Median: {m0.median * 1e6:6.2f} us")
輸出#
 Mean:   231.79 us
 Median: 231.79 us

5. 比較基準測試結果#

到目前為止,我們一直在將兩種版本的批處理點積與單個輸入進行比較。在實踐中,我們希望嘗試輸入組合以及不同的執行緒數。 Compare 類有助於以格式化的表格顯示許多測量的結果。它使用上面描述的註解(labelsub_labelnum_threads 等)以及 description 來分組和組織表格。讓我們使用 Compare 來檢視我們的函式在不同輸入大小和執行緒數下的效能。

from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()
輸出#
 [--------------- Batched dot ----------------]
                       |  mul/sum   |    bmm
 1 threads: -----------------------------------
       [1, 1]          |       5.9  |      11.2
       [1, 64]         |       6.4  |      11.4
       [1, 1024]       |       6.7  |      14.2
       [1, 10000]      |      10.2  |      23.7
       [64, 1]         |       6.3  |      11.5
       [64, 64]        |       8.6  |      15.4
       [64, 1024]      |      39.4  |     204.4
       [64, 10000]     |     274.9  |     748.5
       [1024, 1]       |       7.7  |      17.8
       [1024, 64]      |      40.3  |      76.4
       [1024, 1024]    |     432.4  |    2795.9
       [1024, 10000]   |   22657.3  |   11899.5
       [10000, 1]      |      16.9  |      74.8
       [10000, 64]     |     300.3  |     609.4
       [10000, 1024]   |   23098.6  |   27246.1
       [10000, 10000]  |  267073.7  |  118823.7
 4 threads: -----------------------------------
       [1, 1]          |       6.0  |      11.5
       [1, 64]         |       6.2  |      11.2
       [1, 1024]       |       6.8  |      14.3
       [1, 10000]      |      10.2  |      23.7
       [64, 1]         |       6.3  |      16.2
       [64, 64]        |       8.8  |      18.2
       [64, 1024]      |      41.5  |     189.1
       [64, 10000]     |      91.7  |     849.1
       [1024, 1]       |       7.6  |      17.4
       [1024, 64]      |      43.5  |      33.5
       [1024, 1024]    |     135.4  |    2782.3
       [1024, 10000]   |    7471.1  |   11874.0
       [10000, 1]      |      16.8  |      33.9
       [10000, 64]     |     118.7  |     173.2
       [10000, 1024]   |    7264.6  |   27824.7
       [10000, 10000]  |  100060.9  |  121499.0
 16 threads: ----------------------------------
       [1, 1]          |       6.0  |      11.3
       [1, 64]         |       6.2  |      11.2
       [1, 1024]       |       6.9  |      14.2
       [1, 10000]      |      10.3  |      23.8
       [64, 1]         |       6.4  |      24.1
       [64, 64]        |       9.0  |      23.8
       [64, 1024]      |      54.1  |     188.5
       [64, 10000]     |      49.9  |     748.0
       [1024, 1]       |       7.6  |      23.4
       [1024, 64]      |      55.5  |      28.2
       [1024, 1024]    |      66.9  |    2773.9
       [1024, 10000]   |    6111.5  |   12833.7
       [10000, 1]      |      16.9  |      27.5
       [10000, 64]     |      59.5  |      73.7
       [10000, 1024]   |    6295.9  |   27062.0
       [10000, 10000]  |   71804.5  |  120365.8
 32 threads: ----------------------------------
       [1, 1]          |       5.9  |      11.3
       [1, 64]         |       6.2  |      11.3
       [1, 1024]       |       6.7  |      14.2
       [1, 10000]      |      10.5  |      23.8
       [64, 1]         |       6.3  |      31.7
       [64, 64]        |       9.1  |      30.4
       [64, 1024]      |      72.0  |     190.4
       [64, 10000]     |     103.1  |     746.9
       [1024, 1]       |       7.6  |      28.4
       [1024, 64]      |      70.5  |      31.9
       [1024, 1024]    |      65.6  |    2804.6
       [1024, 10000]   |    6764.0  |   11871.4
       [10000, 1]      |      17.8  |      31.8
       [10000, 64]     |     110.3  |      56.0
       [10000, 1024]   |    6640.2  |   27592.2
       [10000, 10000]  |   73003.4  |  120083.2

 Times are in microseconds (us).

上述結果表明,對於在多個執行緒上執行的大型張量,轉換為 bmm 的版本更好;而對於較小和/或單執行緒程式碼,另一個版本更好。

Compare 還提供了用於更改表格格式的函式。

compare.trim_significant_figures()
compare.colorize()
compare.print()

6. 儲存/載入基準測試結果#

Measurements(以及第 8 節中描述的 CallgrindStats)可以透過 pickle 模組進行序列化。這使得 A/B 測試變得容易,因為您可以收集來自兩個獨立環境的測量結果,將它們 pickle,然後在單個環境中載入兩者。Timer 甚至有一個 env 建構函式引數,因此這種 A/B 測試可以無縫進行。

讓我們假設,而不是兩個 Python 函式,add/sum 和 bmm 方法位於 PyTorch 的兩個不同構建版本中。下面的示例演示瞭如何進行 A/B 測試。為簡單起見,我們僅使用部分形狀,並透過 pickle 往返處理結果,而不是實際使用多個環境並將結果寫入磁碟。

import pickle

ab_test_results = []
for env in ('environment A: mul/sum', 'environment B: bmm'):
    for b, n in ((1, 1), (1024, 10000), (10000, 1)):
        x = torch.ones((b, n))
        dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)
        m = benchmark.Timer(
            stmt='batched_dot(x, x)',
            globals={'x': x, 'batched_dot': dot_fn},
            num_threads=1,
            label='Batched dot',
            description=f'[{b}, {n}]',
            env=env,
        ).blocked_autorange(min_run_time=1)
        ab_test_results.append(pickle.dumps(m))

ab_results = [pickle.loads(i) for i in ab_test_results]
compare = benchmark.Compare(ab_results)
compare.trim_significant_figures()
compare.colorize()
compare.print()
輸出#
 [------------------------------------- Batched dot -------------------------------------]
                                                |  [1, 1]  |  [1024, 10000]  |  [10000, 1]
 1 threads: ------------------------------------------------------------------------------
   (environment A: mul/sum)  batched_dot(x, x)  |     7    |      36000      |      21
   (environment B: bmm)      batched_dot(x, x)  |    14    |      40000      |      85

 Times are in microseconds (us).
# And just to show that we can round trip all of the results from earlier:
round_tripped_results = pickle.loads(pickle.dumps(results))
assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))

7. 使用 Fuzzed Parameters 生成輸入#

正如我們在上一節中所見,根據輸入張量的不同,效能可能存在顯著差異。因此,在許多不同的輸入上執行基準測試是一個好主意。然而,建立所有這些輸入張量可能很麻煩,這時 torch.utils.benchmark.Fuzzer 和相關類就派上用場了。讓我們看看如何使用 Fuzzer 為基準測試建立一些測試用例。

from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias

# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.
example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),
        FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)
    ],
    seed=0,
)

results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
    # description is the column label
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()
輸出#
 [--------------------- Batched dot ---------------------]
                                      |  mul/sum  |   bmm
 1 threads: ----------------------------------------------
       725    x 257                   |      87   |    180
       49     x 383                   |      15   |     30
       34     x 1468                  |      30   |    118
       187    x 5039                  |     400   |   1200
       2140   x 1296 (discontiguous)  |    2000   |  41000
       78     x 1598                  |      74   |    310
       519    x 763                   |     190   |   1500
       141    x 1082                  |      87   |    500
       78     x 5    (discontiguous)  |       9   |     20
       187    x 1                     |      12   |     10

 Times are in microseconds (us).

定義自定義 fuzzers 具有很大的靈活性,這對於建立一套強大的輸入來進行基準測試非常有用。但為了使事情變得更簡單,PyTorch benchmark 模組帶有一些內建的 fuzzers,用於常見的基準測試需求。讓我們看看如何使用其中一個內建的 fuzzers

from torch.utils.benchmark.op_fuzzers import binary

results = []
for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
輸出#
 [----------------------- Batched dot ------------------------]
                                          |  mul/sum  |   bmm
 1 threads: ---------------------------------------------------
       64     x 473  (discontiguous)      |    10000  |   40000
       16384  x 12642115 (discontiguous)  |       31  |      78
       8192   x 892                       |     4800  |   20400
       512    x 64   (discontiguous)      |   110000  |  400000
       493    x 27   (discontiguous)      |     1100  |    2440
       118    x 32   (discontiguous)      |      870  |    2030
       16     x 495  (discontiguous)      |    23600  |   24000
       488    x 62374                     |    90000  |  100000
       240372 x 69                        |    40000  |   16000
       40156  x 32   (discontiguous)      |     2670  |    5000

 Times are in microseconds (us).

8. 使用 Callgrind 收集指令計數#

最佳化程式碼的一個挑戰是實際執行時間的變異性和不透明性。存在許多非確定性的來源,從自適應時鐘速度到與其他程序的資源爭用。此外,端到端時間並不能說明時間花在哪裡,而這正是我們在最佳化程式碼時真正感興趣的。

一種補充方法是同時收集指令計數。這些計數是代理指標,不能捕獲效能的所有方面(例如,記憶體或 I/O 密集型任務),但它們確實具有一些有用的屬性。指令計數是可重現的,對環境變化不敏感,並提供對程式花費週期的位置的細粒度洞察。

為了展示指令計數的效用,讓我們看看如何減少 batched_dot_mul_sum 的開銷。顯而易見的解決方案是將其移至 C++,這樣我們就避免了多次在 Python 和 C++ 之間切換。

幸運的是,原始碼幾乎相同。在 C++ 中,我們必須問的一個問題是,我們應該按值還是按引用傳遞引數。

batched_dot_src = """\
/* ---- Python ---- */
// def batched_dot_mul_sum(a, b):
//     return a.mul(b).sum(-1)

torch::Tensor batched_dot_mul_sum_v0(
    const torch::Tensor a,
    const torch::Tensor b) {
  return a.mul(b).sum(-1);
}

torch::Tensor batched_dot_mul_sum_v1(
    const torch::Tensor& a,
    const torch::Tensor& b) {
  return a.mul(b).sum(-1);
}
"""


# PyTorch makes it easy to test our C++ implementations by providing a utility
# to JIT compile C++ source into Python extensions:
import os
from torch.utils import cpp_extension
cpp_lib = cpp_extension.load_inline(
    name='cpp_lib',
    cpp_sources=batched_dot_src,
    extra_cflags=['-O3'],
    extra_include_paths=[
        # `load_inline` needs to know where to find ``pybind11`` headers.
        os.path.join(os.getenv('CONDA_PREFIX'), 'include')
    ],
    functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']
)

# `load_inline` will create a shared object that is loaded into Python. When we collect
# instruction counts Timer will create a subprocess, so we need to re-import it. The
# import process is slightly more complicated for C extensions, but that's all we're
# doing here.
module_import_str = f"""\
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
import importlib.util
spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})
cpp_lib = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cpp_lib)"""

import textwrap
def pretty_print(result):
    """Import machinery for ``cpp_lib.so`` can get repetitive to look at."""
    print(repr(result).replace(textwrap.indent(module_import_str, "  "), "  import cpp_lib"))


t_baseline = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='''\
from __main__ import batched_dot_mul_sum
x = torch.randn(2, 2)''')

t0 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

t1 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

# Moving to C++ did indeed reduce overhead, but it's hard to tell which
# calling convention is more efficient. v1 (call with references) seems to
# be a bit faster, but it's within measurement error.
pretty_print(t_baseline.blocked_autorange())
pretty_print(t0.blocked_autorange())
pretty_print(t1.blocked_autorange())
輸出#
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
 batched_dot_mul_sum(x, x)
 setup:
   from __main__ import batched_dot_mul_sum
   x = torch.randn(2, 2)

   6.92 us
   1 measurement, 100000 runs , 1 thread
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
 cpp_lib.batched_dot_mul_sum_v0(x, x)
 setup:
   import cpp_lib
   x = torch.randn(2, 2)

   5.29 us
   1 measurement, 100000 runs , 1 thread
 <torch.utils.benchmark.utils.common.Measurement object at 0x7fb16935d2e8>
 cpp_lib.batched_dot_mul_sum_v1(x, x)
 setup:
   import cpp_lib
   x = torch.randn(2, 2)

   5.22 us
   1 measurement, 100000 runs , 1 thread
# Let's use ``Callgrind`` to determine which is better.
stats_v0 = t0.collect_callgrind()
stats_v1 = t1.collect_callgrind()

pretty_print(stats_v0)
pretty_print(stats_v1)

# `.as_standardized` removes file names and some path prefixes, and makes
# it easier to read the function symbols.
stats_v0 = stats_v0.as_standardized()
stats_v1 = stats_v1.as_standardized()

# `.delta` diffs the instruction counts, and `.denoise` removes several
# functions in the Python interpreter that are known to have significant
# jitter.
delta = stats_v1.delta(stats_v0).denoise()

# `.transform` is a convenience API for transforming function names. It is
# useful for increasing cancelation when ``diff-ing`` instructions, as well as
# just generally improving readability.
replacements = (
    ("???:void pybind11", "pybind11"),
    ("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),
    ("at::Tensor, at::Tensor", "..."),
    ("at::Tensor const&, at::Tensor const&", "..."),
    ("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),
)
for before, after in replacements:
    delta = delta.transform(lambda l: l.replace(before, after))

# We can use print options to control how much of the function to display.
torch.set_printoptions(linewidth=160)

# Once parsed, the instruction counts make clear that passing `a` and `b`
# by reference is more efficient as it skips some ``c10::TensorImpl`` bookkeeping
# for the intermediate Tensors, and is also works better with ``pybind11``. This
# is consistent with our noisy wall time observations.
print(delta)
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb0f06e7630>
cpp_lib.batched_dot_mul_sum_v0(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)
                           All          Noisy symbols removed
    Instructions:      2392671                    2392671
    Baseline:             4367                       4367
100 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fb10400d208>
cpp_lib.batched_dot_mul_sum_v1(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)
                           All          Noisy symbols removed
    Instructions:      2378978                    2378978
    Baseline:             4367                       4367
    100 runs per measurement, 1 thread
    Warning: PyTorch was not built with debug symbols.
             Source information may be limited. Rebuild with
             REL_WITH_DEB_INFO=1 for more detailed results.
    <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7fb1000ab358>
          86  ???:0x000000000020d9e0
      56  ???:0x000000000020db10
   -1100  pybind11::cpp_function::initialize<wrap_pybind_function_impl_<at::Tensor ... r (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)
   -1600  ???:wrap_pybind_function_impl_<at::Tensor (&)(...), 0ul, 1ul>(at::Tensor (&)(...), std::integer_sequence<unsigned long, 0ul, 1ul>)::{lambda(...)
   -5200  ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_()
   -5935  ???:0x000000000022c0e0
Total: -13693

瞭解更多#

檢視這些其他秘籍以繼續您的學習