• 文件 >
  • 模型和預訓練權重
快捷方式

模型和預訓練權重

torchvision.models 子包包含用於處理不同任務的模型定義,包括:影像分類、畫素級語義分割、目標檢測、例項分割、人體關鍵點檢測、影片分類和光流。

預訓練權重的通用資訊

TorchVision 透過 PyTorch 的 torch.hub 為提供的每種架構提供預訓練權重。例項化預訓練模型時,其權重將被下載到快取目錄。可以使用 TORCH_HOME 環境變數設定此目錄。有關詳細資訊,請參閱 torch.hub.load_state_dict_from_url()

注意

本庫中提供的預訓練模型可能有其自己的許可證或來自訓練所用資料集的條款和條件。您有責任確定您是否有權在您的用例中使用這些模型。

注意

使用舊版 PyTorch 建立的模型載入序列化的 state_dict 可保證向後相容。相反,載入整個儲存的模型或序列化的 ScriptModules(使用舊版 PyTorch 序列化)可能無法保留歷史行為。請參閱以下 文件

初始化預訓練模型

從 v0.13 開始,TorchVision 提供了一個新的 多權重支援 API,用於將不同的權重載入到現有的模型構建器方法中。

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None)

遷移到新 API 非常簡單。以下兩種 API 之間的呼叫方法都是等效的。

from torchvision.models import resnet50, ResNet50_Weights

# Using pretrained weights:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights="IMAGENET1K_V1")
resnet50(pretrained=True)  # deprecated
resnet50(True)  # deprecated

# Using no weights:
resnet50(weights=None)
resnet50()
resnet50(pretrained=False)  # deprecated
resnet50(False)  # deprecated

請注意,pretrained 引數現已棄用,使用它將發出警告,並將在 v0.15 中刪除。

使用預訓練模型

在使用預訓練模型之前,必須對影像進行預處理(以正確的​​解析度/插值進行調整大小、應用推理轉換、重新縮放值等)。沒有標準的方法可以做到這一點,因為它取決於給定模型是如何訓練的。它可能因模型系列、變體或甚至權重版本而異。使用正確的預處理方法至關重要,否則可能導致準確性下降或輸出不正確。

每種預訓練模型的推理轉換所需的所有必要資訊均在其權重文件中提供。為了簡化推理,TorchVision 將必要的預處理轉換捆綁到每個模型權重中。這些可以透過 weight.transforms 屬性訪問。

# Initialize the Weight Transforms
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

某些模型使用的模組在訓練和評估行為上有所不同,例如批次歸一化。要在這兩種模式之間切換,請根據需要使用 model.train()model.eval()。有關詳細資訊,請參閱 train()eval()

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()

列出和檢索可用模型

從 v0.14 開始,TorchVision 提供了一種新的機制,允許按名稱列出和檢索模型和權重。以下是一些使用示例。

# List available models
all_models = list_models()
classification_models = list_models(module=torchvision.models)

# Initialize models
m1 = get_model("mobilenet_v3_large", weights=None)
m2 = get_model("quantized_mobilenet_v3_large", weights="DEFAULT")

# Fetch weights
weights = get_weight("MobileNet_V3_Large_QuantizedWeights.DEFAULT")
assert weights == MobileNet_V3_Large_QuantizedWeights.DEFAULT

weights_enum = get_model_weights("quantized_mobilenet_v3_large")
assert weights_enum == MobileNet_V3_Large_QuantizedWeights

weights_enum2 = get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
assert weights_enum == weights_enum2

以下是用於檢索模型及其對應權重的可用公共函式:

get_model(name, **config)

獲取模型名稱和配置,並返回一個例項化的模型。

get_model_weights(name)

返回與給定模型關聯的權重列舉類。

get_weight(name)

透過完整名稱獲取權重列舉值。

list_models([module, include, exclude])

返回一個包含已註冊模型名稱的列表。

使用 Hub 中的模型

大多數預訓練模型可以直接透過 PyTorch Hub 訪問,而無需安裝 TorchVision。

import torch

# Option 1: passing weights param as string
model = torch.hub.load("pytorch/vision", "resnet50", weights="IMAGENET1K_V2")

