評價此頁

torch.masked#

創建於: 2022年8月15日 | 最後更新於: 2025年6月17日

引言#

動機#

警告

Masked tensors 的 PyTorch API 處於原型階段,未來可能會更改。

MaskedTensor 作為 torch.Tensor 的擴充套件,賦予使用者能夠

  • 使用任何掩碼語義(例如,可變長度張量、NaN* 運算子等)

  • 區分 0 和 NaN 梯度

  • 各種稀疏應用(請參閱下面的教程)

“指定”和“未指定”在 PyTorch 中有著悠久的歷史,但沒有正式的語義,當然也沒有一致性;事實上,MaskedTensor 的誕生源於 vanilla torch.Tensor 類無法妥善解決的一系列問題。因此,MaskedTensor 的主要目標是成為 PyTorch 中所述“指定”和“未指定”值的真實來源,其中它們是一等公民,而不是事後才考慮。反過來,這應該進一步釋放稀疏性的潛力,實現更安全、更一致的運算子,併為使用者和開發人員提供更順暢、更直觀的體驗。

什麼是 MaskedTensor?#

MaskedTensor 是一個張量子類,它包含 1) 一個輸入(資料)和 2) 一個掩碼。掩碼指示應包含或忽略輸入中的哪些條目。

舉個例子,假設我們想掩蓋所有等於 0 的值(用灰色表示)並取最大值

_images/tensor_comparison.jpg

上面是 vanilla 張量示例,下面是 MaskedTensor,其中所有 0 都被掩蓋了。這顯然會產生不同的結果,具體取決於我們是否有掩碼,但這種靈活的結構允許使用者在計算過程中系統地忽略任何他們想要的元素。

我們已經編寫了許多現有教程來幫助使用者上手,例如

支援的運算子#

一元運算子#

一元運算子是僅包含單個輸入的運算子。將它們應用於 MaskedTensor 相對簡單:如果給定索引的資料被掩蓋,我們將應用運算子,否則將繼續掩蓋資料。

可用的 unary operators 是

abs

計算 input 中每個元素的絕對值。

absolute

torch.abs() 的別名torch.abs()

acos

計算input中每個元素的反正弦。

arccos

torch.acos() 的別名torch.acos()

acosh

返回一個新張量,其中包含input元素的雙曲反餘弦。

arccosh

torch.acosh() 的別名torch.acosh()

angle

計算給定input張量的每個元素的角度(以弧度表示)。

asin

返回一個新張量,其中包含input元素的反正弦。

arcsin

torch.asin() 的別名torch.asin()

asinh

返回一個新張量,其中包含input元素的雙曲反正弦。

arcsinh

torch.asinh() 的別名torch.asinh()

atan

返回一個新張量,其中包含input元素的反正切。

arctan

torch.atan() 的別名torch.atan()

atanh

返回一個新張量,其中包含input元素的雙曲反正切。

arctanh

torch.atanh() 的別名torch.atanh()

bitwise_not

計算給定輸入張量的按位 NOT。

ceil

返回一個新張量,其中包含input元素的向上取整值,即大於或等於每個元素的最小整數。

clamp

input 中的所有元素限制在 [ min, max ] 範圍內。

clip

torch.clamp() 的別名torch.clamp()

conj_physical

計算給定 input 張量的逐元素共軛。

cos

返回一個新張量,其中包含input元素的餘弦。

cosh

返回一個新張量,其中包含input元素的雙曲餘弦。

deg2rad

返回一個新張量,其中input的每個元素都從角度(度)轉換為弧度。

digamma

torch.special.digamma() 的別名torch.special.digamma()

erf

torch.special.erf() 的別名torch.special.erf()

erfc

torch.special.erfc() 的別名torch.special.erfc()

erfinv

torch.special.erfinv() 的別名torch.special.erfinv()

exp

返回一個新張量,其元素是輸入張量input的指數。

exp2

torch.special.exp2() 的別名torch.special.exp2()

expm1

torch.special.expm1() 的別名torch.special.expm1()

fix

torch.trunc() 的別名torch.trunc()

floor

返回一個新張量,其中包含 input 元素的向下取整值,即小於或等於每個元素的最大整數。

frac

計算input中每個元素的小數部分。

lgamma

計算 input 上伽馬函式絕對值的自然對數。

log

返回一個新張量,其中包含 input 元素對應的自然對數。

log10

