From 43cdf0dad2287734502e34c67385d8a1b57f37d1 Mon Sep 17 00:00:00 2001 From: quic-xiyushi Date: Mon, 18 May 2026 10:28:45 -0700 Subject: [PATCH] Add user_vision_size in VLM get_specializations for chunked embedding in vLLM v1 Signed-off-by: quic-xiyushi --- .../models/gemma3/modeling_gemma3.py | 21 ++++++++++++------- .../models/internvl/modeling_internvl.py | 7 ++++++- .../models/llama4/modeling_llama4.py | 7 ++++++- .../models/llava/modeling_llava.py | 7 ++++++- .../models/llava_next/modeling_llava_next.py | 7 ++++++- .../models/mistral3/modeling_mistral3.py | 11 +++++++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 15 ++++++------- .../models/qwen3_vl/modeling_qwen3_vl.py | 15 ++++++------- .../qwen3_vl_moe/modeling_qwen3_vl_moe.py | 15 ++++++------- 9 files changed, 69 insertions(+), 36 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 8fb8cdbdda..1b28c38fc3 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -731,7 +731,12 @@ def get_specializations( elif img_size is None: img_size = 896 # FIXME based on gemma3 Image size logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = getattr(self.config, "mm_tokens_per_image", 256) vision = [ { @@ -752,7 +757,7 @@ def get_specializations( "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -771,7 +776,7 @@ def get_specializations( "comp_ctx_lengths": comp_ctx_lengths_decode[i], "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -787,7 +792,7 @@ def get_specializations( "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -803,7 +808,7 @@ def get_specializations( "ctx_len": ctx_len, "sliding_window": self.language_model.config.sliding_window, "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, + "vision_size": vision_size, "vision_batch_size": batch_size, } if continuous_batching: @@ -829,7 +834,7 @@ def get_onnx_dynamic_axes( lang_dynamic_axes = {} lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} - lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"} + lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "vision_size"} if continuous_batching: lang_dynamic_axes["batch_index"] = {0: "batch_size"} vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"} @@ -911,13 +916,13 @@ def get_dummy_inputs( else: img_size = 896 - mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256) + vision_size = getattr(self.config, "mm_tokens_per_image", 256) # Define shapes inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) inputs_shapes["vision_embeds"] = ( 1, # constants.INTERN_NUM_PATCHES, - mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE, + vision_size, # constants.INTERN_FEATURE_SIZE, self.language_model.config.hidden_size, # 5120 ) inputs_shapes["position_ids"] = ( diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index e389e6a840..da4ec5758c 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -134,7 +134,12 @@ def get_specializations( raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.") per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2 - vision_size = int(batch_size * num_patches * per_patch_embed_size) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = int(batch_size * num_patches * per_patch_embed_size) vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index a49f9a24be..7f6c160d19 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -1001,7 +1001,12 @@ def get_specializations( * (img_size // self.config.vision_config.patch_size) // downsample_ratio ) - vision_size = num_features_per_tile * max_num_tiles + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = num_features_per_tile * max_num_tiles vision = [ { diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 48b002a31a..4a873bb0bd 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -237,7 +237,12 @@ def get_specializations( logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config") if img_size != 336 and kv_offload: raise NotImplementedError("Image Size other than 336 is not supported for Llava models yet.") - vision_size = (img_size // self.config.vision_config.patch_size) ** 2 + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = (img_size // self.config.vision_config.patch_size) ** 2 vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 59d5cad229..fec5ad8253 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -326,7 +326,12 @@ def get_specializations( logger.warning("Setting img_size to be 384, as it was neither passed nor found in vision_config") if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload: logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.") - vision_size = constants.GRANITEVISION_FEATURE_SIZE + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = constants.GRANITEVISION_FEATURE_SIZE vision = [ { "batch_size": batch_size, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a8fb34bafe..f5658025f3 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -370,9 +370,14 @@ def get_specializations( ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN patch_size = self.config.vision_config.patch_size kernel_size = self.config.spatial_merge_size - vision_size = ( - ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) - ) + user_vision_size = compiler_options.pop("vision_size", None) + if user_vision_size: + assert user_vision_size < ctx_len, "vision_size must be less than ctx_len" + vision_size = user_vision_size + else: + vision_size = ( + ((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size) + ) vision = [ { diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 45c6616018..dd0a23ccdd 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1091,13 +1091,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( { diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 2d834423f6..39ff9e0138 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -972,13 +972,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( { diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 078cb4afb7..2f5cd86ee2 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -1001,13 +1001,14 @@ def get_specializations( "resolution." ) else: - assert vision_size * f <= user_vision_size, ( - f"Computed vision_size of {vision_size * f} tokens " - f"(vision_size={vision_size}, num_frames={f}) for image resolution " - f"(width={w}, height={h}) cannot exceed the provided " - f"vision_size={user_vision_size}. Please adjust the image resolution or " - "increase the vision_size." - ) + if vision_size * f >= user_vision_size: + logger.warning_once( + f"Computed vision_size of {vision_size * f} tokens " + f"(vision_size={vision_size}, num_frames={f}) for image resolution " + f"(width={w}, height={h}) exceed the provided " + f"vision_size={user_vision_size}. " + f"Vision embedding need to be chunked during prefill." + ) vision.append( {