torch.overrides#
創建於:2020 年 11 月 30 日 | 最後更新於:2025 年 6 月 6 日
此模組公開了用於 __torch_function__ 協議的各種輔助函式。有關 __torch_function__ 協議的更多詳細資訊,請參閱 擴充套件 PyTorch Python API。
函式#
- torch.overrides.get_ignored_functions()[source]#
返回無法被
__torch_function__覆蓋的公共函式。- 返回
torch API 中公開可用但不能使用
__torch_function__覆蓋的函式元組。主要是因為這些函式中的任何引數都不是張量或類似張量的值。- 返回型別
set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides.get_overridable_functions()[source]#
列出可透過 __torch_function__ 覆蓋的函式
- 返回
一個字典,它將包含可覆蓋函式的名稱空間對映到該名稱空間中可以被覆蓋的函式。
- 返回型別
Dict[Any, List[Callable]]
- torch.overrides.resolve_name(f)[source]#
獲取傳遞給 __torch_function__ 的函式的易於閱讀的字串名稱
- 引數
f (Callable) – 要解析名稱的函式。
- 返回
函式的名稱;如果求值,它應該返回輸入函式。
- 返回型別
- torch.overrides.get_testing_overrides()[source]#
返回一個包含所有可覆蓋函式的虛擬覆蓋的字典
- 返回
一個字典,它將 PyTorch API 中的可覆蓋函式對映到具有與實際函式相同簽名的 lambda 函式,並無條件地返回 -1。這些 lambda 函式對於測試支援
__torch_function__的型別的 API 覆蓋率非常有用。- 返回型別
Dict[Callable, Callable]
示例
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]#
實現一個帶有
__torch_function__覆蓋檢查的函式。有關此函式在 C++ 實現中的等效項,請參閱 torch::autograd::handle_torch_function。
- 引數
- 返回
呼叫
implementation或__torch_function__方法的結果,視情況而定。- 返回型別
:raises TypeError : 如果找不到實現。
示例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides.has_torch_function()#
檢查可迭代元素中的 __torch_function__ 實現,或者檢查是否啟用了 __torch_function__ 模式。將精確的
Tensor和Parameter視為不可分派。使用此函式來保護對handle_torch_function()的呼叫;不要用它來測試一個東西是否是類似張量的,而是使用is_tensor_like()。 :param relevant_args: 用於檢查 __torch_function__ 方法的引數的可迭代物件。 :type relevant_args: iterable- 返回
如果 relevant_args 的任何元素具有 __torch_function__ 實現,則返回 True,否則返回 False。
- 返回型別
另請參閱
torch.is_tensor_like檢查一個東西是否是類似張量的,包括精確的
Tensor。
- torch.overrides.is_tensor_like(inp)[source]#
如果傳入的輸入是類似張量的,則返回
True。目前,這發生在輸入型別上存在
__torch_function__屬性時。示例
張量的子類通常是類似張量的。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
內建型別或使用者定義的型別通常不是類似張量的。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,可以透過實現 __torch_function__ 來使其成為類似張量的。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides.is_tensor_method_or_property(func)[source]#
如果傳入的函式是
__torch_function__傳遞給torch.Tensor的方法或屬性的處理程式,則返回 True。注意
對於屬性,必須傳入它們的
__get__方法。這可能需要,特別是出於以下原因:
方法/屬性有時不包含 __module__ 插槽。
它們要求傳入的第一個引數是
torch.Tensor的例項。
示例
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 返回型別
- torch.overrides.wrap_torch_function(dispatcher)[source]#
使用
__torch_function__相關功能包裝給定的函式。- 引數
dispatcher (Callable) – 一個可呼叫物件,它返回傳遞給函式的類張量 (Tensor-like) 的可迭代物件。
注意
此裝飾器可能會降低程式碼的效能。通常,將程式碼表示為一系列本身支援 __torch_function__ 的函式就足夠了。如果您發現自己處於這種情況很少見,例如,如果您正在包裝一個底層庫,並且您也希望它適用於類張量,那麼此函式可用。
示例
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0