# Option 2: passing weights param as enum
weights = torch.hub.load(
    "pytorch/vision",
    "get_weight",
    weights="ResNet50_Weights.IMAGENET1K_V2",
)
model = torch.hub.load("pytorch/vision", "resnet50", weights=weights)

您還可以透過以下方式使用 PyTorch Hub 檢索特定模型的所有可用權重:

import torch

weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])

上述情況的唯一例外是包含在 torchvision.models.detection 中的檢測模型。這些模型需要安裝 TorchVision,因為它們依賴於自定義 C++ 運算子。

分類

以下分類模型可用,帶或不帶預訓練權重:


以下是如何使用預訓練影像分類模型的示例:

from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

預訓練模型的輸出類別可以在 weights.meta["categories"] 中找到。

所有可用分類權重的表格

準確率在 ImageNet-1K 上使用單張裁剪報告

權重

Acc@1

Acc@5

引數

GFLOPS

例項

AlexNet_Weights.IMAGENET1K_V1

56.522

79.066

61.1M

0.71

連結

ConvNeXt_Base_Weights.IMAGENET1K_V1

84.062

96.87

88.6M

15.36

連結

ConvNeXt_Large_Weights.IMAGENET1K_V1

84.414

96.976

197.8M

34.36

連結

ConvNeXt_Small_Weights.IMAGENET1K_V1

83.616

96.65

50.2M

8.68

連結

ConvNeXt_Tiny_Weights.IMAGENET1K_V1

82.52

96.146

28.6M

4.46

連結

DenseNet121_Weights.IMAGENET1K_V1

74.434

91.972

8.0M

2.83

連結

DenseNet161_Weights.IMAGENET1K_V1

77.138

93.56

28.7M

7.73

連結

DenseNet169_Weights.IMAGENET1K_V1

75.6

92.806

14.1M

3.36

連結

DenseNet201_Weights.IMAGENET1K_V1

76.896

93.37

20.0M

4.29

連結

EfficientNet_B0_Weights.IMAGENET1K_V1

77.692

93.532

5.3M

0.39

連結

EfficientNet_B1_Weights.IMAGENET1K_V1

78.642

94.186

7.8M

0.69

連結

EfficientNet_B1_Weights.IMAGENET1K_V2

79.838

94.934

7.8M

0.69

連結

EfficientNet_B2_Weights.IMAGENET1K_V1

80.608

95.31

9.1M

1.09

連結

EfficientNet_B3_Weights.IMAGENET1K_V1

82.008

96.054

12.2M

1.83

連結

EfficientNet_B4_Weights.IMAGENET1K_V1

83.384

96.594

19.3M

4.39

連結

EfficientNet_B5_Weights.IMAGENET1K_V1

83.444

96.628

30.4M

10.27

連結

EfficientNet_B6_Weights.IMAGENET1K_V1

84.008

96.916

43.0M

19.07

連結

EfficientNet_B7_Weights.IMAGENET1K_V1

84.122

96.908

66.3M

37.75

連結

EfficientNet_V2_L_Weights.IMAGENET1K_V1

85.808

97.788

118.5M

56.08

連結

EfficientNet_V2_M_Weights.IMAGENET1K_V1

85.112

97.156

54.1M

24.58

連結

EfficientNet_V2_S_Weights.IMAGENET1K_V1

84.228

96.878

21.5M

8.37

連結

GoogLeNet_Weights.IMAGENET1K_V1

69.778

89.53

6.6M

1.5

連結

Inception_V3_Weights.IMAGENET1K_V1

77.294

93.45

27.2M

5.71

連結

MNASNet0_5_Weights.IMAGENET1K_V1

67.734

87.49

2.2M

0.1

連結

MNASNet0_75_Weights.IMAGENET1K_V1

71.18

90.496

3.2M

0.21

連結

MNASNet1_0_Weights.IMAGENET1K_V1

73.456

91.51

4.4M

0.31

連結

MNASNet1_3_Weights.IMAGENET1K_V1

76.506

93.522

6.3M

0.53

連結

