Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
from .search import (
KVCachePerHeadValueHook,
KVCacheValueHook,
KVMSEProfileCollector,
KVPerHeadMSEProfileCollector,
KVScaleSearcher,
KVScaleSearcherPerHead,
get_kv_mse_profile_results,
get_kv_mse_profile_results_perhead,
get_kv_scale_search_results,
get_kv_scale_search_results_perhead,
remove_kv_scale_search_hooks,
Expand Down Expand Up @@ -88,12 +92,16 @@
"KVCacheValueHook",
"setup_kvcache_value_hooks",
"KVScaleSearcher",
"KVMSEProfileCollector",
"get_kv_scale_search_results",
"get_kv_mse_profile_results",
"remove_kv_scale_search_hooks",
# KV scale search (per-head)
"KVCachePerHeadValueHook",
"setup_kvcache_perhead_value_hooks",
"remove_kvcache_perhead_value_hooks",
"KVScaleSearcherPerHead",
"KVPerHeadMSEProfileCollector",
"get_kv_scale_search_results_perhead",
"get_kv_mse_profile_results_perhead",
]
274 changes: 274 additions & 0 deletions angelslim/compressor/quant/core/vllm_calibrate_utils/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,19 @@
# Per-tensor scale search
"KVCacheValueHook",
"KVScaleSearcher",
"KVMSEProfileCollector",
"setup_kvcache_value_hooks",
"get_kv_scale_search_results",
"get_kv_mse_profile_results",
"remove_kv_scale_search_hooks",
# Per-head scale search
"KVCachePerHeadValueHook",
"setup_kvcache_perhead_value_hooks",
"remove_kvcache_perhead_value_hooks",
"KVScaleSearcherPerHead",
"KVPerHeadMSEProfileCollector",
"get_kv_scale_search_results_perhead",
"get_kv_mse_profile_results_perhead",
]


Expand Down Expand Up @@ -410,6 +414,151 @@ def get_kv_scale_search_results(results_list: list) -> dict:
return first


def _make_log_uniform_multipliers(
min_multiplier: float,
max_multiplier: float,
num_steps: int,
) -> list[float]:
"""Build the log-uniform multiplier grid used by KV scale search."""
import math

if num_steps <= 1:
return [float(min_multiplier)]

log_min = math.log(min_multiplier)
log_max = math.log(max_multiplier)
return [
math.exp(log_min + (log_max - log_min) * i / (num_steps - 1))
for i in range(num_steps)
]


def _compute_kv_mse_profile_flat(
flat: torch.Tensor,
base_scale: float,
min_multiplier: float,
max_multiplier: float,
num_steps: int,
) -> dict:
"""Compute an additive SSE/numel profile for DP KV search aggregation."""
multipliers = _make_log_uniform_multipliers(
min_multiplier,
max_multiplier,
num_steps,
)
if flat.numel() == 0:
return {
"multipliers": multipliers,
"sse": [0.0 for _ in multipliers],
"numel": 0,
"base_scale": base_scale,
}

fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
sse = []

if flat.is_cuda:
sse_vals = torch.empty(len(multipliers), dtype=torch.float64, device=flat.device)
for i, multiplier in enumerate(multipliers):
scale = base_scale * multiplier
q_fp8 = (
(flat / scale).clamp(fp8_min, fp8_max).to(torch.float8_e4m3fn).to(torch.float32)
)
diff = flat - q_fp8 * scale
sse_vals[i] = (diff.double() * diff.double()).sum()
sse = [float(v) for v in sse_vals.cpu().tolist()]
del sse_vals
else:
for multiplier in multipliers:
scale = base_scale * multiplier
q_fp8 = (
(flat / scale).clamp(fp8_min, fp8_max).to(torch.float8_e4m3fn).to(torch.float32)
)
diff = flat - q_fp8 * scale
sse.append(float((diff.double() * diff.double()).sum().item()))

return {
"multipliers": multipliers,
"sse": sse,
"numel": int(flat.numel()),
"base_scale": base_scale,
}


class KVMSEProfileCollector:
"""Collect local per-tensor KV MSE profiles for DP aggregation."""

def __init__(
self,
activation_stats: dict,
min_multiplier: float = 0.8,
max_multiplier: float = 16.0,
num_steps: int = 100,
):
self.activation_stats = activation_stats
self.min_multiplier = min_multiplier
self.max_multiplier = max_multiplier
self.num_steps = num_steps

def __call__(self, model):
fp8_max = torch.finfo(torch.float8_e4m3fn).max
kv_values = _get_kv_search_values(model)
if not kv_values:
print(
"[KVMSEProfileCollector] WARNING: No kv values collected. "
"Did you call setup_kvcache_value_hooks before inference?"
)
return {}

use_gpu = torch.cuda.is_available()
search_device = (
torch.device("cuda", torch.cuda.current_device()) if use_gpu else torch.device("cpu")
)
profiles = {}

for layer_name, tensors_dict in kv_values.items():
for kv_slot, tensors in tensors_dict.items():
stats_key = f"{layer_name}.{kv_slot}_cache"
if stats_key not in self.activation_stats:
print(
f"[KVMSEProfileCollector] WARNING: {stats_key} not found in "
"activation_stats, skipping."
)
continue
if not tensors:
continue

