diff --git a/dev/sft/sft-from-file.py b/dev/sft/sft-from-file.py index df66e61cd..284c0ae35 100644 --- a/dev/sft/sft-from-file.py +++ b/dev/sft/sft-from-file.py @@ -4,12 +4,12 @@ import random import art -from art.local import LocalBackend +from art.megatron import MegatronBackend from art.utils.sft import train_sft_from_file async def main(): - backend = LocalBackend() + backend = MegatronBackend() model_name = "run-" + "".join( random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8) @@ -17,7 +17,7 @@ async def main(): model = art.TrainableModel( name=model_name, project="sft-from-file", - base_model="meta-llama/Llama-3.1-8B-Instruct", + base_model="Qwen/Qwen3.6-35B-A3B", ) await model.register(backend) diff --git a/dev/sft/sft-warmup.py b/dev/sft/sft-warmup.py index 7a0244039..e44e54139 100644 --- a/dev/sft/sft-warmup.py +++ b/dev/sft/sft-warmup.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv import art -from art.local import LocalBackend +from art.megatron import MegatronBackend from art.utils.sft import create_sft_dataset_iterator # Simple SFT trajectories - teach model to respond "maybe" @@ -43,7 +43,7 @@ async def rl_rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory: async def main(): load_dotenv() - backend = LocalBackend() + backend = MegatronBackend() model_name = "sft-warmup-" + "".join( random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8) ) diff --git a/src/art/dev/train.py b/src/art/dev/train.py index d22bdfee6..b4a797355 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -2,6 +2,8 @@ from typing_extensions import TypedDict +from art.types import SFTMetricLoggingConfig + if TYPE_CHECKING: from art.megatron.routing_replay import MoeRoutingReplayBundle @@ -40,3 +42,5 @@ class TrainConfig(TypedDict, total=False): class TrainSFTConfig(TypedDict, total=False): """Experimental SFT configuration options. Use at your own risk.""" + + metric_logging: SFTMetricLoggingConfig diff --git a/src/art/metrics_taxonomy.py b/src/art/metrics_taxonomy.py index 5e5245e2b..68a749d3b 100644 --- a/src/art/metrics_taxonomy.py +++ b/src/art/metrics_taxonomy.py @@ -6,6 +6,9 @@ from .trajectories import TrajectoryGroup TRAIN_GRADIENT_STEPS_KEY = "data/step_num_gradient_steps" +SFT_METRIC_PREFIX = "sft" +SFT_GRADIENT_STEP_KEY = "gradient_step" +SFT_WANDB_GRADIENT_STEP_KEY = f"{SFT_METRIC_PREFIX}/{SFT_GRADIENT_STEP_KEY}" _INVARIANT_METRIC_KEYS = frozenset({TRAIN_GRADIENT_STEPS_KEY}) diff --git a/src/art/model.py b/src/art/model.py index 182207458..06f1f88e1 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -17,6 +17,9 @@ from .costs import CostCalculator from .metrics import MetricsBuilder, is_builder_managed_metric from .metrics_taxonomy import ( + SFT_GRADIENT_STEP_KEY, + SFT_METRIC_PREFIX, + SFT_WANDB_GRADIENT_STEP_KEY, TRAIN_GRADIENT_STEPS_KEY, average_metric_samples, build_data_metrics_from_summary, @@ -24,7 +27,7 @@ ) from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice from .trajectories import Trajectory, TrajectoryGroup -from .types import TrainSFTConfig +from .types import SFTMetricLoggingConfig, TrainSFTConfig from .utils.trajectory_logging import write_trajectory_groups_parquet if TYPE_CHECKING: @@ -625,6 +628,7 @@ def _get_wandb_run(self) -> Optional["Run"]: self, "_wandb_defined_metrics", { + SFT_WANDB_GRADIENT_STEP_KEY, "training_step", "time/wall_clock_sec", }, @@ -634,12 +638,16 @@ def _get_wandb_run(self) -> Optional["Run"]: # This allows out-of-order logging (e.g., async validation for previous steps). run.define_metric("training_step") run.define_metric("time/wall_clock_sec") + run.define_metric(SFT_WANDB_GRADIENT_STEP_KEY) run.define_metric("reward/*", step_metric="training_step") run.define_metric("loss/*", step_metric="training_step") run.define_metric("throughput/*", step_metric="training_step") run.define_metric("costs/*", step_metric="training_step") run.define_metric("time/*", step_metric="training_step") run.define_metric("data/*", step_metric="training_step") + run.define_metric( + f"{SFT_METRIC_PREFIX}/*", step_metric=SFT_WANDB_GRADIENT_STEP_KEY + ) run.define_metric("train/*", step_metric="training_step") run.define_metric("val/*", step_metric="training_step") run.define_metric("test/*", step_metric="training_step") @@ -1230,6 +1238,7 @@ async def train_sft( config: TrainSFTConfig | None = None, _config: dev.TrainSFTConfig | None = None, verbose: bool = False, + log_metrics: bool = True, ) -> None: """ Supervised fine-tune the model with an iterable of trajectories. @@ -1241,16 +1250,35 @@ async def train_sft( _config: Additional experimental configuration that is subject to change and not yet part of the public API. Use at your own risk. verbose: Whether to print verbose output. + log_metrics: Whether to log SFT optimizer metrics. Defaults to True. """ if config is None: config = TrainSFTConfig() + backend = self.backend() + backend_logs_sft_metrics = ( + log_metrics and self._backend_logs_sft_metrics_remotely(backend) + ) + + _config = cast(dev.TrainSFTConfig, {**(_config or {})}) + if log_metrics: + metric_logging_config: SFTMetricLoggingConfig = { + "enabled": True, + } + if backend_logs_sft_metrics: + metric_logging_config["target_training_step"] = ( + await self.get_step() + ) + 1 + _config["metric_logging"] = metric_logging_config + else: + _config["metric_logging"] = {"enabled": False} + # Train (backend yields metrics for each batch without logging) - # Collect all metrics and aggregate them at the end (same as RL) - _config = _config or {} # ty:ignore[invalid-assignment] + # Collect all metrics and aggregate them at the end for the checkpoint summary. training_metrics: list[dict[str, float]] = [] + local_sft_checkpoint_step: int | None = None trainer_started = time.monotonic() - async for metrics in self.backend()._train_sft( + async for metrics in backend._train_sft( self, trajectories, config, @@ -1258,10 +1286,20 @@ async def train_sft( verbose, ): training_metrics.append(metrics) + gradient_step = len(training_metrics) + if log_metrics and not backend_logs_sft_metrics: + if local_sft_checkpoint_step is None: + local_sft_checkpoint_step = await self.get_step() + 1 + await self._log_sft_metric_sample( + metrics, + checkpoint_step=local_sft_checkpoint_step, + gradient_step=gradient_step, + ) trainer_elapsed = time.monotonic() - trainer_started - # Log aggregated training metrics once (same as RL) - if training_metrics: + # Log aggregated training metrics once at the checkpoint step. For + # remote-logging backends, the remote SFT job owns this row too. + if training_metrics and log_metrics and not backend_logs_sft_metrics: avg_metrics = average_metric_samples(training_metrics) avg_metrics["time/step_trainer_s"] = trainer_elapsed # Get the current step after training @@ -1269,3 +1307,27 @@ async def train_sft( await self.log( trajectories=None, split="train", metrics=avg_metrics, step=step ) + + @staticmethod + def _backend_logs_sft_metrics_remotely(backend: "Backend") -> bool: + remote_logger = getattr(type(backend), "logs_sft_metrics_remotely", None) + if not callable(remote_logger): + return False + return bool(remote_logger(backend)) + + async def _log_sft_metric_sample( + self, + metrics: dict[str, float], + *, + checkpoint_step: int, + gradient_step: int, + ) -> None: + await self.log( + trajectories=None, + split=SFT_METRIC_PREFIX, + metrics={ + SFT_GRADIENT_STEP_KEY: float(gradient_step), + **metrics, + }, + step=checkpoint_step, + ) diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 4ab10742b..9600f8bef 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,7 +1,7 @@ import asyncio from contextlib import asynccontextmanager import time -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, cast import warnings from openai._types import NOT_GIVEN @@ -22,7 +22,12 @@ summarize_trajectory_groups, ) from ..trajectories import Trajectory, TrajectoryGroup -from ..types import ServerlessTrainResult, TrainConfig, TrainSFTConfig +from ..types import ( + ServerlessTrainResult, + SFTMetricLoggingConfig, + TrainConfig, + TrainSFTConfig, +) from ..utils.record_provenance import record_provenance if TYPE_CHECKING: @@ -88,6 +93,9 @@ def __init__( self._base_url = str(client.base_url) self._client = client + def logs_sft_metrics_remotely(self) -> bool: + return True + async def close(self) -> None: await self._client.close() # ty:ignore[possibly-missing-attribute] @@ -607,6 +615,12 @@ async def _train_sft( ) sft_config["batch_size"] = batch_size sft_config["learning_rate"] = config.learning_rate + metric_logging = cast( + SFTMetricLoggingConfig, + dict(dev_config.get("metric_logging", {}) or {}), + ) + if metric_logging.get("enabled"): + sft_config["metric_logging"] = metric_logging sft_training_job = await self._client.sft_training_jobs.create( model_id=model.id, diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 19d724e7d..be1b935b3 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -20,6 +20,7 @@ from typing_extensions import override from ..trajectories import TrajectoryGroup +from ..types import SFTMetricLoggingConfig ResponseT = TypeVar("ResponseT") @@ -80,6 +81,7 @@ class ExperimentalTrainingConfig(TypedDict, total=False): class SFTTrainingConfig(TypedDict, total=False): batch_size: int | None learning_rate: float | list[float] | None + metric_logging: SFTMetricLoggingConfig | None class TrainingJob(BaseModel): diff --git a/src/art/types.py b/src/art/types.py index 389d513ff..e5282665b 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -6,6 +6,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam import pydantic from pydantic import SkipValidation +from typing_extensions import TypedDict Message = Annotated[ChatCompletionMessageParam, SkipValidation] MessageOrChoice = Message | Choice @@ -25,6 +26,11 @@ class TrainSFTConfig(pydantic.BaseModel): batch_size: int | Literal["auto"] = "auto" +class SFTMetricLoggingConfig(TypedDict, total=False): + enabled: bool + target_training_step: int + + Verbosity = Literal[0, 1, 2] diff --git a/src/art/utils/sft.py b/src/art/utils/sft.py index 73db8cd28..6a9a50872 100644 --- a/src/art/utils/sft.py +++ b/src/art/utils/sft.py @@ -352,6 +352,7 @@ async def train_sft_from_file( _config: "DevTrainSFTConfig | None" = None, verbose: bool = False, shuffle_buffer_size: int = 10000, + log_metrics: bool = True, ) -> None: """ Train a model using supervised fine-tuning from a JSONL file. @@ -375,6 +376,7 @@ async def train_sft_from_file( verbose: Whether to print verbose output. Default: False shuffle_buffer_size: Size of shuffle buffer. Default: 10000. Larger values give better shuffling but use more memory. + log_metrics: Whether to log SFT optimizer metrics. Default: True. Example: await train_sft_from_file( @@ -449,4 +451,5 @@ async def train_sft_from_file( config, _config=_config, verbose=verbose, + log_metrics=log_metrics, ) diff --git a/tests/unit/test_frontend_logging.py b/tests/unit/test_frontend_logging.py index 9e72f82d3..8014c98b2 100644 --- a/tests/unit/test_frontend_logging.py +++ b/tests/unit/test_frontend_logging.py @@ -11,12 +11,14 @@ import json import os from pathlib import Path +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import polars as pl import pytest from art import Model, TrainableModel, Trajectory, TrajectoryGroup +from art.backend import Backend from art.local.backend import LocalBackend from art.metrics_taxonomy import TRAIN_GRADIENT_STEPS_KEY from art.utils.trajectory_logging import read_trajectory_groups_parquet @@ -1011,11 +1013,11 @@ def test_report_metrics_custom(self): class TestTrainSFTMetricsAggregation: - """Test that train_sft aggregates metrics and logs once (same as RL).""" + """Test that train_sft logs SFT progress and one checkpoint summary.""" @pytest.mark.asyncio async def test_train_sft_aggregates_metrics(self, tmp_path: Path): - """Verify train_sft aggregates metrics from multiple batches into one log entry.""" + """Verify train_sft logs batch metrics plus an aggregate checkpoint row.""" model = TrainableModel( name="test-sft", project="test-project", @@ -1045,7 +1047,7 @@ async def mock_train_sft(*args, **kwargs): } mock_backend._train_sft = mock_train_sft - mock_backend._get_step = AsyncMock(return_value=1) # Step after training + mock_backend._get_step = AsyncMock(side_effect=[0, 1]) model._backend = mock_backend # Create dummy trajectories @@ -1063,25 +1065,29 @@ async def mock_train_sft(*args, **kwargs): # Run train_sft await model.train_sft(trajectories) - # Verify history.jsonl has exactly ONE entry (not 3) history_path = tmp_path / "test-project/models/test-sft/history.jsonl" assert history_path.exists(), "history.jsonl should be created" with open(history_path) as f: lines = f.readlines() - assert len(lines) == 1, f"Expected 1 log entry, got {len(lines)}" - entries = [json.loads(line) for line in lines] - merged: dict[str, float] = {} - for entry in entries: - merged.update(entry) - + sft_entries = [entry for entry in entries if "sft/gradient_step" in entry] + summary_entries = [entry for entry in entries if "loss/train" in entry] + + assert len(entries) == 4 + assert len(sft_entries) == 3 + assert len(summary_entries) == 1 + assert [entry["sft/gradient_step"] for entry in sft_entries] == [1.0, 2.0, 3.0] + assert [entry["sft/loss/train"] for entry in sft_entries] == [1.0, 0.8, 0.6] assert all(entry["step"] == 1 for entry in entries) - assert merged["loss/train"] == pytest.approx(0.8) # (1.0 + 0.8 + 0.6) / 3 - assert merged["loss/grad_norm"] == pytest.approx(0.4) # (0.5 + 0.4 + 0.3) / 3 - assert merged["time/step_trainer_s"] >= 0 - assert merged["time/cum/trainer_s"] >= 0 + assert all(entry["training_step"] == 1 for entry in entries) + + summary = summary_entries[0] + assert summary["loss/train"] == pytest.approx(0.8) # (1.0 + 0.8 + 0.6) / 3 + assert summary["loss/grad_norm"] == pytest.approx(0.4) # (0.5 + 0.4 + 0.3) / 3 + assert summary["time/step_trainer_s"] >= 0 + assert summary["time/cum/trainer_s"] >= 0 @pytest.mark.asyncio async def test_train_sft_single_step_increment(self, tmp_path: Path): @@ -1101,7 +1107,7 @@ async def mock_train_sft(*args, **kwargs): yield {"loss": 1.0 - i * 0.1} mock_backend._train_sft = mock_train_sft - mock_backend._get_step = AsyncMock(return_value=1) # Step is 1 after training + mock_backend._get_step = AsyncMock(side_effect=[0, 1]) model._backend = mock_backend trajectories = [ @@ -1114,12 +1120,19 @@ async def mock_train_sft(*args, **kwargs): await model.train_sft(trajectories) - # Verify only one log entry at step 1 history_path = tmp_path / "test-project/models/test-sft-step/history.jsonl" df = pl.read_ndjson(str(history_path)) - assert len(df) == 1, "Should have exactly 1 log entry" + assert len(df) == 6, "Should log 5 SFT rows plus 1 summary row" assert set(df["step"].to_list()) == {1}, "Step should be 1 (single increment)" + assert set(df["training_step"].to_list()) == {1} + assert df.drop_nulls("sft/gradient_step")["sft/gradient_step"].to_list() == [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + ] @pytest.mark.asyncio async def test_train_sft_no_metrics_when_empty(self, tmp_path: Path): @@ -1151,6 +1164,84 @@ async def mock_train_sft(*args, **kwargs): "No history.jsonl should be created for empty training" ) + @pytest.mark.asyncio + async def test_train_sft_logs_every_gradient_step(self, tmp_path: Path): + """Verify train_sft logs every SFT optimizer metric row.""" + model = TrainableModel( + name="test-sft-every-step", + project="test-project", + base_model="gpt-4", + base_path=str(tmp_path), + ) + + mock_backend = MagicMock() + + async def mock_train_sft(*args, **kwargs): + for i in range(5): + yield {"loss/train": 1.0 - i * 0.1} + + mock_backend._train_sft = mock_train_sft + mock_backend._get_step = AsyncMock(side_effect=[0, 1]) + model._backend = mock_backend + + await model.train_sft([]) + + history_path = ( + tmp_path / "test-project/models/test-sft-every-step/history.jsonl" + ) + rows = [json.loads(line) for line in history_path.read_text().splitlines()] + sft_rows = [row for row in rows if "sft/gradient_step" in row] + + assert [row["sft/gradient_step"] for row in sft_rows] == [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + ] + assert len([row for row in rows if "loss/train" in row]) == 1 + + @pytest.mark.asyncio + async def test_train_sft_remote_logging_does_not_write_local_history( + self, tmp_path: Path + ): + """Verify remote SFT metric owners suppress client-side metric logging.""" + + class RemoteLoggingBackend: + def __init__(self) -> None: + self.dev_config = None + + def logs_sft_metrics_remotely(self) -> bool: + return True + + async def _get_step(self, _model): + return 0 + + async def _train_sft( + self, _model, _trajectories, _config, dev_config, _verbose + ): + self.dev_config = dev_config + yield {"loss/train": 1.0} + yield {"loss/train": 0.5} + + backend = RemoteLoggingBackend() + model = TrainableModel( + name="test-sft-remote", + project="test-project", + base_model="gpt-4", + base_path=str(tmp_path), + ) + model._backend = cast(Backend, backend) + + await model.train_sft([]) + + history_path = tmp_path / "test-project/models/test-sft-remote/history.jsonl" + assert not history_path.exists() + assert backend.dev_config is not None + metric_logging = backend.dev_config["metric_logging"] + assert metric_logging["enabled"] is True + assert metric_logging["target_training_step"] == 1 + class TestGradientStepMetrics: @pytest.mark.asyncio diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index 5a290ebfb..529cdf14a 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -67,12 +67,14 @@ def test_get_wandb_run_registers_taxonomy_sections(self, tmp_path: Path) -> None assert define_calls == [ (("training_step",), {}), (("time/wall_clock_sec",), {}), + (("sft/gradient_step",), {}), (("reward/*",), {"step_metric": "training_step"}), (("loss/*",), {"step_metric": "training_step"}), (("throughput/*",), {"step_metric": "training_step"}), (("costs/*",), {"step_metric": "training_step"}), (("time/*",), {"step_metric": "training_step"}), (("data/*",), {"step_metric": "training_step"}), + (("sft/*",), {"step_metric": "sft/gradient_step"}), (("train/*",), {"step_metric": "training_step"}), (("val/*",), {"step_metric": "training_step"}), (("test/*",), {"step_metric": "training_step"}), diff --git a/tests/unit/test_serverless_pipeline_trainer_compat.py b/tests/unit/test_serverless_pipeline_trainer_compat.py index fec8d23f7..89898ce7b 100644 --- a/tests/unit/test_serverless_pipeline_trainer_compat.py +++ b/tests/unit/test_serverless_pipeline_trainer_compat.py @@ -1,3 +1,4 @@ +import sys from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -6,7 +7,7 @@ from art import TrainableModel, Trajectory, TrajectoryGroup from art.serverless.backend import ServerlessBackend -from art.types import TrainConfig +from art.types import TrainConfig, TrainSFTConfig def _make_group() -> TrajectoryGroup: @@ -187,3 +188,84 @@ async def no_sleep(_seconds: float) -> None: assert payload["kl_ref_adapter_path"] == "/tmp/ref" assert payload["allow_training_without_logprobs"] is True assert payload["scale_learning_rate_by_reward_std_dev"] is True + + +@pytest.mark.asyncio +async def test_serverless_train_sft_forwards_metric_logging_config() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-sft-config-payload", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + model.entity = "entity" + model.run_id = "canonical-run-id" + + captured: dict[str, Any] = {} + backend._client.sft_training_jobs.create = AsyncMock( # type: ignore[attr-defined] + side_effect=lambda **kwargs: ( + captured.update(kwargs) or SimpleNamespace(id="sft-training-job-id") + ) + ) + + async def events_list(**_kwargs: Any): + yield SimpleNamespace(id="event-id", type="training_ended", data={}) + + backend._client.sft_training_jobs.events.list = events_list # type: ignore[attr-defined] + + async def no_sleep(_seconds: float) -> None: + return None + + class FakeArtifact: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + pass + + def add_file(self, *_args: Any, **_kwargs: Any) -> None: + pass + + def wait(self): + return self + + class FakeRun: + def log_artifact(self, artifact): + return artifact + + def finish(self) -> None: + pass + + fake_wandb = SimpleNamespace( + Artifact=FakeArtifact, + init=MagicMock(return_value=FakeRun()), + Settings=lambda **kwargs: kwargs, + ) + + trajectory = Trajectory( + messages_and_choices=[ + {"role": "user", "content": "prompt"}, + {"role": "assistant", "content": "answer"}, + ], + ) + + with patch.object(model, "_get_wandb_run", return_value=None): + with patch.dict(sys.modules, {"wandb": fake_wandb}): + with patch("art.serverless.backend.asyncio.sleep", no_sleep): + async for _ in backend._train_sft( + model, + [trajectory], + TrainSFTConfig(learning_rate=[1e-4], batch_size=2), + { + "metric_logging": { + "enabled": True, + "target_training_step": 1, + }, + }, + ): + pass + + config = captured["config"] + metric_logging = config["metric_logging"] + assert config["learning_rate"] == [1e-4] + assert config["batch_size"] == 2 + assert metric_logging["enabled"] is True + assert metric_logging["target_training_step"] == 1