Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
self.model = model
self.config = model.config
self.hash_params = create_model_params(self, **kwargs)
self.hash_params["num_kv_heads_repeat"] = kwargs.get("num_kv_heads_repeat", 1)
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -440,23 +441,43 @@ def transform(
**compiler_options,
):
# Apply the transformations that are dependent on compilation parameters
def _transform_tracking_root(module: torch.nn.Module) -> torch.nn.Module:
"""
Use the shared wrapped model as transform-tracking root when available.
This lets encoder/decoder wrappers coordinate one-time transforms.
"""
wrapped = getattr(module, "model", None)
return wrapped if isinstance(wrapped, torch.nn.Module) else module

qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)

model_config = getattr(self.model, "config", None) or getattr(self.model.model, "config", None)
model_config = getattr(self.model, "config", None) or getattr(
getattr(self.model, "model", None), "config", None
)

if model_config:
if "DeepseekV3ForCausalLM" in (getattr(model_config, "architectures", None) or []):
if qaic_config:
if qaic_config.get("blocking_mode", None) == "h":
qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices)
num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1)
architectures = getattr(model_config, "architectures", None) or []
is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures
Copy link
Copy Markdown
Contributor

@quic-mamta quic-mamta May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the lines 459-463, not needed.

if qaic_config:
if is_deepseek_v3 and (qaic_config.get("blocking_mode", None) == "h"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for models w/mla and single kv heads, we do not want to replicate, ex: deepseekv3 is this what is being done here? not clear.

qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices)
Comment thread
quic-dhirajku marked this conversation as resolved.
num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1)
transform_root = _transform_tracking_root(self.model)
applied_transforms = getattr(transform_root, "_qeff_runtime_transforms_applied", set())

if ReplicateKVHeadTransform.__name__ in applied_transforms:
replicate_kv_transformed = False
logger.warning("Skipping RepeatKVTransform: already applied on this model instance.")
else:
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(
self.model, num_kv_heads_repeat
self.model,
num_kv_heads_repeat=num_kv_heads_repeat,
)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()

applied_transforms.add(ReplicateKVHeadTransform.__name__)
setattr(transform_root, "_qeff_runtime_transforms_applied", applied_transforms)
if replicate_kv_transformed:
self.hash_params["config"] = self.model.config.to_diff_dict()
blocking_config = build_transformer_blocking_config_for_transform(
model_config,
ctx_len=ctx_len,
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ class QEffGemma3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config
self.model.vision_model = self.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class QEffInternEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ class QEffLlama4EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.vision_tower
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class QEFFMistral3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config
self.model.vision_model = self.model.vision_tower

def get_submodules_for_export(self) -> Type[nn.Module]:
Expand Down
19 changes: 18 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ def __init__(
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None
# self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
# are done. The role of the sampler is to just add nodes at the output of the
Expand Down Expand Up @@ -1273,6 +1274,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -1281,6 +1283,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_kv_heads_repeat=num_kv_heads_repeat,
**kwargs,
)

Expand Down Expand Up @@ -1371,7 +1374,12 @@ def export(
if prefill_only and prefill_seq_len > 1:
offload_pt_weights = False # to keep weight for decode onnx
else:
offload_pt_weights = kwargs.get("offload_pt_weights", True)
num_kv_heads_repeat = (
(self.lang_model.model.qaic_config or {}).get("num_kv_heads_repeat", 1)
if hasattr(self.lang_model.model, "qaic_config")
else 1
)
offload_pt_weights = kwargs.get("offload_pt_weights", num_kv_heads_repeat <= 1)

if not skip_lang:
self.lang_model.export(
Expand Down Expand Up @@ -2037,6 +2045,7 @@ def __init__(
self.model.config.text_config.use_cache = True
else:
self.model.config.use_cache = True
# self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove commented code.

self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = False
if qaic_config:
Expand Down Expand Up @@ -2086,6 +2095,7 @@ def from_pretrained(
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
config._attn_implementation = "eager"
config.vision_config.use_flash_attn = "false"
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -2094,6 +2104,7 @@ def from_pretrained(
model,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_kv_heads_repeat=num_kv_heads_repeat,
**kwargs,
)

Expand Down Expand Up @@ -2698,6 +2709,7 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
Expand All @@ -2708,6 +2720,7 @@ def from_pretrained(
continuous_batching=continuous_batching,
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
num_kv_heads_repeat=num_kv_heads_repeat,
**kwargs,
)

Expand Down Expand Up @@ -2867,6 +2880,7 @@ def __init__(
setattr(self.model, "mla_absorption", mla_absorption)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.hash_params["max_seq_len_cached"] = max_seq_len_cached
# self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)

# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
Expand Down Expand Up @@ -2950,6 +2964,7 @@ def from_pretrained(
kv_offload = kwargs.pop("kv_offload", None)

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if qaic_config is not None:
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
Expand All @@ -2963,6 +2978,7 @@ def from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
qaic_config=qaic_config,
continuous_batching=continuous_batching,
num_kv_heads_repeat=num_kv_heads_repeat,
**kwargs,
)
return cls(
Expand All @@ -2971,6 +2987,7 @@ def from_pretrained(
qaic_config=qaic_config,
pretrained_model_name_or_path=pretrained_model_name_or_path,
max_seq_len_cached=max_seq_len_cached,
num_kv_heads_repeat=num_kv_heads_repeat,
**kwargs,
)

Expand Down
1 change: 1 addition & 0 deletions QEfficient/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ class QEffMolmoEncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config

def get_submodules_for_export(self) -> Type[nn.Module]:
"""
Expand Down
Loading
Loading