MaxVit_T_Weights.IMAGENET1K_V1

83.7

96.722

30.9M

5.56

連結

MobileNet_V2_Weights.IMAGENET1K_V1

71.878

90.286

3.5M

0.3

連結

MobileNet_V2_Weights.IMAGENET1K_V2

72.154

90.822

3.5M

0.3

連結

MobileNet_V3_Large_Weights.IMAGENET1K_V1

74.042

91.34

5.5M

0.22

連結

MobileNet_V3_Large_Weights.IMAGENET1K_V2

75.274

92.566

5.5M

0.22

連結

MobileNet_V3_Small_Weights.IMAGENET1K_V1

67.668

87.402

2.5M

0.06

連結

RegNet_X_16GF_Weights.IMAGENET1K_V1

80.058

94.944

54.3M

15.94

連結

RegNet_X_16GF_Weights.IMAGENET1K_V2

82.716

96.196

54.3M

15.94

連結

RegNet_X_1_6GF_Weights.IMAGENET1K_V1

77.04

93.44

9.2M

1.6

連結

RegNet_X_1_6GF_Weights.IMAGENET1K_V2

79.668

94.922

9.2M

1.6

連結

RegNet_X_32GF_Weights.IMAGENET1K_V1

80.622

95.248

107.8M

31.74

連結

RegNet_X_32GF_Weights.IMAGENET1K_V2

83.014

96.288

107.8M

31.74

連結

RegNet_X_3_2GF_Weights.IMAGENET1K_V1

78.364

93.992

15.3M

3.18

連結

RegNet_X_3_2GF_Weights.IMAGENET1K_V2

81.196

95.43

15.3M

3.18

連結

RegNet_X_400MF_Weights.IMAGENET1K_V1

72.834

90.95

5.5M

0.41

連結

RegNet_X_400MF_Weights.IMAGENET1K_V2

74.864

92.322

5.5M

0.41

連結

RegNet_X_800MF_Weights.IMAGENET1K_V1

75.212

92.348

7.3M

0.8

連結

RegNet_X_800MF_Weights.IMAGENET1K_V2

77.522

93.826

7.3M

0.8

連結

RegNet_X_8GF_Weights.IMAGENET1K_V1

79.344

94.686

39.6M

8

連結

RegNet_X_8GF_Weights.IMAGENET1K_V2

81.682

95.678

39.6M

8

連結

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1

88.228

98.682

644.8M

374.57

連結

RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

86.068

97.844

644.8M

127.52

連結

RegNet_Y_16GF_Weights.IMAGENET1K_V1

80.424

95.24

83.6M

15.91

連結

RegNet_Y_16GF_Weights.IMAGENET1K_V2

82.886

96.328

83.6M

15.91

連結

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.012

98.054

83.6M

46.73

連結

RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

83.976

97.244

83.6M

15.91

連結

RegNet_Y_1_6GF_Weights.IMAGENET1K_V1

77.95

93.966

11.2M

1.61

連結

RegNet_Y_1_6GF_Weights.IMAGENET1K_V2

80.876

95.444

11.2M

1.61

連結

RegNet_Y_32GF_Weights.IMAGENET1K_V1

80.878

95.34

145.0M

32.28

連結

RegNet_Y_32GF_Weights.IMAGENET1K_V2

83.368

96.498

145.0M

32.28

連結

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1

86.838

98.362

145.0M

94.83

連結

RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1

84.622

97.48

145.0M

32.28

連結

RegNet_Y_3_2GF_Weights.IMAGENET1K_V1

78.948

94.576

19.4M

3.18

連結

RegNet_Y_3_2GF_Weights.IMAGENET1K_V2

81.982

95.972

19.4M

3.18

連結

RegNet_Y_400MF_Weights.IMAGENET1K_V1

74.046

91.716

4.3M

0.4

連結

RegNet_Y_400MF_Weights.IMAGENET1K_V2

75.804

92.742

4.3M

0.4

連結

RegNet_Y_800MF_Weights.IMAGENET1K_V1

76.42

93.136

6.4M

0.83

連結

