評價此頁

在 torch.compile/torch.export 中支援自定義 C++ 類#

本教程是《自定義 C++ 類》教程的後續,它介紹了支援自定義 C++ 類在 torch.compile/torch.export 中所需的附加步驟。

警告

此功能處於原型狀態,並且可能發生向後相容性破壞性更改。本教程提供了 PyTorch 2.8 的快照。如果您遇到任何問題,請在 Github 上聯絡我們!

具體來說,有幾個步驟:

  1. 在 C++ 自定義類的實現中實現一個 __obj_flatten__ 方法,以便我們能夠檢查其狀態並保護更改。該方法應返回一個屬性名、值的元組 ( tuple[tuple[str, value] * n] )。

  2. 使用 @torch._library.register_fake_class 註冊一個 Python 虛類。

    1. 實現類中每個 C++ 方法的“虛方法”,這些方法應該與 C++ 實現具有相同的模式。

    2. 此外,在 Python 虛類中實現一個 __obj_unflatten__ 類方法,以告訴我們如何從 __obj_flatten__ 返回的扁平化狀態建立虛類。

以下是差異的分解。遵循《使用自定義 C++ 類擴充套件 TorchScript》中的指南,我們可以建立一個執行緒安全的張量佇列並構建它。

// Thread-safe Tensor Queue

#include <torch/custom_class.h>
#include <torch/script.h>

#include <iostream>
#include <string>
#include <vector>

using namespace torch::jit;

// Thread-safe Tensor Queue
struct TensorQueue : torch::CustomClassHolder {
explicit TensorQueue(at::Tensor t) : init_tensor_(t) {}

explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) {
    init_tensor_ = dict.at(std::string("init_tensor"));
    const std::string key = "queue";
    at::Tensor size_tensor;
    size_tensor = dict.at(std::string(key + "/size")).cpu();
    const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>();
    int64_t queue_size = size_tensor_acc[0];

    for (const auto index : c10::irange(queue_size)) {
        at::Tensor val;
        queue_[index] = dict.at(key + "/" + std::to_string(index));
        queue_.push_back(val);
    }
}

// Push the element to the rear of queue.
// Lock is added for thread safe.
void push(at::Tensor x) {
    std::lock_guard<std::mutex> guard(mutex_);
    queue_.push_back(x);
}
// Pop the front element of queue and return it.
// If empty, return init_tensor_.
// Lock is added for thread safe.
at::Tensor pop() {
    std::lock_guard<std::mutex> guard(mutex_);
    if (!queue_.empty()) {
        auto val = queue_.front();
        queue_.pop_front();
        return val;
    } else {
        return init_tensor_;
    }
}

std::vector<at::Tensor> get_raw_queue() {
    std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
    return raw_queue;
}

private:
    std::deque<at::Tensor> queue_;
    std::mutex mutex_;
    at::Tensor init_tensor_;
};

// The torch binding code
TORCH_LIBRARY(MyCustomClass, m) {
    m.class_<TensorQueue>("TensorQueue")
        .def(torch::init<at::Tensor>())
        .def("push", &TensorQueue::push)
        .def("pop", &TensorQueue::pop)
        .def("get_raw_queue", &TensorQueue::get_raw_queue);
}

步驟 1:在 C++ 自定義類實現中新增 __obj_flatten__ 方法

// Thread-safe Tensor Queue
struct TensorQueue : torch::CustomClassHolder {
...
std::tuple<std::tuple<std::string, std::vector<at::Tensor>>, std::tuple<std::string, at::Tensor>> __obj_flatten__() {
    return std::tuple(std::tuple("queue", this->get_raw_queue()), std::tuple("init_tensor_", this->init_tensor_.clone()));
}
...
};

TORCH_LIBRARY(MyCustomClass, m) {
    m.class_<TensorQueue>("TensorQueue")
        .def(torch::init<at::Tensor>())
        ...
        .def("__obj_flatten__", &TensorQueue::__obj_flatten__);
}

步驟 2a:在 Python 中註冊一個實現每個方法的虛類。

