Source code for lalandre_db_postgres.repository

"""
PostgreSQL repository implementation
"""

import re
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Any, Dict, List, Optional, Sequence, TypeAlias

from lalandre_core.config import get_config
from lalandre_core.repositories.base import BaseRepository
from sqlalchemy import and_, create_engine, func, literal_column, text
from sqlalchemy import column as sa_column
from sqlalchemy.dialects.postgresql import TSVECTOR, insert
from sqlalchemy.orm import Session, joinedload, sessionmaker

from .models import (
    ActRelationsSQL,
    ActsSQL,
    ActSummarySQL,
    ChunksSQL,
    ConversationMessageSQL,
    ConversationSQL,
    EmbeddingStateSQL,
    SubdivisionsSQL,
    VersionsSQL,
)

_WORD_RE = re.compile(r"[a-zA-ZÀ-ÿ0-9][\w\-\.]*")

# Mapping from LanguageCode values to PostgreSQL text-search configurations.
# Languages not listed fall back to 'simple' (no stemming, no stop-words).
_PG_TS_LANG: Dict[str, str] = {
    "fr": "french",
    "en": "english",
    "de": "german",
    "es": "spanish",
    "it": "italian",
    "nl": "dutch",
    "pt": "portuguese",
}
_DEFAULT_PG_TS_LANG = "simple"
# Languages searched in parallel when no language override is given.
_ACTIVE_LANGS = ["fr", "en"]
# Maps pg_lang config name → suffix for stored fts_* columns (fts_fr, fts_en).
_PG_LANG_TO_COL_SUFFIX: Dict[str, str] = {"french": "fr", "english": "en"}
LanguageSearchPair: TypeAlias = tuple[Optional[str], str]


