"""POST /search endpoint — hybrid/semantic/lexical search handler."""
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, cast
from fastapi import APIRouter, Depends, HTTPException
from lalandre_core.config import get_config
from lalandre_core.utils import as_dict
from lalandre_rag.response import build_source_trace
from rag_service.bootstrap import RagComponents
from rag_service.metrics import observe_provider_error, observe_search_request
from rag_service.models import SearchRequest, SearchResponse, SearchResult
from rag_service.routers._deps import get_components
logger = logging.getLogger(__name__)
router = APIRouter()
_ALLOWED_SEARCH_GRANULARITIES: set[Optional[str]] = {None, "chunks", "all"}
_VALID_SEARCH_MODES = ["semantic", "lexical", "hybrid"]
[docs]
def apply_config_defaults_search(search_request: SearchRequest) -> None:
"""Fill None fields with config defaults (mutates in place)."""
cfg = get_config().search
if search_request.mode is None:
search_request.mode = cfg.default_search_mode
if search_request.top_k is None:
search_request.top_k = cfg.default_limit
if search_request.granularity is None:
search_request.granularity = cfg.default_granularity
[docs]
@router.post("/search", response_model=SearchResponse)
async def search(
search_request: SearchRequest,
components: RagComponents = Depends(get_components),
) -> SearchResponse:
"""
Hybrid search (semantic + lexical)
Supports different search modes: semantic, lexical, or hybrid
`score_threshold` is interpreted on a normalized [0,1] scale across modes.
Lexical BM25 and hybrid RRF scores are normalized before filtering.
"""
search_id = str(uuid.uuid4())
started_at = time.perf_counter()
try:
retrieval_service = components.retrieval_service
context_service = components.context_service
apply_config_defaults_search(search_request)
assert search_request.top_k is not None # set by apply_config_defaults_search
if search_request.mode not in _VALID_SEARCH_MODES:
raise HTTPException(
status_code=400,
detail=f"Invalid mode '{search_request.mode}'. Must be one of: {', '.join(_VALID_SEARCH_MODES)}",
)
actual_mode = search_request.mode
if search_request.granularity not in _ALLOWED_SEARCH_GRANULARITIES:
raise HTTPException(
status_code=400,
detail="Invalid granularity. Must be one of: chunks, all.",
)
effective_granularity = search_request.granularity
if actual_mode == "lexical":
effective_granularity = "chunks"
if actual_mode == "semantic":
results = retrieval_service.semantic_only(
query=search_request.query if not search_request.query_embedding else None,
query_vector=search_request.query_embedding,
top_k=search_request.top_k,
score_threshold=search_request.score_threshold,
filters=search_request.filters,
granularity=effective_granularity,
embedding_preset=search_request.embedding_preset,
)
elif actual_mode == "lexical":
results = retrieval_service.lexical_only(
query=cast(str, search_request.query),
top_k=search_request.top_k,
score_threshold=search_request.score_threshold,
filters=search_request.filters,
)
else: # hybrid
if search_request.query_embedding:
results = retrieval_service.hybrid_with_embedding(
query=cast(str, search_request.query),
query_vector=search_request.query_embedding,
top_k=search_request.top_k,
score_threshold=search_request.score_threshold,
filters=search_request.filters,
granularity=effective_granularity,
embedding_preset=search_request.embedding_preset,
)
else:
results = retrieval_service.retrieve(
query=cast(str, search_request.query),
top_k=search_request.top_k,
score_threshold=search_request.score_threshold,
filters=search_request.filters,
granularity=effective_granularity,
embedding_preset=search_request.embedding_preset,
)
context_slices = context_service.enrich_results(
results,
include_relations=False,
include_subjects=False,
)
response_results: List[SearchResult] = []
for doc in context_slices:
metadata: Dict[str, Any] = dict(as_dict(doc.doc.payload))
metadata.pop("content", None)
metadata.setdefault("act_id", doc.act.act_id)
metadata.setdefault("act_type", doc.act.act_type)
metadata.setdefault("subdivision_type", doc.doc.subdivision_type)
metadata.setdefault("sequence_order", doc.doc.sequence_order)
metadata.setdefault("url", doc.act.url_eurlex or "")
metadata.setdefault("source_kind", doc.doc.source_kind)
metadata.setdefault("content_length", len(doc.content) if doc.content else 0)
response_results.append(
SearchResult(
celex=doc.act.celex,
subdivision_id=doc.doc.subdivision_id,
chunk_id=doc.doc.chunk_id,
chunk_index=doc.doc.chunk_index,
content=doc.content if search_request.include_full_content else "",
score=doc.score,
metadata=metadata,
trace={
"search_id": search_id,
**build_source_trace(doc.trace or metadata),
},
)
)
response = SearchResponse(
search_id=search_id,
results=response_results,
total=len(response_results),
mode=actual_mode,
)
observe_search_request(
mode=actual_mode,
granularity=effective_granularity,
top_k=search_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome="success",
)
return response
except HTTPException as exc:
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome="client_error" if exc.status_code < 500 else "server_error",
)
if exc.status_code >= 500:
observe_provider_error(
mode=search_request.mode,
stage="search_request",
exc_or_reason=exc.detail,
)
raise
except Exception as e:
observe_search_request(
mode=search_request.mode,
granularity=search_request.granularity,
top_k=search_request.top_k,
duration_seconds=time.perf_counter() - started_at,
outcome="server_error",
)
observe_provider_error(
mode=search_request.mode,
stage="search_request",
exc_or_reason=e,
)
logger.exception("Search failed")
raise HTTPException(status_code=500, detail="An error occurred while processing your search")