Source code for lalandre_core.llm.structured

"""Shared helpers for running PydanticAI structured-output agents.

Extracted from ``lalandre_rag.agentic.tools`` so that any package
(extraction, RAG, summaries) can reuse the same ``FunctionModel`` bridge
without depending on the RAG layer.
"""

from __future__ import annotations

import json
from collections.abc import Callable
from typing import Any, Optional, TypeVar, cast

from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel
from pydantic_ai import Agent
from pydantic_ai.messages import (
    ModelMessage,
    ModelRequest,
    ModelResponse,
    RetryPromptPart,
    SystemPromptPart,
    TextPart,
    ToolCallPart,
    ToolReturnPart,
    UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel

T = TypeVar("T", bound=BaseModel)


[docs] def json_payload_from_text(raw: str) -> Optional[dict[str, Any]]: """Extract a JSON object from potentially noisy LLM text.""" text = raw.strip() if not text: return None if text.startswith("```"): lines = [line for line in text.splitlines() if not line.strip().startswith("```")] text = "\n".join(lines).strip() try: payload = json.loads(text) if isinstance(payload, dict): return payload except json.JSONDecodeError: pass start = text.find("{") end = text.rfind("}") if start < 0 or end <= start: return None snippet = text[start : end + 1] try: payload = json.loads(snippet) except json.JSONDecodeError: return None return payload if isinstance(payload, dict) else None
def _coerce_prompt_part_content(value: Any) -> str: if isinstance(value, str): return value if isinstance(value, list): return "\n".join(_coerce_prompt_part_content(item) for item in value) return str(value) def _render_model_messages(messages: list[ModelMessage]) -> str: blocks: list[str] = [] for message in messages: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, SystemPromptPart): blocks.append(f"SYSTEM:\n{part.content}") elif isinstance(part, UserPromptPart): blocks.append(f"USER:\n{_coerce_prompt_part_content(part.content)}") elif isinstance(part, RetryPromptPart): blocks.append(f"RETRY:\n{_coerce_prompt_part_content(part.content)}") elif isinstance(part, ToolReturnPart): blocks.append(f"TOOL {part.tool_name}:\n{part.content}") elif isinstance(message, ModelResponse): text_chunks = [ part.content for part in message.parts if isinstance(part, TextPart) and isinstance(part.content, str) ] if text_chunks: blocks.append(f"ASSISTANT:\n{''.join(text_chunks)}") return "\n\n".join(blocks).strip()
[docs] def build_structured_prompt( *, messages: list[ModelMessage], agent_info: AgentInfo, ) -> str: """Build a single text prompt from PydanticAI messages + output schema.""" output_tool = agent_info.output_tools[0] if agent_info.output_tools else None schema_json = ( json.dumps(output_tool.parameters_json_schema, ensure_ascii=False, indent=2, sort_keys=True) if output_tool is not None else "{}" ) instructions = (agent_info.instructions or "").strip() conversation = _render_model_messages(messages) return ( f"{instructions}\n\n" "Réponds uniquement avec un objet JSON valide.\n" "N'ajoute ni markdown, ni explication, ni texte hors JSON.\n" "Le JSON doit respecter strictement ce schéma :\n" f"{schema_json}\n\n" "Conversation utile :\n" f"{conversation}" ).strip()
[docs] def to_text_generator(llm_or_generate: Any) -> Callable[[str], str]: """Normalize an LLM object or callable into a simple ``str -> str`` function.""" if callable(llm_or_generate) and not hasattr(llm_or_generate, "invoke"): return cast(Callable[[str], str], llm_or_generate) def _generate(prompt: str) -> str: chain = llm_or_generate | StrOutputParser() return str(chain.invoke(prompt)) return _generate
[docs] def run_structured_agent( *, agent: Agent[Any, T], prompt: str, llm_or_generate: Any, model_name: str, ) -> tuple[T, int]: """Run a PydanticAI agent using a FunctionModel bridge to any LLM. Returns ``(output, retries)`` where *retries* is the number of output-validation retries triggered. """ attempts: dict[str, int] = {"count": 0} generate_text = to_text_generator(llm_or_generate) def _model_function(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse: attempts["count"] += 1 raw = generate_text(build_structured_prompt(messages=messages, agent_info=agent_info)) payload = json_payload_from_text(raw) or {} output_tool = agent_info.output_tools[0] if agent_info.output_tools else None if output_tool is None: return ModelResponse(parts=[TextPart(raw.strip())], model_name=model_name) return ModelResponse( parts=[ToolCallPart(output_tool.name, payload, tool_call_id=f"pyd_ai_tool_{attempts['count']}")], model_name=model_name, ) result = agent.run_sync( prompt, model=FunctionModel(function=_model_function, model_name=model_name), infer_name=False, ) retries = max(attempts["count"] - 1, 0) output = cast(T, result.output.model_copy(update={"output_validation_retries": retries})) return output, retries