TorchAOBaseTensor¶
- class torchao.utils.TorchAOBaseTensor[原始碼]¶
- 一個工具張量子類,提供常用函式
新的張量子類可以繼承它來獲得所有實用功能
- class MyTensor(TorchAOBaseTensor)
pass
- 這包括
- _get_to_kwargs 可以獲取 to 的 kwargs
- 實現了:
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)
…
- register_layout:
register_layout = MyTensor.register_layout
@register_layout(PlainLayout) class PlainAQTTensorImpl(…)
…
- get_tensor_impl_constructor:
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # 在 MyTensor 的建構函式中: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
- 用於簡化張量子類實現的類變數
- tensor_data_names (List[str]): 所有必需的 tensor_data 的名稱列表,順序應匹配
張量子類的 __init__ 列表
- tensor_attribute_names (List[str]): 非 Tensor 屬性的名稱列表,
順序應與張量子類的 __init__ 列表匹配,後跟所有 tensor_data_names 引數
optional_tensor_data_names (List[str]): 可選定義此欄位以為您實現額外的樣板函式,但這在有任何可選的 Tensor 資料屬性時是必需的,定義時,這將是可選項的 Tensor 的名稱列表 optional_tensor_attribute_names (List[str]): 可選定義此欄位以為您實現額外的樣板函式,但這在有任何可選的非 Tensor 屬性時是必需的,定義時,這將是可選項的屬性的名稱列表 注意:`__init__` 和 `__new__` 中的引數順序應與 `tensor_data_names` + `tensor_attribute_names` + `optional_tensor_data_names` (如果存在) + `optional_tensor_attribute_names` (如果存在) 完全匹配。
如果定義了 tensor_data_names 和 tensor_attribute_names,則會新增一些額外的函式,包括: __tensor_flatten__:展平子類化的張量例項,返回一個元組,第一個元素是有效張量資料的名稱,
第二個元素是非 Tensor 屬性的列表
__tensor_unflatten__:接受一個 `tensor_data_dict`(張量名稱到張量的對映)和非張量屬性列表,返回子類化張量的新例項 _apply_fn_to_data:接受一個函式(Tensor -> Tensor),將函式應用於所有張量資料並
用轉換後的張量資料重新建立一個子類化張量
__repr__:子類化張量例項的字串表示形式 _same_metadata:返回 cls 例項之間元資料是否相同 __setstate__:載入序列化的張量子類檢查點時,它會將舊檢查點中儲存的新可選張量和張量屬性設定為 None,以在將新的可選張量資料或屬性新增到張量子類時保持舊檢查點的向後相容性。 PyTorch 操作:torch.Tensor.contiguous ATen 操作:aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (啟用 t.to)
示例
- class MyTensor(torch.Tensor)
tensor_data_names = [“a”, “b”] tensor_attribute_names = [“c”, “d”] optional_tensor_data_names = [“e”, “f”] optional_tensor_attribute_names = [“g”, “h”]
- def __new__(
cls, a: Tensor, b: Tensor, c: int, d: str, e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,
- ):
pass
- def __init__(
self, a: Tensor, b: Tensor, c: int, d: str e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,
- ):
pass
- classmethod get_tensor_impl_constructor(layout_class: Callable) Callable¶
獲取 tensor_class 的 TensorImpl 類建構函式 (TensorImplClass.from_plain),基於 layout_class layout_class 表示 `Layout` 的子類型別,例如 PlainLayout
- 引數:
tensor_class – 張量子類型別
layout_class – `Layout` 的子類型別,例如 PlainLayout
- 返回:
layout_class 的 tensor impl 子類建構函式
- classmethod implements(aten_ops_or_torch_fns)¶
使用此裝飾器為 `__torch_dispatch__` 中的 aten 操作(如果使用者傳入了一個操作列表)或 `__torch_function__` 中的 torch 函式(如果使用者傳入了一個單一物件)實現一個函式。
- class MyTensor(torch.Tensor)
… implements = classmethod(_implements)
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)
…