Source code for lalandre_rag.retrieval.bm25_search

"""
BM25 Lexical Search Service
PostgreSQL full-text search with BM25-like ranking (ts_rank_cd)
"""

import logging
from typing import Any, Dict, List, Optional, cast

from lalandre_core.config import get_config
from lalandre_core.repositories.common import PayloadBuilder
from lalandre_db_postgres import ActsSQL, ChunksSQL, PostgresRepository, SubdivisionsSQL, VersionsSQL

from .result import RetrievalResult

logger = logging.getLogger(__name__)


def _to_int(value: Any, default: int = 0) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return default


def _to_optional_str(value: Any) -> str | None:
    if value is None:
        return None
    return value if isinstance(value, str) else str(value)


[docs] class BM25SearchService: """ BM25-based lexical search using PostgreSQL full-text search Uses PostgreSQL's ts_rank_cd (Cover Density Ranking) which provides BM25-like scoring that considers: - Term frequency (TF) - Document length normalization - Cover density (proximity of terms) Responsibilities: - Execute BM25 search via PostgreSQL - Convert PostgreSQL results to RetrievalResult format - Apply filters and language configuration - Manage full-text search indexes Does NOT: - Fuse with semantic results (handled by RetrievalService) - Generate embeddings - Access Qdrant """ def __init__( self, pg_repo: PostgresRepository, language: str = "french", payload_builder: Optional[PayloadBuilder] = None ): """ Initialize BM25 search service Args: pg_repo: PostgreSQL repository for text search language: PostgreSQL text search language configuration """ self.pg_repo = pg_repo self.language = language self.payload_builder = payload_builder or PayloadBuilder() config = get_config() self.default_limit = config.search.default_limit self.bm25_normalization = int(config.search.bm25_normalization) def _normalize_bm25_score(self, score: float) -> float: """Return a stable [0,1] lexical score for thresholding and fusion. PostgreSQL can already normalize `ts_rank_cd` into `(0, 1)` when bit 32 is enabled. If it is not enabled, we apply the same monotonic squashing in the application layer so lexical thresholds remain comparable across runtimes. """ bounded_score = max(float(score), 0.0) if self.bm25_normalization & 32: return min(bounded_score, 1.0) return bounded_score / (bounded_score + 1.0) if bounded_score > 0.0 else 0.0
[docs] def search( self, query: str, top_k: Optional[int] = None, filters: Optional[Dict[str, Any]] = None, language: Optional[str] = None, target: str = "subdivisions", ) -> list["RetrievalResult"]: """ Execute BM25 lexical search Args: query: Search query text top_k: Number of results to return (default: config.search.default_limit) filters: Optional metadata filters (e.g., {"act_id": 123, "celex": "32016R0679"}) language: Override default language (default: "french") target: "subdivisions" or "chunks" Returns: List of RetrievalResult objects sorted by BM25 score """ if top_k is None: top_k = self.default_limit search_language = language or self.language logger.debug(f"BM25 search: query='{query[:50]}...', top_k={top_k}, language={search_language}") # Execute PostgreSQL full-text search if target == "chunks": pg_results = self.pg_repo.search_bm25_chunks( query=query, top_k=top_k, language=search_language, filter_conditions=filters ) results = self._convert_to_retrieval_results(pg_results, target="chunks") else: pg_results = self.pg_repo.search_bm25( query=query, top_k=top_k, language=search_language, filter_conditions=filters ) results = self._convert_to_retrieval_results(pg_results, target="subdivisions") logger.info(f"BM25 search returned {len(results)} results") return results
def _convert_to_retrieval_results( self, pg_results: List[Dict[str, Any]], target: str = "subdivisions" ) -> list["RetrievalResult"]: """ Convert PostgreSQL search results to RetrievalResult objects Args: pg_results: List of dicts with 'subdivision' and 'score' keys Returns: List of RetrievalResult objects """ results: List[RetrievalResult] = [] def _enum_value(value: Any) -> Any: return value.value if hasattr(value, "value") else value for result_dict in pg_results: subdivision = cast(SubdivisionsSQL | None, result_dict.get("subdivision")) act = cast(ActsSQL | None, result_dict.get("act")) version = cast(VersionsSQL | None, result_dict.get("version")) raw_score = float(result_dict["score"]) normalized_score = self._normalize_bm25_score(raw_score) chunk = cast(ChunksSQL | None, result_dict.get("chunk")) subdivision_data: Dict[str, Any] = { "id": subdivision.id if subdivision else None, "subdivision_type": _enum_value(subdivision.subdivision_type) if subdivision else None, "number": subdivision.number if subdivision else None, "title": subdivision.title if subdivision else None, "sequence_order": subdivision.sequence_order if subdivision else None, "hierarchy_path": subdivision.hierarchy_path if subdivision else None, "depth": subdivision.depth if subdivision else None, "parent_id": subdivision.parent_id if subdivision else None, "content": subdivision.content if subdivision else None, } act_id_value = act.id if act else (subdivision.act_id if subdivision else None) celex_value = act.celex if act else None subdivision_type_value = str(_enum_value(subdivision.subdivision_type) or "") if subdivision else "" sequence_order_value = _to_int(subdivision.sequence_order if subdivision else None) act_data: Dict[str, Any] = { "id": act_id_value, "celex": act.celex if act else None, "title": act.title if act else None, "act_type": _enum_value(act.act_type) if act else None, "language": _enum_value(act.language) if act else None, "adoption_date": act.adoption_date if act else None, "force_date": act.force_date if act else None, "level": act.level if act else None, } version_data: Dict[str, Any] | None = None if version: version_data = { "id": version.id, "version_number": version.version_number, "version_type": _enum_value(version.version_type), "version_date": version.version_date, "is_current": version.is_current, } if target == "chunks" and chunk is not None: chunk_data: Dict[str, Any] = { "id": chunk.id, "chunk_index": chunk.chunk_index, "subdivision_id": chunk.subdivision_id, "content": chunk.content, "char_start": chunk.char_start, "char_end": chunk.char_end, "token_count": chunk.token_count, "chunk_metadata": chunk.chunk_metadata, } payload = self.payload_builder.build_chunk_payload( chunk_data=chunk_data, subdivision_data=subdivision_data, act_data=act_data ) payload.pop("content", None) payload.setdefault("collection", "chunks") payload.setdefault("source_collection", "chunks") results.append( RetrievalResult( content=chunk.content or "", score=normalized_score, subdivision_id=_to_int(chunk.subdivision_id), act_id=_to_int(act_id_value), celex=_to_optional_str(celex_value), subdivision_type=subdivision_type_value, sequence_order=sequence_order_value, metadata={ **payload, "search_method": "bm25", "lexical_score_raw": raw_score, "lexical_score_normalized": normalized_score, }, ) ) else: if subdivision is None: logger.debug("Skipping BM25 row without subdivision") continue payload = self.payload_builder.build_subdivision_payload( subdivision_data=subdivision_data, act_data=act_data, version_data=version_data, metadata={} ) payload.pop("content", None) payload.setdefault("collection", "subdivisions") payload.setdefault("source_collection", "subdivisions") results.append( RetrievalResult( content=subdivision.content or "", score=normalized_score, subdivision_id=_to_int(subdivision.id), act_id=_to_int(subdivision.act_id), celex=_to_optional_str(celex_value), subdivision_type=subdivision_type_value, sequence_order=sequence_order_value, metadata={ **payload, "search_method": "bm25", "lexical_score_raw": raw_score, "lexical_score_normalized": normalized_score, }, ) ) return results
[docs] def get_statistics(self) -> Dict[str, Any]: """ Get BM25 search statistics Returns: Dictionary with configuration and statistics """ return { "language": self.language, "default_limit": self.default_limit, "search_method": "PostgreSQL ts_rank_cd (BM25-like)", }