評價此頁

torch.set_default_dtype#

torch.set_default_dtype(d, /)[原始碼]#

將預設浮點資料型別設定為 d。支援浮點資料型別作為輸入。其他資料型別將導致torch引發異常。

當PyTorch初始化時,其預設浮點資料型別為torch.float32,而set_default_dtype(torch.float64)的目的是為了便於NumPy風格的型別推斷。預設浮點資料型別用於

  1. 隱式確定預設複數資料型別。當預設浮點型別為float16時,預設複數資料型別為complex32。對於float32,預設複數資料型別為complex64。對於float64,它是complex128。對於bfloat16,將引發異常,因為bfloat16沒有對應的複數型別。

  2. 推斷使用Python浮點數或複數構建的張量的資料型別。請參閱下面的示例。

  3. 確定布林值和整數張量與Python浮點數和複數之間的型別提升結果。

引數

d (torch.dtype) – 要設定為預設值的浮點資料型別。

示例

>>> # initial default for floating point is torch.float32
>>> # Python floats are interpreted as float32
>>> torch.tensor([1.2, 3]).dtype
torch.float32
>>> # initial default for floating point is torch.complex64
>>> # Complex Python numbers are interpreted as complex64
>>> torch.tensor([1.2, 3j]).dtype
torch.complex64
>>> torch.set_default_dtype(torch.float64)
>>> # Python floats are now interpreted as float64
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float64
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex128
>>> torch.set_default_dtype(torch.float16)
>>> # Python floats are now interpreted as float16
>>> torch.tensor([1.2, 3]).dtype  # a new floating point tensor
torch.float16
>>> # Complex Python numbers are now interpreted as complex128
>>> torch.tensor([1.2, 3j]).dtype  # a new complex tensor
torch.complex32