torch.utils.module_tracker#
創建於:2024 年 5 月 4 日 | 最後更新於:2025 年 6 月 11 日
此實用工具可用於跟蹤 torch.nn.Module 層級結構中的當前位置。它可用於其他跟蹤工具中,以便輕鬆地將測量到的數量與使用者友好的名稱關聯起來。例如,FlopCounterMode 目前就使用了此工具。
- class torch.utils.module_tracker.ModuleTracker[source]#
ModuleTracker是一個上下文管理器,可在執行期間跟蹤 nn.Module 層級結構,以便其他系統可以查詢當前正在執行哪個 Module(或其反向傳播正在執行)。您可以透過此上下文管理器訪問
parents屬性,以獲取當前透過其 fqn(完全限定名,也用作 state_dict 中的鍵)執行的所有 Module 的集合。您可以透過訪問is_bw屬性來了解您當前是否正在執行反向傳播。請注意,
parents永不為空,並且始終包含“Global”鍵。is_bw標誌將在前向傳播完成後保持True,直到執行另一個 Module。如果您需要更精確,請提交一個問題請求此功能。新增一個從 fqn 到 module 例項的對映是可能的,但尚未實現,如果您需要,請提交一個問題請求此功能。使用示例
mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))