Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Llama:
def __init__(
self,
model_path: str,
mmproj_path: Optional[str] = None,
*,
# Model Params
n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto",
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
log_filters: Optional[Sequence[str]] = None,
log_filters_case_sensitive: bool = True,
# Extra Params
chat_handler_kwargs: Dict[str, Any] = {},
**kwargs, # type: ignore
):
"""Load a llama.cpp model from `model_path`.
Expand Down Expand Up @@ -711,6 +713,18 @@ def __init__(
print(f"Failed to load metadata: {e}", file=sys.stderr)

if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)

if mmproj_path is not None:
if self.chat_handler is not None and self.verbose:
print("Warning: Both `chat_handler` and `mmproj_path` are not null. Chat handler will be overwritten.", flush = True)

self.chat_handler = llama_chat_format.GenericMTMDChatHandler(
chat_format = self.metadata.get("tokenizer.chat_template", None),
mmproj_path = mmproj_path,
verbose = self.verbose,
**chat_handler_kwargs
)
print(f"Model desc: {self.model_desc}, "
f"Model size: {self.model_size / (1024 * 1024):.2f} MB, "
f"Model metadata: {self.metadata}",
Expand Down
81 changes: 67 additions & 14 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3089,11 +3089,12 @@ class MTMDChatHandler:

def __init__(
self,
clip_model_path: str,
mmproj_path: str,
verbose: bool = True,
use_gpu: bool = True,
image_min_tokens: int = -1,
image_max_tokens: int = -1,
chat_template_override: Optional[str] = None,
batch_max_tokens: int = 1024,
**kwargs
):
Expand All @@ -3106,7 +3107,7 @@ def __init__(
f"If you are passing model-specific parameters, ensure they are supported by {self.log_prefix}."
)

self.clip_model_path = clip_model_path
self.mmproj_path = mmproj_path
self.image_min_tokens = image_min_tokens
self.image_max_tokens = image_max_tokens
self.batch_max_tokens = batch_max_tokens
Expand All @@ -3122,16 +3123,25 @@ def __init__(
self.is_support_audio = False
self.is_support_video = False

if not os.path.exists(clip_model_path):
raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}")
if not os.path.exists(mmproj_path):
raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {mmproj_path}")

# Pre-compile Jinja template
self.chat_template = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
).from_string(self.CHAT_FORMAT)
if (not hasattr(self, "chat_format") or self.chat_format is None) and chat_template_override is None:
self.chat_format = self.CHAT_FORMAT
elif chat_template_override is not None:
self.chat_format = chat_template_override

self._chat_format_parser_tags = []
self.change_chat_template(self.chat_format)

self._exit_stack = ExitStack()

def change_chat_template(self, new_template: str):
self.chat_template = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True
).from_string(new_template)

def _init_mtmd_context(self, llama_model: llama_core.Llama):
"""Initialize mtmd context with the llama model."""
Expand Down Expand Up @@ -3165,13 +3175,13 @@ def _init_mtmd_context(self, llama_model: llama_core.Llama):

# Initialize mtmd context
self.mtmd_ctx = self._mtmd_cpp.mtmd_init_from_file(
self.clip_model_path.encode(),
self.mmproj_path.encode(),
llama_model.model,
self.mctx_params
)

if self.mtmd_ctx is None:
raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Failed to load mtmd context from: {self.clip_model_path}")
raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Failed to load mtmd context from: {self.mmproj_path}")

# Check if vision is supported
self.is_support_vision = self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx)
Expand Down Expand Up @@ -3241,13 +3251,13 @@ def _get_media_items(self, messages: List[llama_types.ChatCompletionRequestMessa
media_items.append({"url": url, "type": "image"})

# 2. Audio Processing
elif content_type in ["audio_url", "input_audio"]:
elif content_type in ["audio", "audio_url", "input_audio"]:
if not self.is_support_audio:
raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support audio inputs.")

# Case A: Handle custom/forward-compatible audio_url format
if content_type == "audio_url":
audio_url = content["audio_url"]
if content_type == "audio_url" or content_type == "audio":
audio_url = content[content_type]
url = audio_url if isinstance(audio_url, str) else audio_url["url"]
media_items.append({"url": url, "type": "audio"})
# Case B: Handle OpenAI standard input_audio format
Expand Down Expand Up @@ -3407,6 +3417,13 @@ def _process_mtmd_prompt(
tool_choice=tool_choice,
**getattr(self, 'extra_template_arguments', {})
)

for tag in self._chat_format_parser_tags:
if tag not in text:
continue

text = text.replace(tag, media_marker)

# Replace image_url by media_marker in text
for item in media_items:
text = text.replace(item["url"], media_marker)
Expand Down Expand Up @@ -4142,10 +4159,46 @@ def from_pretrained(
model_path = os.path.join(local_dir, filename)

return cls(
clip_model_path=model_path,
mmproj_path=model_path,
**kwargs,
)

class GenericMTMDChatHandler(MTMDChatHandler):
KNOWN_MEDIA_TAGS = [
"<|image_pad|>",
"<|audio_pad|>",
"<|video_pad|>",
"<|image|>",
"<|audio|>",
"<|video|>",
"[IMG]"
]

def __init__(
self,
chat_format: str,
mmproj_path: str,
verbose: bool = True,
**kwargs
) -> None:
self.chat_format = chat_format

if verbose:
print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True)

if self.chat_format is None:
raise ValueError("Failed to get model chat template automatically.")

super().__init__(mmproj_path = mmproj_path, verbose = verbose, **kwargs)

def __call__(self, **kwargs):
self._chat_format_parser_tags = [tag for tag in self.KNOWN_MEDIA_TAGS if tag in self.chat_format]

if self.verbose:
print(f"{self.log_prefix} - Start processing")

# Use parent implementation
return super().__call__(**kwargs)

class Llava15ChatHandler(MTMDChatHandler):
CHAT_FORMAT = (
Expand Down
Loading