RegNet_Y_800MF_Weights.IMAGENET1K_V2

78.828

94.502

6.4M

0.83

連結

RegNet_Y_8GF_Weights.IMAGENET1K_V1

80.032

95.048

39.4M

8.47

連結

RegNet_Y_8GF_Weights.IMAGENET1K_V2

82.828

96.33

39.4M

8.47

連結

ResNeXt101_32X8D_Weights.IMAGENET1K_V1

79.312

94.526

88.8M

16.41

連結

ResNeXt101_32X8D_Weights.IMAGENET1K_V2

82.834

96.228

88.8M

16.41

連結

ResNeXt101_64X4D_Weights.IMAGENET1K_V1

83.246

96.454

83.5M

15.46

連結

ResNeXt50_32X4D_Weights.IMAGENET1K_V1

77.618

93.698

25.0M

4.23

連結

ResNeXt50_32X4D_Weights.IMAGENET1K_V2

81.198

95.34

25.0M

4.23

連結

ResNet101_Weights.IMAGENET1K_V1

77.374

93.546

44.5M

7.8

連結

ResNet101_Weights.IMAGENET1K_V2

81.886

95.78

44.5M

7.8

連結

ResNet152_Weights.IMAGENET1K_V1

78.312

94.046

60.2M

11.51

連結

ResNet152_Weights.IMAGENET1K_V2

82.284

96.002

60.2M

11.51

連結

ResNet18_Weights.IMAGENET1K_V1

69.758

89.078

11.7M

1.81

連結

ResNet34_Weights.IMAGENET1K_V1

73.314

91.42

21.8M

3.66

連結

ResNet50_Weights.IMAGENET1K_V1

76.13

92.862

25.6M

4.09

連結

ResNet50_Weights.IMAGENET1K_V2

80.858

95.434

25.6M

4.09

連結

ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1

60.552

81.746

1.4M

0.04

連結

ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1

69.362

88.316

2.3M

0.14

連結

ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1

72.996

91.086

3.5M

0.3

連結

ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1

76.23

93.006

7.4M

0.58

連結

SqueezeNet1_0_Weights.IMAGENET1K_V1

58.092

80.42

1.2M

0.82

連結

SqueezeNet1_1_Weights.IMAGENET1K_V1

58.178

80.624

1.2M

0.35

連結

Swin_B_Weights.IMAGENET1K_V1

83.582

96.64

87.8M

15.43

連結

Swin_S_Weights.IMAGENET1K_V1

83.196

96.36

49.6M

8.74

連結

Swin_T_Weights.IMAGENET1K_V1

81.474

95.776

28.3M

4.49

連結

Swin_V2_B_Weights.IMAGENET1K_V1

84.112

96.864

87.9M

20.32

連結

Swin_V2_S_Weights.IMAGENET1K_V1

83.712

96.816

49.7M

11.55

連結

Swin_V2_T_Weights.IMAGENET1K_V1

82.072

96.132

28.4M

5.94

連結

VGG11_BN_Weights.IMAGENET1K_V1

70.37

89.81

132.9M

7.61

連結

VGG11_Weights.IMAGENET1K_V1

69.02

88.628

132.9M

7.61

連結

VGG13_BN_Weights.IMAGENET1K_V1

71.586

90.374

133.1M

11.31

連結

VGG13_Weights.IMAGENET1K_V1

69.928

89.246

133.0M

11.31

連結

VGG16_BN_Weights.IMAGENET1K_V1

73.36

91.516

138.4M

15.47

連結

VGG16_Weights.IMAGENET1K_V1

71.592

90.382

138.4M

15.47

連結

VGG16_Weights.IMAGENET1K_FEATURES

nan

nan

138.4M

15.47

連結

VGG19_BN_Weights.IMAGENET1K_V1

74.218

91.842

143.7M

19.63

連結

VGG19_Weights.IMAGENET1K_V1

72.376

90.876

143.7M

19.63

連結

ViT_B_16_Weights.IMAGENET1K_V1

81.072

95.318

86.6M

17.56

連結

ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1

85.304

97.65

86.9M

55.48

連結

ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

81.886

96.18

86.6M

17.56

