評價此頁

torch.mps.compile_shader#

torch.mps.compile_shader(source)[source]#

從源編譯計算著色器,並允許您從 Python 執行時舒適地呼叫其中定義的核心。示例

>>> lib = torch.mps.compile_shader(
... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
...  )
>>> x = torch.zeros(16, device="mps")
>>> lib.full(x, 3.14)