"""
PostgreSQL repository implementation
"""
import re
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Any, Dict, List, Optional, Sequence, TypeAlias
from lalandre_core.config import get_config
from lalandre_core.repositories.base import BaseRepository
from sqlalchemy import and_, create_engine, func, literal_column, text
from sqlalchemy import column as sa_column
from sqlalchemy.dialects.postgresql import TSVECTOR, insert
from sqlalchemy.orm import Session, joinedload, sessionmaker
from .models import (
ActRelationsSQL,
ActsSQL,
ActSummarySQL,
ChunksSQL,
ConversationMessageSQL,
ConversationSQL,
EmbeddingStateSQL,
SubdivisionsSQL,
VersionsSQL,
)
_WORD_RE = re.compile(r"[a-zA-ZÀ-ÿ0-9][\w\-\.]*")
# Mapping from LanguageCode values to PostgreSQL text-search configurations.
# Languages not listed fall back to 'simple' (no stemming, no stop-words).
_PG_TS_LANG: Dict[str, str] = {
"fr": "french",
"en": "english",
"de": "german",
"es": "spanish",
"it": "italian",
"nl": "dutch",
"pt": "portuguese",
}
_DEFAULT_PG_TS_LANG = "simple"
# Languages searched in parallel when no language override is given.
_ACTIVE_LANGS = ["fr", "en"]
# Maps pg_lang config name → suffix for stored fts_* columns (fts_fr, fts_en).
_PG_LANG_TO_COL_SUFFIX: Dict[str, str] = {"french": "fr", "english": "en"}
LanguageSearchPair: TypeAlias = tuple[Optional[str], str]
[docs]
class PostgresRepository(BaseRepository):
"""
Repository for PostgreSQL database operations
for RAG: retrieval, context enrichment, and structured legal data access
"""
def __init__(self, connection_string: str):
self.engine = create_engine(
connection_string,
echo=False,
pool_size=10,
max_overflow=20,
pool_pre_ping=True,
)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
@staticmethod
def _build_or_tsquery(session: Session, language: str, query: str):
"""Build an OR-combined tsquery for BM25 recall.
Extracts individual words, stems each one via PostgreSQL, and joins
them with ``|`` so that documents matching *any* term are retrieved.
``ts_rank_cd`` then naturally scores documents with more matching
terms higher — the standard BM25 behaviour expected by RAG systems.
Falls back to ``websearch_to_tsquery`` when the query is very short
(single word) or when stemming produces no usable lexemes.
"""
words = _WORD_RE.findall(query)
if len(words) <= 2:
return func.websearch_to_tsquery(language, query)
# Classify words before stemming: acronyms/words-with-digits are "priority"
# because they are almost always distinctive (DSP2, IA, GDPR, EBA, PSD2…).
# Regular lowercase words are candidates for the short-lexeme filter.
priority_lexemes: list[str] = []
normal_lexemes: list[str] = []
for word in words:
is_priority = word.isupper() or any(c.isdigit() for c in word)
row = session.execute(
text("SELECT * FROM ts_debug(:lang, :word) LIMIT 1"),
{"lang": language, "word": word},
).fetchone()
if row:
# ts_debug returns (alias, description, token, dictionaries, dictionary, lexemes)
lex_list = row[-1] # lexemes column
if lex_list:
if is_priority:
priority_lexemes.extend(lex_list)
else:
normal_lexemes.extend(lex_list)
# Deduplicate: priority lexemes first (kept regardless of length),
# then normal lexemes filtered to >= 3 chars to drop function-word
# residuals like "le", "si", "a" that slip through the stopword list.
seen: set[str] = set()
unique: list[str] = []
for lex in priority_lexemes:
if lex not in seen:
seen.add(lex)
unique.append(lex)
for lex in normal_lexemes:
if lex not in seen and len(lex) >= 3:
seen.add(lex)
unique.append(lex)
if not unique:
return func.websearch_to_tsquery(language, query)
max_lex = get_config().search.fts_max_lexemes
if len(unique) > max_lex:
unique = unique[:max_lex]
or_expr = " | ".join(unique)
return func.to_tsquery(language, or_expr)
[docs]
def get_session(self) -> Session:
"""Get a database session"""
return self.SessionLocal()
def _run_language_queries(
self,
active_pairs: Sequence[LanguageSearchPair],
run_one: Any,
) -> List[Any]:
"""Run per-language BM25 searches in order, with bounded parallelism."""
if not active_pairs:
return []
if len(active_pairs) == 1:
act_lang, pg_lang = active_pairs[0]
return [run_one(act_lang, pg_lang)]
max_workers = min(len(active_pairs), max(int(get_config().search.max_parallel_workers), 1))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(run_one, act_lang, pg_lang) for act_lang, pg_lang in active_pairs]
return [future.result() for future in futures]
[docs]
def close(self):
"""Close the database connection"""
self.engine.dispose()
[docs]
def health_check(self) -> bool:
"""Verify PostgreSQL connectivity"""
try:
with self.get_session() as session:
session.execute(text("SELECT 1"))
return True
except Exception:
return False
# === EMBEDDING WORKER HELPERS ===
[docs]
@staticmethod
def ensure_embedding_state_table(session: Session) -> None:
"""Ensure embedding_state table exists and supports all runtime object types."""
EmbeddingStateSQL.__table__.create(bind=session.get_bind(), checkfirst=True) # type: ignore[attr-defined]
session.execute(text("ALTER TABLE embedding_state DROP CONSTRAINT IF EXISTS check_embedding_state_object_type"))
session.execute(
text(
"ALTER TABLE embedding_state "
"ADD CONSTRAINT check_embedding_state_object_type "
"CHECK (object_type IN ('subdivision', 'chunk', 'act'))"
)
)
[docs]
@staticmethod
def purge_orphan_embedding_states(session: Session) -> int:
"""Delete embedding-state rows referencing missing acts/chunks/subdivisions."""
stale_subdivisions = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == "subdivision",
~EmbeddingStateSQL.object_id.in_(session.query(SubdivisionsSQL.id)),
)
.delete(synchronize_session=False)
)
stale_chunks = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == "chunk",
~EmbeddingStateSQL.object_id.in_(session.query(ChunksSQL.id)),
)
.delete(synchronize_session=False)
)
stale_acts = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == "act",
~EmbeddingStateSQL.object_id.in_(session.query(ActsSQL.id)),
)
.delete(synchronize_session=False)
)
return int(stale_subdivisions) + int(stale_chunks) + int(stale_acts)
[docs]
@staticmethod
def count_subdivisions(session: Session) -> int:
"""Return the total number of stored subdivisions."""
return int(session.query(SubdivisionsSQL).count())
[docs]
@staticmethod
def count_chunks(session: Session) -> int:
"""Return the total number of stored chunks."""
return int(session.query(ChunksSQL).count())
[docs]
@staticmethod
def count_embedded_subdivisions(
session: Session,
provider: str,
model_name: str,
vector_size: int,
) -> int:
"""Count subdivisions already embedded for one embedding runtime."""
return int(
session.query(EmbeddingStateSQL)
.join(
SubdivisionsSQL,
and_(
EmbeddingStateSQL.object_type == "subdivision",
EmbeddingStateSQL.object_id == SubdivisionsSQL.id,
),
)
.filter(
EmbeddingStateSQL.provider == provider,
EmbeddingStateSQL.model_name == model_name,
EmbeddingStateSQL.vector_size == vector_size,
)
.count()
)
[docs]
@staticmethod
def count_embedded_chunks(
session: Session,
provider: str,
model_name: str,
vector_size: int,
) -> int:
"""Count chunks already embedded for one embedding runtime."""
return int(
session.query(EmbeddingStateSQL)
.join(
ChunksSQL,
and_(
EmbeddingStateSQL.object_type == "chunk",
EmbeddingStateSQL.object_id == ChunksSQL.id,
),
)
.filter(
EmbeddingStateSQL.provider == provider,
EmbeddingStateSQL.model_name == model_name,
EmbeddingStateSQL.vector_size == vector_size,
)
.count()
)
[docs]
@staticmethod
def get_embedding_state_map(
session: Session,
object_type: str,
object_ids: list[int],
provider: str,
model_name: str,
vector_size: int,
) -> dict[int, str]:
"""Return {object_id: content_hash} for existing embeddings."""
if not object_ids:
return {}
rows = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == object_type,
EmbeddingStateSQL.object_id.in_(object_ids),
EmbeddingStateSQL.provider == provider,
EmbeddingStateSQL.model_name == model_name,
EmbeddingStateSQL.vector_size == vector_size,
)
.all()
)
state_map: dict[int, str] = {}
for row in rows:
object_id = row.object_id
if isinstance(object_id, int):
state_map[object_id] = str(row.content_hash)
return state_map
[docs]
@staticmethod
def upsert_embedding_states(session: Session, records: list[dict[str, Any]]) -> None:
"""Upsert embedding-state rows for a batch."""
if not records:
return
stmt = insert(EmbeddingStateSQL).values(records)
stmt = stmt.on_conflict_do_update(
index_elements=["object_type", "object_id", "provider", "model_name", "vector_size"],
set_={
"content_hash": stmt.excluded.content_hash,
"embedded_at": func.now(),
},
)
session.execute(stmt)
[docs]
@staticmethod
def list_acts_with_metadata(session: Session) -> list[Any]:
"""Return acts with eager-loaded metadata, versions, and summaries."""
return (
session.query(ActsSQL)
.options(
joinedload(ActsSQL.metadata_entries),
joinedload(ActsSQL.versions),
joinedload(ActsSQL.summaries),
)
.all()
)
[docs]
@staticmethod
def get_act_by_celex(session: Session, celex: str) -> Any | None:
"""Return the act identified by *celex*, if present."""
return session.query(ActsSQL).filter(ActsSQL.celex == celex).first()
[docs]
@staticmethod
def list_subdivisions_for_act(session: Session, act_id: int) -> list[Any]:
"""List subdivisions for one act ordered by sequence."""
return (
session.query(SubdivisionsSQL)
.filter(SubdivisionsSQL.act_id == act_id)
.order_by(SubdivisionsSQL.sequence_order)
.all()
)
[docs]
@staticmethod
def list_subdivisions_for_act_version(
session: Session,
act_id: int,
version_id: Optional[int],
) -> list[Any]:
"""List subdivisions for an act, optionally restricted to one version."""
query = (
session.query(SubdivisionsSQL)
.filter(SubdivisionsSQL.act_id == act_id)
.order_by(SubdivisionsSQL.sequence_order)
)
if version_id is not None:
query = query.filter(SubdivisionsSQL.version_id == version_id)
return query.all()
[docs]
@staticmethod
def get_current_version_for_act(session: Session, act_id: int) -> Any | None:
"""Return the current version row for an act, if any."""
return session.query(VersionsSQL).filter(VersionsSQL.act_id == act_id, VersionsSQL.is_current.is_(True)).first()
[docs]
@staticmethod
def get_act_summary(
session: Session,
*,
act_id: int,
language: str,
summary_kind: str = "canonical",
) -> Any | None:
"""Return one summary row for an act/language/kind triplet."""
return (
session.query(ActSummarySQL)
.filter(
ActSummarySQL.act_id == act_id,
ActSummarySQL.language == language,
ActSummarySQL.summary_kind == summary_kind,
)
.first()
)
[docs]
@staticmethod
def upsert_act_summary(session: Session, record: Dict[str, Any]) -> None:
"""Insert or update one act summary record."""
if not record:
return
stmt = insert(ActSummarySQL).values(record)
stmt = stmt.on_conflict_do_update(
index_elements=["act_id", "language", "summary_kind"],
set_={
"status": stmt.excluded.status,
"summary_text": stmt.excluded.summary_text,
"content_hash": stmt.excluded.content_hash,
"source_version_id": stmt.excluded.source_version_id,
"prompt_version": stmt.excluded.prompt_version,
"model_id": stmt.excluded.model_id,
"generated_at": stmt.excluded.generated_at,
"last_attempt_at": stmt.excluded.last_attempt_at,
"error_text": stmt.excluded.error_text,
"trace_jsonb": stmt.excluded.trace_jsonb,
"updated_at": func.now(),
},
)
session.execute(stmt)
[docs]
@staticmethod
def list_chunks_for_act(session: Session, act_id: int) -> list[tuple[Any, Any]]:
"""Return chunk rows paired with their subdivisions for one act."""
rows = (
session.query(ChunksSQL, SubdivisionsSQL)
.join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id)
.filter(SubdivisionsSQL.act_id == act_id)
.all()
)
return [(chunk, subdivision) for chunk, subdivision in rows]
# === CHUNKING / EXTRACTION WORKER HELPERS ===
[docs]
@staticmethod
def count_acts(session: Session) -> int:
"""Return the total number of acts."""
return int(session.query(ActsSQL).count())
[docs]
@staticmethod
def count_subdivisions_without_chunks(
session: Session,
min_content_length: int,
) -> int:
"""Count eligible subdivisions that still have no generated chunks."""
eligible = session.query(SubdivisionsSQL).filter(func.length(SubdivisionsSQL.content) >= min_content_length)
missing = eligible.filter(
~session.query(ChunksSQL).filter(ChunksSQL.subdivision_id == SubdivisionsSQL.id).exists()
)
return int(missing.count())
[docs]
@staticmethod
def list_chunk_ids_for_subdivision(session: Session, subdivision_id: int) -> list[int]:
"""Return all chunk identifiers attached to one subdivision."""
rows = session.query(ChunksSQL.id).filter(ChunksSQL.subdivision_id == subdivision_id).all()
return [int(row[0]) for row in rows]
[docs]
@staticmethod
def list_chunk_ids_for_act(session: Session, act_id: int) -> list[int]:
"""Return all chunk identifiers attached to one act."""
rows = (
session.query(ChunksSQL.id)
.join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id)
.filter(SubdivisionsSQL.act_id == act_id)
.all()
)
return [int(row[0]) for row in rows]
[docs]
@staticmethod
def subdivision_has_chunks(session: Session, subdivision_id: int) -> bool:
"""Return whether a subdivision already has at least one chunk."""
row = session.query(ChunksSQL.id).filter(ChunksSQL.subdivision_id == subdivision_id).first()
return row is not None
[docs]
@staticmethod
def delete_embedding_states_for_chunk_ids(session: Session, chunk_ids: list[int]) -> int:
"""Delete embedding-state rows for the given chunk identifiers."""
if not chunk_ids:
return 0
deleted = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == "chunk",
EmbeddingStateSQL.object_id.in_(chunk_ids),
)
.delete(synchronize_session=False)
)
return int(deleted)
[docs]
@staticmethod
def delete_embedding_states_for_act_ids(session: Session, act_ids: list[int]) -> int:
"""Delete embedding-state rows for the given act identifiers."""
if not act_ids:
return 0
deleted = (
session.query(EmbeddingStateSQL)
.filter(
EmbeddingStateSQL.object_type == "act",
EmbeddingStateSQL.object_id.in_(act_ids),
)
.delete(synchronize_session=False)
)
return int(deleted)
[docs]
@staticmethod
def delete_chunks_for_subdivision(session: Session, subdivision_id: int) -> int:
"""Delete all chunks belonging to one subdivision."""
deleted = (
session.query(ChunksSQL)
.filter(ChunksSQL.subdivision_id == subdivision_id)
.delete(synchronize_session=False)
)
return int(deleted)
[docs]
@staticmethod
def delete_chunks_for_act(session: Session, act_id: int) -> int:
"""Delete all chunks belonging to one act."""
subdivision_ids_rows = session.query(SubdivisionsSQL.id).filter(SubdivisionsSQL.act_id == act_id).all()
subdivision_ids = [int(row[0]) for row in subdivision_ids_rows]
if not subdivision_ids:
return 0
deleted = (
session.query(ChunksSQL)
.filter(ChunksSQL.subdivision_id.in_(subdivision_ids))
.delete(synchronize_session=False)
)
return int(deleted)
[docs]
@staticmethod
def insert_chunk_records(session: Session, records: list[dict[str, Any]]) -> None:
"""Insert a batch of precomputed chunk rows."""
if not records:
return
stmt = insert(ChunksSQL)
session.execute(stmt, records)
[docs]
@staticmethod
def clear_chunks(session: Session) -> int:
"""Delete every chunk row."""
deleted = session.query(ChunksSQL).delete(synchronize_session=False)
return int(deleted)
[docs]
@staticmethod
def clear_embedding_states(session: Session) -> int:
"""Delete every embedding-state row."""
deleted = session.query(EmbeddingStateSQL).delete(synchronize_session=False)
return int(deleted)
[docs]
@staticmethod
def clear_act_relations(session: Session) -> int:
"""Delete every extracted act relation."""
deleted = session.query(ActRelationsSQL).delete(synchronize_session=False)
return int(deleted)
[docs]
@staticmethod
def clear_act_summaries(session: Session) -> int:
"""Delete every persisted act summary."""
deleted = session.query(ActSummarySQL).delete(synchronize_session=False)
return int(deleted)
# === FULL-TEXT SEARCH (BM25-like) ===
[docs]
def search_bm25(
self,
query: str,
top_k: Optional[int] = None,
language: Optional[str] = None,
filter_conditions: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Full-text search using PostgreSQL native ts_rank_cd (BM25-like ranking).
Returns subdivisions with relevance scores for hybrid RAG retrieval.
Searches across all active languages (FR + EN) in parallel and merges
results by best score, unless a language override is provided via the
``language`` parameter or ``filter_conditions['language']``.
Args:
query: Search query
top_k: Number of results to return (defaults to config.search.default_limit)
language: PostgreSQL text-search config override (e.g. 'french'). When set,
only that language is searched. Defaults to multilingual UNION.
filter_conditions: Optional filters (e.g., {"act_id": 123, "celex": "32016R0679"}).
``filter_conditions['language']`` restricts results to a single act language.
"""
config = get_config()
if top_k is None:
top_k = config.search.default_limit
# Determine which (act_language, pg_ts_config) pairs to search.
# A language override collapses the search to a single language.
lang_override = filter_conditions.get("language") if filter_conditions else None
if language is not None:
# Caller passed an explicit pg_ts config string (e.g. 'french') — honour it.
active_pairs: list[LanguageSearchPair] = [(None, language)]
elif lang_override is not None:
pg_lang = _PG_TS_LANG.get(str(lang_override), _DEFAULT_PG_TS_LANG)
active_pairs = [(lang_override, pg_lang)]
else:
active_pairs = [(lc, _PG_TS_LANG.get(lc, _DEFAULT_PG_TS_LANG)) for lc in _ACTIVE_LANGS]
normalization = config.search.bm25_normalization
def _run_subdiv_search(
session: Session, tsquery: Any, pg_lang: str, act_lang_filter: Optional[str]
) -> List[Any]:
col_suffix = _PG_LANG_TO_COL_SUFFIX.get(pg_lang)
fts_expr = (
sa_column(f"fts_{col_suffix}", TSVECTOR)
if col_suffix
else func.to_tsvector(pg_lang, SubdivisionsSQL.content)
)
q = (
session.query(
SubdivisionsSQL,
ActsSQL,
VersionsSQL,
func.ts_rank_cd(fts_expr, tsquery, normalization).label("rank"),
)
.join(ActsSQL, SubdivisionsSQL.act_id == ActsSQL.id)
.outerjoin(VersionsSQL, SubdivisionsSQL.version_id == VersionsSQL.id)
.filter(fts_expr.op("@@")(tsquery))
)
if act_lang_filter is not None:
q = q.filter(ActsSQL.language == act_lang_filter)
if filter_conditions:
if "act_id" in filter_conditions:
q = q.filter(SubdivisionsSQL.act_id == filter_conditions["act_id"])
if "subdivision_type" in filter_conditions:
q = q.filter(SubdivisionsSQL.subdivision_type == filter_conditions["subdivision_type"])
if "subdivision_number" in filter_conditions:
q = q.filter(SubdivisionsSQL.number == filter_conditions["subdivision_number"])
if "celex" in filter_conditions:
q = q.filter(ActsSQL.celex == filter_conditions["celex"])
if "act_type" in filter_conditions:
q = q.filter(ActsSQL.act_type == filter_conditions["act_type"])
return q.order_by(text("rank DESC")).limit(top_k).all()
def _search_language(act_lang: Optional[str], pg_lang: str) -> List[Any]:
with self.get_session() as session:
rows = _run_subdiv_search(session, func.websearch_to_tsquery(pg_lang, query), pg_lang, act_lang)
if rows:
return rows
return _run_subdiv_search(session, self._build_or_tsquery(session, pg_lang, query), pg_lang, act_lang)
best: Dict[int, Any] = {}
for rows in self._run_language_queries(active_pairs, _search_language):
for subdiv, act, version, rank in rows:
score = float(rank)
if subdiv.id not in best or score > best[subdiv.id]["score"]:
best[subdiv.id] = {
"subdivision": subdiv,
"act": act,
"version": version,
"score": score,
}
return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:top_k]
# === CONVERSATION HELPERS ===
[docs]
@staticmethod
def get_conversation(session: Session, conversation_id: str) -> Optional[ConversationSQL]:
"""Return a conversation session by identifier."""
return session.query(ConversationSQL).filter(ConversationSQL.id == conversation_id).first()
[docs]
@staticmethod
def create_conversation(
session: Session,
conversation_id: str,
title: str,
user_id: Optional[str] = None,
) -> ConversationSQL:
"""Create and flush a new conversation session."""
conv = ConversationSQL(id=conversation_id, title=title, user_id=user_id)
session.add(conv)
session.flush()
return conv
[docs]
@staticmethod
def list_conversations(
session: Session,
user_id: Optional[str],
limit: int = 50,
) -> List[ConversationSQL]:
"""List the most recent conversations for one user."""
return (
session.query(ConversationSQL)
.filter(ConversationSQL.user_id == user_id)
.order_by(ConversationSQL.updated_at.desc())
.limit(limit)
.all()
)
[docs]
@staticmethod
def delete_conversation(session: Session, conversation_id: str) -> bool:
"""Delete one conversation and report whether a row was removed."""
count = session.query(ConversationSQL).filter(ConversationSQL.id == conversation_id).delete()
return count > 0
[docs]
@staticmethod
def add_conversation_message(
session: Session,
message_id: str,
conversation_id: str,
role: str,
content: str,
query_id: Optional[str] = None,
mode: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> ConversationMessageSQL:
"""Append and flush one message inside a conversation."""
msg = ConversationMessageSQL(
id=message_id,
conversation_id=conversation_id,
role=role,
content=content,
query_id=query_id,
mode=mode,
metadata_=metadata,
)
session.add(msg)
session.flush()
return msg
[docs]
@staticmethod
def get_conversation_messages(
session: Session,
conversation_id: str,
limit: int = 20,
) -> List[ConversationMessageSQL]:
"""Return the latest conversation messages in chronological order."""
rows = (
session.query(ConversationMessageSQL)
.filter(ConversationMessageSQL.conversation_id == conversation_id)
.order_by(ConversationMessageSQL.created_at.desc())
.limit(limit)
.all()
)
return list(reversed(rows))
[docs]
@staticmethod
def touch_conversation(session: Session, conversation_id: str) -> None:
"""Refresh the ``updated_at`` timestamp for a conversation."""
session.query(ConversationSQL).filter(
ConversationSQL.id == conversation_id,
).update({"updated_at": func.now()})
# === FULL-TEXT SEARCH (BM25-like — chunks) ===
[docs]
def search_bm25_chunks(
self,
query: str,
top_k: Optional[int] = None,
language: Optional[str] = None,
filter_conditions: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Full-text search using PostgreSQL ts_rank_cd over chunk content.
Returns chunks with relevance scores for fine-grained retrieval.
Searches across all active languages (FR + EN) in parallel and merges
results by best score, unless a language override is provided via the
``language`` parameter or ``filter_conditions['language']``.
"""
config = get_config()
if top_k is None:
top_k = config.search.default_limit
lang_override = filter_conditions.get("language") if filter_conditions else None
if language is not None:
active_pairs: list[LanguageSearchPair] = [(None, language)]
elif lang_override is not None:
pg_lang = _PG_TS_LANG.get(str(lang_override), _DEFAULT_PG_TS_LANG)
active_pairs = [(lang_override, pg_lang)]
else:
active_pairs = [(lc, _PG_TS_LANG.get(lc, _DEFAULT_PG_TS_LANG)) for lc in _ACTIVE_LANGS]
normalization = config.search.bm25_normalization
def _run_chunk_search(
session: Session, tsquery: Any, pg_lang: str, act_lang_filter: Optional[str]
) -> List[Any]:
col_suffix = _PG_LANG_TO_COL_SUFFIX.get(pg_lang)
fts_expr = (
literal_column(f"chunks.fts_{col_suffix}", TSVECTOR)
if col_suffix
else func.to_tsvector(pg_lang, ChunksSQL.content)
)
q = (
session.query(
ChunksSQL,
SubdivisionsSQL,
ActsSQL,
VersionsSQL,
func.ts_rank_cd(fts_expr, tsquery, normalization).label("rank"),
)
.join(SubdivisionsSQL, ChunksSQL.subdivision_id == SubdivisionsSQL.id)
.join(ActsSQL, SubdivisionsSQL.act_id == ActsSQL.id)
.outerjoin(VersionsSQL, SubdivisionsSQL.version_id == VersionsSQL.id)
.filter(fts_expr.op("@@")(tsquery))
)
if act_lang_filter is not None:
q = q.filter(ActsSQL.language == act_lang_filter)
if filter_conditions:
if "act_id" in filter_conditions:
q = q.filter(SubdivisionsSQL.act_id == filter_conditions["act_id"])
if "subdivision_type" in filter_conditions:
q = q.filter(SubdivisionsSQL.subdivision_type == filter_conditions["subdivision_type"])
if "subdivision_number" in filter_conditions:
q = q.filter(SubdivisionsSQL.number == filter_conditions["subdivision_number"])
if "celex" in filter_conditions:
q = q.filter(ActsSQL.celex == filter_conditions["celex"])
if "subdivision_id" in filter_conditions:
q = q.filter(SubdivisionsSQL.id == filter_conditions["subdivision_id"])
if "chunk_id" in filter_conditions:
q = q.filter(ChunksSQL.id == filter_conditions["chunk_id"])
if "act_type" in filter_conditions:
q = q.filter(ActsSQL.act_type == filter_conditions["act_type"])
return q.order_by(text("rank DESC")).limit(top_k).all()
def _search_language(act_lang: Optional[str], pg_lang: str) -> List[Any]:
with self.get_session() as session:
rows = _run_chunk_search(session, func.websearch_to_tsquery(pg_lang, query), pg_lang, act_lang)
if rows:
return rows
return _run_chunk_search(session, self._build_or_tsquery(session, pg_lang, query), pg_lang, act_lang)
best: Dict[int, Any] = {}
for rows in self._run_language_queries(active_pairs, _search_language):
for chunk, subdiv, act, version, rank in rows:
score = float(rank)
if chunk.id not in best or score > best[chunk.id]["score"]:
best[chunk.id] = {
"chunk": chunk,
"subdivision": subdiv,
"act": act,
"version": version,
"score": score,
}
return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:top_k]