stats = self.activation_stats[stats_key]
abs_max = max(abs(stats["min"]), abs(stats["max"]))
base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8
flat_cpu = torch.cat([t.reshape(-1).float() for t in tensors])
flat = flat_cpu.to(search_device, non_blocking=True) if use_gpu else flat_cpu
profile = _compute_kv_mse_profile_flat(
flat=flat,
base_scale=base_scale,
min_multiplier=self.min_multiplier,
max_multiplier=self.max_multiplier,
num_steps=self.num_steps,
)
if use_gpu:
del flat
profiles[stats_key] = profile
print(
f"[KVMSEProfileCollector] {stats_key}: "
f"numel={profile['numel']} base_scale={base_scale:.6f}"
)

return profiles


def get_kv_mse_profile_results(results_list: list) -> dict:
"""Extract the per-tensor MSE profile dict from ``llm.apply_model`` results."""
if not results_list:
return {}
return results_list[0] or {}


def remove_kv_scale_search_hooks(model):
"""
Clean up kv-value hooks after search. Pass to ``llm.apply_model``.
Expand Down Expand Up @@ -717,6 +866,102 @@ def __call__(self, model):
return multipliers


class KVPerHeadMSEProfileCollector:
"""Collect local per-head KV MSE profiles for DP aggregation."""

def __init__(
self,
activation_stats: dict,
min_multiplier: float = 0.8,
max_multiplier: float = 16.0,
num_steps: int = 100,
):
self.activation_stats = activation_stats
self.min_multiplier = min_multiplier
self.max_multiplier = max_multiplier
self.num_steps = num_steps

def __call__(self, model):
kv_values = _get_kv_perhead_search_values(model)
if not kv_values:
print(
"[KVPerHeadMSEProfileCollector] WARNING: No per-head kv values collected. "
"Did you call setup_kvcache_perhead_value_hooks before inference?"
)
return {}

use_gpu = torch.cuda.is_available()
search_device = (
torch.device("cuda", torch.cuda.current_device()) if use_gpu else torch.device("cpu")
)
fp8_max = torch.finfo(torch.float8_e4m3fn).max

rank, world_size = _get_dist_info()
num_kv_heads_total = None
for stats in self.activation_stats.values():
min_vals = stats.get("min")
if isinstance(min_vals, list):
num_kv_heads_total = len(min_vals)
break
if isinstance(min_vals, torch.Tensor):
num_kv_heads_total = min_vals.numel()
break
if num_kv_heads_total is None or num_kv_heads_total <= 0:
num_kv_heads_total = 1

role, _, global_head_offset, replication = _compute_perhead_layout(
rank,
world_size,
num_kv_heads_total,
)
profiles: dict[str, dict] = {}

for layer_name, tensors_dict in kv_values.items():
for kv_slot, tensors in tensors_dict.items():
if role != "both" and kv_slot != role:
continue

stats_key = f"{layer_name}.{kv_slot}_cache"
if stats_key not in self.activation_stats or not tensors:
continue

stats = self.activation_stats[stats_key]
min_vals = stats["min"]
max_vals = stats["max"]
stacked = torch.cat(tensors, dim=1)
profiles.setdefault(stats_key, {})

for local_head_idx in range(stacked.shape[0]):
global_head_idx = global_head_offset + local_head_idx
if global_head_idx >= num_kv_heads_total:
continue

abs_max = max(
abs(min_vals[global_head_idx]),
abs(max_vals[global_head_idx]),
)
base_scale = abs_max / fp8_max * 2.0 if abs_max != 0 else 1e-8
flat_cpu = stacked[local_head_idx].reshape(-1).float()
flat = flat_cpu.to(search_device, non_blocking=True) if use_gpu else flat_cpu
profile = _compute_kv_mse_profile_flat(
flat=flat,
base_scale=base_scale,
min_multiplier=self.min_multiplier,
max_multiplier=self.max_multiplier,
num_steps=self.num_steps,
)
if use_gpu:
del flat
profiles[stats_key][global_head_idx] = profile

print(
f"[KVPerHeadMSEProfileCollector] rank={rank}/{world_size}, role={role}, "
f"head_offset={global_head_offset}, H_total={num_kv_heads_total}, "
f"replication={replication}"
)
return profiles


def get_kv_scale_search_results_perhead(results_list: list) -> dict:
"""
Merge per-head multiplier dicts from all TP workers.
Expand Down Expand Up @@ -749,3 +994,32 @@ def get_kv_scale_search_results_perhead(results_list: list) -> dict:
final[stats_key] = [head_dict[i] for i in sorted_indices]

return final


def get_kv_mse_profile_results_perhead(results_list: list) -> dict:
"""
Merge per-head MSE profile dicts from all TP workers.

Each worker's result is ``{stats_key: {global_head_idx: profile}}``.
The returned format is ``{stats_key: [profile_head0, profile_head1, ...]}``.
Missing heads are kept as ``None`` placeholders so the list index always
matches the global KV head index during cross-DP aggregation.
"""
if not results_list:
return {}

merged: dict = {}
for worker_result in results_list:
if not worker_result:
continue
for stats_key, head_profiles in worker_result.items():
merged.setdefault(stats_key, {}).update(head_profiles)

final: dict = {}
for stats_key, head_profiles in merged.items():
if not head_profiles:
final[stats_key] = []
continue
max_head_idx = max(int(i) for i in head_profiles.keys())
final[stats_key] = [head_profiles.get(i) for i in range(max_head_idx + 1)]
return final
2 changes: 2 additions & 0 deletions angelslim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@
from .utils import print_with_rank # noqa: F401
from .utils import rank0_print # noqa: F401
from .utils import set_op_by_name # noqa: F401
from .vllm_calibration_dp import run_vllm_calibration_with_dp # noqa: F401
from .vllm_calibration_dp import validate_vllm_calibration_dp_args # noqa: F401
from .zero3_io import * # noqa: F401 F403
Loading
Loading