連結

ViT_B_32_Weights.IMAGENET1K_V1

75.912

92.466

88.2M

4.41

連結

ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1

88.552

98.694

633.5M

1016.72

連結

ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.708

97.73

632.0M

167.29

連結

ViT_L_16_Weights.IMAGENET1K_V1

79.662

94.638

304.3M

61.55

連結

ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1

88.064

98.512

305.2M

361.99

連結

ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1

85.146

97.422

304.3M

61.55

連結

ViT_L_32_Weights.IMAGENET1K_V1

76.972

93.07

306.5M

15.38

連結

Wide_ResNet101_2_Weights.IMAGENET1K_V1

78.848

94.284

126.9M

22.75

連結

Wide_ResNet101_2_Weights.IMAGENET1K_V2

82.51

96.02

126.9M

22.75

連結

Wide_ResNet50_2_Weights.IMAGENET1K_V1

78.468

94.086

68.9M

11.4

連結

Wide_ResNet50_2_Weights.IMAGENET1K_V2

81.602

95.758

68.9M

11.4

連結

量化模型

以下架構支援 INT8 量化模型,帶或不帶預訓練權重:


以下是如何使用預訓練量化影像分類模型的示例:

from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
model = resnet50(weights=weights, quantize=True)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

預訓練模型的輸出類別可以在 weights.meta["categories"] 中找到。

所有可用量化分類權重的表格

準確率在 ImageNet-1K 上使用單張裁剪報告

權重

Acc@1

Acc@5

引數

GIPS

例項

GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.826

89.404

6.6M

1.5

連結

Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1

77.176

93.354

27.2M

5.71

連結

MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1

71.658

90.15

3.5M

0.3

連結

MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1

73.004

90.858

5.5M

0.22

連結

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

78.986

94.48

88.8M

16.41

連結

ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V2

82.574

96.132

88.8M

16.41

連結

ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1

82.898

96.326

83.5M

15.46

連結

ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1

69.494

88.882

11.7M

1.81

連結

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.92

92.814

25.6M

4.09

連結

ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2

80.282

94.976

25.6M

4.09

連結

ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

57.972

79.78

1.4M

0.04

連結

ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

68.36

87.582

2.3M

0.14

連結

ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1

72.052

90.7

3.5M

0.3

連結

ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1

75.354

92.488

7.4M

0.58

連結

語義分割

警告

The segmentation module is in Beta stage, and backward compatibility is not guaranteed.

以下語義分割模型可用,帶或不帶預訓練權重:


以下是如何使用預訓練語義分割模型的示例:

from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image

img = decode_image("gallery/assets/dog1.jpg")

# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
model = fcn_resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
prediction = model(batch)["out"]
normalized_masks = prediction.softmax(dim=1)
class_to_idx = {cls: idx for (idx, cls) in enumerate(weights.meta["categories"])}
mask = normalized_masks[0, class_to_idx["dog"]]
to_pil_image(mask).show()

預訓練模型的輸出類別可以在 weights.meta["categories"] 中找到。模型的輸出格式在 語義分割模型 中進行了說明。

所有可用語義分割權重的表格

所有模型均在 COCO val2017 的子集上進行評估,針對 Pascal VOC 資料集中存在的 20 個類別。

權重

平均 IoU

畫素精度

引數

GFLOPS

例項

DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

60.3

91.2

11.0M

10.45

連結

DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

67.4

92.4

61.0M

258.74

連結

DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

66.4

92.4

42.0M

178.72

連結

FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1

63.7

91.9

54.3M

232.74

連結

FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1

60.5

91.4

35.3M

152.72

連結

LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1

57.9

91.2

3.2M

2.09

連結

目標檢測、例項分割和人體關鍵點檢測

用於檢測、例項分割和關鍵點檢測的預訓練模型透過 torchvision 中的分類模型進行初始化。這些模型需要一個 Tensor[C, H, W] 列表。有關更多資訊,請檢視模型的建構函式。

警告

檢測模組處於 Beta 階段,不保證向後相容。

目標檢測

以下目標檢測模型可用,帶或不帶預訓練權重:


以下是如何使用預訓練目標檢測模型的示例:

