Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dev/sft/sft-from-file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
import random

import art
from art.local import LocalBackend
from art.megatron import MegatronBackend
from art.utils.sft import train_sft_from_file


async def main():
backend = LocalBackend()
backend = MegatronBackend()

model_name = "run-" + "".join(
random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)
)
model = art.TrainableModel(
name=model_name,
project="sft-from-file",
base_model="meta-llama/Llama-3.1-8B-Instruct",
base_model="Qwen/Qwen3.6-35B-A3B",
)
await model.register(backend)

Expand Down
4 changes: 2 additions & 2 deletions dev/sft/sft-warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dotenv import load_dotenv

import art
from art.local import LocalBackend
from art.megatron import MegatronBackend
from art.utils.sft import create_sft_dataset_iterator

# Simple SFT trajectories - teach model to respond "maybe"
Expand Down Expand Up @@ -43,7 +43,7 @@ async def rl_rollout(model: art.TrainableModel, prompt: str) -> art.Trajectory:
async def main():
load_dotenv()

backend = LocalBackend()
backend = MegatronBackend()
model_name = "sft-warmup-" + "".join(
random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)
)
Expand Down
4 changes: 4 additions & 0 deletions src/art/dev/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing_extensions import TypedDict

from art.types import SFTMetricLoggingConfig

if TYPE_CHECKING:
from art.megatron.routing_replay import MoeRoutingReplayBundle

Expand Down Expand Up @@ -40,3 +42,5 @@ class TrainConfig(TypedDict, total=False):

class TrainSFTConfig(TypedDict, total=False):
"""Experimental SFT configuration options. Use at your own risk."""

metric_logging: SFTMetricLoggingConfig
3 changes: 3 additions & 0 deletions src/art/metrics_taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .trajectories import TrajectoryGroup

TRAIN_GRADIENT_STEPS_KEY = "data/step_num_gradient_steps"
SFT_METRIC_PREFIX = "sft"
SFT_GRADIENT_STEP_KEY = "gradient_step"
SFT_WANDB_GRADIENT_STEP_KEY = f"{SFT_METRIC_PREFIX}/{SFT_GRADIENT_STEP_KEY}"
_INVARIANT_METRIC_KEYS = frozenset({TRAIN_GRADIENT_STEPS_KEY})


Expand Down
74 changes: 68 additions & 6 deletions src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from .costs import CostCalculator
from .metrics import MetricsBuilder, is_builder_managed_metric
from .metrics_taxonomy import (
SFT_GRADIENT_STEP_KEY,
SFT_METRIC_PREFIX,
SFT_WANDB_GRADIENT_STEP_KEY,
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
build_data_metrics_from_summary,
summarize_trajectory_groups,
)
from .preprocessing.moe_routing import attach_moe_routing_metadata_to_choice
from .trajectories import Trajectory, TrajectoryGroup
from .types import TrainSFTConfig
from .types import SFTMetricLoggingConfig, TrainSFTConfig
from .utils.trajectory_logging import write_trajectory_groups_parquet

if TYPE_CHECKING:
Expand Down Expand Up @@ -625,6 +628,7 @@ def _get_wandb_run(self) -> Optional["Run"]:
self,
"_wandb_defined_metrics",
{
SFT_WANDB_GRADIENT_STEP_KEY,
"training_step",
"time/wall_clock_sec",
},
Expand All @@ -634,12 +638,16 @@ def _get_wandb_run(self) -> Optional["Run"]:
# This allows out-of-order logging (e.g., async validation for previous steps).
run.define_metric("training_step")
run.define_metric("time/wall_clock_sec")
run.define_metric(SFT_WANDB_GRADIENT_STEP_KEY)
run.define_metric("reward/*", step_metric="training_step")
run.define_metric("loss/*", step_metric="training_step")
run.define_metric("throughput/*", step_metric="training_step")
run.define_metric("costs/*", step_metric="training_step")
run.define_metric("time/*", step_metric="training_step")
run.define_metric("data/*", step_metric="training_step")
run.define_metric(
f"{SFT_METRIC_PREFIX}/*", step_metric=SFT_WANDB_GRADIENT_STEP_KEY
)
run.define_metric("train/*", step_metric="training_step")
run.define_metric("val/*", step_metric="training_step")
run.define_metric("test/*", step_metric="training_step")
Expand Down Expand Up @@ -1230,6 +1238,7 @@ async def train_sft(
config: TrainSFTConfig | None = None,
_config: dev.TrainSFTConfig | None = None,
verbose: bool = False,
log_metrics: bool = True,
) -> None:
"""
Supervised fine-tune the model with an iterable of trajectories.
Expand All @@ -1241,31 +1250,84 @@ async def train_sft(
_config: Additional experimental configuration that is subject to change and
not yet part of the public API. Use at your own risk.
verbose: Whether to print verbose output.
log_metrics: Whether to log SFT optimizer metrics. Defaults to True.
"""
if config is None:
config = TrainSFTConfig()

backend = self.backend()
backend_logs_sft_metrics = (
log_metrics and self._backend_logs_sft_metrics_remotely(backend)
)

_config = cast(dev.TrainSFTConfig, {**(_config or {})})
if log_metrics:
metric_logging_config: SFTMetricLoggingConfig = {
"enabled": True,
}
if backend_logs_sft_metrics:
metric_logging_config["target_training_step"] = (
await self.get_step()
) + 1
_config["metric_logging"] = metric_logging_config
else:
_config["metric_logging"] = {"enabled": False}

