From 4726d2031b7bfef61ff5ceb2053f720801b375a3 Mon Sep 17 00:00:00 2001 From: Kovbo Date: Thu, 4 Jun 2026 00:01:39 +0000 Subject: [PATCH] Avoid pickle in LangGraph rollout logging --- src/art/langgraph/llm_wrapper.py | 13 ++++--- src/art/langgraph/logging.py | 22 +++-------- tests/unit/test_langgraph_logging.py | 55 ++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 23 deletions(-) create mode 100644 tests/unit/test_langgraph_logging.py diff --git a/src/art/langgraph/llm_wrapper.py b/src/art/langgraph/llm_wrapper.py index 36b5314b3..e33733962 100644 --- a/src/art/langgraph/llm_wrapper.py +++ b/src/art/langgraph/llm_wrapper.py @@ -26,19 +26,20 @@ def add_thread(thread_id, base_url, api_key, model): log_path = f".art/langgraph/{thread_id}" os.makedirs(os.path.dirname(log_path), exist_ok=True) + logger = FileLogger(log_path) CURRENT_CONFIG.set( { - "logger": FileLogger(log_path), + "logger": logger, "base_url": base_url, "api_key": api_key, "model": model, } ) - return log_path + return logger -def create_messages_from_logs(log_path: str, trajectory: Trajectory): - logs = FileLogger(log_path).load_logs() +def create_messages_from_logs(logger: FileLogger, trajectory: Trajectory): + logs = logger.load_logs() conversations = [] tools = [] @@ -95,14 +96,14 @@ def create_messages_from_logs(log_path: str, trajectory: Trajectory): def wrap_rollout(model, fn): async def wrapper(*args, **kwargs): thread_id = str(uuid.uuid4()) - log_path = add_thread( + logger = add_thread( thread_id, model.inference_base_url, model.inference_api_key, model.inference_model_name, ) result = await fn(*args, **kwargs) - return create_messages_from_logs(log_path, result) + return create_messages_from_logs(logger, result) return wrapper diff --git a/src/art/langgraph/logging.py b/src/art/langgraph/logging.py index 4b50a9530..d3eae7624 100644 --- a/src/art/langgraph/logging.py +++ b/src/art/langgraph/logging.py @@ -1,30 +1,18 @@ -import os -import pickle +from typing import Any class FileLogger: def __init__(self, filepath): self.text_path = filepath - self.pickle_path = filepath + ".pkl" + self._logs: list[tuple[str, Any]] = [] def log(self, name, entry): # Log as readable text with open(self.text_path, "a") as f: f.write(f"{name}: {entry}\n") - # Append to pickle log - with open(self.pickle_path, "ab") as pf: - pickle.dump((name, entry), pf) + self._logs.append((name, entry)) def load_logs(self): - """Load all logs from the pickle file.""" - if not os.path.exists(self.pickle_path): - return [] - logs = [] - with open(self.pickle_path, "rb") as pf: - try: - while True: - logs.append(pickle.load(pf)) - except EOFError: - pass - return logs + """Load all structured logs captured by this logger.""" + return list(self._logs) diff --git a/tests/unit/test_langgraph_logging.py b/tests/unit/test_langgraph_logging.py new file mode 100644 index 000000000..af59b0691 --- /dev/null +++ b/tests/unit/test_langgraph_logging.py @@ -0,0 +1,55 @@ +from pathlib import Path + +import pytest + +pytest.importorskip("langchain_openai") +from langchain_core.messages import AIMessage, HumanMessage # noqa: E402 + +from art import Trajectory # noqa: E402 +from art.langgraph.llm_wrapper import create_messages_from_logs # noqa: E402 +from art.langgraph.logging import FileLogger + + +class NonSerializable: + pass + + +def test_file_logger_keeps_structured_logs_in_memory(tmp_path: Path): + log_path = tmp_path / "rollout" + logger = FileLogger(str(log_path)) + entry = {"input": NonSerializable(), "output": NonSerializable()} + + logger.log("completion-id", entry) + + assert logger.load_logs() == [("completion-id", entry)] + assert not log_path.with_suffix(".pkl").exists() + assert log_path.read_text().startswith("completion-id: ") + + +def test_file_logger_load_logs_returns_copy(tmp_path: Path): + logger = FileLogger(str(tmp_path / "rollout")) + logger.log("completion-id", {"output": "ok"}) + + logs = logger.load_logs() + logs.append(("other-id", {"output": "mutated"})) + + assert logger.load_logs() == [("completion-id", {"output": "ok"})] + + +def test_create_messages_from_logs_reads_in_memory_entries(tmp_path: Path): + logger = FileLogger(str(tmp_path / "rollout")) + logger.log( + "completion-id", + { + "input": [HumanMessage(content="hello")], + "output": AIMessage(content="hi"), + "tools": None, + }, + ) + + trajectory = create_messages_from_logs(logger, Trajectory()) + + assert trajectory.messages() == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ]