"""
Conversation history management for multi-turn RAG.
Handles loading, trimming, and persisting conversation turns.
History is returned as LangChain ``BaseMessage`` objects ready for injection
into a ``ChatPromptTemplate`` via ``MessagesPlaceholder``.
"""
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any, Iterable, List, Optional, Sequence
from fastapi import HTTPException
from lalandre_db_postgres import PostgresRepository
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
logger = logging.getLogger(__name__)
_MAX_SUMMARY_SOURCE_CHARS = 6_000
_DEFAULT_SUMMARY_MAX_CHARS = 1_200
_DEFAULT_RECENT_HISTORY_CHARS = 1_800
_SUMMARY_PROMPT_TEMPLATE = """Tu compresses l'historique d'une conversation pour un assistant juridique RAG.
Objectif :
- conserver uniquement le contexte utile pour répondre à la nouvelle question ;
- retenir les faits établis, documents/actes cités, CELEX, dates, contraintes, préférences et questions ouvertes ;
- ne rien inventer ;
- ne pas reformuler comme une réponse finale à l'utilisateur.
Question actuelle :
{question}
Historique à condenser :
{history}
Rédige un résumé court en français, structuré en 4 à 6 points maximum, sans dépasser {max_chars} caractères.
"""
[docs]
@dataclass
class ConversationContext:
"""Result of loading conversation history."""
conversation_id: str
is_new: bool
history_messages: List[BaseMessage] = field(default_factory=list)
[docs]
class ConversationManager:
"""Load and persist multi-turn conversation state backed by PostgreSQL."""
def __init__(
self,
pg_repo: PostgresRepository,
llm: Optional[Any] = None,
max_history_chars: int = 4_000,
max_turns: int = 10,
summary_max_chars: int = _DEFAULT_SUMMARY_MAX_CHARS,
recent_history_chars: int = _DEFAULT_RECENT_HISTORY_CHARS,
) -> None:
self.pg_repo = pg_repo
self.llm = llm
self.max_history_chars = max_history_chars
self.max_turns = max_turns
self.summary_max_chars = summary_max_chars
self.recent_history_chars = recent_history_chars
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def load_history(
self,
conversation_id: Optional[str],
question: str,
user_id: Optional[str] = None,
) -> ConversationContext:
"""Load (or create) a conversation and return its history.
When *conversation_id* is ``None``, auto-creates a new conversation.
"""
if conversation_id is None:
conversation_id = str(uuid.uuid4())
session = self.pg_repo.get_session()
try:
title = question[:100]
self.pg_repo.create_conversation(
session,
conversation_id,
title,
user_id=user_id,
)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
return ConversationContext(
conversation_id=conversation_id,
is_new=True,
)
session = self.pg_repo.get_session()
try:
conv = self.pg_repo.get_conversation(session, conversation_id)
if conv is None:
title = question[:100]
self.pg_repo.create_conversation(
session,
conversation_id,
title,
user_id=user_id,
)
session.commit()
return ConversationContext(
conversation_id=conversation_id,
is_new=True,
)
self._ensure_conversation_access(
owner_user_id=conv.user_id,
requester_user_id=user_id,
)
rows = self.pg_repo.get_conversation_messages(
session,
conversation_id,
limit=self.max_turns * 2,
)
history_turns = [
{
"role": str(msg.role),
"content": str(msg.content),
}
for msg in rows
if str(msg.role) in {"human", "assistant"} and str(msg.content).strip()
]
lc_messages = self._build_history_messages(
turns=history_turns,
question=question,
)
self.pg_repo.touch_conversation(session, conversation_id)
session.commit()
return ConversationContext(
conversation_id=conversation_id,
is_new=False,
history_messages=lc_messages,
)
except Exception:
session.rollback()
raise
finally:
session.close()
def _build_history_messages(
self,
*,
turns: Sequence[dict[str, str]],
question: str,
) -> List[BaseMessage]:
total_chars = sum(len(turn["content"]) for turn in turns)
if total_chars <= self.max_history_chars:
return self._turns_to_langchain_messages(turns)
recent_budget = min(self.recent_history_chars, self.max_history_chars)
recent_turns: List[dict[str, str]] = []
recent_chars = 0
for turn in reversed(turns):
content = turn["content"]
if recent_turns and recent_chars + len(content) > recent_budget:
break
recent_turns.insert(0, turn)
recent_chars += len(content)
if not recent_turns and turns:
recent_turns = [turns[-1]]
summary_turn_count = max(len(turns) - len(recent_turns), 0)
summary_turns = list(turns[:summary_turn_count])
history_messages: List[BaseMessage] = []
summary = self._summarize_turns(summary_turns, question=question)
if summary:
history_messages.append(SystemMessage(content=summary))
history_messages.extend(self._turns_to_langchain_messages(recent_turns))
while self._messages_char_count(history_messages) > self.max_history_chars and len(history_messages) > 1:
removal_index = 1 if isinstance(history_messages[0], SystemMessage) else 0
history_messages.pop(removal_index)
return history_messages
def _turns_to_langchain_messages(
self,
turns: Iterable[dict[str, str]],
) -> List[BaseMessage]:
messages: List[BaseMessage] = []
for turn in turns:
role = turn["role"]
content = turn["content"]
if role == "human":
messages.append(HumanMessage(content=content))
elif role == "assistant":
messages.append(AIMessage(content=content))
return messages
def _summarize_turns(
self,
turns: Sequence[dict[str, str]],
*,
question: str,
) -> Optional[str]:
if not turns:
return None
history_text = self._format_turns_for_summary(turns)
if not history_text:
return None
if self.llm is not None:
summary = self._generate_summary_with_llm(
history_text=history_text,
question=question,
)
if summary:
return summary
return self._build_fallback_summary(history_text)
def _generate_summary_with_llm(
self,
*,
history_text: str,
question: str,
) -> Optional[str]:
llm = self.llm
if llm is None:
return None
try:
summary_chain = llm | StrOutputParser()
prompt = _SUMMARY_PROMPT_TEMPLATE.format(
question=question.strip(),
history=history_text,
max_chars=self.summary_max_chars,
)
summary = str(summary_chain.invoke(prompt)).strip()
except Exception as exc:
logger.warning("Conversation history summarization failed, using fallback: %s", exc)
return None
if not summary:
return None
return self._wrap_history_summary(summary)
def _build_fallback_summary(self, history_text: str) -> str:
lines = [line.strip() for line in history_text.splitlines() if line.strip()]
fallback = "\n".join(lines[-6:])
return self._wrap_history_summary(fallback)
def _wrap_history_summary(self, summary: str) -> str:
clean_summary = summary.strip()
prefix = (
"Résumé condensé du contexte précédent. "
"Utilise-le comme mémoire conversationnelle, sans le traiter comme une preuve documentaire.\n"
)
remaining_chars = max(self.summary_max_chars - len(prefix), 0)
if remaining_chars <= 0:
return prefix[: self.summary_max_chars]
if len(clean_summary) > remaining_chars:
clean_summary = clean_summary[: remaining_chars - 1].rstrip() + "…"
return f"{prefix}{clean_summary}"
def _format_turns_for_summary(self, turns: Sequence[dict[str, str]]) -> str:
serialized_turns: List[str] = []
total_chars = 0
for turn in turns:
speaker = "Utilisateur" if turn["role"] == "human" else "Assistant"
entry = f"{speaker}: {turn['content'].strip()}"
if total_chars + len(entry) > _MAX_SUMMARY_SOURCE_CHARS:
remaining = _MAX_SUMMARY_SOURCE_CHARS - total_chars
if remaining > 0:
serialized_turns.append(entry[:remaining].rstrip())
break
serialized_turns.append(entry)
total_chars += len(entry)
return "\n\n".join(serialized_turns).strip()
@staticmethod
def _messages_char_count(messages: Sequence[BaseMessage]) -> int:
total = 0
for message in messages:
content = message.content
if isinstance(content, str):
total += len(content)
return total
@staticmethod
def _ensure_conversation_access(
*,
owner_user_id: Optional[str],
requester_user_id: Optional[str],
) -> None:
if owner_user_id == requester_user_id:
return
if owner_user_id is None and requester_user_id is None:
return
raise HTTPException(status_code=403, detail="Not your conversation")
[docs]
def save_turn(
self,
conversation_id: str,
question: str,
answer: str,
query_id: str,
mode: Optional[str] = None,
sources: Optional[Any] = None,
timings: Optional[dict] = None,
steps: Optional[list] = None,
) -> str:
"""Persist human + assistant messages and return the assistant message id."""
human_msg_id = str(uuid.uuid4())
assistant_msg_id = str(uuid.uuid4())
assistant_metadata: Optional[dict] = None
if sources is not None or timings is not None or steps is not None:
assistant_metadata = {}
if sources is not None:
assistant_metadata["sources"] = sources
if timings is not None:
assistant_metadata["timings"] = timings
if steps:
assistant_metadata["steps"] = steps
session = self.pg_repo.get_session()
try:
self.pg_repo.add_conversation_message(
session,
human_msg_id,
conversation_id,
"human",
question,
query_id=query_id,
mode=mode,
)
self.pg_repo.add_conversation_message(
session,
assistant_msg_id,
conversation_id,
"assistant",
answer,
query_id=query_id,
mode=mode,
metadata=assistant_metadata,
)
self.pg_repo.touch_conversation(session, conversation_id)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
return assistant_msg_id