Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions src/art/megatron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import subprocess
import sys
from typing import Any, AsyncIterator, Literal, TypedDict, cast
import warnings

from peft.tuners.lora.config import LoraConfig
import torch
Expand Down Expand Up @@ -115,7 +116,7 @@ def create_identity_lora(
model_config = handler.identity_lora_model_config(base_config)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
model_config, torch_dtype=torch.bfloat16, trust_remote_code=True
model_config, dtype=torch.bfloat16, trust_remote_code=True
)
model.name_or_path = base_model

Expand All @@ -142,8 +143,21 @@ def _skip_meta_to(
return module
return orig_to(module, *args, **kwargs)

with patch.object(torch.nn.Module, "to", _skip_meta_to):
peft_model = get_peft_model(model, peft_lora_config)
# PEFT does not recognize fused MoE expert modules, but our handler
# converts the resulting identity LoRA checkpoint into supported tensors.
with warnings.catch_warnings():
if bool(getattr(handler, "is_moe", False)):
warnings.filterwarnings(
"ignore",
message=(
r"Unsupported layer type '.*MoeExperts.*' encountered, "
r"proceed at your own risk\."
),
category=UserWarning,
module=r"peft\.tuners\.tuners_utils",
)
with patch.object(torch.nn.Module, "to", _skip_meta_to):
peft_model = get_peft_model(model, peft_lora_config)

os.makedirs(lora_path, exist_ok=True)
peft_model.save_pretrained(lora_path)
Expand Down Expand Up @@ -209,6 +223,13 @@ def _on_child_process_exit(self, _error: RuntimeError) -> None:
def _raise_if_child_failed(self) -> None:
self._child_processes.raise_if_failed()

def _status(self, message: str) -> None:
print(f"[ART Megatron] {message}", flush=True)

@staticmethod
def _display_path(path: str | os.PathLike[str]) -> str:
return str(Path(path).resolve())

@property
def is_dedicated(self) -> bool:
return is_dedicated_mode(self.config)
Expand Down Expand Up @@ -410,6 +431,10 @@ def _adapter_exists_and_loads(self, lora_path: str) -> bool:
return True

def _create_identity_lora(self, lora_path: str) -> None:
self._status(
"Preparing initial LoRA adapter "
f"for {self.base_model} at {self._display_path(lora_path)}"
)
create_identity_lora(
self.base_model,
lora_path,
Expand Down Expand Up @@ -531,6 +556,10 @@ async def _start_vllm_subprocess(
os.makedirs(log_dir, exist_ok=True)
self._vllm_log_path = os.path.join(log_dir, "vllm-runtime.log")
self._vllm_log_file = open(self._vllm_log_path, "w", buffering=1)
self._status(
"Starting vLLM runtime "
f"for {self.base_model}. Logs: {self._display_path(self._vllm_log_path)}"
)
self._vllm_process = subprocess.Popen(
managed_process_cmd(cmd),
cwd=str(get_vllm_runtime_working_dir()),
Expand Down Expand Up @@ -581,6 +610,7 @@ async def _start_vllm_subprocess(
) from exc
assert self._vllm_process is not None
assert self._vllm_log_path is not None
self._status(f"vLLM runtime is ready at {self._vllm_base_url}")
self._child_processes.watch_popen(
"vLLM runtime",
self._vllm_process,
Expand Down Expand Up @@ -663,6 +693,7 @@ async def _sleep_runtime(self) -> None:
import httpx

self._raise_if_child_failed()
self._status("Sleeping vLLM runtime to free GPU memory for training")
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._vllm_base_url}/sleep",
Expand All @@ -672,11 +703,13 @@ async def _sleep_runtime(self) -> None:
)
response.raise_for_status()
self._is_sleeping = True
self._status("vLLM runtime is sleeping")

async def _wake_runtime(self) -> None:
import httpx

self._raise_if_child_failed()
self._status("Waking vLLM runtime")
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._vllm_base_url}/wake_up",
Expand All @@ -685,6 +718,7 @@ async def _wake_runtime(self) -> None:
)
response.raise_for_status()
self._is_sleeping = False
self._status("vLLM runtime is awake")

async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
self._raise_if_child_failed()
Expand Down Expand Up @@ -764,6 +798,10 @@ async def _ensure_megatron_running(self) -> None:
"w",
buffering=1,
)
self._status(
f"Starting Megatron worker on {num_gpus} GPU(s). "
f"Logs: {self._display_path(megatron_log_path)}"
)
self._megatron_process = await asyncio.create_subprocess_exec(
*managed_process_cmd(command),
cwd=str(project_root),
Expand All @@ -778,6 +816,7 @@ async def _ensure_megatron_running(self) -> None:
self._megatron_process,
log_path=megatron_log_path,
)
self._status("Megatron worker is initializing")

def _clear_pending_jobs(self) -> None:
jobs_dir, _training_log_dir, _wake_lock_path = self._megatron_runtime_paths()
Expand Down Expand Up @@ -805,9 +844,10 @@ def _resolve_training_lora_path(self) -> str:
async def _prepare_for_training(self) -> str:
self._raise_if_child_failed()
self._validate_megatron_dependencies()
await self._ensure_megatron_running()
# Shared-GPU Megatron must start after vLLM has released GPU memory.
await self._sleep_runtime()
gc_and_empty_cuda_cache()
await self._ensure_megatron_running()

lora_path = self._resolve_training_lora_path()
self._clear_pending_jobs()
Expand All @@ -820,6 +860,10 @@ async def _publish_training_checkpoint(
) -> None:
next_step = self._latest_step + 1
new_checkpoint_dir = get_step_checkpoint_dir(self.output_dir, next_step)
self._status(
f"Publishing training checkpoint {next_step} "
f"to {self._display_path(new_checkpoint_dir)}"
)
os.makedirs(new_checkpoint_dir, exist_ok=True)
shutil.copy(
f"{lora_path}/adapter_model.safetensors",
Expand All @@ -837,6 +881,7 @@ async def _publish_training_checkpoint(
os.remove(wake_lock_path)

await self._reload_adapter(new_checkpoint_dir, next_step)
self._status(f"Loaded checkpoint {next_step} into vLLM")

async def start_openai_server(
self, config: dev.OpenAIServerConfig | None
Expand Down Expand Up @@ -1015,18 +1060,28 @@ async def train_sft(
log_path=log_path,
)
write_megatron_job(job, job_path=job_path)
self._status(
f"Starting Megatron SFT job with {serialized_batches.num_batches} "
f"batch(es). First batch may take a few minutes while kernels compile. "
f"Training log: {self._display_path(log_path)}"
)

async for result in stream_megatron_job(
job,
job_path=job_path,
process=self._megatron_process,
process_log_path=self._megatron_log_path,
):
yield {
metrics = {
"loss/train": float(result["loss"]),
"loss/learning_rate": float(result["learning_rate"]),
"loss/grad_norm": float(result["grad_norm"]),
}
if "tokens_per_second" in result:
metrics["throughput/step_trainer_tok_per_s"] = float(
result["tokens_per_second"]
)
yield metrics

await self._publish_training_checkpoint(lora_path=lora_path)
except BaseException:
Expand Down
Loading