diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d89f4c361..ec202568f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -96,6 +96,7 @@ class Llama: def __init__( self, model_path: str, + mmproj_path: Optional[str] = None, *, # Model Params n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto", @@ -172,6 +173,7 @@ def __init__( log_filters: Optional[Sequence[str]] = None, log_filters_case_sensitive: bool = True, # Extra Params + chat_handler_kwargs: Dict[str, Any] = {}, **kwargs, # type: ignore ): """Load a llama.cpp model from `model_path`. @@ -711,6 +713,18 @@ def __init__( print(f"Failed to load metadata: {e}", file=sys.stderr) if self.verbose: + print(f"Model metadata: {self.metadata}", file=sys.stderr) + + if mmproj_path is not None: + if self.chat_handler is not None and self.verbose: + print("Warning: Both `chat_handler` and `mmproj_path` are not null. Chat handler will be overwritten.", flush = True) + + self.chat_handler = llama_chat_format.GenericMTMDChatHandler( + chat_format = self.metadata.get("tokenizer.chat_template", None), + mmproj_path = mmproj_path, + verbose = self.verbose, + **chat_handler_kwargs + ) print(f"Model desc: {self.model_desc}, " f"Model size: {self.model_size / (1024 * 1024):.2f} MB, " f"Model metadata: {self.metadata}", diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index aadec4600..0e5c9d490 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -3089,11 +3089,12 @@ class MTMDChatHandler: def __init__( self, - clip_model_path: str, + mmproj_path: str, verbose: bool = True, use_gpu: bool = True, image_min_tokens: int = -1, image_max_tokens: int = -1, + chat_template_override: Optional[str] = None, batch_max_tokens: int = 1024, **kwargs ): @@ -3106,7 +3107,7 @@ def __init__( f"If you are passing model-specific parameters, ensure they are supported by {self.log_prefix}." ) - self.clip_model_path = clip_model_path + self.mmproj_path = mmproj_path self.image_min_tokens = image_min_tokens self.image_max_tokens = image_max_tokens self.batch_max_tokens = batch_max_tokens @@ -3122,16 +3123,25 @@ def __init__( self.is_support_audio = False self.is_support_video = False - if not os.path.exists(clip_model_path): - raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}") + if not os.path.exists(mmproj_path): + raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {mmproj_path}") # Pre-compile Jinja template - self.chat_template = ImmutableSandboxedEnvironment( - trim_blocks=True, - lstrip_blocks=True, - ).from_string(self.CHAT_FORMAT) + if (not hasattr(self, "chat_format") or self.chat_format is None) and chat_template_override is None: + self.chat_format = self.CHAT_FORMAT + elif chat_template_override is not None: + self.chat_format = chat_template_override + + self._chat_format_parser_tags = [] + self.change_chat_template(self.chat_format) self._exit_stack = ExitStack() + + def change_chat_template(self, new_template: str): + self.chat_template = ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True + ).from_string(new_template) def _init_mtmd_context(self, llama_model: llama_core.Llama): """Initialize mtmd context with the llama model.""" @@ -3165,13 +3175,13 @@ def _init_mtmd_context(self, llama_model: llama_core.Llama): # Initialize mtmd context self.mtmd_ctx = self._mtmd_cpp.mtmd_init_from_file( - self.clip_model_path.encode(), + self.mmproj_path.encode(), llama_model.model, self.mctx_params ) if self.mtmd_ctx is None: - raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Failed to load mtmd context from: {self.clip_model_path}") + raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Failed to load mtmd context from: {self.mmproj_path}") # Check if vision is supported self.is_support_vision = self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx) @@ -3241,13 +3251,13 @@ def _get_media_items(self, messages: List[llama_types.ChatCompletionRequestMessa media_items.append({"url": url, "type": "image"}) # 2. Audio Processing - elif content_type in ["audio_url", "input_audio"]: + elif content_type in ["audio", "audio_url", "input_audio"]: if not self.is_support_audio: raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support audio inputs.") # Case A: Handle custom/forward-compatible audio_url format - if content_type == "audio_url": - audio_url = content["audio_url"] + if content_type == "audio_url" or content_type == "audio": + audio_url = content[content_type] url = audio_url if isinstance(audio_url, str) else audio_url["url"] media_items.append({"url": url, "type": "audio"}) # Case B: Handle OpenAI standard input_audio format @@ -3407,6 +3417,13 @@ def _process_mtmd_prompt( tool_choice=tool_choice, **getattr(self, 'extra_template_arguments', {}) ) + + for tag in self._chat_format_parser_tags: + if tag not in text: + continue + + text = text.replace(tag, media_marker) + # Replace image_url by media_marker in text for item in media_items: text = text.replace(item["url"], media_marker) @@ -4142,10 +4159,46 @@ def from_pretrained( model_path = os.path.join(local_dir, filename) return cls( - clip_model_path=model_path, + mmproj_path=model_path, **kwargs, ) +class GenericMTMDChatHandler(MTMDChatHandler): + KNOWN_MEDIA_TAGS = [ + "<|image_pad|>", + "<|audio_pad|>", + "<|video_pad|>", + "<|image|>", + "<|audio|>", + "<|video|>", + "[IMG]" + ] + + def __init__( + self, + chat_format: str, + mmproj_path: str, + verbose: bool = True, + **kwargs + ) -> None: + self.chat_format = chat_format + + if verbose: + print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True) + + if self.chat_format is None: + raise ValueError("Failed to get model chat template automatically.") + + super().__init__(mmproj_path = mmproj_path, verbose = verbose, **kwargs) + + def __call__(self, **kwargs): + self._chat_format_parser_tags = [tag for tag in self.KNOWN_MEDIA_TAGS if tag in self.chat_format] + + if self.verbose: + print(f"{self.log_prefix} - Start processing") + + # Use parent implementation + return super().__call__(**kwargs) class Llava15ChatHandler(MTMDChatHandler): CHAT_FORMAT = (