定製#
創建於:2021年5月4日 | 最後更新:2021年5月4日
本節介紹如何根據您的需求定製 TorchElastic。
啟動器#
TorchElastic 附帶的啟動器程式應該足以滿足大多數用例(請參閱 torchrun (Elastic Launch))。您可以像下面一樣,透過以程式設計方式建立代理並將工作節點的配置傳遞給它來實現自定義啟動器。
# my_launcher.py
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
rdzv_handler = RendezvousHandler(...)
spec = WorkerSpec(
local_world_size=args.nproc_per_node,
fn=trainer_entrypoint_fn,
args=(trainer_entrypoint_fn args.fn_args,...),
rdzv_handler=rdzv_handler,
max_restarts=args.max_restarts,
monitor_interval=args.monitor_interval,
)
agent = LocalElasticAgent(spec, start_method="spawn")
try:
run_result = agent.run()
if run_result.is_failed():
print(f"worker 0 failed with: run_result.failures[0]")
else:
print(f"worker 0 return value is: run_result.return_values[0]")
except Exception ex:
# handle exception
集合點處理程式#
要實現自己的集合點,請擴充套件 torch.distributed.elastic.rendezvous.RendezvousHandler 並實現其方法。
警告
集合點處理程式的實現很棘手。在開始之前,請確保您完全理解集合點的屬性。有關更多資訊,請參閱 集合點。
實現後,您可以在建立代理時將自定義集合點處理程式傳遞給工作節點配置。
spec = WorkerSpec(
rdzv_handler=MyRendezvousHandler(params),
...
)
elastic_agent = LocalElasticAgent(spec, start_method=start_method)
elastic_agent.run(spec.role)
指標處理程式#
TorchElastic 會發出平臺級指標(請參閱 指標)。預設情況下,指標會發送到 /dev/null,因此您將看不到它們。要將指標推送到您基礎架構中的指標處理服務,請實現 torch.distributed.elastic.metrics.MetricHandler 並在自定義啟動器中對其進行 配置。
# my_launcher.py
import torch.distributed.elastic.metrics as metrics
class MyMetricHandler(metrics.MetricHandler):
def emit(self, metric_data: metrics.MetricData):
# push metric_data to your metric sink
def main():
metrics.configure(MyMetricHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()
事件處理程式#
TorchElastic 支援事件記錄(請參閱 事件)。事件模組定義了一個 API,允許您記錄事件並實現自定義 EventHandler。EventHandler 用於將 torchelastic 執行期間生成的事件釋出到不同的源,例如 AWS CloudWatch。預設情況下,它使用 torch.distributed.elastic.events.NullEventHandler,該處理程式會忽略事件。要配置自定義事件處理程式,您需要實現 torch.distributed.elastic.events.EventHandler 介面並在自定義啟動器中對其進行 配置。
# my_launcher.py
import torch.distributed.elastic.events as events
class MyEventHandler(events.EventHandler):
def record(self, event: events.Event):
# process event
def main():
events.configure(MyEventHandler())
spec = WorkerSpec(...)
agent = LocalElasticAgent(spec)
agent.run()