[docs] class PostgresRepository(BaseRepository): """ Repository for PostgreSQL database operations for RAG: retrieval, context enrichment, and structured legal data access """ def __init__(self, connection_string: str): self.engine = create_engine( connection_string, echo=False, pool_size=10, max_overflow=20, pool_pre_ping=True, ) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) @staticmethod def _build_or_tsquery(session: Session, language: str, query: str): """Build an OR-combined tsquery for BM25 recall. Extracts individual words, stems each one via PostgreSQL, and joins them with ``|`` so that documents matching *any* term are retrieved. ``ts_rank_cd`` then naturally scores documents with more matching terms higher — the standard BM25 behaviour expected by RAG systems. Falls back to ``websearch_to_tsquery`` when the query is very short (single word) or when stemming produces no usable lexemes. """ words = _WORD_RE.findall(query) if len(words) <= 2: return func.websearch_to_tsquery(language, query) # Classify words before stemming: acronyms/words-with-digits are "priority" # because they are almost always distinctive (DSP2, IA, GDPR, EBA, PSD2…). # Regular lowercase words are candidates for the short-lexeme filter. priority_lexemes: list[str] = [] normal_lexemes: list[str] = [] for word in words: is_priority = word.isupper() or any(c.isdigit() for c in word) row = session.execute( text("SELECT * FROM ts_debug(:lang, :word) LIMIT 1"), {"lang": language, "word": word}, ).fetchone() if row: # ts_debug returns (alias, description, token, dictionaries, dictionary, lexemes) lex_list = row[-1] # lexemes column if lex_list: if is_priority: priority_lexemes.extend(lex_list) else: normal_lexemes.extend(lex_list) # Deduplicate: priority lexemes first (kept regardless of length), # then normal lexemes filtered to >= 3 chars to drop function-word # residuals like "le", "si", "a" that slip through the stopword list. seen: set[str] = set() unique: list[str] = [] for lex in priority_lexemes: if lex not in seen: seen.add(lex) unique.append(lex) for lex in normal_lexemes: if lex not in seen and len(lex) >= 3: seen.add(lex) unique.append(lex) if not unique: return func.websearch_to_tsquery(language, query) max_lex = get_config().search.fts_max_lexemes if len(unique) > max_lex: unique = unique[:max_lex] or_expr = " | ".join(unique) return func.to_tsquery(language, or_expr)
[docs] def get_session(self) -> Session: """Get a database session""" return self.SessionLocal()
def _run_language_queries( self, active_pairs: Sequence[LanguageSearchPair], run_one: Any, ) -> List[Any]: """Run per-language BM25 searches in order, with bounded parallelism.""" if not active_pairs: return [] if len(active_pairs) == 1: act_lang, pg_lang = active_pairs[0] return [run_one(act_lang, pg_lang)] max_workers = min(len(active_pairs), max(int(get_config().search.max_parallel_workers), 1)) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(run_one, act_lang, pg_lang) for act_lang, pg_lang in active_pairs] return [future.result() for future in futures]
[docs] def close(self): """Close the database connection""" self.engine.dispose()
[docs] def health_check(self) -> bool: """Verify PostgreSQL connectivity""" try: with self.get_session() as session: session.execute(text("SELECT 1")) return True except Exception: return False
# === EMBEDDING WORKER HELPERS ===
[docs] @staticmethod def ensure_embedding_state_table(session: Session) -> None: """Ensure embedding_state table exists and supports all runtime object types.""" EmbeddingStateSQL.__table__.create(bind=session.get_bind(), checkfirst=True) # type: ignore[attr-defined] session.execute(text("ALTER TABLE embedding_state DROP CONSTRAINT IF EXISTS check_embedding_state_object_type")) session.execute( text( "ALTER TABLE embedding_state " "ADD CONSTRAINT check_embedding_state_object_type " "CHECK (object_type IN ('subdivision', 'chunk', 'act'))" ) )
[docs] @staticmethod def purge_orphan_embedding_states(session: Session) -> int: """Delete embedding-state rows referencing missing acts/chunks/subdivisions.""" stale_subdivisions = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == "subdivision", ~EmbeddingStateSQL.object_id.in_(session.query(SubdivisionsSQL.id)), ) .delete(synchronize_session=False) ) stale_chunks = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == "chunk", ~EmbeddingStateSQL.object_id.in_(session.query(ChunksSQL.id)), ) .delete(synchronize_session=False) ) stale_acts = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == "act", ~EmbeddingStateSQL.object_id.in_(session.query(ActsSQL.id)), ) .delete(synchronize_session=False) ) return int(stale_subdivisions) + int(stale_chunks) + int(stale_acts)
[docs] @staticmethod def count_subdivisions(session: Session) -> int: """Return the total number of stored subdivisions.""" return int(session.query(SubdivisionsSQL).count())
[docs] @staticmethod def count_chunks(session: Session) -> int: """Return the total number of stored chunks.""" return int(session.query(ChunksSQL).count())
[docs] @staticmethod def count_embedded_subdivisions( session: Session, provider: str, model_name: str, vector_size: int, ) -> int: """Count subdivisions already embedded for one embedding runtime.""" return int( session.query(EmbeddingStateSQL) .join( SubdivisionsSQL, and_( EmbeddingStateSQL.object_type == "subdivision", EmbeddingStateSQL.object_id == SubdivisionsSQL.id, ), ) .filter( EmbeddingStateSQL.provider == provider, EmbeddingStateSQL.model_name == model_name, EmbeddingStateSQL.vector_size == vector_size, ) .count() )
[docs] @staticmethod def count_embedded_chunks( session: Session, provider: str, model_name: str, vector_size: int, ) -> int: """Count chunks already embedded for one embedding runtime.""" return int( session.query(EmbeddingStateSQL) .join( ChunksSQL, and_( EmbeddingStateSQL.object_type == "chunk", EmbeddingStateSQL.object_id == ChunksSQL.id, ), ) .filter( EmbeddingStateSQL.provider == provider, EmbeddingStateSQL.model_name == model_name, EmbeddingStateSQL.vector_size == vector_size, ) .count() )
[docs] @staticmethod def get_embedding_state_map( session: Session, object_type: str, object_ids: list[int], provider: str, model_name: str, vector_size: int, ) -> dict[int, str]: """Return {object_id: content_hash} for existing embeddings.""" if not object_ids: return {} rows = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == object_type, EmbeddingStateSQL.object_id.in_(object_ids), EmbeddingStateSQL.provider == provider, EmbeddingStateSQL.model_name == model_name, EmbeddingStateSQL.vector_size == vector_size, ) .all() ) state_map: dict[int, str] = {} for row in rows: object_id = row.object_id if isinstance(object_id, int): state_map[object_id] = str(row.content_hash) return state_map
[docs] @staticmethod def upsert_embedding_states(session: Session, records: list[dict[str, Any]]) -> None: """Upsert embedding-state rows for a batch.""" if not records: return stmt = insert(EmbeddingStateSQL).values(records) stmt = stmt.on_conflict_do_update( index_elements=["object_type", "object_id", "provider", "model_name", "vector_size"], set_={ "content_hash": stmt.excluded.content_hash, "embedded_at": func.now(), }, ) session.execute(stmt)
[docs] @staticmethod def list_acts_with_metadata(session: Session) -> list[Any]: """Return acts with eager-loaded metadata, versions, and summaries.""" return ( session.query(ActsSQL) .options( joinedload(ActsSQL.metadata_entries), joinedload(ActsSQL.versions), joinedload(ActsSQL.summaries), ) .all() )
[docs] @staticmethod def get_act_by_celex(session: Session, celex: str) -> Any | None: """Return the act identified by *celex*, if present.""" return session.query(ActsSQL).filter(ActsSQL.celex == celex).first()
[docs] @staticmethod def list_subdivisions_for_act(session: Session, act_id: int) -> list[Any]: """List subdivisions for one act ordered by sequence.""" return ( session.query(SubdivisionsSQL) .filter(SubdivisionsSQL.act_id == act_id) .order_by(SubdivisionsSQL.sequence_order) .all() )
[docs] @staticmethod def list_subdivisions_for_act_version( session: Session, act_id: int, version_id: Optional[int], ) -> list[Any]: """List subdivisions for an act, optionally restricted to one version.""" query = ( session.query(SubdivisionsSQL) .filter(SubdivisionsSQL.act_id == act_id) .order_by(SubdivisionsSQL.sequence_order) ) if version_id is not None: query = query.filter(SubdivisionsSQL.version_id == version_id) return query.all()
[docs] @staticmethod def get_current_version_for_act(session: Session, act_id: int) -> Any | None: """Return the current version row for an act, if any.""" return session.query(VersionsSQL).filter(VersionsSQL.act_id == act_id, VersionsSQL.is_current.is_(True)).first()
[docs] @staticmethod def get_act_summary( session: Session, *, act_id: int, language: str, summary_kind: str = "canonical", ) -> Any | None: """Return one summary row for an act/language/kind triplet.""" return ( session.query(ActSummarySQL) .filter( ActSummarySQL.act_id == act_id, ActSummarySQL.language == language, ActSummarySQL.summary_kind == summary_kind, ) .first() )
[docs] @staticmethod def upsert_act_summary(session: Session, record: Dict[str, Any]) -> None: """Insert or update one act summary record.""" if not record: return stmt = insert(ActSummarySQL).values(record) stmt = stmt.on_conflict_do_update( index_elements=["act_id", "language", "summary_kind"], set_={ "status": stmt.excluded.status, "summary_text": stmt.excluded.summary_text, "content_hash": stmt.excluded.content_hash, "source_version_id": stmt.excluded.source_version_id, "prompt_version": stmt.excluded.prompt_version, "model_id": stmt.excluded.model_id, "generated_at": stmt.excluded.generated_at, "last_attempt_at": stmt.excluded.last_attempt_at, "error_text": stmt.excluded.error_text, "trace_jsonb": stmt.excluded.trace_jsonb, "updated_at": func.now(), }, ) session.execute(stmt)
[docs] @staticmethod def list_chunks_for_act(session: Session, act_id: int) -> list[tuple[Any, Any]]: """Return chunk rows paired with their subdivisions for one act.""" rows = ( session.query(ChunksSQL, SubdivisionsSQL) .join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id) .filter(SubdivisionsSQL.act_id == act_id) .all() ) return [(chunk, subdivision) for chunk, subdivision in rows]
# === CHUNKING / EXTRACTION WORKER HELPERS ===
[docs] @staticmethod def count_acts(session: Session) -> int: """Return the total number of acts.""" return int(session.query(ActsSQL).count())
[docs] @staticmethod def count_acts_pending_extraction(session: Session) -> int: """Return the number of acts not yet marked as extracted.""" return int(session.query(ActsSQL).filter(ActsSQL.extraction_status != "extracted").count())
[docs] @staticmethod def reset_extraction_status(session: Session, act_id: int) -> None: """Reset extraction state so the act is re-extracted on next run.""" session.query(ActsSQL).filter(ActsSQL.id == act_id).update( { ActsSQL.extraction_status: "pending", ActsSQL.extracted_at: None, } )
[docs] @staticmethod def reset_all_extraction_statuses(session: Session) -> int: """Reset extraction state for all acts after a global pipeline purge.""" updated = session.query(ActsSQL).update( { ActsSQL.extraction_status: "pending", ActsSQL.extracted_at: None, }, synchronize_session=False, ) return int(updated)
[docs] @staticmethod def reset_stale_extracting_acts(session: Session, timeout_minutes: int) -> int: """Reset acts stuck in 'extracting' for longer than *timeout_minutes*.""" cutoff = func.now() - timedelta(minutes=timeout_minutes) return int( session.query(ActsSQL) .filter( ActsSQL.extraction_status == "extracting", ActsSQL.updated_at < cutoff, ) .update( {ActsSQL.extraction_status: "pending", ActsSQL.extracted_at: None}, synchronize_session=False, ) )
[docs] @staticmethod def count_subdivisions_without_chunks( session: Session, min_content_length: int, ) -> int: """Count eligible subdivisions that still have no generated chunks.""" eligible = session.query(SubdivisionsSQL).filter(func.length(SubdivisionsSQL.content) >= min_content_length) missing = eligible.filter( ~session.query(ChunksSQL).filter(ChunksSQL.subdivision_id == SubdivisionsSQL.id).exists() ) return int(missing.count())
[docs] @staticmethod def list_chunk_ids_for_subdivision(session: Session, subdivision_id: int) -> list[int]: """Return all chunk identifiers attached to one subdivision.""" rows = session.query(ChunksSQL.id).filter(ChunksSQL.subdivision_id == subdivision_id).all() return [int(row[0]) for row in rows]
[docs] @staticmethod def list_chunk_ids_for_act(session: Session, act_id: int) -> list[int]: """Return all chunk identifiers attached to one act.""" rows = ( session.query(ChunksSQL.id) .join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id) .filter(SubdivisionsSQL.act_id == act_id) .all() ) return [int(row[0]) for row in rows]
[docs] @staticmethod def subdivision_has_chunks(session: Session, subdivision_id: int) -> bool: """Return whether a subdivision already has at least one chunk.""" row = session.query(ChunksSQL.id).filter(ChunksSQL.subdivision_id == subdivision_id).first() return row is not None
[docs] @staticmethod def delete_embedding_states_for_chunk_ids(session: Session, chunk_ids: list[int]) -> int: """Delete embedding-state rows for the given chunk identifiers.""" if not chunk_ids: return 0 deleted = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == "chunk", EmbeddingStateSQL.object_id.in_(chunk_ids), ) .delete(synchronize_session=False) ) return int(deleted)
[docs] @staticmethod def delete_embedding_states_for_act_ids(session: Session, act_ids: list[int]) -> int: """Delete embedding-state rows for the given act identifiers.""" if not act_ids: return 0 deleted = ( session.query(EmbeddingStateSQL) .filter( EmbeddingStateSQL.object_type == "act", EmbeddingStateSQL.object_id.in_(act_ids), ) .delete(synchronize_session=False) ) return int(deleted)
[docs] @staticmethod def delete_chunks_for_subdivision(session: Session, subdivision_id: int) -> int: """Delete all chunks belonging to one subdivision.""" deleted = ( session.query(ChunksSQL) .filter(ChunksSQL.subdivision_id == subdivision_id) .delete(synchronize_session=False) ) return int(deleted)
[docs] @staticmethod def delete_chunks_for_act(session: Session, act_id: int) -> int: """Delete all chunks belonging to one act.""" subdivision_ids_rows = session.query(SubdivisionsSQL.id).filter(SubdivisionsSQL.act_id == act_id).all() subdivision_ids = [int(row[0]) for row in subdivision_ids_rows] if not subdivision_ids: return 0 deleted = ( session.query(ChunksSQL) .filter(ChunksSQL.subdivision_id.in_(subdivision_ids)) .delete(synchronize_session=False) ) return int(deleted)
[docs] @staticmethod def insert_chunk_records(session: Session, records: list[dict[str, Any]]) -> None: """Insert a batch of precomputed chunk rows.""" if not records: return stmt = insert(ChunksSQL) session.execute(stmt, records)
[docs] @staticmethod def clear_chunks(session: Session) -> int: """Delete every chunk row.""" deleted = session.query(ChunksSQL).delete(synchronize_session=False) return int(deleted)
[docs] @staticmethod def clear_embedding_states(session: Session) -> int: """Delete every embedding-state row.""" deleted = session.query(EmbeddingStateSQL).delete(synchronize_session=False) return int(deleted)
[docs] @staticmethod def clear_act_relations(session: Session) -> int: """Delete every extracted act relation.""" deleted = session.query(ActRelationsSQL).delete(synchronize_session=False) return int(deleted)
[docs] @staticmethod def clear_act_summaries(session: Session) -> int: """Delete every persisted act summary.""" deleted = session.query(ActSummarySQL).delete(synchronize_session=False) return int(deleted)
# === FULL-TEXT SEARCH (BM25-like) ===
[docs] def search_bm25( self, query: str, top_k: Optional[int] = None, language: Optional[str] = None, filter_conditions: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Full-text search using PostgreSQL native ts_rank_cd (BM25-like ranking). Returns subdivisions with relevance scores for hybrid RAG retrieval. Searches across all active languages (FR + EN) in parallel and merges results by best score, unless a language override is provided via the ``language`` parameter or ``filter_conditions['language']``. Args: query: Search query top_k: Number of results to return (defaults to config.search.default_limit) language: PostgreSQL text-search config override (e.g. 'french'). When set, only that language is searched. Defaults to multilingual UNION. filter_conditions: Optional filters (e.g., {"act_id": 123, "celex": "32016R0679"}). ``filter_conditions['language']`` restricts results to a single act language. """ config = get_config() if top_k is None: top_k = config.search.default_limit # Determine which (act_language, pg_ts_config) pairs to search. # A language override collapses the search to a single language. lang_override = filter_conditions.get("language") if filter_conditions else None if language is not None: # Caller passed an explicit pg_ts config string (e.g. 'french') — honour it. active_pairs: list[LanguageSearchPair] = [(None, language)] elif lang_override is not None: pg_lang = _PG_TS_LANG.get(str(lang_override), _DEFAULT_PG_TS_LANG) active_pairs = [(lang_override, pg_lang)] else: active_pairs = [(lc, _PG_TS_LANG.get(lc, _DEFAULT_PG_TS_LANG)) for lc in _ACTIVE_LANGS] normalization = config.search.bm25_normalization def _run_subdiv_search( session: Session, tsquery: Any, pg_lang: str, act_lang_filter: Optional[str] ) -> List[Any]: col_suffix = _PG_LANG_TO_COL_SUFFIX.get(pg_lang) fts_expr = ( sa_column(f"fts_{col_suffix}", TSVECTOR) if col_suffix else func.to_tsvector(pg_lang, SubdivisionsSQL.content) ) q = ( session.query( SubdivisionsSQL, ActsSQL, VersionsSQL, func.ts_rank_cd(fts_expr, tsquery, normalization).label("rank"), ) .join(ActsSQL, SubdivisionsSQL.act_id == ActsSQL.id) .outerjoin(VersionsSQL, SubdivisionsSQL.version_id == VersionsSQL.id) .filter(fts_expr.op("@@")(tsquery)) ) if act_lang_filter is not None: q = q.filter(ActsSQL.language == act_lang_filter) if filter_conditions: if "act_id" in filter_conditions: q = q.filter(SubdivisionsSQL.act_id == filter_conditions["act_id"]) if "subdivision_type" in filter_conditions: q = q.filter(SubdivisionsSQL.subdivision_type == filter_conditions["subdivision_type"]) if "subdivision_number" in filter_conditions: q = q.filter(SubdivisionsSQL.number == filter_conditions["subdivision_number"]) if "celex" in filter_conditions: q = q.filter(ActsSQL.celex == filter_conditions["celex"]) if "act_type" in filter_conditions: q = q.filter(ActsSQL.act_type == filter_conditions["act_type"]) return q.order_by(text("rank DESC")).limit(top_k).all() def _search_language(act_lang: Optional[str], pg_lang: str) -> List[Any]: with self.get_session() as session: rows = _run_subdiv_search(session, func.websearch_to_tsquery(pg_lang, query), pg_lang, act_lang) if rows: return rows return _run_subdiv_search(session, self._build_or_tsquery(session, pg_lang, query), pg_lang, act_lang) best: Dict[int, Any] = {} for rows in self._run_language_queries(active_pairs, _search_language): for subdiv, act, version, rank in rows: score = float(rank) if subdiv.id not in best or score > best[subdiv.id]["score"]: best[subdiv.id] = { "subdivision": subdiv, "act": act, "version": version, "score": score, } return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:top_k]
# === CONVERSATION HELPERS ===
[docs] @staticmethod def get_conversation(session: Session, conversation_id: str) -> Optional[ConversationSQL]: """Return a conversation session by identifier.""" return session.query(ConversationSQL).filter(ConversationSQL.id == conversation_id).first()
[docs] @staticmethod def create_conversation( session: Session, conversation_id: str, title: str, user_id: Optional[str] = None, ) -> ConversationSQL: """Create and flush a new conversation session.""" conv = ConversationSQL(id=conversation_id, title=title, user_id=user_id) session.add(conv) session.flush() return conv
[docs] @staticmethod def list_conversations( session: Session, user_id: Optional[str], limit: int = 50, ) -> List[ConversationSQL]: """List the most recent conversations for one user.""" return ( session.query(ConversationSQL) .filter(ConversationSQL.user_id == user_id) .order_by(ConversationSQL.updated_at.desc()) .limit(limit) .all() )
[docs] @staticmethod def delete_conversation(session: Session, conversation_id: str) -> bool: """Delete one conversation and report whether a row was removed.""" count = session.query(ConversationSQL).filter(ConversationSQL.id == conversation_id).delete() return count > 0
[docs] @staticmethod def add_conversation_message( session: Session, message_id: str, conversation_id: str, role: str, content: str, query_id: Optional[str] = None, mode: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> ConversationMessageSQL: """Append and flush one message inside a conversation.""" msg = ConversationMessageSQL( id=message_id, conversation_id=conversation_id, role=role, content=content, query_id=query_id, mode=mode, metadata_=metadata, ) session.add(msg) session.flush() return msg
[docs] @staticmethod def get_conversation_messages( session: Session, conversation_id: str, limit: int = 20, ) -> List[ConversationMessageSQL]: """Return the latest conversation messages in chronological order.""" rows = ( session.query(ConversationMessageSQL) .filter(ConversationMessageSQL.conversation_id == conversation_id) .order_by(ConversationMessageSQL.created_at.desc()) .limit(limit) .all() ) return list(reversed(rows))
[docs] @staticmethod def touch_conversation(session: Session, conversation_id: str) -> None: """Refresh the ``updated_at`` timestamp for a conversation.""" session.query(ConversationSQL).filter( ConversationSQL.id == conversation_id, ).update({"updated_at": func.now()})
# === FULL-TEXT SEARCH (BM25-like — chunks) ===
[docs] def search_bm25_chunks( self, query: str, top_k: Optional[int] = None, language: Optional[str] = None, filter_conditions: Optional[Dict[str, Any]] = None, ) -> List[Dict[str, Any]]: """ Full-text search using PostgreSQL ts_rank_cd over chunk content. Returns chunks with relevance scores for fine-grained retrieval. Searches across all active languages (FR + EN) in parallel and merges results by best score, unless a language override is provided via the ``language`` parameter or ``filter_conditions['language']``. """ config = get_config() if top_k is None: top_k = config.search.default_limit lang_override = filter_conditions.get("language") if filter_conditions else None if language is not None: active_pairs: list[LanguageSearchPair] = [(None, language)] elif lang_override is not None: pg_lang = _PG_TS_LANG.get(str(lang_override), _DEFAULT_PG_TS_LANG) active_pairs = [(lang_override, pg_lang)] else: active_pairs = [(lc, _PG_TS_LANG.get(lc, _DEFAULT_PG_TS_LANG)) for lc in _ACTIVE_LANGS] normalization = config.search.bm25_normalization def _run_chunk_search( session: Session, tsquery: Any, pg_lang: str, act_lang_filter: Optional[str] ) -> List[Any]: col_suffix = _PG_LANG_TO_COL_SUFFIX.get(pg_lang) fts_expr = ( literal_column(f"chunks.fts_{col_suffix}", TSVECTOR) if col_suffix else func.to_tsvector(pg_lang, ChunksSQL.content) ) q = ( session.query( ChunksSQL, SubdivisionsSQL, ActsSQL, VersionsSQL, func.ts_rank_cd(fts_expr, tsquery, normalization).label("rank"), ) .join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id) .join(ActsSQL, SubdivisionsSQL.act_id == ActsSQL.id) .outerjoin(VersionsSQL, SubdivisionsSQL.version_id == VersionsSQL.id) .filter(fts_expr.op("@@")(tsquery)) ) if act_lang_filter is not None: q = q.filter(ActsSQL.language == act_lang_filter) if filter_conditions: if "act_id" in filter_conditions: q = q.filter(SubdivisionsSQL.act_id == filter_conditions["act_id"]) if "subdivision_type" in filter_conditions: q = q.filter(SubdivisionsSQL.subdivision_type == filter_conditions["subdivision_type"]) if "subdivision_number" in filter_conditions: q = q.filter(SubdivisionsSQL.number == filter_conditions["subdivision_number"]) if "celex" in filter_conditions: q = q.filter(ActsSQL.celex == filter_conditions["celex"]) if "subdivision_id" in filter_conditions: q = q.filter(SubdivisionsSQL.id == filter_conditions["subdivision_id"]) if "chunk_id" in filter_conditions: q = q.filter(ChunksSQL.id == filter_conditions["chunk_id"]) if "act_type" in filter_conditions: q = q.filter(ActsSQL.act_type == filter_conditions["act_type"]) return q.order_by(text("rank DESC")).limit(top_k).all() def _search_language(act_lang: Optional[str], pg_lang: str) -> List[Any]: with self.get_session() as session: rows = _run_chunk_search(session, func.websearch_to_tsquery(pg_lang, query), pg_lang, act_lang) if rows: return rows return _run_chunk_search(session, self._build_or_tsquery(session, pg_lang, query), pg_lang, act_lang) best: Dict[int, Any] = {} for rows in self._run_language_queries(active_pairs, _search_language): for chunk, subdiv, act, version, rank in rows: score = float(rank) if chunk.id not in best or score > best[chunk.id]["score"]: best[chunk.id] = { "chunk": chunk, "subdivision": subdiv, "act": act, "version": version, "score": score, } return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:top_k]