# Train (backend yields metrics for each batch without logging)
# Collect all metrics and aggregate them at the end (same as RL)
_config = _config or {} # ty:ignore[invalid-assignment]
# Collect all metrics and aggregate them at the end for the checkpoint summary.
training_metrics: list[dict[str, float]] = []
local_sft_checkpoint_step: int | None = None
trainer_started = time.monotonic()
async for metrics in self.backend()._train_sft(
async for metrics in backend._train_sft(
self,
trajectories,
config,
_config, # ty:ignore[invalid-argument-type]
verbose,
):
training_metrics.append(metrics)
gradient_step = len(training_metrics)
if log_metrics and not backend_logs_sft_metrics:
if local_sft_checkpoint_step is None:
local_sft_checkpoint_step = await self.get_step() + 1
await self._log_sft_metric_sample(
metrics,
checkpoint_step=local_sft_checkpoint_step,
gradient_step=gradient_step,
)
trainer_elapsed = time.monotonic() - trainer_started

# Log aggregated training metrics once (same as RL)
if training_metrics:
# Log aggregated training metrics once at the checkpoint step. For
# remote-logging backends, the remote SFT job owns this row too.
if training_metrics and log_metrics and not backend_logs_sft_metrics:
avg_metrics = average_metric_samples(training_metrics)
avg_metrics["time/step_trainer_s"] = trainer_elapsed
# Get the current step after training
step = await self.get_step()
await self.log(
trajectories=None, split="train", metrics=avg_metrics, step=step
)

@staticmethod
def _backend_logs_sft_metrics_remotely(backend: "Backend") -> bool:
remote_logger = getattr(type(backend), "logs_sft_metrics_remotely", None)
if not callable(remote_logger):
return False
return bool(remote_logger(backend))

async def _log_sft_metric_sample(
self,
metrics: dict[str, float],
*,
checkpoint_step: int,
gradient_step: int,
) -> None:
await self.log(
trajectories=None,
split=SFT_METRIC_PREFIX,
metrics={
SFT_GRADIENT_STEP_KEY: float(gradient_step),
**metrics,
},
step=checkpoint_step,
)
18 changes: 16 additions & 2 deletions src/art/serverless/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from contextlib import asynccontextmanager
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Literal, cast
import warnings

from openai._types import NOT_GIVEN
Expand All @@ -22,7 +22,12 @@
summarize_trajectory_groups,
)
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import ServerlessTrainResult, TrainConfig, TrainSFTConfig
from ..types import (
ServerlessTrainResult,
SFTMetricLoggingConfig,
TrainConfig,
TrainSFTConfig,
)
from ..utils.record_provenance import record_provenance

if TYPE_CHECKING:
Expand Down Expand Up @@ -88,6 +93,9 @@ def __init__(
self._base_url = str(client.base_url)
self._client = client

def logs_sft_metrics_remotely(self) -> bool:
return True

async def close(self) -> None:
await self._client.close() # ty:ignore[possibly-missing-attribute]

Expand Down Expand Up @@ -607,6 +615,12 @@ async def _train_sft(
)
sft_config["batch_size"] = batch_size
sft_config["learning_rate"] = config.learning_rate
metric_logging = cast(
SFTMetricLoggingConfig,
dict(dev_config.get("metric_logging", {}) or {}),
)
if metric_logging.get("enabled"):
sft_config["metric_logging"] = metric_logging

sft_training_job = await self._client.sft_training_jobs.create(
model_id=model.id,
Expand Down
2 changes: 2 additions & 0 deletions src/art/serverless/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing_extensions import override

from ..trajectories import TrajectoryGroup
from ..types import SFTMetricLoggingConfig

ResponseT = TypeVar("ResponseT")

Expand Down Expand Up @@ -80,6 +81,7 @@ class ExperimentalTrainingConfig(TypedDict, total=False):
class SFTTrainingConfig(TypedDict, total=False):
batch_size: int | None
learning_rate: float | list[float] | None
metric_logging: SFTMetricLoggingConfig | None


class TrainingJob(BaseModel):
Expand Down
6 changes: 6 additions & 0 deletions src/art/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
import pydantic
from pydantic import SkipValidation
from typing_extensions import TypedDict

Message = Annotated[ChatCompletionMessageParam, SkipValidation]
MessageOrChoice = Message | Choice
Expand All @@ -25,6 +26,11 @@ class TrainSFTConfig(pydantic.BaseModel):
batch_size: int | Literal["auto"] = "auto"


class SFTMetricLoggingConfig(TypedDict, total=False):
enabled: bool
target_training_step: int


Verbosity = Literal[0, 1, 2]


Expand Down
3 changes: 3 additions & 0 deletions src/art/utils/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ async def train_sft_from_file(
_config: "DevTrainSFTConfig | None" = None,
verbose: bool = False,
shuffle_buffer_size: int = 10000,
log_metrics: bool = True,
) -> None:
"""
Train a model using supervised fine-tuning from a JSONL file.
Expand All @@ -375,6 +376,7 @@ async def train_sft_from_file(
verbose: Whether to print verbose output. Default: False
shuffle_buffer_size: Size of shuffle buffer. Default: 10000.
Larger values give better shuffling but use more memory.
log_metrics: Whether to log SFT optimizer metrics. Default: True.

Example:
await train_sft_from_file(
Expand Down Expand Up @@ -449,4 +451,5 @@ async def train_sft_from_file(
config,
_config=_config,
verbose=verbose,
log_metrics=log_metrics,
)
Loading
Loading