Repeatkv transform#997
Conversation
…VLMs. Based on PR quic#625. Addressed most of the comments made on the previous PR. Repeat check is done on a subset of models during CI, primarily due to difference in configs of such models. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
…ng with changes made for the new transforms. TODO: Check for the ONNX directory path name being different. Check if the list of classes for mapping covers all the models that we support. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
…oder Wrappers were added to string mapping list to enable dummy model export for CI. Changes were made to prevent multiple application of ReplicateKVTransform if done in either Encoder or Decoder Wrapper already. Modeling files updated to access config in EncoderWrapper as well. Infra added for causalLM and VLM checks for repeatKV setup CI tests. CausalLM script APIRunner instantiation moved to allow updated input shapes to be made. Similarly commented export in VLM script since compile will call it with updated changes already. TODO: Confirm the changes that were made for DeepSeekV3 model for RepeatKV, currently they were removed for a generic approach. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Made changes to allow generic name based transformation of heads (num_attention_heads, n_heads, n_head etc). Minor edits and utils created for this task. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Edited the changes as suggested by quic-mamta. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
|
nit: should we rename this to |
| architectures = getattr(model_config, "architectures", None) or [] | ||
| is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures | ||
| if qaic_config: | ||
| if is_deepseek_v3 and (qaic_config.get("blocking_mode", None) == "h"): |
There was a problem hiding this comment.
nit: for models w/mla and single kv heads, we do not want to replicate, ex: deepseekv3 is this what is being done here? not clear.
| self.model.config.text_config.use_cache = True | ||
| else: | ||
| self.model.config.use_cache = True | ||
| # self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs) |
There was a problem hiding this comment.
nit: remove commented code.
| if cls._is_mla_attention(attn): | ||
| # Legacy MLA support: KV compression projection is organized as | ||
| # [kv_heads, kv_lora_rank + qk_rope_head_dim, hidden_size]. | ||
| mla_orig_kv_heads = 1 |
There was a problem hiding this comment.
nit: remove magic numbers, get it from the constants file
| # Generic config key aliases used across model families. | ||
| ATTENTION_HEAD_CONFIG_KEYS = ("num_attention_heads", "n_head", "n_heads", "num_heads") | ||
| KV_HEAD_CONFIG_KEYS = ("num_key_value_heads", "n_kv_heads", "num_kv_heads", "effective_n_kv_heads") | ||
| HIDDEN_SIZE_CONFIG_KEYS = ("hidden_size", "n_embd", "d_model") |
There was a problem hiding this comment.
does this cover all the models we support as of today?
| "meta-llama/Llama-3.2-1B", | ||
| # "unsloth/gemma-2b", | ||
| # "unsloth/gemma-2-2b", | ||
| # "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", |
There was a problem hiding this comment.
why is this commented? any known failures w/awq, gemma, mistral models?
|
@quic-dhirajku also added detailed pr desp. about the design and changes added and test plan validated. thanks |
| qaic_config["head_block_size"] = qaic_config.get("head_block_size", num_devices) | ||
| num_kv_heads_repeat = qaic_config.get("num_kv_heads_repeat", 1) | ||
| architectures = getattr(model_config, "architectures", None) or [] | ||
| is_deepseek_v3 = "DeepseekV3ForCausalLM" in architectures |
There was a problem hiding this comment.
Please remove the lines 459-463, not needed.
No description provided.