torch.nn.utils.parametrize.cached#
- torch.nn.utils.parametrize.cached()[原始碼]#
當使用
register_parametrization()註冊的引數化時,啟用快取系統的上下文管理器。當此上下文管理器處於活動狀態時,引數化物件的值將在首次需要時計算並快取。快取的值將在離開上下文管理器時被丟棄。
當在前向傳播中使用引數化引數超過一次時,此功能很有用。例如,當引數化 RNN 的迴圈核或共享權重時。
啟用快取的最簡單方法是在神經網路的前向傳播中包裝。
import torch.nn.utils.parametrize as P ... with P.cached(): output = model(inputs)
在訓練和評估中。也可以包裝模組中多次使用引數化張量的部分。例如,具有引數化迴圈核的 RNN 的迴圈。
with P.cached(): for x in xs: out_rnn = self.rnn_cell(x, out_rnn)