diff --git a/dev/run_yes_no_maybe_kl_advantage_tinker.py b/dev/run_yes_no_maybe_kl_advantage_tinker.py new file mode 100644 index 000000000..468001d30 --- /dev/null +++ b/dev/run_yes_no_maybe_kl_advantage_tinker.py @@ -0,0 +1,104 @@ +"""Launch yes-no-maybe-kl-advantage-tinker training on SkyPilot (Kubernetes). + +Usage: + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --fast + uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --base-model Qwen/Qwen2.5-7B-Instruct +""" + +import argparse +import os +import textwrap + +from dotenv import load_dotenv +import sky +from sky import ClusterStatus + +load_dotenv() + +parser = argparse.ArgumentParser( + description="Launch yes-no-maybe KL advantage training (Tinker) on SkyPilot." +) +parser.add_argument( + "--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)." +) +parser.add_argument( + "--base-model", type=str, default="meta-llama/Llama-3.1-8B-Instruct" +) +parser.add_argument("--num-steps", type=int, default=20) +parser.add_argument("--kl-penalty-coef", type=float, default=0.1) +parser.add_argument("--accelerator", type=str, default="H200:1") +parser.add_argument("--cluster-name", type=str, default=None) +parser.add_argument( + "--kl-ref-step", + type=int, + default=None, + help="Checkpoint step of training model to use as KL reference", +) +args = parser.parse_args() + +cluster_name = args.cluster_name or f"ynm-tinker-kl-{args.kl_penalty_coef}" +cluster_prefix = os.environ.get("CLUSTER_PREFIX") +if cluster_prefix: + cluster_name = f"{cluster_prefix}-{cluster_name}" + +setup_script = textwrap.dedent("""\ + echo 'Setting up environment...' + apt install -y nvtop + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.local/bin/env +""") + +kl_ref_env = "" +if args.kl_ref_step is not None: + kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} " + +run_script = textwrap.dedent(f"""\ + source $HOME/.local/bin/env + cd ~/sky_workdir + {kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra tinker dev/yes-no-maybe-kl-advantage-tinker.py +""") + +task = sky.Task( + name="yes-no-maybe-kl-advantage-tinker", + setup=setup_script, + run=run_script, + workdir=".", +) +task.set_resources( + sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes()) +) +task.set_file_mounts( + { + "~/sky_workdir/.env": ".env", + } +) + +print(f"Launching on cluster: {cluster_name}") +print(f" base_model: {args.base_model}") +print(f" accelerator: {args.accelerator}") +print(f" num_steps: {args.num_steps}") +print(f" kl_penalty_coef: {args.kl_penalty_coef}") +if args.kl_ref_step is not None: + print(f" kl_ref_step: {args.kl_ref_step}") + +# Cancel any existing jobs on this cluster +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) +if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP: + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") + sky.stream_and_get(sky.cancel(cluster_name, all=True)) + +job_id, _ = sky.stream_and_get( + sky.launch( + task, + cluster_name=cluster_name, + retry_until_up=True, + idle_minutes_to_autostop=60, + down=True, + fast=args.fast, + ) +) + +print(f"Job submitted (ID: {job_id}). Streaming logs...") +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) +print(f"Job {job_id} finished with exit code {exit_code}.") diff --git a/dev/yes-no-maybe-kl-advantage-tinker.py b/dev/yes-no-maybe-kl-advantage-tinker.py new file mode 100644 index 000000000..5983a1f2d --- /dev/null +++ b/dev/yes-no-maybe-kl-advantage-tinker.py @@ -0,0 +1,111 @@ +"""Yes-no-maybe training with KL-penalized advantage adjustment (Tinker backend). + +Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted +more from the reference model get reduced advantages, while tokens that have +drifted less get increased advantages. + +Uses meta-llama/Llama-3.1-8B-Instruct as the base model (trained via Tinker). +""" + +import asyncio +from itertools import permutations +import os +import random +import string + +from dotenv import load_dotenv +import openai + +import art +from art.tinker_native import TinkerNativeBackend + + +async def rollout( + client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str +) -> art.Trajectory: + messages: art.Messages = [ + { + "role": "user", + "content": prompt, + } + ] + chat_completion = await client.chat.completions.create( + messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100 + ) + choice = chat_completion.choices[0] + content = choice.message.content + assert isinstance(content, str) + if content == "yes": + reward = 0.5 + elif content == "no": + reward = 0.75 + elif content == "maybe": + reward = 1.0 + else: + reward = 0.0 + return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward) + + +def with_quotes(w: str) -> str: + return f"'{w}'" + + +async def main(): + load_dotenv() + + backend = TinkerNativeBackend() + base_model = os.environ.get("BASE_MODEL", "meta-llama/Llama-3.1-8B-Instruct") + kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) + model = art.TrainableModel( + name=os.environ.get("MODEL_NAME", f"tinker-{random_suffix}-{kl_penalty_coef}"), + project="yes-no-maybe", + base_model=base_model, + ) + await model.register(backend) + + kl_penalty_reference_step: int | None = ( + int(os.environ["KL_REF_STEP"]) + if os.environ.get("KL_REF_STEP") is not None + else None + ) + + prompts = [ + f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}" + for prefix in ["respond", "just respond"] + for use_quotes in [True, False] + for words in ( + list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n) + ) + ] + + openai_client = model.openai_client() + max_steps = int(os.environ.get("NUM_STEPS", "20")) + start_step = await model.get_step() + for step in range(start_step, start_step + max_steps): + train_groups = await art.gather_trajectory_groups( + ( + art.TrajectoryGroup( + rollout(openai_client, model, prompt) for _ in range(32) + ) + for prompt in prompts + ) + ) + result = await backend.train( + model, + train_groups, + learning_rate=1e-4, + kl_penalty_coef=kl_penalty_coef, + kl_penalty_reference_step=kl_penalty_reference_step, + ) + await model.log( + train_groups, + metrics=result.metrics, + step=result.step, + split="train", + ) + print(f"step {result.step}: {result.metrics}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py index 41ce0b119..ccd21b243 100644 --- a/dev/yes-no-maybe-kl-advantage.py +++ b/dev/yes-no-maybe-kl-advantage.py @@ -10,6 +10,8 @@ import asyncio from itertools import permutations import os +import random +import string from dotenv import load_dotenv import openai @@ -54,8 +56,9 @@ async def main(): backend = LocalBackend() base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct") kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1")) + random_suffix = "".join(random.choices(string.ascii_lowercase, k=4)) model = art.TrainableModel( - name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"), + name=os.environ.get("MODEL_NAME", f"local-{random_suffix}-{kl_penalty_coef}"), project="yes-no-maybe", base_model=base_model, ) diff --git a/src/art/_backend_training.py b/src/art/_backend_training.py index 92e013f00..a2feb8133 100644 --- a/src/art/_backend_training.py +++ b/src/art/_backend_training.py @@ -28,6 +28,7 @@ def build_rl_train_configs( max_negative_advantage_importance_sampling_weight: float | None = None, kimi_k2_tau: float | None = None, kl_penalty_coef: float = 0.0, + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner", allow_training_without_logprobs: bool | None = None, plot_tensors: bool | None = None, truncated_importance_sampling: float | None = None, @@ -41,11 +42,13 @@ def build_rl_train_configs( config = TrainConfig( learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=kl_penalty_source, ) dev_config: dev.TrainConfig = { "advantage_balance": advantage_balance, "importance_sampling_level": importance_sampling_level, "kl_penalty_coef": kl_penalty_coef, + "kl_penalty_source": kl_penalty_source, "mask_prob_ratio": mask_prob_ratio, "ppo": ppo, "precalculate_logprobs": precalculate_logprobs, diff --git a/src/art/dev/train.py b/src/art/dev/train.py index c9819b4b3..98c0abdcb 100644 --- a/src/art/dev/train.py +++ b/src/art/dev/train.py @@ -21,6 +21,9 @@ class TrainConfig(TypedDict, total=False): ] kimi_k2_tau: float | None kl_penalty_coef: float + kl_penalty_reference_step: int | None + kl_penalty_source: Literal["current_learner", "sample"] + kl_penalty_step_lag: int | None kl_ref_adapter_path: str | None logprob_calculation_chunk_size: int mask_prob_ratio: bool diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 37cb6f882..431de179b 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -705,6 +705,7 @@ async def train( # type: ignore[override] kl_penalty_coef: float = 0.0, kl_penalty_reference_step: int | None = None, kl_ref_adapter_path: str | None = None, + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner", epsilon: float | None = None, epsilon_high: float | None = None, # Advantage computation @@ -761,6 +762,11 @@ async def train( # type: ignore[override] kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. Alternative to kl_penalty_reference_step. + kl_penalty_source: Which policy's logprobs to compare against the + reference when building the centered KL penalty. Use + "current_learner" to match the original ART implementation, or + "sample" to shape from the rollout policy logprobs, which is + usually better for async/off-policy workloads. epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn. epsilon_high: Asymmetric upper clip bound. Defaults to epsilon. advantage_balance: Balance between negative and positive advantages @@ -814,6 +820,7 @@ async def train( # type: ignore[override] scale_rewards = False if adam_params is not None: raise ValueError("LocalBackend requires adam_params=None.") + assert kl_penalty_source in {"current_learner", "sample"} if ( self._requires_explicit_packed_sequence_length and packed_sequence_length is None @@ -831,6 +838,15 @@ async def train( # type: ignore[override] get_model_dir(model=model, art_path=self._path), kl_penalty_reference_step, ) + elif ( + resolved_kl_ref_adapter_path is None + and kl_penalty_coef > 0.0 + and self._requires_explicit_packed_sequence_length + ): + resolved_kl_ref_adapter_path = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=self._path), + 0, + ) config, dev_config = build_rl_train_configs( learning_rate=learning_rate, advantage_balance=advantage_balance, @@ -844,6 +860,7 @@ async def train( # type: ignore[override] max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, kimi_k2_tau=kimi_k2_tau, kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=kl_penalty_source, allow_training_without_logprobs=allow_training_without_logprobs, plot_tensors=plot_tensors, truncated_importance_sampling=truncated_importance_sampling, diff --git a/src/art/loss.py b/src/art/loss.py index 6a4096e68..99719be19 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -167,7 +167,14 @@ def loss_fn( kl_policy_ref: torch.Tensor | None = None kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0) if kl_penalty_coef > 0 and ref_logprobs is not None: - kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask + match experimental_config.get("kl_penalty_source", "current_learner"): + case "sample": + kl_source_logprobs = old_logprobs.detach() + case "current_learner": + kl_source_logprobs = new_logprobs.detach() + case other: + raise AssertionError(other) + kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask avg_kl = aligned_inputs.masked_mean(kl_per_token, assistant_mask) kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask advantages = advantages + kl_penalty diff --git a/src/art/megatron/train.py b/src/art/megatron/train.py index 134dec74f..992a98c2e 100644 --- a/src/art/megatron/train.py +++ b/src/art/megatron/train.py @@ -74,6 +74,7 @@ ) from art.megatron.training.microbatches import ( CpBatchLookaheadState, + PreparedRLMicroInputs, PreparedSFTMicroInputs, _causal_attention_state, _clone_packed_tensors, @@ -173,6 +174,7 @@ class TrainStepResult(BaseModel): reduced_loss: torch.Tensor probs_corr: float + kl_policy_ref: float | None = None new_logprobs: list[torch.Tensor] | None = None update_successful: bool grad_norm: float @@ -486,8 +488,10 @@ def run_megatron_rl_job( adapter_model = None template = None zero_template = None + ref_logprobs_by_index = None cp_lookahead_state = None next_step_first_micro = None + next_step_first_ref_logprobs = None step_result = None job_completed = False @@ -516,6 +520,15 @@ def run_megatron_rl_job( job.config.grad_accumulation_sequences ) num_steps = math.ceil(num_sequences / global_grad_accumulation_sequences) + ref_logprobs_by_index = _prepare_kl_reference_logprobs( + runtime=runtime, + job=job, + adapter_model=adapter_model, + packed_tensors=packed_tensors, + num_sequences=num_sequences, + num_steps=num_steps, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) topology = _infer_parallel_topology(runtime.model) cp_lookahead_state = CpBatchLookaheadState() if int(topology.cp) > 1 else None for step_index in range(num_steps): @@ -529,6 +542,15 @@ def run_megatron_rl_job( micro_indices, zero_template, ) + ref_logprobs = ( + select_micro_ref_logprobs( + ref_logprobs_by_index, + micro_indices, + zero_template, + ) + if ref_logprobs_by_index is not None + else None + ) next_step_first_micro = ( _select_next_step_first_micro( packed_tensors=packed_tensors, @@ -541,6 +563,18 @@ def run_megatron_rl_job( if cp_lookahead_state is not None else None ) + next_step_first_ref_logprobs = ( + _select_next_step_first_ref_logprobs( + ref_logprobs_by_index=ref_logprobs_by_index, + zero_template=zero_template, + step_index=step_index, + num_steps=num_steps, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + if cp_lookahead_state is not None and ref_logprobs_by_index is not None + else None + ) step_result = run_training_step( model_chunks=runtime.model, provider=runtime.provider, @@ -550,12 +584,13 @@ def run_megatron_rl_job( inputs=micro_inputs, config=job.config, experimental_config=cast(dev.TrainConfig, job.experimental_config), - ref_logprobs=None, + ref_logprobs=ref_logprobs, step_index=step_index, sample_index=micro_indices, moe_routing_replay_controller=runtime.moe_routing_replay_controller, cp_lookahead_state=cp_lookahead_state, next_step_first_micro=next_step_first_micro, + next_step_first_ref_logprobs=next_step_first_ref_logprobs, ) print0( runtime.rank, @@ -565,14 +600,15 @@ def run_megatron_rl_job( if runtime.rank == 0: with open(job.log_path, "a+", encoding="utf-8") as log_file: - log_msg = json.dumps( - { - "loss": step_result.reduced_loss.item(), - "grad_norm": step_result.grad_norm, - "probs_corr": step_result.probs_corr, - TRAIN_GRADIENT_STEPS_KEY: num_steps, - } - ) + metrics = { + "loss": step_result.reduced_loss.item(), + "grad_norm": step_result.grad_norm, + "probs_corr": step_result.probs_corr, + TRAIN_GRADIENT_STEPS_KEY: num_steps, + } + if step_result.kl_policy_ref is not None: + metrics["kl_policy_ref"] = step_result.kl_policy_ref + log_msg = json.dumps(metrics) print("Logging", log_msg) log_file.write(log_msg + "\n") @@ -593,10 +629,14 @@ def run_megatron_rl_job( del template if zero_template is not None: del zero_template + if ref_logprobs_by_index is not None: + del ref_logprobs_by_index if "micro_inputs" in locals(): del micro_inputs if next_step_first_micro is not None: del next_step_first_micro + if next_step_first_ref_logprobs is not None: + del next_step_first_ref_logprobs if step_result is not None: del step_result if cp_lookahead_state is not None: @@ -941,6 +981,300 @@ def _infer_parallel_topology(model_chunks: ModelChunks) -> ParallelTopology: ) +def select_micro_ref_logprobs( + ref_logprobs_by_index: dict[int, torch.Tensor], + sample_indices: list[int | None], + zero_template: PackedTensors, +) -> list[torch.Tensor]: + zero_ref_logprobs = torch.zeros_like(zero_template["tokens"], dtype=torch.float32) + return [ + zero_ref_logprobs.clone() + if sample_index is None + else ref_logprobs_by_index[sample_index] + for sample_index in sample_indices + ] + + +def _select_next_step_first_ref_logprobs( + *, + ref_logprobs_by_index: dict[int, torch.Tensor], + zero_template: PackedTensors, + step_index: int, + num_steps: int, + num_sequences: int, + global_grad_accumulation_sequences: int, +) -> torch.Tensor | None: + next_step_index = step_index + 1 + if next_step_index >= num_steps: + return None + next_micro_indices = build_micro_sample_indices( + step_index=next_step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + return select_micro_ref_logprobs( + ref_logprobs_by_index, + [next_micro_indices[0]], + zero_template, + )[0] + + +def _select_ref_logprobs( + ref_logprobs: torch.Tensor | list[torch.Tensor] | None, + micro_order: int, +) -> torch.Tensor | None: + if isinstance(ref_logprobs, list): + return ref_logprobs[micro_order] + return ref_logprobs + + +def _select_next_ref_logprobs( + ref_logprobs: torch.Tensor | list[torch.Tensor] | None, + *, + micro_order: int, + micro_count: int, + next_step_first_ref_logprobs: torch.Tensor | None, +) -> torch.Tensor | None: + if isinstance(ref_logprobs, list): + if micro_order + 1 < len(ref_logprobs): + return ref_logprobs[micro_order + 1] + return next_step_first_ref_logprobs + if micro_order + 1 >= micro_count and next_step_first_ref_logprobs is not None: + return next_step_first_ref_logprobs + return ref_logprobs + + +def _forward_prepared_rl_micro( + *, + model_chunks: ModelChunks, + model_support_handler: Any, + prepared_micro: PreparedRLMicroInputs, + device: torch.device, +) -> torch.Tensor: + model_forward_kwargs = dict( + input_ids=prepared_micro.model_tokens, + position_ids=prepared_micro.model_input_pos, + attention_mask=_placeholder_attention_mask(device), + packed_seq_params=prepared_micro.packed_seq_params, + **model_support_handler.get_forward_kwargs( + model_chunks[0], + attention_bias=prepared_micro.attention_state, + ), + ) + with attach_trace_token_uids(model_chunks, prepared_micro.local_token_uids): + if int(prepared_micro.model_tokens.numel()) == 0: + logits = model_chunks[0](**model_forward_kwargs, labels=None) + return _empty_new_logprobs_from_logits(logits, prepared_micro.model_labels) + return -model_chunks[0]( + **model_forward_kwargs, + labels=prepared_micro.model_labels, + ) + + +def _globalize_context_parallel_logprobs( + *, + local_logprobs: torch.Tensor, + attention_state: Any, + seq_len: int, +) -> torch.Tensor: + rank_plan = getattr(attention_state, "rank_plan", None) + cp_group = getattr(attention_state, "cp_group", None) + if rank_plan is None or cp_group is None: + raise RuntimeError("Context-parallel reference logprobs require a rank plan") + + global_logprobs = local_logprobs.new_zeros((1, seq_len)) + local_values = local_logprobs.reshape(-1) + cursor = 0 + for range_ in rank_plan.local_row_ranges: + if range_ is None: + continue + size = int(range_.size()) + if size <= 0: + continue + global_logprobs[0, int(range_.start) : int(range_.end)] = local_values[ + cursor : cursor + size + ] + cursor += size + + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + global_logprobs, + group=cp_group, + ) + return global_logprobs + + +@torch.no_grad() +def _calculate_megatron_logprobs( + *, + model_chunks: ModelChunks, + provider: Any, + model_support_handler: Any, + inputs: PackedTensors, + moe_routing_replay_controller: MoeRoutingReplayController | None = None, + step_index: int | None = None, + sample_index: int | None = None, + global_grad_accumulation_sequences: int | None = None, +) -> torch.Tensor: + if moe_routing_replay_controller is not None: + if step_index is None or sample_index is None: + raise ValueError( + "step_index and sample_index are required for routing replay" + ) + moe_routing_replay_controller.set_step( + step_index=step_index, + sample_index=sample_index, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + moe_routing_replay_controller.begin_micro(sample_index, 0) + + device = next(model_chunks[0].parameters()).device + topology = _infer_parallel_topology(model_chunks) + trace_token_uids = context_parallel_trace_token_uids_enabled( + topology, + moe_routing_replay_controller, + ) + previous_training_modes = [chunk.training for chunk in model_chunks] + for chunk in model_chunks: + chunk.eval() + forward_succeeded = False + try: + prepared_micro, _pending_prepared_micro = _prepare_current_rl_micro( + inputs, + device=device, + topology=topology, + provider=provider, + model_support_handler=model_support_handler, + ref_logprobs=None, + trace_token_uids=trace_token_uids, + pending_prepared_micro=None, + ) + prepare_replay_local_input_token_uids( + moe_routing_replay_controller, + prepared_micro.local_token_uids, + prepared_micro.attention_state, + ) + logprobs = _forward_prepared_rl_micro( + model_chunks=model_chunks, + model_support_handler=model_support_handler, + prepared_micro=prepared_micro, + device=device, + ) + if int(topology.cp) > 1: + logprobs = _globalize_context_parallel_logprobs( + local_logprobs=logprobs, + attention_state=prepared_micro.attention_state, + seq_len=int(inputs["tokens"].shape[1]), + ) + forward_succeeded = True + finally: + for chunk, was_training in zip(model_chunks, previous_training_modes): + chunk.train(was_training) + if moe_routing_replay_controller is not None and forward_succeeded: + moe_routing_replay_controller.finalize_step() + return logprobs.detach().cpu() + + +def _precompute_reference_logprobs( + *, + runtime: TrainingRuntime, + packed_tensors: PackedTensors, + sample_step_indices: dict[int, int], + global_grad_accumulation_sequences: int, +) -> dict[int, torch.Tensor]: + print0( + runtime.rank, + "Precomputing KL reference logprobs for", + len(sample_step_indices), + "local sequences", + ) + return { + sample_index: _calculate_megatron_logprobs( + model_chunks=runtime.model, + provider=runtime.provider, + model_support_handler=runtime.model_support_handler, + inputs=select_indexed_inputs(packed_tensors, sample_index), + moe_routing_replay_controller=runtime.moe_routing_replay_controller, + step_index=step_index, + sample_index=sample_index, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + for sample_index, step_index in sorted(sample_step_indices.items()) + } + + +def _reference_sample_step_indices( + *, + num_sequences: int, + num_steps: int, + global_grad_accumulation_sequences: int, +) -> dict[int, int]: + return { + sample_index: step_index + for step_index in range(num_steps) + for sample_index in build_micro_sample_indices( + step_index=step_index, + num_sequences=num_sequences, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + if sample_index is not None + } + + +def _prepare_kl_reference_logprobs( + *, + runtime: TrainingRuntime, + job: MegatronTrainingJob | MegatronMergedTrainingJob, + adapter_model: dict[str, torch.Tensor], + packed_tensors: PackedTensors, + num_sequences: int, + num_steps: int, + global_grad_accumulation_sequences: int, +) -> dict[int, torch.Tensor] | None: + if job.config.kl_penalty_coef <= 0.0: + return None + + ref_adapter_path = cast(dev.TrainConfig, job.experimental_config).get( + "kl_ref_adapter_path" + ) + if ref_adapter_path is None: + raise RuntimeError( + "KL penalty is enabled but no kl_ref_adapter_path was provided. " + "Megatron training requires an explicit reference LoRA path; pass " + "kl_penalty_reference_step=0 for the identity/base reference or " + "provide kl_ref_adapter_path." + ) + + adapter_swapped = os.path.abspath(ref_adapter_path) != os.path.abspath( + job.lora_path + ) + loaded_ref_adapter = False + try: + if adapter_swapped: + _load_adapter_into_model( + runtime.model, + ref_adapter_path, + runtime.rank, + handler=runtime.model_support_handler, + ) + loaded_ref_adapter = True + return _precompute_reference_logprobs( + runtime=runtime, + packed_tensors=packed_tensors, + sample_step_indices=_reference_sample_step_indices( + num_sequences=num_sequences, + num_steps=num_steps, + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ), + global_grad_accumulation_sequences=global_grad_accumulation_sequences, + ) + finally: + if loaded_ref_adapter: + assert runtime.optimizer is not None + load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer) + gc.collect() + torch.cuda.empty_cache() + + def run_megatron_sft_step( *, model_chunks: ModelChunks, @@ -1092,10 +1426,11 @@ def run_training_step( experimental_config: dev.TrainConfig, step_index: int, sample_index: int | list[int | None], - ref_logprobs: torch.Tensor | None = None, + ref_logprobs: torch.Tensor | list[torch.Tensor] | None = None, moe_routing_replay_controller: MoeRoutingReplayController | None = None, cp_lookahead_state: CpBatchLookaheadState | None = None, next_step_first_micro: PackedTensors | None = None, + next_step_first_ref_logprobs: torch.Tensor | None = None, ) -> TrainStepResult: micro_inputs = inputs if isinstance(inputs, list) else [inputs] if not micro_inputs: @@ -1145,6 +1480,8 @@ def run_training_step( raw_loss_sum: torch.Tensor | None = None loss_inputs_for_count: list[LossInputs | DispatchedPackedTensors] = [] probs_corr_total: torch.Tensor | None = None + kl_policy_ref_sum = 0.0 + kl_policy_ref_count = 0 new_logprobs_gpu: list[torch.Tensor] = [] def begin_micro(micro_order: int) -> None: @@ -1156,13 +1493,16 @@ def begin_micro(micro_order: int) -> None: for micro_order in range(micro_count): begin_micro(micro_order) + micro_ref_logprobs = _select_ref_logprobs(ref_logprobs, micro_order) + if micro_ref_logprobs is not None and int(topology.cp) <= 1: + micro_ref_logprobs = micro_ref_logprobs.to(device) prepared_micro, pending_prepared_micro = _prepare_current_rl_micro( micro_inputs[micro_order], device=device, topology=topology, provider=provider, model_support_handler=model_support_handler, - ref_logprobs=ref_logprobs, + ref_logprobs=micro_ref_logprobs, trace_token_uids=trace_token_uids, pending_prepared_micro=pending_prepared_micro, ) @@ -1172,27 +1512,12 @@ def begin_micro(micro_order: int) -> None: prepared_micro.attention_state, ) - model_forward_kwargs = dict( - input_ids=prepared_micro.model_tokens, - position_ids=prepared_micro.model_input_pos, - attention_mask=_placeholder_attention_mask(device), - packed_seq_params=prepared_micro.packed_seq_params, - **model_support_handler.get_forward_kwargs( - model_chunks[0], - attention_bias=prepared_micro.attention_state, - ), + new_logprobs = _forward_prepared_rl_micro( + model_chunks=model_chunks, + model_support_handler=model_support_handler, + prepared_micro=prepared_micro, + device=device, ) - with attach_trace_token_uids(model_chunks, prepared_micro.local_token_uids): - if int(prepared_micro.model_tokens.numel()) == 0: - logits = model_chunks[0](**model_forward_kwargs, labels=None) - new_logprobs = _empty_new_logprobs_from_logits( - logits, prepared_micro.model_labels - ) - else: - new_logprobs = -model_chunks[0]( - **model_forward_kwargs, - labels=prepared_micro.model_labels, - ) loss_info = loss_fn( prepared_micro.loss_inputs, @@ -1225,7 +1550,6 @@ def begin_micro(micro_order: int) -> None: ) micro_loss.backward() loss_inputs_for_count.append(prepared_micro.loss_inputs) - del model_forward_kwargs del prepared_micro pending_prepared_micro = _prepare_next_rl_cp_micro( _next_micro_lookahead( @@ -1237,13 +1561,21 @@ def begin_micro(micro_order: int) -> None: topology=topology, model_support_handler=model_support_handler, trace_token_uids=trace_token_uids, - ref_logprobs=ref_logprobs, + ref_logprobs=_select_next_ref_logprobs( + ref_logprobs, + micro_order=micro_order, + micro_count=micro_count, + next_step_first_ref_logprobs=next_step_first_ref_logprobs, + ), ) detached_probs_corr = loss_info.probs_corr.detach() if probs_corr_total is None: probs_corr_total = detached_probs_corr else: probs_corr_total = probs_corr_total + detached_probs_corr + if loss_info.kl_policy_ref is not None: + kl_policy_ref_sum += float(loss_info.kl_policy_ref.item()) + kl_policy_ref_count += 1 detached_micro_loss = micro_loss.detach() if raw_loss_sum is None: raw_loss_sum = detached_micro_loss @@ -1287,6 +1619,9 @@ def begin_micro(micro_order: int) -> None: return TrainStepResult( reduced_loss=reduced_loss, probs_corr=float((probs_corr_total / micro_count).item()), + kl_policy_ref=( + kl_policy_ref_sum / kl_policy_ref_count if kl_policy_ref_count > 0 else None + ), new_logprobs=[ tensor.to(device="cpu", non_blocking=True) for tensor in new_logprobs_gpu ], diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index d056ecb11..da9aa921a 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -92,6 +92,8 @@ def __init__( normalize_advantages: bool = True, adam_params: object | None = None, packed_sequence_length: int | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_step_lag: int | None = None, megatron_topology: art.MegatronTopologyConfig | None = None, max_steps: int | None = None, # Discard handling @@ -131,6 +133,8 @@ def __init__( raise ValueError("discard_queue_multiplier must be > 0") if checkpoint_retention_interval <= 0: raise ValueError("checkpoint_retention_interval must be > 0") + if kl_penalty_step_lag is not None and kl_penalty_step_lag < 1: + raise ValueError("kl_penalty_step_lag must be >= 1") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -149,6 +153,8 @@ def __init__( self.normalize_advantages = normalize_advantages self.adam_params = adam_params self.packed_sequence_length = packed_sequence_length + self.kl_penalty_coef = kl_penalty_coef + self.kl_penalty_step_lag = kl_penalty_step_lag self.megatron_topology = megatron_topology self.max_steps = max_steps self._status_log_interval_seconds = log_interval_seconds @@ -448,6 +454,11 @@ def _retained_adapter_steps(self, current_step: int) -> set[int]: min_step = max(0, current_step - self.max_steps_off_policy) return set(range(min_step, current_step + 1)) + def _kl_penalty_reference_step(self, current_step: int) -> int: + if self.kl_penalty_step_lag is None: + return 0 + return max(0, current_step - self.kl_penalty_step_lag) + async def _prune_model_adapters(self, current_step: int) -> None: if not hasattr(type(self.backend), "prune_model_adapters"): return @@ -574,6 +585,15 @@ async def _training_stage(self) -> None: } if self.packed_sequence_length is not None: train_kwargs["packed_sequence_length"] = self.packed_sequence_length + if self.kl_penalty_coef > 0.0: + kl_penalty_reference_step = self._kl_penalty_reference_step( + current_step + ) + train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef + train_kwargs["kl_penalty_source"] = "sample" + train_kwargs["kl_penalty_reference_step"] = ( + kl_penalty_reference_step + ) if self.megatron_topology is not None: train_kwargs["megatron_topology"] = self.megatron_topology result = await self.backend.train( @@ -1066,7 +1086,22 @@ def _checkpoint_infos(self) -> list[CheckpointInfo]: return sorted(checkpoints, key=lambda checkpoint: checkpoint.step) def _protected_checkpoint_steps(self, current_step: int) -> set[int]: - return {current_step} | set(self._checkpoint_lease_counts) + protected_steps = ( + {current_step} + | set(self._checkpoint_lease_counts) + | set(self._scheduled_eval_steps) + ) + if self.kl_penalty_coef > 0.0: + if self.kl_penalty_step_lag is None: + protected_steps.add(0) + else: + kl_penalty_reference_step = self._kl_penalty_reference_step( + current_step + ) + protected_steps.update( + range(kl_penalty_reference_step, current_step + 1) + ) + return protected_steps async def _run_checkpoint_retention(self, current_step: int) -> None: strategy = self.checkpoint_retention_strategy diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 4ab10742b..acdef16da 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -217,6 +217,9 @@ async def train( # type: ignore[override] adam_params: object | None = None, # KL-penalized advantage adjustment kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_penalty_source: Literal["current_learner", "sample"] | None = None, + kl_penalty_step_lag: int | None = None, kl_ref_adapter_path: str | None = None, # RL algorithm settings ppo: bool | None = None, @@ -267,6 +270,15 @@ async def train( # type: ignore[override] ServerlessBackend. kl_penalty_coef: Coefficient for KL-penalized advantage adjustment. Defaults to 0.0 (disabled). + kl_penalty_reference_step: Checkpoint step of the training model to + use as the KL reference. When omitted, the backend may use + kl_ref_adapter_path or its default reference policy. + kl_penalty_source: Which policy's logprobs to compare against the + reference policy. When omitted, defaults to "sample" if KL is + enabled and "current_learner" otherwise. + kl_penalty_step_lag: Moving KL reference lag. The serverless + backend resolves this as max(0, current_step - lag). Mutually + exclusive with kl_penalty_reference_step. kl_ref_adapter_path: Direct filesystem path to a LoRA adapter checkpoint to use as the KL reference. ppo: Legacy flag for PPO clipping. Prefer loss_fn="ppo". @@ -327,6 +339,21 @@ async def train( # type: ignore[override] scale_rewards = False if adam_params is not None: raise ValueError("ServerlessBackend requires adam_params=None.") + if kl_penalty_reference_step is not None and kl_penalty_reference_step < 0: + raise ValueError("kl_penalty_reference_step must be >= 0.") + if kl_penalty_step_lag is not None: + if kl_penalty_step_lag < 1: + raise ValueError("kl_penalty_step_lag must be >= 1.") + if kl_penalty_reference_step is not None: + raise ValueError( + "Only one of kl_penalty_reference_step and " + "kl_penalty_step_lag may be set." + ) + resolved_kl_penalty_source: Literal["current_learner", "sample"] = ( + kl_penalty_source + if kl_penalty_source is not None + else ("sample" if kl_penalty_coef > 0.0 else "current_learner") + ) _ = save_checkpoint config, dev_config = build_rl_train_configs( @@ -342,6 +369,7 @@ async def train( # type: ignore[override] max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight, kimi_k2_tau=kimi_k2_tau, kl_penalty_coef=kl_penalty_coef, + kl_penalty_source=resolved_kl_penalty_source, allow_training_without_logprobs=allow_training_without_logprobs, plot_tensors=plot_tensors, truncated_importance_sampling=truncated_importance_sampling, @@ -351,6 +379,10 @@ async def train( # type: ignore[override] num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power, kl_ref_adapter_path=kl_ref_adapter_path, ) + if kl_penalty_reference_step is not None: + dev_config["kl_penalty_reference_step"] = kl_penalty_reference_step + if kl_penalty_step_lag is not None: + dev_config["kl_penalty_step_lag"] = kl_penalty_step_lag # Collect metrics from training training_metrics: list[dict[str, float]] = [] @@ -410,6 +442,9 @@ async def _train_model( importance_sampling_level=dev_config.get("importance_sampling_level"), kimi_k2_tau=dev_config.get("kimi_k2_tau"), kl_penalty_coef=dev_config.get("kl_penalty_coef"), + kl_penalty_reference_step=dev_config.get("kl_penalty_reference_step"), + kl_penalty_source=dev_config.get("kl_penalty_source"), + kl_penalty_step_lag=dev_config.get("kl_penalty_step_lag"), kl_ref_adapter_path=dev_config.get("kl_ref_adapter_path"), learning_rate=config.learning_rate, logprob_calculation_chunk_size=dev_config.get( diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 19d724e7d..249862c13 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -60,6 +60,9 @@ class ExperimentalTrainingConfig(TypedDict, total=False): ) kimi_k2_tau: float | None kl_penalty_coef: float | None + kl_penalty_reference_step: int | None + kl_penalty_source: Literal["current_learner", "sample"] | None + kl_penalty_step_lag: int | None kl_ref_adapter_path: str | None learning_rate: float | None logprob_calculation_chunk_size: int | None diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py index 796ae69be..916cd9152 100644 --- a/src/art/test/test_kl_advantage.py +++ b/src/art/test/test_kl_advantage.py @@ -2,7 +2,7 @@ import torch -from art.loss import LossInputs, loss_fn +from art.loss import LossInputs, loss_fn, shift_tensor def _make_inputs( @@ -86,7 +86,7 @@ def test_kl_advantage_zero_mean_penalty(): # Compute what the penalty should be kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask - avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6) + avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-18) kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask # Sum of penalty across tokens should be ~0 @@ -138,3 +138,51 @@ def test_kl_advantage_does_not_affect_when_no_ref(): {"kl_penalty_coef": 0.5}, ) assert loss.kl_policy_ref is None + + +def test_kl_advantage_can_use_sample_logprobs() -> None: + """Sample-source KL should use stored rollout logprobs rather than learner logprobs.""" + inputs = _make_inputs(seq_len=8) + inputs["logprobs"] = torch.tensor( + [[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32 + ) + new_logprobs = torch.tensor( + [[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32 + ) + ref_logprobs = torch.full((1, 8), -0.5) + assistant_mask = shift_tensor(inputs["assistant_mask"], False).to( + new_logprobs.dtype + ) + shifted_logprobs = shift_tensor(inputs["logprobs"], float("nan")) + sampled_logprobs = torch.where( + torch.isnan(shifted_logprobs), + new_logprobs.detach(), + shifted_logprobs, + ) + expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-18 + ) + expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-18 + ) + + sample_loss = loss_fn( + LossInputs(inputs=inputs), + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"}, + ) + learner_loss = loss_fn( + LossInputs(inputs=inputs), + new_logprobs, + ref_logprobs, + None, + {"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"}, + ) + + assert sample_loss.kl_policy_ref is not None + assert learner_loss.kl_policy_ref is not None + assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl) + assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl) + assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref) diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py index 4a4d89291..9728a290d 100644 --- a/src/art/tinker_native/backend.py +++ b/src/art/tinker_native/backend.py @@ -25,6 +25,7 @@ from openai.types.completion_usage import CompletionUsage import tinker from tinker_cookbook import renderers, tokenizer_utils +import torch import uvicorn from .. import dev @@ -81,6 +82,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str: return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric) +async def _apply_kl_penalty( + datums: list[tinker.Datum], + reference_sampling_client: tinker.SamplingClient, + kl_penalty_coef: float, +) -> dict[str, float]: + assert datums + assert kl_penalty_coef > 0.0 + + full_sequences: list[tinker.ModelInput] = [] + sampled_logprobs_by_datum: list[torch.Tensor] = [] + masks_by_datum: list[torch.Tensor] = [] + advantages_by_datum: list[torch.Tensor] = [] + for datum in datums: + target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch() + assert target_tokens.numel() > 0 + full_sequences.append( + datum.model_input.append_int(int(target_tokens[-1].item())) + ) + sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch()) + masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float()) + advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch()) + + reference_logprobs_by_datum = await asyncio.gather( + *[ + reference_sampling_client.compute_logprobs_async(full_sequence) + for full_sequence in full_sequences + ] + ) + + logprob_diffs_by_datum: list[torch.Tensor] = [] + for reference_logprobs, sampled_logprobs, mask in zip( + reference_logprobs_by_datum, + sampled_logprobs_by_datum, + masks_by_datum, + strict=True, + ): + reference_values = reference_logprobs[1:] + assert len(reference_values) == sampled_logprobs.numel() + assert all(value is not None for value in reference_values) + reference_logprobs_tensor = torch.tensor( + reference_values, + dtype=sampled_logprobs.dtype, + ) + logprob_diffs_by_datum.append( + (sampled_logprobs - reference_logprobs_tensor) * mask + ) + + total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum() + assert total_tokens.item() > 0 + avg_logprob_diff = ( + torch.stack( + [logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum] + ).sum() + / total_tokens + ) + + for datum, advantages, mask, logprob_diff in zip( + datums, + advantages_by_datum, + masks_by_datum, + logprob_diffs_by_datum, + strict=True, + ): + datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch( + advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask + ) + + return {"loss/kl_policy_ref": float(avg_logprob_diff)} + + @dataclass class ModelState: service_client: tinker.ServiceClient @@ -243,8 +314,15 @@ async def train( save_checkpoint: bool = False, loss_fn_config: dict | None = None, adam_params: tinker.AdamParams | None = None, + kl_penalty_coef: float = 0.0, + kl_penalty_reference_step: int | None = None, + kl_penalty_source: Literal["sample"] = "sample", **kwargs: Any, ) -> TrainResult: + assert kl_penalty_source == "sample", ( + "TinkerNativeBackend only supports kl_penalty_source='sample'." + ) + state = self._model_state[model.name] groups_list = list(trajectory_groups) summary = summarize_trajectory_groups(groups_list) @@ -277,6 +355,23 @@ async def train( train_tokens, pricing ) trainer_started = time.monotonic() + sampled_kl_policy_ref: float | None = None + + if kl_penalty_coef > 0: + kl_metrics = await self._tinker_sample_call( + "apply_kl_penalty", + _apply_kl_penalty( + datums, + await self._get_kl_reference_sampling_client( + state, + model.base_model, + kl_penalty_reference_step, + ), + kl_penalty_coef, + ), + ) + sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"] + metrics.update(kl_metrics) if adam_params is None: adam_params = tinker.AdamParams( @@ -315,6 +410,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) if optim_output.metrics: @@ -322,6 +422,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum: if value is None: continue canonical_key = _canonicalize_upstream_metric_key(key) + if ( + sampled_kl_policy_ref is not None + and canonical_key == "loss/kl_policy_ref" + ): + continue if canonical_key: metrics[canonical_key] = float(value) @@ -715,6 +820,19 @@ async def _get_sampler_client( state.sampler_clients[actual_step] = sampler_client return sampler_client + async def _get_kl_reference_sampling_client( + self, + state: ModelState, + base_model: str, + step: int | None, + ) -> tinker.SamplingClient: + if step is not None: + return await self._get_sampler_client(state, step) + return await self._tinker_sample_call( + "create_sampling_client_async", + state.service_client.create_sampling_client_async(base_model=base_model), + ) + def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]: normalized: list[dict[str, Any]] = [] for message in messages: diff --git a/src/art/types.py b/src/art/types.py index db04390ad..0fbc732f7 100644 --- a/src/art/types.py +++ b/src/art/types.py @@ -25,6 +25,7 @@ def _visible_device_count() -> int: class TrainConfig(pydantic.BaseModel): learning_rate: float = 5e-6 kl_penalty_coef: float = 0.0 + kl_penalty_source: Literal["current_learner", "sample"] = "current_learner" grad_accumulation_sequences: int | None = pydantic.Field(default=None, ge=1) diff --git a/tests/unit/test_megatron_reference_logprobs.py b/tests/unit/test_megatron_reference_logprobs.py new file mode 100644 index 000000000..262136709 --- /dev/null +++ b/tests/unit/test_megatron_reference_logprobs.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +import torch +from torch import nn + +from art import types +from art.megatron import train as megatron_train +from art.megatron.training import microbatches as megatron_microbatches +from art.preprocessing.pack import PackedTensors + + +def _packed_inputs(seq_len: int = 4) -> PackedTensors: + return cast( + PackedTensors, + { + "tokens": torch.arange(seq_len, dtype=torch.long).unsqueeze(0), + "input_pos": torch.arange(seq_len, dtype=torch.long).unsqueeze(0), + "assistant_mask": torch.ones((1, seq_len), dtype=torch.bool), + "group_ids": torch.zeros((1, seq_len), dtype=torch.long), + "parent_ids": torch.zeros((1, seq_len), dtype=torch.long), + }, + ) + + +def test_precompute_reference_logprobs_preserves_sample_steps(monkeypatch) -> None: + calls: list[tuple[int, int, int]] = [] + + def fake_select_indexed_inputs( + packed_tensors: dict[str, torch.Tensor], sample_index: int + ) -> dict[str, torch.Tensor]: + del packed_tensors + return {"sample_index": torch.tensor(sample_index)} + + def fake_calculate_megatron_logprobs( + *, + model_chunks: Any, + provider: Any, + model_support_handler: Any, + inputs: dict[str, torch.Tensor], + moe_routing_replay_controller: Any, + step_index: int, + sample_index: int, + global_grad_accumulation_sequences: int, + ) -> torch.Tensor: + del ( + model_chunks, + provider, + model_support_handler, + inputs, + moe_routing_replay_controller, + ) + calls.append((sample_index, step_index, global_grad_accumulation_sequences)) + return torch.tensor([[float(sample_index)]]) + + monkeypatch.setattr( + megatron_train, "select_indexed_inputs", fake_select_indexed_inputs + ) + monkeypatch.setattr( + megatron_train, + "_calculate_megatron_logprobs", + fake_calculate_megatron_logprobs, + ) + runtime = SimpleNamespace( + rank=0, + model=[], + provider=object(), + model_support_handler=object(), + moe_routing_replay_controller=object(), + ) + + result = megatron_train._precompute_reference_logprobs( + runtime=cast(megatron_train.TrainingRuntime, runtime), + packed_tensors=_packed_inputs(), + sample_step_indices={3: 1, 0: 0}, + global_grad_accumulation_sequences=4, + ) + + assert calls == [(0, 0, 4), (3, 1, 4)] + assert sorted(result) == [0, 3] + + +def test_prepare_kl_reference_logprobs_requires_reference_path() -> None: + runtime = SimpleNamespace(rank=0) + job = SimpleNamespace( + config=types.TrainConfig(kl_penalty_coef=0.25), + experimental_config={}, + lora_path="/tmp/current", + ) + + try: + megatron_train._prepare_kl_reference_logprobs( + runtime=cast(megatron_train.TrainingRuntime, runtime), + job=cast(megatron_train.MegatronTrainingJob, job), + adapter_model={}, + packed_tensors=_packed_inputs(), + num_sequences=1, + num_steps=1, + global_grad_accumulation_sequences=1, + ) + except RuntimeError as exc: + assert "kl_ref_adapter_path" in str(exc) + else: + raise AssertionError("Expected missing reference path to raise") + + +class _ReplayController: + def __init__(self) -> None: + self.events: list[tuple[str, int, int | None, int | None]] = [] + + def set_step( + self, + *, + step_index: int, + sample_index: int, + global_grad_accumulation_sequences: int | None = None, + ) -> None: + self.events.append( + ("set_step", step_index, sample_index, global_grad_accumulation_sequences) + ) + + def begin_micro(self, sample_index: int, micro_order: int) -> None: + self.events.append(("begin_micro", micro_order, sample_index, None)) + + def finalize_step(self) -> None: + self.events.append(("finalize_step", 0, None, None)) + + +class _Chunk(nn.Module): + def __init__(self, controller: _ReplayController) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(())) + self.controller = controller + self.training_modes_seen: list[bool] = [] + + def forward( + self, + *, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor, + packed_seq_params: Any | None = None, + ) -> torch.Tensor: + del input_ids, position_ids, attention_mask, packed_seq_params + self.training_modes_seen.append(self.training) + assert self.controller.events == [ + ("set_step", 2, 5, 8), + ("begin_micro", 0, 5, None), + ] + return torch.full(labels.shape, 0.25, dtype=torch.float32, device=labels.device) + + +class _Handler: + def get_forward_kwargs(self, _chunk: nn.Module, *, attention_bias: Any) -> dict: + del attention_bias + return {} + + +def test_calculate_megatron_logprobs_replays_routes(monkeypatch) -> None: + controller = _ReplayController() + chunk = _Chunk(controller) + monkeypatch.setattr( + megatron_microbatches, + "create_shared_prefix_state", + lambda **kwargs: (kwargs["group_ids"], kwargs["parent_ids"]), + ) + monkeypatch.setattr( + megatron_train, + "_infer_parallel_topology", + lambda _model_chunks: megatron_train.ParallelTopology(), + ) + + logprobs = megatron_train._calculate_megatron_logprobs( + model_chunks=cast(megatron_train.ModelChunks, [chunk]), + provider=object(), + model_support_handler=_Handler(), + inputs=_packed_inputs(), + moe_routing_replay_controller=cast( + megatron_train.MoeRoutingReplayController, controller + ), + step_index=2, + sample_index=5, + global_grad_accumulation_sequences=8, + ) + + assert controller.events == [ + ("set_step", 2, 5, 8), + ("begin_micro", 0, 5, None), + ("finalize_step", 0, None, None), + ] + assert chunk.training_modes_seen == [False] + assert chunk.training is True + assert torch.equal(logprobs, torch.full((1, 4), -0.25)) diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index fd0fc530c..ab0d5765e 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -23,7 +23,7 @@ ) from art.pipeline_trainer.trainer import PipelineTrainer from art.preprocessing.tokenize import TokenizedResult -from art.utils.output_dirs import get_model_dir +from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir def _make_group(rewards: list[float]) -> TrajectoryGroup: @@ -128,6 +128,120 @@ async def test_pipeline_trainer_forwards_packed_sequence_length_when_set( assert backend.train.await_args.kwargs["packed_sequence_length"] == 4096 +@pytest.mark.asyncio +async def test_pipeline_trainer_forwards_default_kl_step_zero_for_generic_backend( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-generic-backend-kl-kwargs", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs == { + "learning_rate": 1e-5, + "loss_fn": "cispo", + "loss_fn_config": None, + "normalize_advantages": True, + "save_checkpoint": False, + "adam_params": None, + "kl_penalty_coef": 0.25, + "kl_penalty_reference_step": 0, + "kl_penalty_source": "sample", + } + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_floors_at_zero( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-kl-step-lag-floor", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=2, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=5, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + trainer.state.next_training_step = 1 + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 0 + + +@pytest.mark.asyncio +async def test_pipeline_trainer_kl_step_lag_computes_reference( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-kl-step-lag", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + backend.train = AsyncMock(return_value=SimpleNamespace(step=4, metrics={})) + + trainer = _make_trainer( + model=model, + backend=backend, + kl_penalty_coef=0.25, + kl_penalty_step_lag=2, + ) + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + trainer.state.next_training_step = 3 + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 1 + + +def test_pipeline_trainer_rejects_zero_kl_step_lag(tmp_path: Path) -> None: + model = TrainableModel( + name="pipeline-kl-zero-step-lag", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + with pytest.raises(ValueError, match="kl_penalty_step_lag must be >= 1"): + _make_trainer( + model=model, + backend=MagicMock(), + kl_penalty_coef=0.25, + kl_penalty_step_lag=0, + ) + + @pytest.mark.asyncio async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( tmp_path: Path, @@ -207,6 +321,89 @@ async def fake_train_model( assert seen["dev_config"]["packed_sequence_length"] == 2048 +@pytest.mark.asyncio +async def test_local_backend_train_passes_kl_penalty_source(tmp_path: Path) -> None: + model = TrainableModel( + name="local-backend-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + seen["config"] = config + seen["dev_config"] = dev_config + seen["verbose"] = verbose + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + with patch.object(model, "_get_wandb_run", return_value=None): + result = await backend.train( + model, + [_make_group([1.0])], + kl_penalty_coef=0.25, + kl_penalty_source="sample", + save_checkpoint=False, + ) + + assert result.step == 1 + assert seen["config"].kl_penalty_source == "sample" + assert seen["dev_config"]["kl_penalty_source"] == "sample" + + +@pytest.mark.asyncio +async def test_megatron_backend_defaults_kl_reference_to_step_zero( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="megatron-default-kl-reference", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = LocalBackend(path=str(tmp_path)) + backend._requires_explicit_packed_sequence_length = True + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + _config: Any, + dev_config: dict[str, Any], + verbose: bool = False, + ): + del verbose + seen["dev_config"] = dev_config + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + + with patch.object(model, "_get_wandb_run", return_value=None): + await backend.train( + model, + [_make_group([1.0])], + kl_penalty_coef=0.25, + packed_sequence_length=4096, + save_checkpoint=False, + ) + + expected_ref_path = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=str(tmp_path)), + 0, + ) + assert seen["dev_config"]["kl_ref_adapter_path"] == expected_ref_path + + @pytest.mark.asyncio async def test_local_backend_train_maps_normalize_advantages_to_scale_rewards( tmp_path: Path, @@ -305,6 +502,119 @@ def strategy(context: CheckpointRetentionContext) -> set[int]: ) +@pytest.mark.asyncio +async def test_pipeline_trainer_checkpoint_retention_protects_default_kl_reference( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention-default-kl-ref", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_dir = Path(model._get_output_dir()) / "checkpoints" + for step in range(4): + (checkpoint_dir / f"{step:04d}").mkdir(parents=True) + + backend = MagicMock() + backend._delete_checkpoint_files = AsyncMock() + contexts: list[CheckpointRetentionContext] = [] + + def strategy(context: CheckpointRetentionContext) -> set[int]: + contexts.append(context) + return set() + + trainer = _make_trainer( + model=model, + backend=backend, + checkpoint_retention_strategy=strategy, + kl_penalty_coef=0.25, + ) + + await trainer._run_checkpoint_retention(3) + + assert [checkpoint.step for checkpoint in contexts[0].checkpoints] == [1, 2] + backend._delete_checkpoint_files.assert_awaited_once_with( # type: ignore[attr-defined] + model, + [0, 3], + ) + + +@pytest.mark.asyncio +async def test_pipeline_trainer_checkpoint_retention_protects_lagged_kl_reference( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention-lagged-kl-ref", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_dir = Path(model._get_output_dir()) / "checkpoints" + for step in range(7): + (checkpoint_dir / f"{step:04d}").mkdir(parents=True) + + backend = MagicMock() + backend._delete_checkpoint_files = AsyncMock() + contexts: list[CheckpointRetentionContext] = [] + + def strategy(context: CheckpointRetentionContext) -> set[int]: + contexts.append(context) + return set() + + trainer = _make_trainer( + model=model, + backend=backend, + checkpoint_retention_strategy=strategy, + kl_penalty_coef=0.25, + kl_penalty_step_lag=5, + ) + + await trainer._run_checkpoint_retention(6) + + assert [checkpoint.step for checkpoint in contexts[0].checkpoints] == [0] + backend._delete_checkpoint_files.assert_awaited_once_with( # type: ignore[attr-defined] + model, + [1, 2, 3, 4, 5, 6], + ) + + +@pytest.mark.asyncio +async def test_pipeline_trainer_checkpoint_retention_lag_warmup_protects_window( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-checkpoint-retention-lag-floor-zero", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_dir = Path(model._get_output_dir()) / "checkpoints" + for step in range(5): + (checkpoint_dir / f"{step:04d}").mkdir(parents=True) + + backend = MagicMock() + backend._delete_checkpoint_files = AsyncMock() + contexts: list[CheckpointRetentionContext] = [] + + def strategy(context: CheckpointRetentionContext) -> set[int]: + contexts.append(context) + return set() + + trainer = _make_trainer( + model=model, + backend=backend, + checkpoint_retention_strategy=strategy, + kl_penalty_coef=0.25, + kl_penalty_step_lag=5, + ) + + await trainer._run_checkpoint_retention(4) + + assert contexts == [] + backend._delete_checkpoint_files.assert_not_awaited() # type: ignore[attr-defined] + + @pytest.mark.asyncio async def test_pipeline_trainer_checkpoint_retention_honors_interval( tmp_path: Path, @@ -426,9 +736,7 @@ def test_local_backend_get_packed_tensors_warns_and_drops_overlong_results( "art.local.backend.AutoTokenizer.from_pretrained", return_value=short_result._tokenizer, ), - patch( - "art.local.backend.AutoImageProcessor.from_pretrained", return_value=None - ), + patch("transformers.AutoImageProcessor.from_pretrained", return_value=None), patch( "art.local.backend.tokenize_trajectory_groups", return_value=iter([short_result, long_result]), diff --git a/tests/unit/test_serverless_pipeline_trainer_compat.py b/tests/unit/test_serverless_pipeline_trainer_compat.py index fec8d23f7..8c1482b02 100644 --- a/tests/unit/test_serverless_pipeline_trainer_compat.py +++ b/tests/unit/test_serverless_pipeline_trainer_compat.py @@ -91,6 +91,7 @@ async def fake_train_model( "allow_training_without_logprobs": True, "importance_sampling_level": "token", "kl_penalty_coef": 0.1, + "kl_penalty_source": "sample", "kl_ref_adapter_path": "/tmp/ref-adapter", "logprob_calculation_chunk_size": 512, "mask_prob_ratio": False, @@ -123,6 +124,17 @@ async def test_serverless_train_rejects_unsupported_pipeline_kwargs() -> None: with pytest.raises(ValueError, match="conflicting loss_fn and ppo"): await backend.train(model, [_make_group()], loss_fn="ppo", ppo=False) + with pytest.raises(ValueError, match="kl_penalty_step_lag must be >= 1"): + await backend.train(model, [_make_group()], kl_penalty_step_lag=0) + + with pytest.raises(ValueError, match="Only one of"): + await backend.train( + model, + [_make_group()], + kl_penalty_reference_step=0, + kl_penalty_step_lag=1, + ) + @pytest.mark.asyncio async def test_serverless_train_model_forwards_experimental_config() -> None: @@ -162,6 +174,8 @@ async def no_sleep(_seconds: float) -> None: "importance_sampling_level": "sequence", "kimi_k2_tau": 0.4, "kl_penalty_coef": 0.2, + "kl_penalty_reference_step": 0, + "kl_penalty_source": "sample", "kl_ref_adapter_path": "/tmp/ref", "logprob_calculation_chunk_size": 512, "mask_prob_ratio": True, @@ -184,6 +198,48 @@ async def no_sleep(_seconds: float) -> None: assert payload["normalize_advantages"] is False assert payload["packed_sequence_length"] == 4096 assert payload["kl_penalty_coef"] == 0.2 + assert payload["kl_penalty_reference_step"] == 0 + assert payload["kl_penalty_source"] == "sample" assert payload["kl_ref_adapter_path"] == "/tmp/ref" assert payload["allow_training_without_logprobs"] is True assert payload["scale_learning_rate_by_reward_std_dev"] is True + + +@pytest.mark.asyncio +async def test_serverless_train_forwards_kl_step_lag() -> None: + backend = _make_backend() + model = TrainableModel( + name="serverless-kl-step-lag", + project="pipeline-tests", + base_model="test-model", + ) + model.id = "model-id" + + seen: dict[str, Any] = {} + + async def fake_train_model( + _model: TrainableModel, + _groups: list[TrajectoryGroup], + _config: TrainConfig, + dev_config: dict[str, Any], + verbose: bool = False, + ): + del verbose + seen["dev_config"] = dev_config + yield {} + + backend._train_model = fake_train_model # type: ignore[method-assign] + backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign] + + with patch.object(model, "_get_wandb_run", return_value=None): + await backend.train( + model, + [_make_group()], + kl_penalty_coef=0.2, + kl_penalty_source="sample", + kl_penalty_step_lag=3, + ) + + assert seen["dev_config"]["kl_penalty_coef"] == 0.2 + assert seen["dev_config"]["kl_penalty_source"] == "sample" + assert seen["dev_config"]["kl_penalty_step_lag"] == 3 diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py new file mode 100644 index 000000000..a2d16d01f --- /dev/null +++ b/tests/unit/test_tinker_native_kl.py @@ -0,0 +1,77 @@ +import pytest +import tinker + +from art import TrainableModel +from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty +from art.tinker_native.data import build_datum + + +class FakeSamplingClient(tinker.SamplingClient): + def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None: + self._responses = responses + + async def compute_logprobs_async( + self, prompt: tinker.ModelInput + ) -> list[float | None]: + return self._responses[tuple(prompt.to_ints())] + + +@pytest.mark.asyncio +async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None: + datum_a = build_datum( + prompt_tokens=[101, 102], + completion_tokens=[201, 202], + logprobs=[-0.4, -0.8], + advantage=1.0, + ) + datum_b = build_datum( + prompt_tokens=[301, 302], + completion_tokens=[401], + logprobs=[-0.2], + advantage=2.0, + ) + assert datum_a is not None + assert datum_b is not None + + sampling_client = FakeSamplingClient( + { + (101, 102, 201, 202): [None, -9.0, -0.1, -0.5], + (301, 302, 401): [None, -7.0, -0.05], + } + ) + + metrics = await _apply_kl_penalty( + [datum_a, datum_b], + sampling_client, + kl_penalty_coef=2.0, + ) + + assert metrics == {"loss/kl_policy_ref": pytest.approx(-0.25)} + assert datum_a.loss_fn_inputs["advantages"].tolist() == pytest.approx( + [0.0, 1.1, 1.1] + ) + assert datum_b.loss_fn_inputs["advantages"].tolist() == pytest.approx([0.0, 1.8]) + + +@pytest.mark.asyncio +async def test_tinker_native_backend_rejects_current_learner_kl_source( + tmp_path, +) -> None: + backend = TinkerNativeBackend(tinker_api_key="test-key", path=str(tmp_path)) + model = TrainableModel( + name="tinker-native-kl-source", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + + with pytest.raises( + AssertionError, + match="only supports kl_penalty_source='sample'", + ): + await backend.train( + model, + [], + kl_penalty_coef=0.25, + kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type] + )