Source code for lalandre_rag.prose_rewriter

"""Post-process a chatbot answer that contains bullets into flowing prose.

Safety rails (every failure falls back to the original answer):

- Skip rewriting when the answer is already mostly prose.
- Reject the rewrite if any native citation tag such as ``[S1]``, ``[G1]``,
  ``[R1]``, ``[C1]``, or ``[CM1]`` with an optional ``, L1/L2/L3`` suffix is
  altered or lost.
- Reject the rewrite if its length drifts too far from the original answer.
- Catch any LLM exception silently.

This is best-effort: the return value is always a valid answer, and citations
are preserved with the same multiplicity as the input.
"""

from __future__ import annotations

import logging
import re
from collections import Counter
from pathlib import Path
from typing import Any, Optional

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

logger = logging.getLogger(__name__)


# Native citation tags as emitted by the RAG prompt. Intentionally stricter
# than the prose_linker regex: the format contract is <TYPE><digits> with
# optional ", L<digits>" suffix. Anything else (e.g. "[S1, article 4]")
# would indicate a malformed citation — we compare tag multisets as-is, so
# drift would reject the rewrite.
_CITATION_TAG_RE = re.compile(r"\[(?:S|G|R|C|CM)\d+(?:,\s*L\d+)?\]")

# Bullet = line starting with "- " or "* " after optional leading whitespace.
_BULLET_LINE_RE = re.compile(r"^\s*[-*]\s+", re.MULTILINE)


_SYSTEM_PROMPT_PATH = Path(__file__).parent / "prompts" / "rewriter_system.txt"


_SYSTEM_PROMPT_CACHE: Optional[str] = None


def _system_prompt() -> str:
    global _SYSTEM_PROMPT_CACHE
    if _SYSTEM_PROMPT_CACHE is None:
        try:
            _SYSTEM_PROMPT_CACHE = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
        except FileNotFoundError:
            logger.warning(
                "rewriter_system.txt missing at %s — prose_rewriter disabled",
                _SYSTEM_PROMPT_PATH,
            )
            _SYSTEM_PROMPT_CACHE = ""
    return _SYSTEM_PROMPT_CACHE


def _count_bullet_lines(text: str) -> int:
    return sum(1 for _ in _BULLET_LINE_RE.finditer(text))


def _count_nonempty_lines(text: str) -> int:
    return max(1, sum(1 for line in text.split("\n") if line.strip()))


def _citation_multiset(text: str) -> Counter[str]:
    return Counter(_CITATION_TAG_RE.findall(text))


def _citations_preserved(original: str, rewritten: str) -> bool:
    return _citation_multiset(original) == _citation_multiset(rewritten)


def _length_within_bounds(original: str, rewritten: str) -> bool:
    if not original:
        return True
    ratio = len(rewritten) / len(original)
    return 0.5 <= ratio <= 2.0


def _invoke_llm(llm: Any, system_prompt: str, answer: str) -> Optional[str]:
    """Call the LangChain-style LLM. Normalize the response into a string."""
    messages = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=answer),
    ]
    try:
        response = llm.invoke(messages)
    except Exception:
        logger.warning("prose_rewriter LLM invocation failed", exc_info=True)
        return None

    if isinstance(response, AIMessage):
        content = response.content
    elif isinstance(response, str):
        content = response
    elif hasattr(response, "content"):
        content = response.content
    else:
        logger.debug("prose_rewriter: unknown LLM response type %s", type(response))
        return None

    if isinstance(content, list):
        parts: list[str] = []
        for part in content:
            if isinstance(part, str):
                parts.append(part)
            elif isinstance(part, dict) and isinstance(part.get("text"), str):
                parts.append(part["text"])
        content = "".join(parts)

    if not isinstance(content, str):
        return None
    return content.strip()


[docs] def rewrite_to_prose( answer: str, llm: Any, *, max_bullet_ratio: float = 0.10, ) -> str: """Rewrite a bullet-heavy answer into flowing prose. Returns ``answer`` unchanged if: - ``llm`` is ``None`` or ``answer`` is empty/whitespace. - The bullet ratio is below ``max_bullet_ratio`` (nothing to rewrite). - The system prompt is missing. - The LLM call fails or returns an unusable payload. - The rewritten output is out of bounds in length. - The rewritten output does not preserve the native citation tags with identical multiplicity. """ if llm is None or not answer or not answer.strip(): return answer total_lines = _count_nonempty_lines(answer) bullet_lines = _count_bullet_lines(answer) bullet_ratio = bullet_lines / total_lines if bullet_ratio < max_bullet_ratio: logger.debug( "prose_rewriter: skip (bullet_ratio=%.2f < threshold=%.2f)", bullet_ratio, max_bullet_ratio, ) return answer system_prompt = _system_prompt() if not system_prompt: return answer rewritten = _invoke_llm(llm, system_prompt, answer) if rewritten is None or not rewritten.strip(): return answer if not _length_within_bounds(answer, rewritten): logger.warning( "prose_rewriter: rejected (length drift %d%d chars)", len(answer), len(rewritten), ) return answer if not _citations_preserved(answer, rewritten): logger.warning( "prose_rewriter: rejected (citation tags altered). original=%s rewritten=%s", sorted(_citation_multiset(answer).items()), sorted(_citation_multiset(rewritten).items()), ) return answer logger.info( "prose_rewriter: applied (bullets %d%d, length %d%d chars)", bullet_lines, _count_bullet_lines(rewritten), len(answer), len(rewritten), ) return rewritten
__all__ = ["rewrite_to_prose"]