快捷方式

tensordict.nn.set_skip_existing

class tensordict.nn.set_skip_existing(mode: bool | None = True, in_key_attr='in_keys', out_key_attr='out_keys')

用於在 TensorDict 圖中跳過現有節點的上下文管理器。

用作上下文管理器時,它會將 `skip_existing()` 的值設定為指定的 `mode`,讓使用者能夠編寫相應的程式碼來檢查全域性值並據此執行程式碼。

用作方法裝飾器時,它會檢查 tensordict 的輸入鍵,如果 `skip_existing()` 呼叫返回 `True`,則當所有輸出鍵都已存在時,將跳過該方法。此裝飾器不適用於不遵循以下簽名的函式:`def fun(self, tensordict, *args, **kwargs)`。

引數:
  • mode (bool, optional) – 如果為 `True`,則表示圖中的現有條目不會被覆蓋,除非它們是部分存在的。`skip_existing()` 將返回 `True`。如果為 `False`,則不會執行檢查。如果為 `None`,則 `skip_existing()` 的值不會改變。這僅用於裝飾方法,並允許它們的行為依賴於上下文管理器中的同一類(參見下面的示例)。預設為 `True`。

  • in_key_attr (str, optional) – 被裝飾模組方法中的輸入鍵列表屬性的名稱。預設為 `'in_keys'`。

  • out_key_attr (str, optional) – 被裝飾模組方法中的輸出鍵列表屬性的名稱。預設為 `'out_keys'`。

示例

>>> with set_skip_existing():
...     if skip_existing():
...         print("True")
...     else:
...         print("False")
...
True
>>> print("calling from outside:", skip_existing())
calling from outside: False

此類也可作為裝飾器使用

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> module(TensorDict())  # prints hello
hello
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

用 `mode` 設定為 `None` 來裝飾一個方法,當你想讓上下文管理器從外部負責跳過內容時非常有用。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import set_skip_existing, skip_existing, TensorDictModuleBase
>>> class MyModule(TensorDictModuleBase):
...     in_keys = []
...     out_keys = ["out"]
...     @set_skip_existing(None)
...     def forward(self, tensordict):
...         print("hello")
...         tensordict.set("out", torch.zeros(()))
...         return tensordict
>>> module = MyModule()
>>> _ = module(TensorDict({"out": torch.zeros(())}, []))  # prints "hello"
hello
>>> with set_skip_existing(True):
...     _ = module(TensorDict({"out": torch.zeros(())}, []))  # no print

注意

為了允許模組具有相同的輸入和輸出鍵而不至於錯誤地忽略子圖,當輸出鍵也是輸入鍵時,`@set_skip_existing(True)` 將被停用。

>>> class MyModule(TensorDictModuleBase):
...     in_keys = ["out"]
...     out_keys = ["out"]
...     @set_skip_existing()
...     def forward(self, tensordict):
...         print("calling the method!")
...         return tensordict
...
>>> module = MyModule()
>>> module(TensorDict({"out": torch.zeros(())}, []))  # does not print anything
calling the method!
TensorDict(
    fields={
        out: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源