torch.testing#
創建於: 2021年5月7日 | 最後更新於: 2025年6月10日
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source]#
斷言
actual和expected非常接近。如果
actual和expected是 strided、非量化的、實值且有限的,則它們被認為是接近的,如果非有限值(
-inf和inf)僅在它們相等時才被視為接近。NaN僅當equal_nan為True時才被視為相等。此外,它們僅在具有相同的
device(如果check_device為True),dtype(如果check_dtype為True),layout(如果check_layout為True),以及stride (如果
check_stride為True)時才被視為接近。
如果
actual或expected是元張量,則僅執行屬性檢查。如果
actual和expected是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 佈局),則會單獨檢查它們的 strided 成員。索引,即 COO 的indices,CSR 和 BSR 的crow_indices和col_indices,或者 CSC 和 BSC 佈局的ccol_indices和row_indices,總是被檢查相等性,而值根據上述定義檢查接近度。如果
actual和expected是量化的,則當它們具有相同的qscheme()並且dequantize()的結果根據上述定義接近時,它們才被視為接近。actual和expected可以是Tensor或任何可以從torch.Tensor構建的張量或標量型別。除了 Python 標量外,輸入型別必須直接相關。此外,actual和expected可以是Sequence或Mapping,在這種情況下,如果它們的結構匹配並且所有元素都根據上述定義被認為是接近的,那麼它們就被認為是接近的。注意
Python 標量是型別關係要求的例外,因為它們的
type(),即int、float和complex,等同於張量型別的dtype。因此,可以檢查不同型別的 Python 標量,但這需要check_dtype=False。- 引數
actual (Any) – 實際輸入。
expected (Any) – 預期輸入。
allow_subclasses (bool) – 如果為
True(預設),並且除了 Python 標量之外,允許使用直接相關的型別的輸入。否則需要型別相等。rtol (Optional[float]) – 相對容差。如果指定了
atol,則也必須指定。如果省略,則使用下表基於dtype選擇的預設值。atol (Optional[float]) – 絕對容差。如果指定了
rtol,則也必須指定。如果省略,則使用下表基於dtype選擇的預設值。check_device (bool) – 如果為
True(預設),則斷言相應的張量在相同的device上。如果停用此檢查,則將不同device上的張量移至 CPU 再進行比較。check_dtype (bool) – 如果為
True(預設),則斷言相應的張量具有相同的dtype。如果停用此檢查,則將具有不同dtype的張量提升到共同的dtype(根據torch.promote_types())後再進行比較。check_layout (bool) – 如果為
True(預設),則斷言相應的張量具有相同的layout。如果停用此檢查,則將具有不同layout的張量轉換為 strided 張量後再進行比較。check_stride (bool) – 如果為
True且相應的張量是 strided 的,則斷言它們具有相同的 stride。msg (Optional[Union[str, Callable[[str], str]]]) – 在比較期間發生失敗時可用於錯誤訊息的可選引數。也可以作為可呼叫物件傳遞,在這種情況下,它將使用生成的
msg進行呼叫,並應返回新的訊息。
- 引發
ValueError – 如果無法從輸入構造任何
torch.Tensor。ValueError – 如果僅指定了
rtol或atol。AssertionError – 如果相應的輸入不是 Python 標量且不直接相關。
AssertionError – 如果
allow_subclasses為False,但相應的輸入不是 Python 標量且型別不同。AssertionError – 如果輸入是
Sequence,但它們的長度不匹配。AssertionError – 如果輸入是
Mapping,但它們的鍵集不匹配。AssertionError – 如果相應的張量不具有相同的
shape。AssertionError – 如果
check_layout為True,但相應的張量不具有相同的layout。AssertionError – 如果只有其中一個相應張量是量化的。
AssertionError – 如果相應的張量是量化的,但具有不同的
qscheme()。AssertionError – 如果
check_device為True,但相應的張量不在相同的device上。AssertionError – 如果
check_dtype為True,但相應的張量不具有相同的dtype。AssertionError – 如果
check_stride為True,但相應的 strided 張量不具有相同的 stride。AssertionError – 如果相應張量的值根據上述定義不接近。
下表顯示了不同
dtype的預設rtol和atol。如果dtype不匹配,則使用兩種容差的最大值。dtypertolatolfloat161e-31e-5bfloat161.6e-21e-5float321.3e-61e-5float641e-71e-7complex321e-31e-5complex641.3e-61e-5complex1281e-71e-7quint81.3e-61e-5quint2x41.3e-61e-5quint4x21.3e-61e-5qint81.3e-61e-5qint321.3e-61e-5other
0.00.0注意
assert_close()具有高度可配置性,並帶有嚴格的預設設定。鼓勵使用者使用partial()來適應他們的用例。例如,如果需要相等性檢查,可以定義一個assert_equal,該函式預設對所有dtype使用零容差。>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0
示例
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close( ... actual, expected, msg="Argh, the tensors are not close!" ... ) Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source]#
使用給定的
shape、device和dtype建立一個張量,並填充從[low, high)中均勻抽取的隨機值。如果指定了
low或high並且它們超出了dtype的可表示有限值的範圍,則它們將分別被裁剪到最低或最高可表示的有限值。如果為None,則下表描述了low和high的預設值,這些值取決於dtype。dtypelowhighboolean type
02unsigned integral type
010signed integral types
-910floating types
-99complex types
-99- 引數
shape (Tuple[int, ...]) – 定義輸出張量形狀的單個整數或整數序列。
dtype (
torch.dtype) – 返回張量的資料型別。device (Union[str, torch.device]) – 返回張量的裝置。
low (Optional[Number]) – 設定給定範圍的下限(包含)。如果提供數字,則會將其裁剪到給定 dtype 的可表示的最小有限值。當為
None(預設)時,此值根據dtype確定(請參閱上表)。預設值:None。high (Optional[Number]) –
設定給定範圍的上限(不包含)。如果提供數字,則會將其裁剪到給定 dtype 的可表示的最大有限值。當為
None(預設)時,此值根據dtype確定(請參閱上表)。預設值:None。自 2.1 版本起已棄用: 將
low==high傳遞給make_tensor()以用於浮點或複數型別,自 2.1 版本起已棄用,並將在 2.3 版本中刪除。請改用torch.full()。requires_grad (Optional[bool]) – 是否應自動記錄返回張量上的操作。預設值:
False。noncontiguous (Optional[bool]) – 如果為 True,則返回的張量將是非連續的。如果構造的張量少於兩個元素,則忽略此引數。與
memory_format互斥。exclude_zero (Optional[bool]) – 如果為
True,則零將被替換為根據dtype的小正值。對於布林和整數型別,零被替換為一。對於浮點型別,它被替換為 dtype 的最小正正常數(dtype 的finfo()物件的“微小”值),對於複數型別,它被替換為一個實部和虛部都表示為該複數型別可表示的最小正正常數的複數。預設值False。memory_format (Optional[torch.memory_format]) – 返回張量的記憶體格式。與
noncontiguous互斥。
- 引發
ValueError – 如果為整數 dtype 傳遞了
requires_grad=True。ValueError – 如果
low >= high。ValueError – 如果
low或high為nan。ValueError – 如果同時傳遞了
noncontiguous和memory_format。TypeError – 如果
dtype不被此函式支援。
- 返回型別
示例
>>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device="cpu", dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device="cuda", dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')
- torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source]#
警告
torch.testing.assert_allclose()自1.12版本起已棄用,並將在未來版本中刪除。請改用torch.testing.assert_close()。您可以在 此處 找到詳細的升級說明。