返回一個新張量,其中包含input元素的以10為底的對數。

log1p

返回一個新張量,其中包含(1 + input)的自然對數。

log2

返回一個新張量,其中包含input元素的以2為底的對數。

logit

torch.special.logit() 的別名torch.special.logit()

i0

torch.special.i0() 的別名torch.special.i0()

isnan

返回一個新張量,其中包含布林元素,表示 input 中的每個元素是否為 NaN。

nan_to_num

nanposinfneginf 指定的值分別替換 input 中的 NaN、正無窮和負無窮。

neg

返回一個新張量,其中包含input元素的負值。

negative

torch.neg() 的別名torch.neg()

positive

返回 input

pow

計算 input 中每個元素以 exponent 為指數的冪,並返回結果張量。

rad2deg

返回一個新張量,其中input的每個元素都從角度(弧度)轉換為度。

reciprocal

返回一個新張量,其中包含input元素的倒數。

round

input的元素四捨五入到最近的整數。

rsqrt

返回一個新張量,其中包含input每個元素的平方根的倒數。

sigmoid

torch.special.expit() 的別名torch.special.expit()

sign

返回一個新張量,其中包含input元素的符號。

sgn

此函式是 torch.sign() 對複數張量的擴充套件。

signbit

測試input的每個元素的符號位是否已設定。

sin

返回一個新張量,其中包含input元素的正弦。

sinc

torch.special.sinc() 的別名torch.special.sinc()

sinh

返回一個新張量,其中包含input元素的雙曲正弦。

sqrt

返回一個新張量,其中包含 input 元素的平方根。

square

返回一個新張量,其中包含input元素的平方。

tan

返回一個新張量,其中包含input元素的正切。

tanh

返回一個新張量,其元素是 input 的雙曲正切值。

trunc

返回一個新張量,其中包含input元素的截斷整數值。

可用的 inplace unary operators 是上面所有運算子,**但**

angle

計算給定input張量的每個元素的角度(以弧度表示)。

positive

返回 input

signbit

測試input的每個元素的符號位是否已設定。

isnan

返回一個新張量,其中包含布林元素,表示 input 中的每個元素是否為 NaN。

二元運算子#

正如您可能在教程中看到的,MaskedTensor 還實現了二元運算子,但有一個條件:兩個 MaskedTensor 的掩碼必須匹配,否則將引發錯誤。如錯誤中所述,如果您需要特定運算子的支援或對如何處理它們有建議的語義,請在 GitHub 上開一個 issue。目前,我們已決定採用最保守的實現方式,以確保使用者確切地知道正在發生什麼,並對他們使用掩碼語義的決定保持謹慎。

可用的 binary operators 是

add

other(縮放 alpha)加到 input

atan2

考慮象限的 inputi/otheri\text{input}_{i} / \text{other}_{i} 的逐元素反正切。

arctan2

torch.atan2() 的別名torch.atan2()

bitwise_and

計算 inputother 的按位 AND。

bitwise_or

計算 inputother 的按位 OR。

bitwise_xor

計算 inputother 的按位 XOR。

bitwise_left_shift

計算 inputother 位左移。

bitwise_right_shift

計算 inputother 位右移。

div

將輸入input的每個元素除以other的相應元素。

divide

torch.div() 的別名torch.div()

floor_divide

fmod

逐元素應用 C++ 的 std::fmod

logaddexp

輸入指數和的對數。

logaddexp2

以2為底的輸入指數和的對數。

mul

input乘以other

multiply

torch.mul() 的別名torch.mul()

nextafter

返回 input 之後,趨向於 other 的下一個浮點值,逐元素進行。

remainder

逐元素計算Python 的模運算

sub

input 中減去 other(縮放 alpha)。

subtract

torch.sub() 的別名torch.sub()

true_divide

torch.div() 帶有 rounding_mode=None 的別名。

eq

計算逐元素相等

ne

逐元素計算 inputother\text{input} \neq \text{other}

le

逐元素計算 inputother\text{input} \leq \text{other}

ge

逐元素計算 inputother\text{input} \geq \text{other}

greater

torch.gt() 的別名torch.gt()

greater_equal

torch.ge() 的別名torch.ge()

gt

逐元素計算 input>other\text{input} > \text{other}

less_equal

torch.le() 的別名torch.le()

lt

逐元素計算 input<other\text{input} < \text{other}

less

