注意
跳轉至末尾 下載完整的示例程式碼。
PyTorch 中的 Channels Last 記憶體格式#
建立日期:2020 年 4 月 20 日 | 最後更新:2025 年 7 月 9 日 | 最後驗證:2024 年 11 月 5 日
PyTorch 中的 Channels Last 記憶體格式是什麼?
它如何用於提高某些運算元的效能?
PyTorch v1.5.0
支援 CUDA 的 GPU
Channels Last 記憶體格式是 NCHW 張量在記憶體中排序的一種替代方式,它保留了維度的順序。Channels Last 張量排序的方式使得通道(channels)成為最密集(densest)的維度(即逐畫素儲存影像)。
例如,NCHW 張量(在本例中是兩個 4x4 的具有 3 個顏色通道的影像)的經典(連續)儲存方式如下:
Channels Last 記憶體格式的排序方式不同
PyTorch 透過利用現有的 stride 結構來支援記憶體格式。例如,Channels Last 格式下的 10x3x16x16 的 batch 張量將具有等於 (768, 1, 48, 3) 的 stride。
Channels Last 記憶體格式僅為 4D NCHW 張量實現。
記憶體格式 API#
以下是如何在連續(contiguous)和 Channels Last 記憶體格式之間轉換張量的方法。
經典的 PyTorch 連續張量
import torch
N, C, H, W = 10, 3, 32, 32
x = torch.empty(N, C, H, W)
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
轉換運算子
x = x.to(memory_format=torch.channels_last)
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride()) # Outputs: (3072, 1, 96, 3)
torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3)
返回連續
x = x.to(memory_format=torch.contiguous_format)
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
替代選項
x = x.contiguous(memory_format=torch.channels_last)
print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
格式檢查
print(x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
to 和 contiguous 這兩個 API 之間存在細微差別。我們建議在明確轉換張量記憶體格式時堅持使用 to。
對於一般情況,這兩個 API 的行為相同。但在特殊情況下,對於一個尺寸為 NCHW 的 4D 張量,當滿足以下條件之一時:C==1 或 H==1 && W==1,只有 to 會生成一個合適的 stride 來表示 Channels Last 記憶體格式。
這是因為在上述兩種情況下,張量的記憶體格式是模糊的,即尺寸為 N1HW 的連續張量在記憶體儲存上既是 contiguous 也是 Channels Last。因此,它們已被視為給定記憶體格式的 is_contiguous,並且 contiguous 呼叫成為一個無操作(no-op),不會更新 stride。相反,to 會使用有意義的 stride 來重新排序張量,以正確表示所需的記憶體格式。
special_x = torch.empty(4, 1, 4, 4)
print(special_x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Outputs: True
True
True
同樣的情況也適用於顯式置換 API permute。在可能發生模糊的特殊情況下,permute 不能保證產生一個能正確承載所需記憶體格式的 stride。我們建議使用 to 和顯式記憶體格式來避免意外行為。
另外需要注意的是,在極端情況下,當三個非 batch 維度都等於 1 時(C==1 && H==1 && W==1),當前的實現無法將張量標記為 Channels Last 記憶體格式。
建立為 Channels Last
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
clone 保留記憶體格式
(3072, 1, 96, 3)
to, cuda, float … 保留記憶體格式
if torch.cuda.is_available():
y = x.cuda()
print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
empty_like, *_like 運算子保留記憶體格式
y = torch.empty_like(x)
print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
逐點運算子保留記憶體格式
(3072, 1, 96, 3)
使用 cudnn 後端的 Conv, Batchnorm 模組支援 Channels Last(僅適用於 cuDNN >= 7.6)。卷積模組不像二元逐點運算子那樣,Channels Last 是其主要的記憶體格式。如果所有輸入都為連續記憶體格式,則運算子輸出為連續記憶體格式。否則,輸出將為 Channels Last 記憶體格式。
if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
out = model(input)
print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
當輸入張量遇到不支援 Channels Last 的運算子時,核心會自動應用置換以恢復輸入張量的連續性。這會引入開銷並停止 Channels Last 記憶體格式的傳播。儘管如此,它保證了正確的輸出。
效能提升#
Channels Last 記憶體格式的最佳化在 GPU 和 CPU 上均可用。在 GPU 上,在支援 Tensor Cores 的 NVIDIA 硬體上執行低精度(torch.float16)時,觀察到最顯著的效能提升。我們使用 NVIDIA 提供的 AMP(自動混合精度)訓練指令碼,在 Channels Last 格式下相比連續格式獲得了超過 22% 的效能提升。我們的指令碼使用了 NVIDIA 的 AMP NVIDIA/apex。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data
# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)
傳遞 --channels-last true 可以使模型以 Channels Last 格式執行,並獲得 22% 的效能提升。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data
# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)
以下模型列表完全支援 Channels Last,並在 Volta 裝置上實現了 8%-35% 的效能提升:alexnet, mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3, mobilenet_v2, resnet101, resnet152, resnet18, resnet34, resnet50, resnext50_32x4d, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5, shufflenet_v2_x2_0, squeezenet1_0, squeezenet1_1, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, wide_resnet101_2, wide_resnet50_2
以下模型列表完全支援 Channels Last,並在 Intel(R) Xeon(R) Ice Lake (或更新版本) CPU 上實現了 26%-76% 的效能提升:alexnet, densenet121, densenet161, densenet169, googlenet, inception_v3, mnasnet0_5, mnasnet1_0, resnet101, resnet152, resnet18, resnet34, resnet50, resnext101_32x8d, resnext50_32x4d, shufflenet_v2_x0_5, shufflenet_v2_x1_0, squeezenet1_0, squeezenet1_1, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn, wide_resnet101_2, wide_resnet50_2
轉換現有模型#
Channels Last 的支援並不侷限於現有模型,因為只要輸入(或某些權重)格式正確,任何模型都可以轉換為 Channels Last 並透過圖傳播格式。
# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last) # Replace with your model
# Need to be done for every input
input = input.to(memory_format=torch.channels_last) # Replace with your input
output = model(input)
然而,並非所有運算元都完全支援 Channels Last(通常返回連續輸出)。在上面給出的示例中,不支援 Channels Last 的層會停止記憶體格式的傳播。儘管如此,由於我們將模型轉換為 Channels Last 格式,這意味著每個具有 Channels Last 記憶體格式的 4 維權重的卷積層都將恢復 Channels Last 記憶體格式並受益於更快的核心。
但是,不支援 Channels Last 的運算元會透過置換引入開銷。如果想提高轉換後模型的效能,可以選擇檢查並識別模型中不支援 Channels Last 的運算元。
這意味著你需要對照支援的運算元列表 pytorch/pytorch 來驗證使用的運算元列表,或者在 eager execution 模式下引入記憶體格式檢查並執行你的模型。
執行以下程式碼後,如果運算元的輸出與輸入的記憶體格式不匹配,運算元將引發異常。
def contains_cl(args):
for t in args:
if isinstance(t, torch.Tensor):
if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
return True
elif isinstance(t, list) or isinstance(t, tuple):
if contains_cl(list(t)):
return True
return False
def print_inputs(args, indent=""):
for t in args:
if isinstance(t, torch.Tensor):
print(indent, t.stride(), t.shape, t.device, t.dtype)
elif isinstance(t, list) or isinstance(t, tuple):
print(indent, type(t))
print_inputs(list(t), indent=indent + " ")
else:
print(indent, t)
def check_wrapper(fn):
name = fn.__name__
def check_cl(*args, **kwargs):
was_cl = contains_cl(args)
try:
result = fn(*args, **kwargs)
except Exception as e:
print("`{}` inputs are:".format(name))
print_inputs(args)
print("-------------------")
raise e
failed = False
if was_cl:
if isinstance(result, torch.Tensor):
if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
print(
"`{}` got channels_last input, but output is not channels_last:".format(name),
result.shape,
result.stride(),
result.device,
result.dtype,
)
failed = True
if failed and True:
print("`{}` inputs are:".format(name))
print_inputs(args)
raise Exception("Operator `{}` lost channels_last property".format(name))
return result
return check_cl
old_attrs = dict()
def attribute(m):
old_attrs[m] = dict()
for i in dir(m):
e = getattr(m, i)
exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
try:
old_attrs[m][i] = e
setattr(m, i, check_wrapper(e))
except Exception as e:
print(i)
print(e)
attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch)
如果你發現不支援 Channels Last 張量的運算元,並且想貢獻程式碼,請隨時使用以下開發者指南 pytorch/pytorch。
以下程式碼用於恢復 torch 的屬性。
for (m, attrs) in old_attrs.items():
for (k, v) in attrs.items():
setattr(m, k, v)
待辦事項#
還有很多工作要做,例如:
解決
N1HW和NC11張量的歧義;分散式訓練支援的測試;
提高運算元覆蓋率。
如果您有反饋和/或改進建議,請透過建立一個 issue 來告知我們。
結論#
本教程介紹了“Channels Last”記憶體格式,並演示瞭如何利用它來提升效能。有關在 CPU 上使用 Channels Last 加速視覺模型的實用示例,請參閱此處的博文:這裡。
指令碼總執行時間: (0 分鐘 0.330 秒)