快捷方式

TorchAOBaseTensor

class torchao.utils.TorchAOBaseTensor[原始碼]
一個工具張量子類,提供常用函式

新的張量子類可以繼承它來獲得所有實用功能

class MyTensor(TorchAOBaseTensor)

pass

這包括
_get_to_kwargs 可以獲取 to 的 kwargs
class MyTensor(TorchAOBaseTensor)
def to(self, *args, **kwargs)

kwargs = _get_to_kwargs(*args, **kwargs) …

實現了:

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)

register_layout:

register_layout = MyTensor.register_layout

@register_layout(PlainLayout) class PlainAQTTensorImpl(…)

get_tensor_impl_constructor:

get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # 在 MyTensor 的建構函式中: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

用於簡化張量子類實現的類變數
tensor_data_names (List[str]): 所有必需的 tensor_data 的名稱列表,順序應匹配

張量子類的 __init__ 列表

tensor_attribute_names (List[str]): 非 Tensor 屬性的名稱列表,

順序應與張量子類的 __init__ 列表匹配,後跟所有 tensor_data_names 引數

optional_tensor_data_names (List[str]): 可選定義此欄位以為您實現額外的樣板函式,但這在有任何可選的 Tensor 資料屬性時是必需的,定義時,這將是可選項的 Tensor 的名稱列表 optional_tensor_attribute_names (List[str]): 可選定義此欄位以為您實現額外的樣板函式,但這在有任何可選的非 Tensor 屬性時是必需的,定義時,這將是可選項的屬性的名稱列表 注意:`__init__` 和 `__new__` 中的引數順序應與 `tensor_data_names` + `tensor_attribute_names` + `optional_tensor_data_names` (如果存在) + `optional_tensor_attribute_names` (如果存在) 完全匹配。

如果定義了 tensor_data_namestensor_attribute_names,則會新增一些額外的函式,包括: __tensor_flatten__:展平子類化的張量例項,返回一個元組,第一個元素是有效張量資料的名稱,

第二個元素是非 Tensor 屬性的列表

__tensor_unflatten__:接受一個 `tensor_data_dict`(張量名稱到張量的對映)和非張量屬性列表,返回子類化張量的新例項 _apply_fn_to_data:接受一個函式(Tensor -> Tensor),將函式應用於所有張量資料並

用轉換後的張量資料重新建立一個子類化張量

__repr__:子類化張量例項的字串表示形式 _same_metadata:返回 cls 例項之間元資料是否相同 __setstate__:載入序列化的張量子類檢查點時,它會將舊檢查點中儲存的新可選張量和張量屬性設定為 None,以在將新的可選張量資料或屬性新增到張量子類時保持舊檢查點的向後相容性。 PyTorch 操作:torch.Tensor.contiguous ATen 操作:aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (啟用 t.to)

示例

class MyTensor(torch.Tensor)

tensor_data_names = [“a”, “b”] tensor_attribute_names = [“c”, “d”] optional_tensor_data_names = [“e”, “f”] optional_tensor_attribute_names = [“g”, “h”]

def __new__(

cls, a: Tensor, b: Tensor, c: int, d: str, e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,

):

pass

def __init__(

self, a: Tensor, b: Tensor, c: int, d: str e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None,

):

pass

classmethod get_tensor_impl_constructor(layout_class: Callable) Callable

獲取 tensor_class 的 TensorImpl 類建構函式 (TensorImplClass.from_plain),基於 layout_class layout_class 表示 `Layout` 的子類型別,例如 PlainLayout

引數:
  • tensor_class – 張量子類型別

  • layout_class – `Layout` 的子類型別,例如 PlainLayout

返回:

layout_class 的 tensor impl 子類建構函式

classmethod implements(aten_ops_or_torch_fns)

使用此裝飾器為 `__torch_dispatch__` 中的 aten 操作(如果使用者傳入了一個操作列表)或 `__torch_function__` 中的 torch 函式(如果使用者傳入了一個單一物件)實現一個函式。

class MyTensor(torch.Tensor)

… implements = classmethod(_implements)

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs)

classmethod register_layout(layout_class: Callable)

佈局註冊的輔助函式,用於實現每個張量子類的 `register_layout` 裝飾器,請參閱 aqt.py 中的示例用法

引數:
  • tensor_class – 張量子類型別

  • layout_class – `Layout` 的子類型別,例如 PlainLayout

返回:

一個在表中註冊 tensor impl 建構函式的裝飾器

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源