評價此頁

torch.nn.utils.parametrize.cached#

torch.nn.utils.parametrize.cached()[原始碼]#

當使用 register_parametrization() 註冊的引數化時,啟用快取系統的上下文管理器。

當此上下文管理器處於活動狀態時,引數化物件的值將在首次需要時計算並快取。快取的值將在離開上下文管理器時被丟棄。

當在前向傳播中使用引數化引數超過一次時,此功能很有用。例如,當引數化 RNN 的迴圈核或共享權重時。

啟用快取的最簡單方法是在神經網路的前向傳播中包裝。

import torch.nn.utils.parametrize as P

...
with P.cached():
    output = model(inputs)

在訓練和評估中。也可以包裝模組中多次使用引數化張量的部分。例如,具有引數化迴圈核的 RNN 的迴圈。

with P.cached():
    for x in xs:
        out_rnn = self.rnn_cell(x, out_rnn)