from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = [preprocess(img)]

# Step 4: Use the model and visualize the prediction
prediction = model(batch)[0]
labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4, font_size=30)
im = to_pil_image(box.detach())
im.show()

預訓練模型的輸出類別可以在 weights.meta["categories"] 中找到。要了解如何繪製模型的邊界框,您可以參考 例項分割模型

所有可用目標檢測權重的表格

邊界框 MAP 在 COCO val2017 上報告。

權重

邊界框 MAP

引數

GFLOPS

例項

FCOS_ResNet50_FPN_Weights.COCO_V1

39.2

32.3M

128.21

連結

FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1

22.8

19.4M

0.72

連結

FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1

32.8

19.4M

4.49

連結

FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1

46.7

43.7M

280.37

連結

FasterRCNN_ResNet50_FPN_Weights.COCO_V1

37

41.8M

134.38

連結

RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1

41.5

38.2M

152.24

連結

RetinaNet_ResNet50_FPN_Weights.COCO_V1

36.4

34.0M

151.54

連結

SSD300_VGG16_Weights.COCO_V1

25.1

35.6M

34.86

連結

SSDLite320_MobileNet_V3_Large_Weights.COCO_V1

21.3

3.4M

0.58

連結

例項分割

以下例項分割模型可用,帶或不帶預訓練權重:


要了解如何繪製模型的掩碼,您可以參考 例項分割模型

所有可用例項分割權重的表格

邊界框和掩碼 MAP 在 COCO val2017 上報告。

權重

邊界框 MAP

掩碼 MAP

引數

GFLOPS

例項

MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1

47.4

41.8

46.4M

333.58

連結

MaskRCNN_ResNet50_FPN_Weights.COCO_V1

37.9

34.6

44.4M

134.38

連結

關鍵點檢測

以下人體關鍵點檢測模型可用,帶或不帶預訓練權重:


預訓練模型的輸出類別可以在 weights.meta["keypoint_names"] 中找到。要了解如何繪製模型的邊界框,您可以參考 視覺化關鍵點

所有可用關鍵點檢測權重的表格

邊界框和關鍵點 MAP 在 COCO val2017 上報告。

權重

邊界框 MAP

關鍵點 MAP

引數

GFLOPS

例項

KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY

50.6

61.1

59.1M

133.92

連結

KeypointRCNN_ResNet50_FPN_Weights.COCO_V1

54.6

65

59.1M

137.42

連結

影片分類

警告

The video module is in Beta stage, and backward compatibility is not guaranteed.

以下影片分類模型可用,帶或不帶預訓練權重:


以下是如何使用預訓練影片分類模型的示例:

from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights

vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32]  # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")

預訓練模型的輸出類別可以在 weights.meta["categories"] 中找到。

所有可用影片分類權重的表格

準確率在 Kinetics-400 上使用長度為 16 的單張裁剪報告

權重

Acc@1

Acc@5

引數

GFLOPS

例項

MC3_18_Weights.KINETICS400_V1

63.96

84.13

11.7M

43.34

連結

MViT_V1_B_Weights.KINETICS400_V1

78.477

93.582

36.6M

70.6

連結

MViT_V2_S_Weights.KINETICS400_V1

80.757

94.665

34.5M

64.22

連結

R2Plus1D_18_Weights.KINETICS400_V1

67.463

86.175

31.5M

40.52

連結

R3D_18_Weights.KINETICS400_V1

63.2

83.479

33.4M

40.7

連結

S3D_Weights.KINETICS400_V1

68.368

88.05

8.3M

17.98

連結

Swin3D_B_Weights.KINETICS400_V1

79.427

94.386

88.0M

140.67

連結

Swin3D_B_Weights.KINETICS400_IMAGENET22K_V1

81.643

95.574

88.0M

140.67

連結

Swin3D_S_Weights.KINETICS400_V1

79.521

94.158

49.8M

82.84

連結

Swin3D_T_Weights.KINETICS400_V1

77.715

93.519

28.2M

43.88

連結

光流

以下光流模型可用,帶或不帶預訓練:

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源