Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,12 @@ def get_specializations(
elif img_size is None:
img_size = 896 # FIXME based on gemma3 Image size
logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config")
mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: raise an exception instead of assert in all modelling changes

vision_size = user_vision_size
else:
vision_size = getattr(self.config, "mm_tokens_per_image", 256)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a warning mm_tokens_per_image will be deprecated in the next release. Having two input arguments for a single input introduces ambiguity.


vision = [
{
Expand All @@ -752,7 +757,7 @@ def get_specializations(
"comp_ctx_lengths": comp_ctx_lengths_prefill[i],
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_size": vision_size,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -771,7 +776,7 @@ def get_specializations(
"comp_ctx_lengths": comp_ctx_lengths_decode[i],
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_size": vision_size,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -787,7 +792,7 @@ def get_specializations(
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_size": vision_size,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -803,7 +808,7 @@ def get_specializations(
"ctx_len": ctx_len,
"sliding_window": self.language_model.config.sliding_window,
"img_size": img_size,
"mm_tokens_per_image": mm_tokens_per_image,
"vision_size": vision_size,
"vision_batch_size": batch_size,
}
if continuous_batching:
Expand All @@ -829,7 +834,7 @@ def get_onnx_dynamic_axes(
lang_dynamic_axes = {}
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "mm_tokens_per_image"}
lang_dynamic_axes["vision_embeds"] = {0: "vision_batch_size", 1: "vision_size"}
if continuous_batching:
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"}
Expand Down Expand Up @@ -911,13 +916,13 @@ def get_dummy_inputs(
else:
img_size = 896

mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
vision_size = getattr(self.config, "mm_tokens_per_image", 256)
# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vision_embeds"] = (
1, # constants.INTERN_NUM_PATCHES,
mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE,
vision_size, # constants.INTERN_FEATURE_SIZE,
self.language_model.config.hidden_size, # 5120
)
inputs_shapes["position_ids"] = (
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ def get_specializations(
raise NotImplementedError("Image Size other than 448 is not supported for Intern models yet.")

per_patch_embed_size = (img_size // self.config.vision_config.patch_size * self.config.downsample_ratio) ** 2
vision_size = int(batch_size * num_patches * per_patch_embed_size)
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
vision_size = user_vision_size
else:
vision_size = int(batch_size * num_patches * per_patch_embed_size)
vision = [
{
"batch_size": batch_size,
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,12 @@ def get_specializations(
* (img_size // self.config.vision_config.patch_size)
// downsample_ratio
)
vision_size = num_features_per_tile * max_num_tiles
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
vision_size = user_vision_size
else:
vision_size = num_features_per_tile * max_num_tiles

vision = [
{
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,12 @@ def get_specializations(
logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config")
if img_size != 336 and kv_offload:
raise NotImplementedError("Image Size other than 336 is not supported for Llava models yet.")
vision_size = (img_size // self.config.vision_config.patch_size) ** 2
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
vision_size = user_vision_size
else:
vision_size = (img_size // self.config.vision_config.patch_size) ** 2
vision = [
{
"batch_size": batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,12 @@ def get_specializations(
logger.warning("Setting img_size to be 384, as it was neither passed nor found in vision_config")
if img_size != constants.GRANITEVISION_IMG_SIZE and kv_offload:
logger.warning("Image Size other than 384 is not supported for LlavaNext models yet.")
vision_size = constants.GRANITEVISION_FEATURE_SIZE
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
vision_size = user_vision_size
else:
vision_size = constants.GRANITEVISION_FEATURE_SIZE
vision = [
{
"batch_size": batch_size,
Expand Down
11 changes: 8 additions & 3 deletions QEfficient/transformers/models/mistral3/modeling_mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,14 @@ def get_specializations(
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
patch_size = self.config.vision_config.patch_size
kernel_size = self.config.spatial_merge_size
vision_size = (
((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
)
user_vision_size = compiler_options.pop("vision_size", None)
if user_vision_size:
assert user_vision_size < ctx_len, "vision_size must be less than ctx_len"
vision_size = user_vision_size
else:
vision_size = (
((img_size // patch_size) * (img_size // patch_size)) * (batch_size) // (kernel_size * kernel_size)
)

vision = [
{
Expand Down
15 changes: 8 additions & 7 deletions QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,13 +1091,14 @@ def get_specializations(
"resolution."
)
else:
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)
if vision_size * f >= user_vision_size:
logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)

vision.append(
{
Expand Down
15 changes: 8 additions & 7 deletions QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,13 +972,14 @@ def get_specializations(
"resolution."
)
else:
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)
if vision_size * f >= user_vision_size:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we want the warning to be raised even if vision_size * f == user_vision_size? I think we should raise the warning only when the user_vision_size is strictly less than the calculated size.

logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)

vision.append(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,13 +1001,14 @@ def get_specializations(
"resolution."
)
else:
assert vision_size * f <= user_vision_size, (
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) cannot exceed the provided "
f"vision_size={user_vision_size}. Please adjust the image resolution or "
"increase the vision_size."
)
if vision_size * f >= user_vision_size:
logger.warning_once(
f"Computed vision_size of {vision_size * f} tokens "
f"(vision_size={vision_size}, num_frames={f}) for image resolution "
f"(width={w}, height={h}) exceed the provided "
f"vision_size={user_vision_size}. "
f"Vision embedding need to be chunked during prefill."
)

vision.append(
{
Expand Down
Loading