如何為 PyTorch 2 匯出量化編寫 Quantizer¶
作者: Leslie Fang, Weiwen Xia, Jiong Gong, Kimish Patel, Jerry Zhang
先決條件:¶
必需
可選
介紹¶
(原型) PyTorch 2 匯出訓練後量化 介紹了 PyTorch 2 匯出量化的整體 API。與 FX 圖模式量化在 API 上的主要區別在於,我們明確了量化是針對特定後端的。因此,要使用新的流程,後端需要實現一個 Quantizer 類,該類編碼了:(1) 後端支援的量化運算元或模式;(2) 使用者可以如何表達他們希望浮點模型被量化的方式,例如,將整個模型量化為 int8 對稱量化,或僅量化線性層等。
有關新 API 和 Quantizer 的動機,請參閱此處。
XNNPACK 的現有量化器物件在 XNNPackQuantizer 中
註解 API¶
Quantizer 使用註解 API 來傳達不同運算元/模式的量化意圖。註解 API 主要由 QuantizationSpec 和 QuantizationAnnotation 組成。
QuantizationSpec 用於傳達張量如何被量化的意圖,例如,資料型別、位寬、最小值、最大值、對稱與否等。此外,QuantizationSpec 還允許量化器指定如何觀察張量值,例如,MinMaxObserver、HistogramObserver 或一些自定義觀察器。
QuantizationAnnotation 由 QuantizationSpec 物件組成,用於註解模式的輸入張量和輸出張量。註解輸入張量等同於註解輸入邊,而註解輸出張量等同於註解節點。QuantizationAnnotation 是一個 dataclass,包含幾個欄位:
input_qspec_map欄位是Dict類,用於將每個輸入張量(作為輸入邊)對映到QuantizationSpec。output_qspec欄位表示用於註解輸出張量的QuantizationSpec;_annotated欄位指示該節點是否已被量化器註解。
總而言之,註解 API 要求量化器註解圖的邊(輸入張量)或節點(輸出張量)。現在,我們將逐步介紹如何使用具有不同型別 QuantizationSpec 的註解 API。
1. 註解常見運算元模式¶
為了使用量化模式/運算元,例如 quantized add,後端開發者會有量化(由 QuantizationSpec 表達)輸入和模式輸出的意圖。以下是一個示例流程(以 add 運算元為例),說明此意圖如何在量化工作流中使用註解 API 傳達。
步驟 1:在 FX 圖中識別原始浮點模式。識別此模式有幾種方法:量化器可以使用模式匹配器來匹配運算元模式;量化器可以從頭到尾遍歷節點,並將節點的 target 型別與運算元模式進行匹配。在此示例中,我們可以使用 get_source_partitions 來匹配此模式。原始浮點
add模式僅包含一個add節點。
add_partitions = get_source_partitions(gm.graph, [operator.add, torch.add])
add_partitions = list(itertools.chain(*add_partitions.values()))
for add_partition in add_partitions:
add_node = add_partition.output_nodes[0]
步驟 2:為模式的輸入和輸出定義
QuantizationSpec。QuantizationSpec定義了使用者關於如何觀察或假量化張量的意圖的資料型別、qscheme和其他量化引數。
act_quantization_spec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12),
)
input_act_qspec = act_quantization_spec
output_act_qspec = act_quantization_spec
步驟 3:使用
QuantizationAnnotation註解模式的輸入和輸出。在此示例中,我們將為add節點(兩個輸入和一個輸出)建立帶有上面步驟 2 中建立的QuantizationSpec的QuantizationAnnotation物件。
input_qspec_map = {}
input_act0 = add_node.args[0]
input_qspec_map[input_act0] = input_act_qspec
input_act1 = add_node.args[1]
input_qspec_map[input_act1] = input_act_qspec
add_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)
在這樣註解了 add 節點後,在後續的量化流程中,HistogramObserver 將在準備階段插入到其兩個輸入節點和一個輸出節點。在轉換階段,HistogramObserver 將被 quantize 節點和 dequantize 節點替換。
3. 註解具有固定量化引數的運算元¶
另一種典型的註解量化模型用例是針對量化引數預先已知的張量。例如,像 sigmoid 這樣的運算元,其輸入和輸出張量具有預定義且固定的 scale/zero_point。 FixedQParamsQuantizationSpec 就是為這種情況設計的。要使用 FixedQParamsQuantizationSpec,使用者需要顯式傳入 scale 和 zero_point 引數。
步驟 1:在 FX 圖中識別原始浮點模式。我們可以使用在
QuantizationSpec示例中介紹的相同方法來識別sigmoid模式。步驟 2:使用固定的
scale、zero_point值建立FixedQParamsQuantizationSpec物件。這些值將用於在轉換階段建立quantize節點和dequantize節點。步驟 3:註解輸入和輸出以使用此
FixedQParamsQuantizationSpec物件。
act_qspec = FixedQParamsQuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
scale=1.0 / 256.0,
zero_point=0,
)
sigmoid_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={input_act: act_qspec},
output_qspec=act_qspec,
_annotated=True,
)
4. 註解具有派生量化引數的張量¶
另一個用例是為量化引數從其他張量派生的張量定義約束。例如,如果我們想註解一個卷積節點,並將啟用張量的 scale 和權重張量的 scale 相乘來定義其偏置輸入張量的 scale。我們可以使用 DerivedQuantizationSpec 來註解這個 conv 節點。
步驟 1:在 FX 圖中識別原始浮點模式。我們可以使用在
QuantizationSpec示例中介紹的相同方法來識別convolution模式。步驟 2:定義
derive_qparams_fn函式,它接受ObserverOrFakeQuantize(ObserverBase 或 FakeQuantizeBase)的列表作為輸入。從每個ObserverOrFakeQuantize物件中,使用者可以獲取scale、zero point值。使用者可以定義其啟發式方法,根據從觀察器或假量化例項計算出的量化引數來派生新的scale、zero point值。步驟 3:定義
DerivedQuantizationSpec物件,它接受以下輸入:EdgeOrNode物件的列表。與每個EdgeOrNode物件對應的觀察器將傳遞給derive_qparams_fn函式;derive_qparams_fn函式;其他幾個量化引數,如dtype、qscheme。步驟 4:使用
QuantizationAnnotation註解此 conv 節點的輸入和輸出。
def derive_qparams_fn(obs_or_fqs: List[ObserverOrFakeQuantize]) -> Tuple[Tensor, Tensor]:
assert len(obs_or_fqs) == 2, \
"Expecting two obs/fqs, one for activation and one for weight, got: {}".format(len(obs_or_fq))
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
return torch.tensor([act_scale * weight_scale]).to(torch.float32), torch.tensor([0]).to(torch.int32)
bias_qspec = DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)],
derive_qparams_fn=derive_qparams_fn,
dtype=torch.int32,
quant_min=-2**31,
quant_max=2**31 - 1,
qscheme=torch.per_tensor_symmetric,
)
input_qspec_map = {input_act: act_quantization_spec, weight: weight_quantization_spec, bias: bias_qspec}
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=act_quantization_spec,
_annotated=True,
)
5. Resnet18 的玩具示例¶
在定義了以上使用 QuantizationAnnotation API 的註解方法後,我們現在可以將它們組合起來,構建一個 BackendQuantizer,並執行一個使用 Torchvision Resnet18 的玩具示例。為了更好地理解最終示例,以下是示例中使用的類和實用函式:
QuantizationConfig 由啟用、權重和偏置的
QuantizationSpec分別組成。在註解模型時,可以使用 get_input_act_qspec、get_output_act_qspec、get_weight_qspec 和 get_bias_qspec 從特定模式的
QuantizationConfig中獲取QuantizationSpec。
關於 PT2E 量化流程 IR 的說明¶
IR 指的是模型的中間表示,例如 torch IR(torch.nn 模組,torch.nn.functional ops)或 aten IR(torch.ops.aten.linear,…)。PT2E 量化流程使用預自動微分的 aten IR(torch.export API 的輸出),以便我們支援訓練。如前所述,我們需要匹配運算元或運算元模式才能在其上附加註解。那麼問題是,我們如何匹配模式?
動機:直接匹配 aten IR 的問題¶
最直接的方法可能是直接匹配 aten IR。
示例
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
]:
continue
relu_node = n
maybe_conv_node = n.args[0]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target
not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
):
continue
# annotate conv and relu nodes
...
然而,使用此 IR 的一個問題是,如果 PyTorch 對模組或函式式操作的實現發生更改,表示形式可能會改變。但這可能是意料之外的,因為建模使用者通常假設當命令式模型程式碼不變時,他們在程式捕獲後也應該得到相同的模型表示。此問題的一個具體影響是,如果一個 Quantizer 基於識別 aten IR 模式來進行註解,那麼在 PyTorch 版本更新後,它可能無法識別該模式,並且相同的命令式浮點模型可能會保持未被量化。
建議:使用 SubgraphMatcherWithNameNodeMap 進行模式匹配¶
因此,我們建議人們透過捕獲 torch IR 模式(與捕獲浮點模型使用的程式捕獲相同)來透過 SubgraphMatcherWithNameNodeMap(SubgraphMatcher 的改進版本,使其更容易查詢要註解的節點)來識別模式,而不是直接使用 aten IR 模式。
示例
def conv_relu_pattern(input, weight, bias):
conv = torch.nn.functional.conv2d(input, weight, bias)
output = torch.nn.functional.relu(conv)
# returns an additional dict that includes a map from name to node that we want to annotate
return relu, {"input": input, "weight": weight, "bias": bias, "output": output}
matcher = SubgraphMatcherWithNameNodeMap(conv_relu_pattern)
matches = matcher.match(model)
for match in matches:
# find input and output of the pattern
# annotate the nodes
name_node_map = match.name_node_map
input_node = name_node_map["input"]
weight_node = name_node_map["weight"]
bias_node = name_node_map["bias"]
output_node = name_node_map["relu"]
input_node.users[0].meta["quantization_annotation"] = ...
weight_node.users[0].meta["quantization_annotation"] = ...
bias_node.users[0].meta["quantization_annotation"] = ...
output_node.meta["quantization_annotation"] = ...
這樣,即使在神經網路模組和函式式的實現發生變化時,Quantizer 仍然有效。aten IR 對於浮點模型來說會發生變化,但由於我們重新捕獲模式而不是硬編碼模式的 aten IR,我們將獲得更新的 aten IR,並且仍然能夠匹配模式。
一個注意事項是,如果模式的輸入有多個使用者,除了檢查 aten op 目標外,我們沒有好方法來識別我們想要註解的哪個使用者節點。
另一個注意事項是,我們需要確保有一個詳盡的示例列表(例如,2D、3D、4D 輸入,真實輸入 vs. 符號輸入,training=True vs. training=False 等)來確保覆蓋從 torch IR 模式捕獲的各種可能的 aten IR 結果。
注意:我們將來可能會提供一些(模式,示例輸入列表)或一些預生成的匹配器物件,以便人們可以直接使用它們。
結論¶
透過本教程,我們介紹了 PyTorch 2 中的新量化路徑。使用者可以學習如何使用 QuantizationAnnotation API 定義 BackendQuantizer 並將其整合到 PyTorch 2 匯出量化流程中。給出了 QuantizationSpec、SharedQuantizationSpec、FixedQParamsQuantizationSpec 和 DerivedQuantizationSpec 的示例,用於特定的註解用例。您可以將 XNNPACKQuantizer 作為示例來開始實現您自己的 Quantizer。之後,請按照此教程實際量化您的模型。