Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 86 additions & 31 deletions src/art/langgraph/llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""LLM wrapper with logging functionality."""

import asyncio
from collections.abc import Callable
import contextvars
import json
import os
Expand All @@ -22,6 +23,9 @@

mappings = {}

DEFAULT_INVOKE_TIMEOUT = 10 * 60
OPENAI_COMPATIBLE_PROVIDERS = {None, "openai", "openai-compatible", "openai_compatible"}


def add_thread(thread_id, base_url, api_key, model):
log_path = f".art/langgraph/{thread_id}"
Expand Down Expand Up @@ -108,31 +112,82 @@ async def wrapper(*args, **kwargs):


def init_chat_model(
model: Literal[None] = None,
model: str | Runnable | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal[None] = None,
config_prefix: str | None = None,
invoke_timeout: float | None = DEFAULT_INVOKE_TIMEOUT,
**kwargs: Any,
):
"""Create a logged LangChain chat model for ART LangGraph rollouts.

By default ART constructs a ChatOpenAI client pointed at the
OpenAI-compatible endpoint from the active rollout context. For other
LangChain providers, pass an already constructed chat model instance as
``model``. Provider kwargs such as ``temperature`` and ``timeout`` are
forwarded to ChatOpenAI; ``invoke_timeout`` controls only ART's outer
``asyncio.wait_for`` timeout.
"""
config = CURRENT_CONFIG.get()

if configurable_fields is not None:
raise ValueError(
"configurable_fields is not supported by ART's init_chat_model"
)
if config_prefix is not None:
raise ValueError("config_prefix is not supported by ART's init_chat_model")

if model is not None and not isinstance(model, str):
return LoggingLLM(
model,
config["logger"],
invoke_timeout=invoke_timeout,
)

if model_provider not in OPENAI_COMPATIBLE_PROVIDERS:
raise ValueError(
"ART's init_chat_model can construct only OpenAI-compatible chat "
"models. Pass a LangChain chat model instance as `model` to use "
f"provider {model_provider!r}."
)

model_name = model

def chat_openai_factory(art_config: dict[str, Any]):
chat_model_kwargs: dict[str, Any] = {
"base_url": art_config["base_url"],
"api_key": art_config["api_key"],
"model": model_name or art_config["model"],
"temperature": 1.0,
}
chat_model_kwargs.update(kwargs)
return ChatOpenAI(**chat_model_kwargs)

return LoggingLLM(
ChatOpenAI(
base_url=config["base_url"], # ty:ignore[unknown-argument]
api_key=config["api_key"], # ty:ignore[unknown-argument]
model=config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
),
chat_openai_factory(config),
config["logger"],
invoke_timeout=invoke_timeout,
chat_model_factory=chat_openai_factory,
)


class LoggingLLM(Runnable):
def __init__(self, llm, logger, structured_output=None, tools=None):
def __init__(
self,
llm,
logger,
structured_output=None,
tools=None,
invoke_timeout: float | None = DEFAULT_INVOKE_TIMEOUT,
chat_model_factory: Callable[[dict[str, Any]], Any] | None = None,
):
self.llm = llm
self.logger = logger
self.structured_output = structured_output
self.tools = [convert_to_openai_tool(t) for t in tools] if tools else None
self.invoke_timeout = invoke_timeout
self.chat_model_factory = chat_model_factory

def _log(self, completion_id, input, output):
if self.logger:
Expand All @@ -143,7 +198,7 @@ def invoke(self, input, config=None, **kwargs):
completion_id = str(uuid.uuid4())

def execute():
result = self.llm.invoke(input, config=config)
result = self.llm.invoke(input, config=config, **kwargs)
self._log(completion_id, input, result)
return result

Expand All @@ -166,9 +221,11 @@ async def ainvoke(self, input, config=None, **kwargs):

async def execute():
try:
result = await asyncio.wait_for(
self.llm.ainvoke(input, config=config), timeout=10 * 60
)
call = self.llm.ainvoke(input, config=config, **kwargs)
if self.invoke_timeout is None:
result = await call
else:
result = await asyncio.wait_for(call, timeout=self.invoke_timeout)
self._log(completion_id, input, result)
except asyncio.TimeoutError as e:
raise e
Expand All @@ -194,10 +251,18 @@ def with_structured_output(self, tools):
self.logger,
structured_output=tools,
tools=[tools],
invoke_timeout=self.invoke_timeout,
chat_model_factory=self.chat_model_factory,
)

def bind_tools(self, tools):
return LoggingLLM(self.llm.bind_tools(tools), self.logger, tools=tools)
return LoggingLLM(
self.llm.bind_tools(tools),
self.logger,
tools=tools,
invoke_timeout=self.invoke_timeout,
chat_model_factory=self.chat_model_factory,
)

def with_retry(
self,
Expand All @@ -217,23 +282,13 @@ def with_config(
art_config = CURRENT_CONFIG.get()
self.logger = art_config["logger"]

if hasattr(self.llm, "bound"):
setattr(
self.llm,
"bound",
ChatOpenAI(
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
model=art_config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
),
)
else:
self.llm = ChatOpenAI(
base_url=art_config["base_url"], # ty:ignore[unknown-argument]
api_key=art_config["api_key"], # ty:ignore[unknown-argument]
model=art_config["model"], # ty:ignore[unknown-argument]
temperature=1.0,
)
if self.chat_model_factory is not None:
configured_llm = self.chat_model_factory(art_config)
if hasattr(self.llm, "bound"):
setattr(self.llm, "bound", configured_llm)
else:
self.llm = configured_llm
elif hasattr(self.llm, "with_config"):
self.llm = self.llm.with_config(config=config, **kwargs)

return self
3 changes: 3 additions & 0 deletions src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ async def ruler(
- "openai/gpt-4o-mini" - Fast and cost-effective
- "openai/o3" - Most capable but expensive (default)
- "anthropic/claude-3-opus-20240229" - Alternative judge
- "ollama/qwen3:32b" - Local Ollama judge via LiteLLM
The default calls OpenAI through LiteLLM. Set this explicitly for
local or custom judge backends.
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
Can include temperature, max_tokens, etc.
rubric: The grading rubric. The default rubric works well for most tasks.
Expand Down
Loading
Loading