From 34de0e7acfb6712021c10d9476fdcc8d5004cee5 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Fri, 5 Jun 2026 02:00:21 +0000 Subject: [PATCH] fix: improve Megatron SFT runtime startup --- src/art/megatron/service.py | 65 ++++++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index dae9a26fb..704dbb3ef 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -10,6 +10,7 @@ import subprocess import sys from typing import Any, AsyncIterator, Literal, TypedDict, cast +import warnings from peft.tuners.lora.config import LoraConfig import torch @@ -115,7 +116,7 @@ def create_identity_lora( model_config = handler.identity_lora_model_config(base_config) with init_empty_weights(): model = AutoModelForCausalLM.from_config( - model_config, torch_dtype=torch.bfloat16, trust_remote_code=True + model_config, dtype=torch.bfloat16, trust_remote_code=True ) model.name_or_path = base_model @@ -142,8 +143,21 @@ def _skip_meta_to( return module return orig_to(module, *args, **kwargs) - with patch.object(torch.nn.Module, "to", _skip_meta_to): - peft_model = get_peft_model(model, peft_lora_config) + # PEFT does not recognize fused MoE expert modules, but our handler + # converts the resulting identity LoRA checkpoint into supported tensors. + with warnings.catch_warnings(): + if bool(getattr(handler, "is_moe", False)): + warnings.filterwarnings( + "ignore", + message=( + r"Unsupported layer type '.*MoeExperts.*' encountered, " + r"proceed at your own risk\." + ), + category=UserWarning, + module=r"peft\.tuners\.tuners_utils", + ) + with patch.object(torch.nn.Module, "to", _skip_meta_to): + peft_model = get_peft_model(model, peft_lora_config) os.makedirs(lora_path, exist_ok=True) peft_model.save_pretrained(lora_path) @@ -209,6 +223,13 @@ def _on_child_process_exit(self, _error: RuntimeError) -> None: def _raise_if_child_failed(self) -> None: self._child_processes.raise_if_failed() + def _status(self, message: str) -> None: + print(f"[ART Megatron] {message}", flush=True) + + @staticmethod + def _display_path(path: str | os.PathLike[str]) -> str: + return str(Path(path).resolve()) + @property def is_dedicated(self) -> bool: return is_dedicated_mode(self.config) @@ -410,6 +431,10 @@ def _adapter_exists_and_loads(self, lora_path: str) -> bool: return True def _create_identity_lora(self, lora_path: str) -> None: + self._status( + "Preparing initial LoRA adapter " + f"for {self.base_model} at {self._display_path(lora_path)}" + ) create_identity_lora( self.base_model, lora_path, @@ -531,6 +556,10 @@ async def _start_vllm_subprocess( os.makedirs(log_dir, exist_ok=True) self._vllm_log_path = os.path.join(log_dir, "vllm-runtime.log") self._vllm_log_file = open(self._vllm_log_path, "w", buffering=1) + self._status( + "Starting vLLM runtime " + f"for {self.base_model}. Logs: {self._display_path(self._vllm_log_path)}" + ) self._vllm_process = subprocess.Popen( managed_process_cmd(cmd), cwd=str(get_vllm_runtime_working_dir()), @@ -581,6 +610,7 @@ async def _start_vllm_subprocess( ) from exc assert self._vllm_process is not None assert self._vllm_log_path is not None + self._status(f"vLLM runtime is ready at {self._vllm_base_url}") self._child_processes.watch_popen( "vLLM runtime", self._vllm_process, @@ -663,6 +693,7 @@ async def _sleep_runtime(self) -> None: import httpx self._raise_if_child_failed() + self._status("Sleeping vLLM runtime to free GPU memory for training") async with httpx.AsyncClient() as client: response = await client.post( f"{self._vllm_base_url}/sleep", @@ -672,11 +703,13 @@ async def _sleep_runtime(self) -> None: ) response.raise_for_status() self._is_sleeping = True + self._status("vLLM runtime is sleeping") async def _wake_runtime(self) -> None: import httpx self._raise_if_child_failed() + self._status("Waking vLLM runtime") async with httpx.AsyncClient() as client: response = await client.post( f"{self._vllm_base_url}/wake_up", @@ -685,6 +718,7 @@ async def _wake_runtime(self) -> None: ) response.raise_for_status() self._is_sleeping = False + self._status("vLLM runtime is awake") async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._raise_if_child_failed() @@ -764,6 +798,10 @@ async def _ensure_megatron_running(self) -> None: "w", buffering=1, ) + self._status( + f"Starting Megatron worker on {num_gpus} GPU(s). " + f"Logs: {self._display_path(megatron_log_path)}" + ) self._megatron_process = await asyncio.create_subprocess_exec( *managed_process_cmd(command), cwd=str(project_root), @@ -778,6 +816,7 @@ async def _ensure_megatron_running(self) -> None: self._megatron_process, log_path=megatron_log_path, ) + self._status("Megatron worker is initializing") def _clear_pending_jobs(self) -> None: jobs_dir, _training_log_dir, _wake_lock_path = self._megatron_runtime_paths() @@ -805,9 +844,10 @@ def _resolve_training_lora_path(self) -> str: async def _prepare_for_training(self) -> str: self._raise_if_child_failed() self._validate_megatron_dependencies() - await self._ensure_megatron_running() + # Shared-GPU Megatron must start after vLLM has released GPU memory. await self._sleep_runtime() gc_and_empty_cuda_cache() + await self._ensure_megatron_running() lora_path = self._resolve_training_lora_path() self._clear_pending_jobs() @@ -820,6 +860,10 @@ async def _publish_training_checkpoint( ) -> None: next_step = self._latest_step + 1 new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step) + self._status( + f"Publishing training checkpoint {next_step} " + f"to {self._display_path(new_checkpoint_dir)}" + ) os.makedirs(new_checkpoint_dir, exist_ok=True) shutil.copy( f"{lora_path}/adapter_model.safetensors", @@ -837,6 +881,7 @@ async def _publish_training_checkpoint( os.remove(wake_lock_path) await self._reload_adapter(new_checkpoint_dir, next_step) + self._status(f"Loaded checkpoint {next_step} into vLLM") async def start_openai_server( self, config: dev.OpenAIServerConfig | None @@ -1015,6 +1060,11 @@ async def train_sft( log_path=log_path, ) write_megatron_job(job, job_path=job_path) + self._status( + f"Starting Megatron SFT job with {serialized_batches.num_batches} " + f"batch(es). First batch may take a few minutes while kernels compile. " + f"Training log: {self._display_path(log_path)}" + ) async for result in stream_megatron_job( job, @@ -1022,11 +1072,16 @@ async def train_sft( process=self._megatron_process, process_log_path=self._megatron_log_path, ): - yield { + metrics = { "loss/train": float(result["loss"]), "loss/learning_rate": float(result["learning_rate"]), "loss/grad_norm": float(result["grad_norm"]), } + if "tokens_per_second" in result: + metrics["throughput/step_trainer_tok_per_s"] = float( + result["tokens_per_second"] + ) + yield metrics await self._publish_training_checkpoint(lora_path=lora_path) except BaseException: