Source code for rag_service.conversation

"""
Conversation history management for multi-turn RAG.

Handles loading, trimming, and persisting conversation turns.
History is returned as LangChain ``BaseMessage`` objects ready for injection
into a ``ChatPromptTemplate`` via ``MessagesPlaceholder``.
"""

import logging
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional, Sequence

from fastapi import HTTPException
from lalandre_db_postgres import PostgresRepository
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser

logger = logging.getLogger(__name__)

_MAX_SUMMARY_SOURCE_CHARS = 6_000
_DEFAULT_SUMMARY_MAX_CHARS = 1_200
_DEFAULT_RECENT_HISTORY_CHARS = 1_800
_SUMMARY_PROMPT_TEMPLATE = """Tu compresses l'historique d'une conversation pour un assistant juridique RAG.

Objectif :
- conserver uniquement le contexte utile pour répondre à la nouvelle question ;
- retenir les faits établis, documents/actes cités, CELEX, dates, contraintes, préférences et questions ouvertes ;
- ne rien inventer ;
- ne pas reformuler comme une réponse finale à l'utilisateur.

Question actuelle :
{question}

Historique à condenser :
{history}

Rédige un résumé court en français, structuré en 4 à 6 points maximum, sans dépasser {max_chars} caractères.
"""


[docs] @dataclass class ConversationContext: """Result of loading conversation history.""" conversation_id: str is_new: bool history_messages: List[BaseMessage] = field(default_factory=list)
[docs] class ConversationManager: """Load and persist multi-turn conversation state backed by PostgreSQL.""" def __init__( self, pg_repo: PostgresRepository, llm: Optional[Any] = None, max_history_chars: int = 4_000, max_turns: int = 10, summary_max_chars: int = _DEFAULT_SUMMARY_MAX_CHARS, recent_history_chars: int = _DEFAULT_RECENT_HISTORY_CHARS, ) -> None: self.pg_repo = pg_repo self.llm = llm self.max_history_chars = max_history_chars self.max_turns = max_turns self.summary_max_chars = summary_max_chars self.recent_history_chars = recent_history_chars # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def load_history( self, conversation_id: Optional[str], question: str, user_id: Optional[str] = None, ) -> ConversationContext: """Load (or create) a conversation and return its history. When *conversation_id* is ``None``, auto-creates a new conversation. """ if conversation_id is None: conversation_id = str(uuid.uuid4()) session = self.pg_repo.get_session() try: title = question[:100] self.pg_repo.create_conversation( session, conversation_id, title, user_id=user_id, ) session.commit() except Exception: session.rollback() raise finally: session.close() return ConversationContext( conversation_id=conversation_id, is_new=True, ) session = self.pg_repo.get_session() try: conv = self.pg_repo.get_conversation(session, conversation_id) if conv is None: title = question[:100] self.pg_repo.create_conversation( session, conversation_id, title, user_id=user_id, ) session.commit() return ConversationContext( conversation_id=conversation_id, is_new=True, ) self._ensure_conversation_access( owner_user_id=conv.user_id, requester_user_id=user_id, ) rows = self.pg_repo.get_conversation_messages( session, conversation_id, limit=self.max_turns * 2, ) history_turns = [ { "role": str(msg.role), "content": str(msg.content), } for msg in rows if str(msg.role) in {"human", "assistant"} and str(msg.content).strip() ] lc_messages = self._build_history_messages( turns=history_turns, question=question, ) self.pg_repo.touch_conversation(session, conversation_id) session.commit() return ConversationContext( conversation_id=conversation_id, is_new=False, history_messages=lc_messages, ) except Exception: session.rollback() raise finally: session.close()
def _build_history_messages( self, *, turns: Sequence[dict[str, str]], question: str, ) -> List[BaseMessage]: total_chars = sum(len(turn["content"]) for turn in turns) if total_chars <= self.max_history_chars: return self._turns_to_langchain_messages(turns) recent_budget = min(self.recent_history_chars, self.max_history_chars) recent_turns: List[dict[str, str]] = [] recent_chars = 0 for turn in reversed(turns): content = turn["content"] if recent_turns and recent_chars + len(content) > recent_budget: break recent_turns.insert(0, turn) recent_chars += len(content) if not recent_turns and turns: recent_turns = [turns[-1]] summary_turn_count = max(len(turns) - len(recent_turns), 0) summary_turns = list(turns[:summary_turn_count]) history_messages: List[BaseMessage] = [] summary = self._summarize_turns(summary_turns, question=question) if summary: history_messages.append(SystemMessage(content=summary)) history_messages.extend(self._turns_to_langchain_messages(recent_turns)) while self._messages_char_count(history_messages) > self.max_history_chars and len(history_messages) > 1: removal_index = 1 if isinstance(history_messages[0], SystemMessage) else 0 history_messages.pop(removal_index) return history_messages def _turns_to_langchain_messages( self, turns: Iterable[dict[str, str]], ) -> List[BaseMessage]: messages: List[BaseMessage] = [] for turn in turns: role = turn["role"] content = turn["content"] if role == "human": messages.append(HumanMessage(content=content)) elif role == "assistant": messages.append(AIMessage(content=content)) return messages def _summarize_turns( self, turns: Sequence[dict[str, str]], *, question: str, ) -> Optional[str]: if not turns: return None history_text = self._format_turns_for_summary(turns) if not history_text: return None if self.llm is not None: summary = self._generate_summary_with_llm( history_text=history_text, question=question, ) if summary: return summary return self._build_fallback_summary(history_text) def _generate_summary_with_llm( self, *, history_text: str, question: str, ) -> Optional[str]: llm = self.llm if llm is None: return None try: summary_chain = llm | StrOutputParser() prompt = _SUMMARY_PROMPT_TEMPLATE.format( question=question.strip(), history=history_text, max_chars=self.summary_max_chars, ) summary = str(summary_chain.invoke(prompt)).strip() except Exception as exc: logger.warning("Conversation history summarization failed, using fallback: %s", exc) return None if not summary: return None return self._wrap_history_summary(summary) def _build_fallback_summary(self, history_text: str) -> str: lines = [line.strip() for line in history_text.splitlines() if line.strip()] fallback = "\n".join(lines[-6:]) return self._wrap_history_summary(fallback) def _wrap_history_summary(self, summary: str) -> str: clean_summary = summary.strip() prefix = ( "Résumé condensé du contexte précédent. " "Utilise-le comme mémoire conversationnelle, sans le traiter comme une preuve documentaire.\n" ) remaining_chars = max(self.summary_max_chars - len(prefix), 0) if remaining_chars <= 0: return prefix[: self.summary_max_chars] if len(clean_summary) > remaining_chars: clean_summary = clean_summary[: remaining_chars - 1].rstrip() + "…" return f"{prefix}{clean_summary}" def _format_turns_for_summary(self, turns: Sequence[dict[str, str]]) -> str: serialized_turns: List[str] = [] total_chars = 0 for turn in turns: speaker = "Utilisateur" if turn["role"] == "human" else "Assistant" entry = f"{speaker}: {turn['content'].strip()}" if total_chars + len(entry) > _MAX_SUMMARY_SOURCE_CHARS: remaining = _MAX_SUMMARY_SOURCE_CHARS - total_chars if remaining > 0: serialized_turns.append(entry[:remaining].rstrip()) break serialized_turns.append(entry) total_chars += len(entry) return "\n\n".join(serialized_turns).strip() @staticmethod def _messages_char_count(messages: Sequence[BaseMessage]) -> int: total = 0 for message in messages: content = message.content if isinstance(content, str): total += len(content) return total @staticmethod def _ensure_conversation_access( *, owner_user_id: Optional[str], requester_user_id: Optional[str], ) -> None: if owner_user_id == requester_user_id: return if owner_user_id is None and requester_user_id is None: return raise HTTPException(status_code=403, detail="Not your conversation")
[docs] def save_turn( self, conversation_id: str, question: str, answer: str, query_id: str, mode: Optional[str] = None, sources: Optional[Any] = None, timings: Optional[dict] = None, steps: Optional[list] = None, ) -> str: """Persist human + assistant messages and return the assistant message id.""" human_msg_id = str(uuid.uuid4()) assistant_msg_id = str(uuid.uuid4()) assistant_metadata: Optional[dict] = None if sources is not None or timings is not None or steps is not None: assistant_metadata = {} if sources is not None: assistant_metadata["sources"] = sources if timings is not None: assistant_metadata["timings"] = timings if steps: assistant_metadata["steps"] = steps session = self.pg_repo.get_session() try: self.pg_repo.add_conversation_message( session, human_msg_id, conversation_id, "human", question, query_id=query_id, mode=mode, ) self.pg_repo.add_conversation_message( session, assistant_msg_id, conversation_id, "assistant", answer, query_id=query_id, mode=mode, metadata=assistant_metadata, ) self.pg_repo.touch_conversation(session, conversation_id) session.commit() except Exception: session.rollback() raise finally: session.close() return assistant_msg_id