"""Shared helpers for running PydanticAI structured-output agents.
Extracted from ``lalandre_rag.agentic.tools`` so that any package
(extraction, RAG, summaries) can reuse the same ``FunctionModel`` bridge
without depending on the RAG layer.
"""
from __future__ import annotations
import json
from collections.abc import Callable
from typing import Any, Optional, TypeVar, cast
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
T = TypeVar("T", bound=BaseModel)
[docs]
def json_payload_from_text(raw: str) -> Optional[dict[str, Any]]:
"""Extract a JSON object from potentially noisy LLM text."""
text = raw.strip()
if not text:
return None
if text.startswith("```"):
lines = [line for line in text.splitlines() if not line.strip().startswith("```")]
text = "\n".join(lines).strip()
try:
payload = json.loads(text)
if isinstance(payload, dict):
return payload
except json.JSONDecodeError:
pass
start = text.find("{")
end = text.rfind("}")
if start < 0 or end <= start:
return None
snippet = text[start : end + 1]
try:
payload = json.loads(snippet)
except json.JSONDecodeError:
return None
return payload if isinstance(payload, dict) else None
def _coerce_prompt_part_content(value: Any) -> str:
if isinstance(value, str):
return value
if isinstance(value, list):
return "\n".join(_coerce_prompt_part_content(item) for item in value)
return str(value)
def _render_model_messages(messages: list[ModelMessage]) -> str:
blocks: list[str] = []
for message in messages:
if isinstance(message, ModelRequest):
for part in message.parts:
if isinstance(part, SystemPromptPart):
blocks.append(f"SYSTEM:\n{part.content}")
elif isinstance(part, UserPromptPart):
blocks.append(f"USER:\n{_coerce_prompt_part_content(part.content)}")
elif isinstance(part, RetryPromptPart):
blocks.append(f"RETRY:\n{_coerce_prompt_part_content(part.content)}")
elif isinstance(part, ToolReturnPart):
blocks.append(f"TOOL {part.tool_name}:\n{part.content}")
elif isinstance(message, ModelResponse):
text_chunks = [
part.content for part in message.parts if isinstance(part, TextPart) and isinstance(part.content, str)
]
if text_chunks:
blocks.append(f"ASSISTANT:\n{''.join(text_chunks)}")
return "\n\n".join(blocks).strip()
[docs]
def build_structured_prompt(
*,
messages: list[ModelMessage],
agent_info: AgentInfo,
) -> str:
"""Build a single text prompt from PydanticAI messages + output schema."""
output_tool = agent_info.output_tools[0] if agent_info.output_tools else None
schema_json = (
json.dumps(output_tool.parameters_json_schema, ensure_ascii=False, indent=2, sort_keys=True)
if output_tool is not None
else "{}"
)
instructions = (agent_info.instructions or "").strip()
conversation = _render_model_messages(messages)
return (
f"{instructions}\n\n"
"Réponds uniquement avec un objet JSON valide.\n"
"N'ajoute ni markdown, ni explication, ni texte hors JSON.\n"
"Le JSON doit respecter strictement ce schéma :\n"
f"{schema_json}\n\n"
"Conversation utile :\n"
f"{conversation}"
).strip()
[docs]
def to_text_generator(llm_or_generate: Any) -> Callable[[str], str]:
"""Normalize an LLM object or callable into a simple ``str -> str`` function."""
if callable(llm_or_generate) and not hasattr(llm_or_generate, "invoke"):
return cast(Callable[[str], str], llm_or_generate)
def _generate(prompt: str) -> str:
chain = llm_or_generate | StrOutputParser()
return str(chain.invoke(prompt))
return _generate
[docs]
def run_structured_agent(
*,
agent: Agent[Any, T],
prompt: str,
llm_or_generate: Any,
model_name: str,
) -> tuple[T, int]:
"""Run a PydanticAI agent using a FunctionModel bridge to any LLM.
Returns ``(output, retries)`` where *retries* is the number of
output-validation retries triggered.
"""
attempts: dict[str, int] = {"count": 0}
generate_text = to_text_generator(llm_or_generate)
def _model_function(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse:
attempts["count"] += 1
raw = generate_text(build_structured_prompt(messages=messages, agent_info=agent_info))
payload = json_payload_from_text(raw) or {}
output_tool = agent_info.output_tools[0] if agent_info.output_tools else None
if output_tool is None:
return ModelResponse(parts=[TextPart(raw.strip())], model_name=model_name)
return ModelResponse(
parts=[ToolCallPart(output_tool.name, payload, tool_call_id=f"pyd_ai_tool_{attempts['count']}")],
model_name=model_name,
)
result = agent.run_sync(
prompt,
model=FunctionModel(function=_model_function, model_name=model_name),
infer_name=False,
)
retries = max(attempts["count"] - 1, 0)
output = cast(T, result.output.model_copy(update={"output_validation_retries": retries}))
return output, retries