diff --git a/angelslim/compressor/quant/core/vllm_calibrate_utils/__init__.py b/angelslim/compressor/quant/core/vllm_calibrate_utils/__init__.py index 18e530dd..0908655d 100644 --- a/angelslim/compressor/quant/core/vllm_calibrate_utils/__init__.py +++ b/angelslim/compressor/quant/core/vllm_calibrate_utils/__init__.py @@ -46,8 +46,12 @@ from .search import ( KVCachePerHeadValueHook, KVCacheValueHook, + KVMSEProfileCollector, + KVPerHeadMSEProfileCollector, KVScaleSearcher, KVScaleSearcherPerHead, + get_kv_mse_profile_results, + get_kv_mse_profile_results_perhead, get_kv_scale_search_results, get_kv_scale_search_results_perhead, remove_kv_scale_search_hooks, @@ -88,12 +92,16 @@ "KVCacheValueHook", "setup_kvcache_value_hooks", "KVScaleSearcher", + "KVMSEProfileCollector", "get_kv_scale_search_results", + "get_kv_mse_profile_results", "remove_kv_scale_search_hooks", # KV scale search (per-head) "KVCachePerHeadValueHook", "setup_kvcache_perhead_value_hooks", "remove_kvcache_perhead_value_hooks", "KVScaleSearcherPerHead", + "KVPerHeadMSEProfileCollector", "get_kv_scale_search_results_perhead", + "get_kv_mse_profile_results_perhead", ] diff --git a/angelslim/compressor/quant/core/vllm_calibrate_utils/search.py b/angelslim/compressor/quant/core/vllm_calibrate_utils/search.py index 6ddf28c3..060c08ee 100644 --- a/angelslim/compressor/quant/core/vllm_calibrate_utils/search.py +++ b/angelslim/compressor/quant/core/vllm_calibrate_utils/search.py @@ -25,15 +25,19 @@ # Per-tensor scale search "KVCacheValueHook", "KVScaleSearcher", + "KVMSEProfileCollector", "setup_kvcache_value_hooks", "get_kv_scale_search_results", + "get_kv_mse_profile_results", "remove_kv_scale_search_hooks", # Per-head scale search "KVCachePerHeadValueHook", "setup_kvcache_perhead_value_hooks", "remove_kvcache_perhead_value_hooks", "KVScaleSearcherPerHead", + "KVPerHeadMSEProfileCollector", "get_kv_scale_search_results_perhead", + "get_kv_mse_profile_results_perhead", ] @@ -410,6 +414,151 @@ def get_kv_scale_search_results(results_list: list) -> dict: return first +def _make_log_uniform_multipliers( + min_multiplier: float, + max_multiplier: float, + num_steps: int, +) -> list[float]: + """Build the log-uniform multiplier grid used by KV scale search.""" + import math + + if num_steps <= 1: + return [float(min_multiplier)] + + log_min = math.log(min_multiplier) + log_max = math.log(max_multiplier) + return [ + math.exp(log_min + (log_max - log_min) * i / (num_steps - 1)) + for i in range(num_steps) + ] + + +def _compute_kv_mse_profile_flat( + flat: torch.Tensor, + base_scale: float, + min_multiplier: float, + max_multiplier: float, + num_steps: int, +) -> dict: + """Compute an additive SSE/numel profile for DP KV search aggregation.""" + multipliers = _make_log_uniform_multipliers( + min_multiplier, + max_multiplier, + num_steps, + ) + if flat.numel() == 0: + return { + "multipliers": multipliers, + "sse": [0.0 for _ in multipliers], + "numel": 0, + "base_scale": base_scale, + } + + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + sse = [] + + if flat.is_cuda: + sse_vals = torch.empty(len(multipliers), dtype=torch.float64, device=flat.device) + for i, multiplier in enumerate(multipliers): + scale = base_scale * multiplier + q_fp8 = ( + (flat / scale).clamp(fp8_min, fp8_max).to(torch.float8_e4m3fn).to(torch.float32) + ) + diff = flat - q_fp8 * scale + sse_vals[i] = (diff.double() * diff.double()).sum() + sse = [float(v) for v in sse_vals.cpu().tolist()] + del sse_vals + else: + for multiplier in multipliers: + scale = base_scale * multiplier + q_fp8 = ( + (flat / scale).clamp(fp8_min, fp8_max).to(torch.float8_e4m3fn).to(torch.float32) + ) + diff = flat - q_fp8 * scale + sse.append(float((diff.double() * diff.double()).sum().item())) + + return { + "multipliers": multipliers, + "sse": sse, + "numel": int(flat.numel()), + "base_scale": base_scale, + } + + +class KVMSEProfileCollector: + """Collect local per-tensor KV MSE profiles for DP aggregation.""" + + def __init__( + self, + activation_stats: dict, + min_multiplier: float = 0.8, + max_multiplier: float = 16.0, + num_steps: int = 100, + ): + self.activation_stats = activation_stats + self.min_multiplier = min_multiplier + self.max_multiplier = max_multiplier + self.num_steps = num_steps + + def __call__(self, model): + fp8_max = torch.finfo(torch.float8_e4m3fn).max + kv_values = _get_kv_search_values(model) + if not kv_values: + print( + "[KVMSEProfileCollector] WARNING: No kv values collected. " + "Did you call setup_kvcache_value_hooks before inference?" + ) + return {} + + use_gpu = torch.cuda.is_available() + search_device = ( + torch.device("cuda", torch.cuda.current_device()) if use_gpu else torch.device("cpu") + ) + profiles = {} + + for layer_name, tensors_dict in kv_values.items(): + for kv_slot, tensors in tensors_dict.items(): + stats_key = f"{layer_name}.{kv_slot}_cache" + if stats_key not in self.activation_stats: + print( + f"[KVMSEProfileCollector] WARNING: {stats_key} not found in " + "activation_stats, skipping." + ) + continue + if not tensors: + continue + + stats = self.activation_stats[stats_key] + abs_max = max(abs(stats["min"]), abs(stats["max"])) + base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8 + flat_cpu = torch.cat([t.reshape(-1).float() for t in tensors]) + flat = flat_cpu.to(search_device, non_blocking=True) if use_gpu else flat_cpu + profile = _compute_kv_mse_profile_flat( + flat=flat, + base_scale=base_scale, + min_multiplier=self.min_multiplier, + max_multiplier=self.max_multiplier, + num_steps=self.num_steps, + ) + if use_gpu: + del flat + profiles[stats_key] = profile + print( + f"[KVMSEProfileCollector] {stats_key}: " + f"numel={profile['numel']} base_scale={base_scale:.6f}" + ) + + return profiles + + +def get_kv_mse_profile_results(results_list: list) -> dict: + """Extract the per-tensor MSE profile dict from ``llm.apply_model`` results.""" + if not results_list: + return {} + return results_list[0] or {} + + def remove_kv_scale_search_hooks(model): """ Clean up kv-value hooks after search. Pass to ``llm.apply_model``. @@ -717,6 +866,102 @@ def __call__(self, model): return multipliers +class KVPerHeadMSEProfileCollector: + """Collect local per-head KV MSE profiles for DP aggregation.""" + + def __init__( + self, + activation_stats: dict, + min_multiplier: float = 0.8, + max_multiplier: float = 16.0, + num_steps: int = 100, + ): + self.activation_stats = activation_stats + self.min_multiplier = min_multiplier + self.max_multiplier = max_multiplier + self.num_steps = num_steps + + def __call__(self, model): + kv_values = _get_kv_perhead_search_values(model) + if not kv_values: + print( + "[KVPerHeadMSEProfileCollector] WARNING: No per-head kv values collected. " + "Did you call setup_kvcache_perhead_value_hooks before inference?" + ) + return {} + + use_gpu = torch.cuda.is_available() + search_device = ( + torch.device("cuda", torch.cuda.current_device()) if use_gpu else torch.device("cpu") + ) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + rank, world_size = _get_dist_info() + num_kv_heads_total = None + for stats in self.activation_stats.values(): + min_vals = stats.get("min") + if isinstance(min_vals, list): + num_kv_heads_total = len(min_vals) + break + if isinstance(min_vals, torch.Tensor): + num_kv_heads_total = min_vals.numel() + break + if num_kv_heads_total is None or num_kv_heads_total <= 0: + num_kv_heads_total = 1 + + role, _, global_head_offset, replication = _compute_perhead_layout( + rank, + world_size, + num_kv_heads_total, + ) + profiles: dict[str, dict] = {} + + for layer_name, tensors_dict in kv_values.items(): + for kv_slot, tensors in tensors_dict.items(): + if role != "both" and kv_slot != role: + continue + + stats_key = f"{layer_name}.{kv_slot}_cache" + if stats_key not in self.activation_stats or not tensors: + continue + + stats = self.activation_stats[stats_key] + min_vals = stats["min"] + max_vals = stats["max"] + stacked = torch.cat(tensors, dim=1) + profiles.setdefault(stats_key, {}) + + for local_head_idx in range(stacked.shape[0]): + global_head_idx = global_head_offset + local_head_idx + if global_head_idx >= num_kv_heads_total: + continue + + abs_max = max( + abs(min_vals[global_head_idx]), + abs(max_vals[global_head_idx]), + ) + base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8 + flat_cpu = stacked[local_head_idx].reshape(-1).float() + flat = flat_cpu.to(search_device, non_blocking=True) if use_gpu else flat_cpu + profile = _compute_kv_mse_profile_flat( + flat=flat, + base_scale=base_scale, + min_multiplier=self.min_multiplier, + max_multiplier=self.max_multiplier, + num_steps=self.num_steps, + ) + if use_gpu: + del flat + profiles[stats_key][global_head_idx] = profile + + print( + f"[KVPerHeadMSEProfileCollector] rank={rank}/{world_size}, role={role}, " + f"head_offset={global_head_offset}, H_total={num_kv_heads_total}, " + f"replication={replication}" + ) + return profiles + + def get_kv_scale_search_results_perhead(results_list: list) -> dict: """ Merge per-head multiplier dicts from all TP workers. @@ -749,3 +994,32 @@ def get_kv_scale_search_results_perhead(results_list: list) -> dict: final[stats_key] = [head_dict[i] for i in sorted_indices] return final + + +def get_kv_mse_profile_results_perhead(results_list: list) -> dict: + """ + Merge per-head MSE profile dicts from all TP workers. + + Each worker's result is ``{stats_key: {global_head_idx: profile}}``. + The returned format is ``{stats_key: [profile_head0, profile_head1, ...]}``. + Missing heads are kept as ``None`` placeholders so the list index always + matches the global KV head index during cross-DP aggregation. + """ + if not results_list: + return {} + + merged: dict = {} + for worker_result in results_list: + if not worker_result: + continue + for stats_key, head_profiles in worker_result.items(): + merged.setdefault(stats_key, {}).update(head_profiles) + + final: dict = {} + for stats_key, head_profiles in merged.items(): + if not head_profiles: + final[stats_key] = [] + continue + max_head_idx = max(int(i) for i in head_profiles.keys()) + final[stats_key] = [head_profiles.get(i) for i in range(max_head_idx + 1)] + return final diff --git a/angelslim/utils/__init__.py b/angelslim/utils/__init__.py index a12f24b5..70617fc8 100644 --- a/angelslim/utils/__init__.py +++ b/angelslim/utils/__init__.py @@ -30,4 +30,6 @@ from .utils import print_with_rank # noqa: F401 from .utils import rank0_print # noqa: F401 from .utils import set_op_by_name # noqa: F401 +from .vllm_calibration_dp import run_vllm_calibration_with_dp # noqa: F401 +from .vllm_calibration_dp import validate_vllm_calibration_dp_args # noqa: F401 from .zero3_io import * # noqa: F401 F403 diff --git a/angelslim/utils/vllm_calibration_dp.py b/angelslim/utils/vllm_calibration_dp.py new file mode 100644 index 00000000..fa451eb0 --- /dev/null +++ b/angelslim/utils/vllm_calibration_dp.py @@ -0,0 +1,755 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline vLLM calibration DP helpers. + +The calibration entrypoint uses this module to implement launcher-level data +parallelism with explicit Ray actors: + +- the driver process connects to an existing Ray cluster once; +- the driver creates one placement group and one long-lived actor per DP replica; +- each actor creates exactly one vLLM instance through the existing worker_fn; +- every vLLM instance still uses its own TP workers via + ``distributed_executor_backend=ray``; +- vLLM reuses the actor's placement group and captures TP child tasks there; +- the driver merges all partial calibration JSON payloads back into the + standard stage-1 filenames. +""" + +import argparse +import copy +import json +import os +import shutil +import socket +import time +from collections.abc import Callable +from typing import Any + +# Native vLLM DP environment variable names that must NOT be set +_NATIVE_VLLM_DP_ENV_NAMES = ( + "VLLM_DP_SIZE", + "VLLM_DP_SIZE_LOCAL", + "VLLM_DP_RANK", + "VLLM_DP_RANK_LOCAL", + "VLLM_DP_MASTER_IP", + "VLLM_DP_MASTER_PORT", +) + + +def validate_vllm_calibration_dp_args(parser, args): + """Validate Ray actor-managed calibration DP arguments after YAML overrides.""" + if args.tp_size < 1: + parser.error("--tp-size must be >= 1") + if args.dp_size < 1: + parser.error("--dp-size must be >= 1") + if args.dp_timeout < 1: + parser.error("--dp-timeout must be >= 1") + + if args.dp_size > 1 and args.distributed_executor_backend != "ray": + parser.error( + "Ray actor-managed calibration DP requires --distributed-executor-backend=ray" + ) + + if ( + args.dp_size > 1 + and args.search_kv_scale + and args.kv_granularity != "none" + and args.search_kv_num_samples < args.dp_size + ): + parser.error( + "DP KV scale search requires --search-kv-num-samples to be >= --dp-size " + "so every DP rank receives at least one search prompt" + ) + + # Legacy parameters - should be 1 and 0 respectively + if args.dp_num_nodes != 1: + parser.error( + "Ray actor-managed calibration DP launches all replicas from one driver process; " + "please keep --dp-num-nodes=1" + ) + if args.dp_node_rank != 0: + parser.error( + "Ray actor-managed calibration DP launches all replicas from one driver process; " + "please keep --dp-node-rank=0" + ) + return args + + +def run_vllm_calibration_with_dp(args, worker_fn: Callable[[Any], Any]) -> None: + """Run calibration directly or launch Ray actor-managed DP replicas.""" + args.dp_rank = 0 + + if args.dp_size == 1: + # Single replica - no extra actors needed + worker_fn(args) + return + + print("\n" + "=" * 80) + print("Launching Ray actor-managed calibration DP") + print(f"Ray address : {args.ray_address or os.environ.get('RAY_ADDRESS') or 'auto'}") + print(f"DP size : {args.dp_size}") + print(f"TP size : {args.tp_size}") + print(f"Required GPUs : {args.dp_size * args.tp_size}") + print(f"Placement strategy : {args.placement_strategy}") + print("Driver creates one placement group and one long-lived actor per DP rank;") + print("vLLM Ray executor manages TP workers inside each rank actor.") + print("=" * 80) + + try: + _run_ray_actor_calibration_dp(args, worker_fn) + except Exception as e: + print(f"\nDP calibration failed: {e}") + raise + + +def _run_ray_actor_calibration_dp(args, worker_fn: Callable[[Any], Any]) -> list[dict[str, Any]]: + """Launch one explicit Ray actor per calibration DP replica.""" + import ray + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + + ray_address = args.ray_address or os.environ.get("RAY_ADDRESS") or "auto" + owns_ray = not ray.is_initialized() + + # Collect environment variables to pass to all Ray workers + # This ensures consistent vLLM configuration across all DP ranks + common_runtime_env = { + "env_vars": {} + } + + # List of vLLM environment variables that should be consistent across all DP ranks + vllm_env_vars = [ + "VLLM_ALLOW_INSECURE_SERIALIZATION", + "VLLM_MOE_COLLECT_STATS", + "VLLM_MOE_COLLECT_STATS_VERBOSE", + "VLLM_MOE_COLLECT_PER_EXPERT_STATS", + "VLLM_ENABLE_CHUNKED_PREFILL", + "VLLM_ATTENTION_BACKEND", + "ASYNC_SCHEDULING", + "VLLM_ENABLE_PREFIX_CACHING", + "PRECISIONMODE", + "RAY_DEDUP_LOGS", + "PYTHONDONTWRITEBYTECODE", + # Add any other environment variables that might be needed + ] + + for env_var in vllm_env_vars: + if env_var in os.environ: + common_runtime_env["env_vars"][env_var] = os.environ[env_var] + print(f"[DP] Will pass {env_var}={os.environ[env_var]} via runtime_env") + + if owns_ray: + ray.init( + address=ray_address, + runtime_env=common_runtime_env, + ignore_reinit_error=True, + log_to_driver=True, + ) + + placement_groups = [] + actors = [] + + try: + # Validate cluster resources + validate_cluster_resources(args) + + # Create placement groups for each DP rank + for dp_rank in range(args.dp_size): + bundles = create_replica_bundles(args.tp_size) + pg = ray.util.placement_group( + bundles=bundles, + strategy=args.placement_strategy, + name=f"calibration_dp_{dp_rank}", + ) + placement_groups.append(pg) + print( + f"[DP] Created placement group for rank {dp_rank}: " + f"{args.tp_size} GPU bundle(s)" + ) + + # Wait for all placement groups to be ready + ready_refs = [pg.ready() for pg in placement_groups] + ready, pending = ray.wait( + ready_refs, + num_returns=len(ready_refs), + timeout=args.dp_timeout, + ) + if pending: + raise TimeoutError( + "Timed out waiting for calibration DP placement groups: " + f"ready={len(ready)}, pending={len(pending)}, " + f"timeout={args.dp_timeout}s. " + "Check whether enough free GPUs are available for " + "dp_size * tp_size." + ) + + # Create actors for each DP rank + ReplicaActor = ray.remote( + num_cpus=1, + num_gpus=0, + max_restarts=0, + max_task_retries=0, + )(CalibrationReplica) + + for dp_rank, pg in enumerate(placement_groups): + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + placement_group_capture_child_tasks=True, + ) + actor = ReplicaActor.options( + scheduling_strategy=scheduling_strategy, + name=f"calibration-replica-{dp_rank}", + ).remote(worker_fn) + actors.append(actor) + print(f"[DP] Submitted calibration actor for rank {dp_rank}") + + # Submit tasks to actors + args_dict = vars(copy.deepcopy(args)) + result_refs = [ + actor.run.remote(args_dict, dp_rank) + for dp_rank, actor in enumerate(actors) + ] + + # Wait for results with timeout + ready, pending = ray.wait( + result_refs, + num_returns=len(result_refs), + timeout=args.dp_timeout, + ) + if pending: + raise TimeoutError( + "Timed out waiting for calibration DP actors: " + f"finished={len(ready)}, pending={len(pending)}, " + f"timeout={args.dp_timeout}s." + ) + + results = ray.get(result_refs) + _merge_dp_payloads(args.output_dir, results) + + if args.search_kv_scale and args.kv_granularity != "none": + activation_stats_path = os.path.join(args.output_dir, "activation_stats.json") + if not os.path.exists(activation_stats_path): + raise FileNotFoundError( + "Merged activation_stats.json is required for DP KV scale search: " + f"{activation_stats_path}" + ) + + print("\n" + "=" * 80) + print("Launching DP KV-cache scale search with merged activation stats...") + print("=" * 80) + args_dict = vars(copy.deepcopy(args)) + kv_refs = [ + actor.run_kv_search.remote(args_dict, dp_rank, activation_stats_path) + for dp_rank, actor in enumerate(actors) + ] + ready, pending = ray.wait( + kv_refs, + num_returns=len(kv_refs), + timeout=args.dp_timeout, + ) + if pending: + raise TimeoutError( + "Timed out waiting for DP KV search actors: " + f"finished={len(ready)}, pending={len(pending)}, " + f"timeout={args.dp_timeout}s." + ) + kv_results = ray.get(kv_refs) + _merge_dp_kv_search_payloads(args.output_dir, kv_results, args.kv_granularity) + + return results + + finally: + # Clean up actors and placement groups + for actor in actors: + try: + ray.kill(actor, no_restart=True) + except Exception: + pass + + for pg in placement_groups: + try: + ray.util.remove_placement_group(pg) + except Exception: + pass + + if owns_ray: + ray.shutdown() + + +class CalibrationReplica: + """Ray actor that runs one calibration DP replica.""" + + def __init__(self, worker_fn: Callable[[Any], Any]) -> None: + self._worker_fn = worker_fn + self._has_calibrated = False + self._dp_rank: int | None = None + self._llm: Any | None = None + + def _prepare_args( + self, + args_dict: dict[str, Any], + dp_rank: int, + output_suffix: str, + ): + args = argparse.Namespace(**args_dict) + args.dp_rank = dp_rank + args.dp_size = args_dict.get("dp_size", 1) + + for name in _NATIVE_VLLM_DP_ENV_NAMES: + os.environ.pop(name, None) + + os.environ["ANGELSLIM_CALIBRATION_DP_RANK"] = str(dp_rank) + os.environ["ANGELSLIM_CALIBRATION_DP_SIZE"] = str(args.dp_size) + + args.output_dir = os.path.join(args.output_dir, output_suffix) + if os.path.isdir(args.output_dir): + shutil.rmtree(args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + return args + + def run( + self, + args_dict: dict[str, Any], + dp_rank: int, + ) -> dict[str, Any]: + if self._has_calibrated: + raise RuntimeError( + "CalibrationReplica calibration stage can only run once. " + f"existing_dp_rank={self._dp_rank}, " + f"new_dp_rank={dp_rank}" + ) + + self._has_calibrated = True + self._dp_rank = dp_rank + + args = self._prepare_args(args_dict, dp_rank, f"dp_rank_{dp_rank}") + args.search_kv_scale = False + + dp_log(args, "actor_start", f"Actor started on host={socket.gethostname()}") + result = self._worker_fn(args, return_llm=True) + if isinstance(result, tuple) and len(result) == 2: + payload, self._llm = result + else: + payload = result + self._llm = None + dp_log(args, "actor_done", "Calibration completed") + + return { + "dp_rank": dp_rank, + "output_dir": args.output_dir, + "payload": payload, + } + + def run_kv_search( + self, + args_dict: dict[str, Any], + dp_rank: int, + activation_stats_path: str, + ) -> dict[str, Any]: + args = self._prepare_args(args_dict, dp_rank, f"dp_rank_{dp_rank}_kv_search") + args.kv_search_only = True + args.kv_search_activation_stats_path = activation_stats_path + args.search_kv_scale = True + + if self._llm is None: + raise RuntimeError( + "Cannot run DP KV search because the first-stage vLLM instance " + "was not retained in the actor." + ) + + dp_log(args, "kv_search_start", f"KV search actor started on host={socket.gethostname()}") + payload = self._worker_fn(args, llm=self._llm) + dp_log(args, "kv_search_done", "KV search completed") + + return { + "dp_rank": dp_rank, + "output_dir": args.output_dir, + "payload": payload, + } + + +def create_replica_bundles(tp_size: int) -> list[dict[str, float]]: + """Create bundles for a single DP replica.""" + bundles = [ + { + "CPU": 1, + "GPU": 1, + } + ] + + bundles.extend( + { + "GPU": 1, + } + for _ in range(tp_size - 1) + ) + + return bundles + + +def validate_cluster_resources(args) -> None: + """Validate that the Ray cluster has enough resources for the DP/TP configuration.""" + import ray + + required_gpus = args.dp_size * args.tp_size + cluster_resources = ray.cluster_resources() + available_gpus = int(cluster_resources.get("GPU", 0)) + + if available_gpus < required_gpus: + raise RuntimeError( + f"Need {required_gpus} GPUs for DP={args.dp_size}, TP={args.tp_size}, " + f"but Ray reports {available_gpus} available GPUs." + ) + + # Validate STRICT_PACK capacity if needed + if args.placement_strategy == "STRICT_PACK": + validate_strict_pack_capacity(args) + + +def validate_strict_pack_capacity(args) -> None: + """Validate that each DP replica can fit entirely on a single node with STRICT_PACK.""" + import ray + + capacity = 0 + node_infos = [] + + for node in ray.nodes(): + if not node.get("Alive", False): + continue + + resources = node.get("Resources", {}) + gpus = int(resources.get("GPU", 0)) + cpus = int(resources.get("CPU", 0)) + + replicas_on_node = min( + gpus // args.tp_size, + cpus, + ) + + capacity += replicas_on_node + + node_infos.append( + { + "address": node.get("NodeManagerAddress"), + "gpus": gpus, + "cpus": cpus, + "replica_capacity": replicas_on_node, + } + ) + + if capacity < args.dp_size: + raise RuntimeError( + "STRICT_PACK placement is infeasible. " + f"Need {args.dp_size} replicas with " + f"TP={args.tp_size}, but capacity={capacity}. " + f"nodes={node_infos}" + ) + + +def dp_log( + args, + stage: str, + message: str, +) -> None: + """Log a message with DP rank and host information.""" + print( + f"[DP {getattr(args, 'dp_rank', 0)}/{getattr(args, 'dp_size', 1)}] " + f"[host={socket.gethostname()}] " + f"[pid={os.getpid()}] " + f"[stage={stage}] " + f"{message}", + flush=True, + ) + + +def _merge_dp_payloads(output_dir: str, results: list[dict[str, Any]]) -> None: + """Merge calibration statistics from all DP ranks.""" + print("\n" + "=" * 80) + print("Merging DP calibration statistics...") + print("=" * 80) + + _MINMAX_FILES = ( + "activation_stats.json", + "moe_expert_stats.json", + "mtp_activation_stats.json", + "mtp_moe_expert_stats.json", + ) + + # Collect payloads from all ranks + payloads = [result["payload"] for result in results] + dp_size = len(payloads) + + # Initialize merged statistics + merged_stats = {} + for filename in _MINMAX_FILES: + merged_stats[filename] = {} + + # Merge min/max statistics from all ranks + for filename in _MINMAX_FILES: + print(f"\nMerging {filename}...") + + # Collect all entries for this file from all ranks + all_entries = {} + for rank, payload in enumerate(payloads): + if filename not in payload: + print(f" Rank {rank}: File {filename} not found") + continue + + rank_stats = payload[filename] + for key, stats in rank_stats.items(): + if key not in all_entries: + all_entries[key] = [] + all_entries[key].append({"rank": rank, "stats": stats}) + + # Merge each key + for key, rank_data in all_entries.items(): + if len(rank_data) < dp_size: + print(f" Key {key}: Only found in {len(rank_data)}/{dp_size} ranks") + + # Validate data types + first_stats = rank_data[0]["stats"] + if not isinstance(first_stats, dict) or "min" not in first_stats or "max" not in first_stats: + print(f" Key {key}: Invalid stats format, skipping") + continue + + # Check if min/max are scalars or lists + first_min = first_stats["min"] + first_max = first_stats["max"] + is_list = isinstance(first_min, list) + + # Validate consistency across ranks + skip_key = False + for entry in rank_data[1:]: + stats = entry["stats"] + if isinstance(stats["min"], list) != is_list: + print(f" Key {key}: Mixed scalar/list types across ranks, skipping") + skip_key = True + break + if is_list and len(stats["min"]) != len(first_min): + print(f" Key {key}: List length mismatch, skipping") + skip_key = True + break + if skip_key: + continue + + # Merge + if is_list: + # List min/max + list_len = len(first_min) + merged_min = [float("inf")] * list_len + merged_max = [float("-inf")] * list_len + + for entry in rank_data: + stats = entry["stats"] + for i in range(list_len): + merged_min[i] = min(merged_min[i], stats["min"][i]) + merged_max[i] = max(merged_max[i], stats["max"][i]) + + merged_stats[filename][key] = { + "min": merged_min, + "max": merged_max, + } + else: + # Scalar min/max + merged_min = float("inf") + merged_max = float("-inf") + + for entry in rank_data: + stats = entry["stats"] + merged_min = min(merged_min, stats["min"]) + merged_max = max(merged_max, stats["max"]) + + merged_stats[filename][key] = { + "min": merged_min, + "max": merged_max, + } + + print(f" Merged {len(merged_stats[filename])} keys") + + # Save merged statistics + print("\n" + "=" * 80) + print("Saving merged statistics...") + print("=" * 80) + + os.makedirs(output_dir, exist_ok=True) + + for filename, stats in merged_stats.items(): + if not stats: + continue + + output_path = os.path.join(output_dir, filename) + import json + with open(output_path, "w") as f: + json.dump(stats, f, indent=2) + print(f"Saved merged {filename} to {output_path}") + + print("\n" + "=" * 80) + print("DP calibration merge completed!") + print("=" * 80) + + +def _same_multiplier_grid(lhs: list[float], rhs: list[float]) -> bool: + if len(lhs) != len(rhs): + return False + return all(abs(float(a) - float(b)) <= 1e-12 for a, b in zip(lhs, rhs)) + + +def _select_best_multiplier_from_profiles(profile_entries: list[dict[str, Any]]) -> tuple[float, float]: + multipliers = profile_entries[0]["multipliers"] + base_scale = profile_entries[0].get("base_scale") + total_sse = [0.0 for _ in multipliers] + total_numel = 0 + + for profile in profile_entries: + if not _same_multiplier_grid(multipliers, profile["multipliers"]): + raise ValueError("KV search multiplier grids differ across DP ranks") + if base_scale is not None and profile.get("base_scale") is not None: + if abs(float(base_scale) - float(profile["base_scale"])) > 1e-12: + raise ValueError("KV search base_scale differs across DP ranks") + if len(profile["sse"]) != len(multipliers): + raise ValueError("KV search SSE length does not match multiplier grid") + for i, sse in enumerate(profile["sse"]): + total_sse[i] += float(sse) + total_numel += int(profile.get("numel", 0)) + + if total_numel <= 0: + return 1.0, float("inf") + + global_mse = [sse / total_numel for sse in total_sse] + best_idx = min(range(len(global_mse)), key=lambda i: global_mse[i]) + return float(multipliers[best_idx]), float(global_mse[best_idx]) + + +def _merge_dp_kv_search_payloads( + output_dir: str, + results: list[dict[str, Any]], + kv_granularity: str, +) -> None: + """Merge DP-local KV MSE profiles and save final scale files.""" + print("\n" + "=" * 80) + print("Merging DP KV-cache scale search profiles...") + print("=" * 80) + + activation_stats_path = os.path.join(output_dir, "activation_stats.json") + with open(activation_stats_path, "r", encoding="utf8") as f: + activation_stats = json.load(f) + + payloads = [result["payload"] for result in results] + fp8_max = 448.0 + + if kv_granularity == "per-tensor": + profile_key = "kv_scale_mse_profiles.json" + all_profiles: dict[str, list[dict[str, Any]]] = {} + for rank, payload in enumerate(payloads): + profiles = payload.get(profile_key, {}) + if not profiles: + print(f" Rank {rank}: no {profile_key} found") + continue + for stats_key, profile in profiles.items(): + all_profiles.setdefault(stats_key, []).append(profile) + + kv_multipliers = {} + for stats_key, entries in all_profiles.items(): + best_multiplier, best_mse = _select_best_multiplier_from_profiles(entries) + kv_multipliers[stats_key] = best_multiplier + print( + f" {stats_key}: best_multiplier={best_multiplier:.6f}, " + f"global_mse={best_mse:.6e}, dp_parts={len(entries)}" + ) + + multipliers_path = os.path.join(output_dir, "kv_scale_multipliers.json") + with open(multipliers_path, "w", encoding="utf8") as f: + json.dump(kv_multipliers, f, indent=2) + + tuned_kv_scales = {} + for stats_key, multiplier in kv_multipliers.items(): + stats = activation_stats[stats_key] + abs_max = max(abs(stats["min"]), abs(stats["max"])) + base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8 + tuned_scale = base_scale * multiplier + save_key = f"{stats_key.replace('attn.attn', 'attn')}.scale" + tuned_kv_scales[save_key] = tuned_scale + tuned_scales_path = os.path.join(output_dir, "kv_cache_tuned_scales.json") + with open(tuned_scales_path, "w", encoding="utf8") as f: + json.dump(tuned_kv_scales, f, indent=2) + + print(f"Saved DP KV multipliers to {multipliers_path}") + print(f"Saved DP KV tuned scales to {tuned_scales_path}") + + elif kv_granularity == "per-head": + profile_key = "kv_scale_mse_profiles_per_head.json" + all_profiles: dict[str, list[list[dict[str, Any]]]] = {} + for rank, payload in enumerate(payloads): + profiles = payload.get(profile_key, {}) + if not profiles: + print(f" Rank {rank}: no {profile_key} found") + continue + for stats_key, head_profiles in profiles.items(): + all_profiles.setdefault(stats_key, []).append(head_profiles) + + kv_multipliers_perhead = {} + for stats_key, rank_head_profiles in all_profiles.items(): + stats = activation_stats.get(stats_key) + if not stats or not isinstance(stats.get("min"), list): + print(f" {stats_key}: missing per-head activation stats, skipping") + continue + num_heads = len(stats["min"]) + multipliers = [] + for head_idx in range(num_heads): + entries = [ + head_profiles[head_idx] + for head_profiles in rank_head_profiles + if head_idx < len(head_profiles) and head_profiles[head_idx] + ] + if not entries: + multipliers.append(1.0) + continue + best_multiplier, _ = _select_best_multiplier_from_profiles(entries) + multipliers.append(best_multiplier) + kv_multipliers_perhead[stats_key] = multipliers + print( + f" {stats_key}: multipliers min={min(multipliers):.6f} " + f"max={max(multipliers):.6f} over {len(multipliers)} heads" + ) + + multipliers_path = os.path.join(output_dir, "kv_scale_multipliers_per_head.json") + with open(multipliers_path, "w", encoding="utf8") as f: + json.dump(kv_multipliers_perhead, f, indent=2) + + tuned_kv_scales_perhead = {} + for stats_key, multipliers in kv_multipliers_perhead.items(): + stats = activation_stats[stats_key] + min_vals = stats["min"] + max_vals = stats["max"] + for head_idx, multiplier in enumerate(multipliers): + abs_max = max(abs(min_vals[head_idx]), abs(max_vals[head_idx])) + base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8 + tuned_scale = base_scale * multiplier + base_key = stats_key.replace("attn.attn", "attn") + save_key = f"{base_key}.head_{head_idx}.scale" + tuned_kv_scales_perhead[save_key] = tuned_scale + tuned_scales_path = os.path.join(output_dir, "kv_cache_tuned_scales_per_head.json") + with open(tuned_scales_path, "w", encoding="utf8") as f: + json.dump(tuned_kv_scales_perhead, f, indent=2) + + print(f"Saved DP per-head KV multipliers to {multipliers_path}") + print(f"Saved DP per-head KV tuned scales to {tuned_scales_path}") + + else: + print(f"Skipping KV search merge for kv_granularity={kv_granularity}") + + print("\n" + "=" * 80) + print("DP KV-cache scale search merge completed!") + print("=" * 80) diff --git a/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_kv_per_head.yaml b/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_kv_per_head.yaml index 3f4fd4e1..8f0d1afc 100644 --- a/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_kv_per_head.yaml +++ b/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_kv_per_head.yaml @@ -24,6 +24,7 @@ output_fp8_hf_path: /path/to/output/fp8_model # stage2 only # -------- Model loading / runtime -------- tp_size: 16 +dp_size: 1 batch_size: 4 num_samples: 512 max_length: 16384 diff --git a/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_per_tensor.yaml b/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_per_tensor.yaml index aa93074c..3fcbbbe4 100644 --- a/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_per_tensor.yaml +++ b/configs/Hy3/ptq/fp8/Hy3_vllm_ptq_per_tensor.yaml @@ -24,6 +24,7 @@ output_fp8_hf_path: /path/to/output/fp8_model # stage2 only # -------- Model loading / runtime -------- tp_size: 16 +dp_size: 1 batch_size: 4 num_samples: 512 max_length: 16384 diff --git a/tools/run_vllm_calibrate.py b/tools/run_vllm_calibrate.py index a5ca8cbe..fbd65102 100644 --- a/tools/run_vllm_calibrate.py +++ b/tools/run_vllm_calibrate.py @@ -21,9 +21,13 @@ from vllm import LLM, SamplingParams from angelslim.compressor.quant import ( # Per-head KV-cache pipeline + KVMSEProfileCollector, + KVPerHeadMSEProfileCollector, KVScaleSearcher, KVScaleSearcherPerHead, get_activation_stats, + get_kv_mse_profile_results, + get_kv_mse_profile_results_perhead, get_kv_scale_search_results, get_kv_scale_search_results_perhead, get_kvcache_perhead_stats, @@ -45,6 +49,8 @@ setup_mtp_activation_hooks, ) from angelslim.engine import Engine +from angelslim.utils import run_vllm_calibration_with_dp +from angelslim.utils import validate_vllm_calibration_dp_args # ============================================================================= # Helper functions to access draft (MTP) model via collective_rpc @@ -97,6 +103,33 @@ def _patched_python_version(): platform.python_version = _patched_python_version +def shard_prompts( + prompts: list[str], + dp_rank: int, + dp_size: int, +) -> list[str]: + if not prompts: + raise ValueError("No calibration prompts were prepared.") + + if dp_size == 1: + return prompts + + if len(prompts) < dp_size: + raise ValueError( + f"Number of prompts ({len(prompts)}) " + f"is smaller than dp_size ({dp_size})." + ) + + shard = prompts[dp_rank::dp_size] + if not shard: + raise RuntimeError( + f"DP rank {dp_rank} received an empty shard." + ) + + return shard + + + def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser( @@ -162,6 +195,52 @@ def parse_args(): choices=["ray", "mp"], help="Distributed executor backend (default: ray)", ) + parser.add_argument( + "--ray-address", + type=str, + default=None, + help="Ray cluster address used by the top-level DP launcher. Defaults to the " + "RAY_ADDRESS environment variable or 'auto'.", + ) + parser.add_argument( + "--placement-strategy", + type=str, + default="STRICT_PACK", + choices=["STRICT_PACK", "PACK", "SPREAD", "STRICT_SPREAD"], + help="Ray placement-group strategy used by the top-level DP launcher when " + "placing one TP replica per vLLM instance.", + ) + + # Data-Parallel configuration + + parser.add_argument( + "--dp-size", + type=int, + default=1, + help="Calibration data parallel size (default: 1). When > 1, Ray Data " + "creates one vLLM calibration actor per DP replica.", + ) + parser.add_argument( + "--dp-timeout", + type=int, + default=7200, + help="Seconds to wait for Ray Data-managed DP workers before failing " + "(default: 7200).", + ) + parser.add_argument( + "--dp-num-nodes", + type=int, + default=1, + help="Legacy compatibility knob. Ray Data-managed DP launches all replicas " + "from one driver process, so this must stay 1.", + ) + parser.add_argument( + "--dp-node-rank", + type=int, + default=0, + help="Legacy compatibility knob. Ray Data-managed DP launches all replicas " + "from one driver process, so this must stay 0.", + ) # MTP (Multi-Token Prediction) configuration parser.add_argument( @@ -259,7 +338,7 @@ def parse_args(): + ", ".join("--" + n.replace("_", "-") for n in missing) ) - return args + return validate_vllm_calibration_dp_args(parser, args) def save_stats_to_json( @@ -295,10 +374,13 @@ def save_stats_to_json( print(f"\n{stats_type.capitalize()} saved to: {output_file}") -def main(): - """Main function to run calibration.""" - args = parse_args() +def run_one_calibration(args, llm=None, return_llm: bool = False): + """Run one calibration worker (single-process or one DP rank). + When ``llm`` is provided, reuse the existing vLLM instance instead of + loading the model again. ``return_llm`` is used by the DP actor to keep the + first-stage instance alive locally for the second-stage KV search. + """ # Verify environment variables are set print(f"VLLM_MOE_COLLECT_STATS: {os.environ.get('VLLM_MOE_COLLECT_STATS')}") print("\nConfiguration:") @@ -306,6 +388,8 @@ def main(): print(f" PTQ Data: {args.ptq_data_path}") print(f" Output Dir: {args.output_dir}") print(f" TP Size: {args.tp_size}") + print(f" DP Size: {args.dp_size}") + print(f" DP Rank: {getattr(args, 'dp_rank', 0)}") print(f" Batch Size: {args.batch_size}") print(f" Num Samples: {args.num_samples}") print(f" Skip Weight Loading: {args.skip_weight_loading}") @@ -322,29 +406,40 @@ def main(): else: print(" MTP Enabled: False") - # Create LLM instance - llm = LLM( - model=args.model_path, - load_format="dummy" if args.skip_weight_loading else "auto", - disable_log_stats=False, - enforce_eager=True, - enable_chunked_prefill=True, - max_num_batched_tokens=16384, - gpu_memory_utilization=0.75, - tensor_parallel_size=args.tp_size, - distributed_executor_backend=args.distributed_executor_backend, - enable_expert_parallel=False, - max_num_seqs=args.batch_size, - max_model_len=args.max_length + 16, - speculative_config=speculative_config, - # Force the Triton MoE backend so the AngelSlim fused_moe.py patch - # (which inserts collect_fused_moe_internal_stats hooks inside - # fused_experts_impl) is actually exercised. Without this vLLM may - # auto-select FlashInfer CUTLASS / TRTLLM, which run the entire - # gate_up -> activation -> down_proj pipeline inside a single - # opaque C++ kernel and bypass our Python-level hooks. - moe_backend="triton", - ) + # Environment variables should be set in the shell script for consistency + # For DP calibration, they are passed through Ray runtime_env + # For single DP (dp_size=1), they should be inherited from the shell environment + # Only set VLLM_ALLOW_INSECURE_SERIALIZATION if it's not already set + if "VLLM_ALLOW_INSECURE_SERIALIZATION" not in os.environ: + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + print(f" Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 (not set in environment)") + + if llm is None: + print("\nCreating vLLM instance...") + llm = LLM( + model=args.model_path, + load_format="dummy" if args.skip_weight_loading else "auto", + disable_log_stats=False, + enforce_eager=True, + enable_chunked_prefill=True, + max_num_batched_tokens=16384, + gpu_memory_utilization=0.75, + tensor_parallel_size=args.tp_size, + distributed_executor_backend=args.distributed_executor_backend, + enable_expert_parallel=False, + max_num_seqs=args.batch_size, + max_model_len=args.max_length + 16, + speculative_config=speculative_config, + # Force the Triton MoE backend so the AngelSlim fused_moe.py patch + # (which inserts collect_fused_moe_internal_stats hooks inside + # fused_experts_impl) is actually exercised. Without this vLLM may + # auto-select FlashInfer CUTLASS / TRTLLM, which run the entire + # gate_up -> activation -> down_proj pipeline inside a single + # opaque C++ kernel and bypass our Python-level hooks. + moe_backend="triton", + ) + else: + print("\nReusing existing vLLM instance for this stage.") if args.skip_weight_loading: print("\n" + "!" * 80) @@ -353,6 +448,116 @@ def main(): print("Use --skip-weight-loading flag to enable this mode.") print("!" * 80 + "\n") + if getattr(args, "kv_search_only", False): + print("\n" + "=" * 80) + print("Running KV-cache scale search profile collection only...") + print("=" * 80) + + activation_stats_path = getattr(args, "kv_search_activation_stats_path", None) + if not activation_stats_path: + raise ValueError("kv_search_activation_stats_path is required in kv_search_only mode") + with open(activation_stats_path, "r", encoding="utf8") as f: + activation_stats = json.load(f) + print(f"Loaded merged activation stats from: {activation_stats_path}") + + tokenizer = llm.get_tokenizer() + slim_engine = Engine() + slim_engine.series = "LLM" + from types import SimpleNamespace + + slim_engine.slim_model = SimpleNamespace( + tokenizer=tokenizer, + model=SimpleNamespace(device="cpu"), + ) + dataset = slim_engine.prepare_data( + data_path=args.ptq_data_path, + max_length=args.max_length, + num_samples=args.num_samples, + shuffle=False, + inference_settings=None, + use_audio_in_video=False, + ) + all_prompts = [tokenizer.decode(data["input_ids"][0]) for data in dataset] + search_prompt_pool = all_prompts[: args.search_kv_num_samples] + search_prompts = shard_prompts( + search_prompt_pool, + dp_rank=getattr(args, "dp_rank", 0), + dp_size=args.dp_size, + ) + print( + f"[DP {getattr(args, 'dp_rank', 0)}/{args.dp_size}] " + f"total_prompts={len(all_prompts)}, " + f"global_search_prompts={len(search_prompt_pool)}, " + f"local_search_prompts={len(search_prompts)}" + ) + + os.makedirs(args.output_dir, exist_ok=True) + payload = {} + + if args.kv_granularity == "per-tensor": + print("\nRegistering KV-value capture hooks...") + hook_results = llm.apply_model(setup_kvcache_value_hooks) + for i, result in enumerate(hook_results): + print(f" Worker {i}: {result}") + + llm.generate( + search_prompts, + SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1), + ) + + collector = KVMSEProfileCollector( + activation_stats=activation_stats, + min_multiplier=args.search_kv_min_multiplier, + max_multiplier=args.search_kv_max_multiplier, + num_steps=args.search_kv_num_steps, + ) + profile_results = llm.apply_model(collector) + kv_profiles = get_kv_mse_profile_results(profile_results) + llm.apply_model(remove_kv_scale_search_hooks) + + profiles_path = os.path.join(args.output_dir, "kv_scale_mse_profiles.json") + with open(profiles_path, "w", encoding="utf8") as f: + json.dump(kv_profiles, f, indent=2) + print(f"KV-cache local MSE profiles saved to: {profiles_path}") + payload["kv_scale_mse_profiles.json"] = kv_profiles + + elif args.kv_granularity == "per-head": + print("\nRegistering per-head KV-value capture hooks...") + hook_results = llm.apply_model(setup_kvcache_perhead_value_hooks) + for i, result in enumerate(hook_results): + print(f" Worker {i}: {result}") + + llm.generate( + search_prompts, + SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1), + ) + + collector_ph = KVPerHeadMSEProfileCollector( + activation_stats=activation_stats, + min_multiplier=args.search_kv_min_multiplier, + max_multiplier=args.search_kv_max_multiplier, + num_steps=args.search_kv_num_steps, + ) + profile_results_ph = llm.apply_model(collector_ph) + kv_profiles_perhead = get_kv_mse_profile_results_perhead(profile_results_ph) + llm.apply_model(remove_kvcache_perhead_value_hooks) + + profiles_ph_path = os.path.join( + args.output_dir, + "kv_scale_mse_profiles_per_head.json", + ) + with open(profiles_ph_path, "w", encoding="utf8") as f: + json.dump(kv_profiles_perhead, f, indent=2) + print(f"Per-head KV-cache local MSE profiles saved to: {profiles_ph_path}") + payload["kv_scale_mse_profiles_per_head.json"] = kv_profiles_perhead + + else: + print("KV search skipped because kv_granularity=none") + + if return_llm: + return payload, llm + return payload + # Setup activation hooks on all workers # kv_granularity controls which KV hooks are registered alongside Linear hooks: # 'none' -> no KV hooks @@ -413,9 +618,13 @@ def main(): slim_engine = Engine() slim_engine.slim_model = llm slim_engine.series = "LLM" - slim_engine.slim_model.tokenizer = tokenizer - slim_engine.slim_model.model = llm - slim_engine.slim_model.model.device = "cpu" + from types import SimpleNamespace + slim_engine.slim_model = SimpleNamespace( + tokenizer=tokenizer, + model=SimpleNamespace( + device="cpu", + ), + ) dataset = slim_engine.prepare_data( data_path=args.ptq_data_path, max_length=args.max_length, @@ -425,8 +634,23 @@ def main(): use_audio_in_video=False, ) - prompts = [tokenizer.decode(data["input_ids"][0]) for data in dataset] - print(f"Loaded {len(prompts)} prompts from dataset") + all_prompts = [tokenizer.decode(data["input_ids"][0]) for data in dataset] + print(f"Loaded {len(all_prompts)} prompts from dataset before DP sharding") + + # Apply DP sharding + prompts = shard_prompts( + all_prompts, + dp_rank=getattr(args, "dp_rank", 0), + dp_size=args.dp_size, + ) + + # prompts = prompts[:10] + # prompts = prompts[-10:] + print( + f"[DP {getattr(args, 'dp_rank', 0)}/{args.dp_size}] " + f"total_prompts={len(all_prompts)}, " + f"local_prompts={len(prompts)}" + ) # Create sampling params (fixed values for calibration) # When MTP is enabled, we need to generate more tokens to trigger @@ -496,11 +720,15 @@ def main(): # Create output directory os.makedirs(args.output_dir, exist_ok=True) + payload = {} + # Save activation statistics stats_list = llm.apply_model(get_activation_stats) save_stats_to_json( stats_list, args.output_dir, "activation_stats.json", stats_type="activation statistics" ) + if stats_list and stats_list[0] is not None: + payload["activation_stats.json"] = stats_list[0] # Save MoE expert statistics moe_stats_dict = llm.apply_model(get_moe_stats) @@ -510,6 +738,8 @@ def main(): "moe_expert_stats.json", stats_type="MoE expert statistics", ) + if moe_stats_dict and moe_stats_dict[0] is not None: + payload["moe_expert_stats.json"] = moe_stats_dict[0] # Save MTP draft model statistics (if MTP is enabled) if args.enable_mtp: @@ -522,6 +752,10 @@ def main(): "mtp_activation_stats.json", stats_type="MTP activation statistics", ) + if mtp_stats_list: + mtp_stats = next((r for r in mtp_stats_list if r), None) + if mtp_stats is not None: + payload["mtp_activation_stats.json"] = mtp_stats mtp_moe_stats_dict = llm.llm_engine.collective_rpc( lambda w: _apply_on_draft_model(w, get_mtp_moe_stats) @@ -532,6 +766,10 @@ def main(): "mtp_moe_expert_stats.json", stats_type="MTP MoE expert statistics", ) + if mtp_moe_stats_dict: + mtp_moe_stats = next((r for r in mtp_moe_stats_dict if r), None) + if mtp_moe_stats is not None: + payload["mtp_moe_expert_stats.json"] = mtp_moe_stats # --------------------------------------------------------------- # Per-head KV-cache stats for the MTP draft model. @@ -575,6 +813,7 @@ def main(): f"[MTP] Merged {len(mtp_ph_stats)} per-head KV-cache entries " f"into {mtp_act_path} (per-tensor scalars overwritten)." ) + payload["mtp_activation_stats.json"] = merged_mtp # Clean up the per-head hooks on the draft model. llm.llm_engine.collective_rpc( @@ -635,6 +874,7 @@ def main(): with open(multipliers_path, "w") as f: json.dump(kv_multipliers, f, indent=2) print(f"\nKV-cache scale multipliers saved to: {multipliers_path}") + payload["kv_scale_multipliers.json"] = kv_multipliers # Also save the final (scaled) kv cache scales for direct use fp8_max = 448.0 # torch.finfo(torch.float8_e4m3fn).max @@ -650,6 +890,7 @@ def main(): with open(tuned_scales_path, "w") as f: json.dump(tuned_kv_scales, f, indent=2) print(f"Tuned KV-cache scales saved to: {tuned_scales_path}") + payload["kv_cache_tuned_scales.json"] = tuned_kv_scales print("\n" + "=" * 80) print("KV-cache per-tensor scale search completed!") @@ -685,6 +926,7 @@ def main(): with open(merged_stats_path, "w") as f: json.dump(merged_stats, f, indent=2) print(f"\nKV-cache per-head statistics merged into: {merged_stats_path}") + payload["activation_stats.json"] = merged_stats # Remove per-head min/max hooks before (optionally) registering value hooks llm.apply_model(remove_kvcache_perhead_hooks) @@ -742,6 +984,7 @@ def main(): with open(multipliers_ph_path, "w") as f: json.dump(kv_multipliers_perhead, f, indent=2) print(f"\nKV-cache per-head scale multipliers saved to: {multipliers_ph_path}") + payload["kv_scale_multipliers_per_head.json"] = kv_multipliers_perhead # Compute and save final tuned per-head scales fp8_max = 448.0 @@ -763,11 +1006,29 @@ def main(): with open(tuned_ph_path, "w") as f: json.dump(tuned_kv_scales_perhead, f, indent=2) print(f"Tuned per-head KV-cache scales saved to: {tuned_ph_path}") + payload["kv_cache_tuned_scales_per_head.json"] = tuned_kv_scales_perhead print("\n" + "=" * 80) print("KV-cache per-head scale search completed!") print("=" * 80) + if return_llm: + return payload, llm + return payload + + +def main(): + """Main function to run calibration.""" + args = parse_args() + + if args.dp_size == 1: + args.dp_rank = 0 + run_one_calibration(args) + return + + # For dp_size > 1, delegate to the existing DP launcher + run_vllm_calibration_with_dp(args, run_one_calibration) + if __name__ == "__main__": main()