"""Post-process a chatbot answer that contains bullets into flowing prose.
Safety rails (every failure falls back to the original answer):
- Skip rewriting when the answer is already mostly prose.
- Reject the rewrite if any native citation tag such as ``[S1]``, ``[G1]``,
``[R1]``, ``[C1]``, or ``[CM1]`` with an optional ``, L1/L2/L3`` suffix is
altered or lost.
- Reject the rewrite if its length drifts too far from the original answer.
- Catch any LLM exception silently.
This is best-effort: the return value is always a valid answer, and citations
are preserved with the same multiplicity as the input.
"""
from __future__ import annotations
import logging
import re
from collections import Counter
from pathlib import Path
from typing import Any, Optional
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
logger = logging.getLogger(__name__)
# Native citation tags as emitted by the RAG prompt. Intentionally stricter
# than the prose_linker regex: the format contract is <TYPE><digits> with
# optional ", L<digits>" suffix. Anything else (e.g. "[S1, article 4]")
# would indicate a malformed citation — we compare tag multisets as-is, so
# drift would reject the rewrite.
_CITATION_TAG_RE = re.compile(r"\[(?:S|G|R|C|CM)\d+(?:,\s*L\d+)?\]")
# Bullet = line starting with "- " or "* " after optional leading whitespace.
_BULLET_LINE_RE = re.compile(r"^\s*[-*]\s+", re.MULTILINE)
_SYSTEM_PROMPT_PATH = Path(__file__).parent / "prompts" / "rewriter_system.txt"
_SYSTEM_PROMPT_CACHE: Optional[str] = None
def _system_prompt() -> str:
global _SYSTEM_PROMPT_CACHE
if _SYSTEM_PROMPT_CACHE is None:
try:
_SYSTEM_PROMPT_CACHE = _SYSTEM_PROMPT_PATH.read_text(encoding="utf-8")
except FileNotFoundError:
logger.warning(
"rewriter_system.txt missing at %s — prose_rewriter disabled",
_SYSTEM_PROMPT_PATH,
)
_SYSTEM_PROMPT_CACHE = ""
return _SYSTEM_PROMPT_CACHE
def _count_bullet_lines(text: str) -> int:
return sum(1 for _ in _BULLET_LINE_RE.finditer(text))
def _count_nonempty_lines(text: str) -> int:
return max(1, sum(1 for line in text.split("\n") if line.strip()))
def _citation_multiset(text: str) -> Counter[str]:
return Counter(_CITATION_TAG_RE.findall(text))
def _citations_preserved(original: str, rewritten: str) -> bool:
return _citation_multiset(original) == _citation_multiset(rewritten)
def _length_within_bounds(original: str, rewritten: str) -> bool:
if not original:
return True
ratio = len(rewritten) / len(original)
return 0.5 <= ratio <= 2.0
def _invoke_llm(llm: Any, system_prompt: str, answer: str) -> Optional[str]:
"""Call the LangChain-style LLM. Normalize the response into a string."""
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=answer),
]
try:
response = llm.invoke(messages)
except Exception:
logger.warning("prose_rewriter LLM invocation failed", exc_info=True)
return None
if isinstance(response, AIMessage):
content = response.content
elif isinstance(response, str):
content = response
elif hasattr(response, "content"):
content = response.content
else:
logger.debug("prose_rewriter: unknown LLM response type %s", type(response))
return None
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, str):
parts.append(part)
elif isinstance(part, dict) and isinstance(part.get("text"), str):
parts.append(part["text"])
content = "".join(parts)
if not isinstance(content, str):
return None
return content.strip()
[docs]
def rewrite_to_prose(
answer: str,
llm: Any,
*,
max_bullet_ratio: float = 0.10,
) -> str:
"""Rewrite a bullet-heavy answer into flowing prose.
Returns ``answer`` unchanged if:
- ``llm`` is ``None`` or ``answer`` is empty/whitespace.
- The bullet ratio is below ``max_bullet_ratio`` (nothing to rewrite).
- The system prompt is missing.
- The LLM call fails or returns an unusable payload.
- The rewritten output is out of bounds in length.
- The rewritten output does not preserve the native citation tags
with identical multiplicity.
"""
if llm is None or not answer or not answer.strip():
return answer
total_lines = _count_nonempty_lines(answer)
bullet_lines = _count_bullet_lines(answer)
bullet_ratio = bullet_lines / total_lines
if bullet_ratio < max_bullet_ratio:
logger.debug(
"prose_rewriter: skip (bullet_ratio=%.2f < threshold=%.2f)",
bullet_ratio,
max_bullet_ratio,
)
return answer
system_prompt = _system_prompt()
if not system_prompt:
return answer
rewritten = _invoke_llm(llm, system_prompt, answer)
if rewritten is None or not rewritten.strip():
return answer
if not _length_within_bounds(answer, rewritten):
logger.warning(
"prose_rewriter: rejected (length drift %d → %d chars)",
len(answer),
len(rewritten),
)
return answer
if not _citations_preserved(answer, rewritten):
logger.warning(
"prose_rewriter: rejected (citation tags altered). original=%s rewritten=%s",
sorted(_citation_multiset(answer).items()),
sorted(_citation_multiset(rewritten).items()),
)
return answer
logger.info(
"prose_rewriter: applied (bullets %d→%d, length %d→%d chars)",
bullet_lines,
_count_bullet_lines(rewritten),
len(answer),
len(rewritten),
)
return rewritten
__all__ = ["rewrite_to_prose"]