torch.lt() 的別名torch.lt()

maximum

計算 inputother 的逐元素最大值。

minimum

計算 inputother 的逐元素最小值。

fmax

計算 inputother 的逐元素最大值。

fmin

計算 inputother 的逐元素最小值。

not_equal

torch.ne() 的別名torch.ne()

可用的 inplace binary operators 是上面所有運算子,**但**

logaddexp

輸入指數和的對數。

logaddexp2

以2為底的輸入指數和的對數。

equal

如果兩個張量具有相同的大小和元素,則為True,否則為False

fmin

計算 inputother 的逐元素最小值。

minimum

計算 inputother 的逐元素最小值。

fmax

計算 inputother 的逐元素最大值。

歸約#

以下歸約可用(支援 autograd)。有關更多資訊,概述教程詳細介紹了一些歸約示例,而高階語義教程對我們如何決定某些歸約語義進行了更深入的討論。

sum

返回 input 張量中所有元素的和。

mean

amin

在給定維度 dim 下,返回 input 張量每個切片的最小值。

amax

在給定維度 dim 下,返回 input 張量每個切片的最大值。

argmin

返回扁平張量或沿某一維度的最小值的索引

argmax

返回 input 張量中所有元素最大值的索引。

prod

返回input張量中所有元素的乘積。

all

測試 input 中的所有元素是否都評估為 True

norm

返回給定張量的矩陣範數或向量範數。

var

在由 dim 指定的維度上計算方差。

std

在由 dim 指定的維度上計算標準差。

檢視和選擇函式#

我們也包含了一些檢視和選擇函式;直觀地說,這些運算子將同時應用於資料和掩碼,然後將結果包裝在 MaskedTensor 中。舉個快速示例,請考慮 select()

    >>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
    >>> data
    tensor([[ 0.,  1.,  2.,  3.],
            [ 4.,  5.,  6.,  7.],
            [ 8.,  9., 10., 11.]])
    >>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
    >>> mt = masked_tensor(data, mask)
    >>> data.select(0, 1)
    tensor([4., 5., 6., 7.])
    >>> mask.select(0, 1)
    tensor([False,  True, False, False])
    >>> mt.select(0, 1)
    MaskedTensor(
      [      --,   5.0000,       --,       --]
    )

當前支援以下 ops

atleast_1d

返回每個輸入張量的一維檢視,其中零維度。

broadcast_tensors

根據 廣播語義 廣播給定的張量。

broadcast_to

input 廣播到 shape 的形狀。

cat

將給定的張量序列 tensors 在給定維度上連線起來。

chunk

嘗試將張量分割成指定的塊數。

column_stack

透過水平堆疊tensors中的張量來建立新張量。

dsplit

根據 indices_or_sections,將三維或更多維的張量 input 深度分割成多個張量。

flatten

透過將 input 重塑為一維張量來展平。

hsplit

根據 indices_or_sections,將具有一個或多個維度的張量 input 水平分割成多個張量。

hstack

按水平(列方向)順序堆疊張量。

kron

計算 inputother 的克羅內克積,表示為 \otimes

meshgrid

建立由 attr:tensors 中一維輸入指定的座標網格。

narrow

返回一個新張量,它是 input 張量的縮小版本。

nn.functional.unfold

從批次輸入張量中提取滑動區域性塊。

ravel

返回一個連續的展平張量。

select

沿選定維度在給定索引處對 input 張量進行切片。

split

將張量分割成塊。

stack

沿新維度連線一系列張量。

t

期望input為小於等於2維的張量,並轉置維度0和1。

轉置

返回一個轉置版本的 input 張量。

vsplit

根據 indices_or_sections,將二維或更多維的張量 input 垂直分割成多個張量。

vstack

按垂直(行方向)順序堆疊張量。

Tensor.expand

返回 self 張量的新檢視,其中單例維度已擴充套件到更大的大小。

Tensor.expand_as

將此張量擴充套件到與 other 相同的尺寸。

Tensor.reshape

返回一個具有與 self 相同的資料和相同數量的元素,但具有指定形狀的張量。

Tensor.reshape_as

將此張量返回為與 other 相同的形狀。

Tensor.unfold

返回原始張量的檢視,該檢視包含 self 張量在 dimension 維度上的所有大小為 size 的切片。

Tensor.view

返回一個新張量,它具有與 self 張量相同的資料,但形狀不同。