# namespace::class_name
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
class FakeTensorQueue:
    def __init__(
        self,
        queue: List[torch.Tensor],
        init_tensor_: torch.Tensor
    ) -> None:
        self.queue = queue
        self.init_tensor_ = init_tensor_

    def push(self, tensor: torch.Tensor) -> None:
        self.queue.append(tensor)

    def pop(self) -> torch.Tensor:
        if len(self.queue) > 0:
            return self.queue.pop(0)
        return self.init_tensor_

步驟 2b:在 Python 中實現一個 __obj_unflatten__ 類方法。

# namespace::class_name
@torch._library.register_fake_class("MyCustomClass::TensorQueue")
class FakeTensorQueue:
    ...
    @classmethod
    def __obj_unflatten__(cls, flattened_tq):
        return cls(**dict(flattened_tq))

就是這樣!現在我們可以建立一個使用此物件的模組,並使用 torch.compiletorch.export 執行它。

import torch

torch.classes.load_library("build/libcustom_class.so")
tq = torch.classes.MyCustomClass.TensorQueue(torch.empty(0).fill_(-1))

class Mod(torch.nn.Module):
    def forward(self, tq, x):
        tq.push(x.sin())
        tq.push(x.cos())
        poped_t = tq.pop()
        assert torch.allclose(poped_t, x.sin())
        return tq, poped_t

tq, poped_t = torch.compile(Mod(), backend="eager", fullgraph=True)(tq, torch.randn(2, 3))
assert tq.size() == 1

exported_program = torch.export.export(Mod(), (tq, torch.randn(2, 3),), strict=False)
exported_program.module()(tq, torch.randn(2, 3))

我們還可以實現接受自定義類作為輸入的自定義操作。例如,我們可以註冊一個自定義操作 for_each_add_(tq, tensor)

struct TensorQueue : torch::CustomClassHolder {
    ...
    void for_each_add_(at::Tensor inc) {
        for (auto& t : queue_) {
            t.add_(inc);
        }
    }
    ...
}


TORCH_LIBRARY_FRAGMENT(MyCustomClass, m) {
    m.class_<TensorQueue>("TensorQueue")
        ...
        .def("for_each_add_", &TensorQueue::for_each_add_);

    m.def(
        "for_each_add_(__torch__.torch.classes.MyCustomClass.TensorQueue foo, Tensor inc) -> ()");
}

void for_each_add_(c10::intrusive_ptr<TensorQueue> tq, at::Tensor inc) {
    tq->for_each_add_(inc);
}

TORCH_LIBRARY_IMPL(MyCustomClass, CPU, m) {
    m.impl("for_each_add_", for_each_add_);
}

由於虛類是在 Python 中實現的,因此我們要求自定義操作的虛實現也必須在 Python 中註冊。

@torch.library.register_fake("MyCustomClass::for_each_add_")
def fake_for_each_add_(tq, inc):
    tq.for_each_add_(inc)

重新編譯後,我們可以使用以下方式匯出自定義操作:

class ForEachAdd(torch.nn.Module):
    def forward(self, tq: torch.ScriptObject, a: torch.Tensor) -> torch.ScriptObject:
        torch.ops.MyCustomClass.for_each_add_(tq, a)
        return tq

mod = ForEachAdd()
tq = empty_tensor_queue()
qlen = 10
for i in range(qlen):
    tq.push(torch.zeros(1))

ep = torch.export.export(mod, (tq, torch.ones(1)), strict=False)

為什麼要建立虛類?#

使用真實自定義物件進行跟蹤有幾個主要的缺點:

  1. 真實物件上的操作可能耗時,例如,自定義物件可能正在從網路讀取或從磁碟載入資料。

  2. 在跟蹤時,我們不希望修改真實自定義物件或對環境產生副作用。

  3. 它不支援動態形狀。

然而,使用者可能難以編寫虛類,例如,如果原始類使用某些第三方庫來確定方法的輸出形狀,或者它很複雜並且由他人編寫。在這種情況下,使用者可以透過定義一個 tracing_mode 方法來返回 "real" 來停用虛化要求。

std::string tracing_mode() {
    return "real";
}

虛化的一個注意事項是關於**張量別名**。我們假設 torchbind 物件中的張量不會別名 torchbind 物件之外的張量。因此,修改其中一個張量將導致未定義的行為。