"""
Bootstrap for RAG service
"""
import logging
from dataclasses import dataclass
from threading import Lock
from typing import Any, Callable, Dict, Optional
from lalandre_core.config import get_config
from lalandre_core.embedding_presets import (
get_default_embedding_preset,
list_embedding_presets,
)
from lalandre_core.linking import LegalEntityLinker
from lalandre_core.repositories.common import PayloadBuilder
from lalandre_core.utils.api_key_pool import APIKeyPool
from lalandre_db_neo4j import Neo4jRepository
from lalandre_db_postgres import PostgresRepository
from lalandre_db_qdrant import QdrantRepository
from lalandre_embedding import EmbeddingService
from lalandre_rag import RAGService
from lalandre_rag.graph import GraphRAGService
from lalandre_rag.linker_factory import (
build_external_detector,
)
from lalandre_rag.linker_factory import (
build_linker as build_entity_linker,
)
from lalandre_rag.retrieval import RetrievalService
from lalandre_rag.retrieval.context.service import ContextService
from lalandre_rag.retrieval.rerank_service import RerankConfig, RerankService
from lalandre_rag.summaries import ActSummaryService
from rag_service.conversation import ConversationManager
logger = logging.getLogger(__name__)
_graph_init_lock = Lock()
[docs]
@dataclass
class RagComponents:
"""Long-lived dependencies exposed by the rag-service bootstrap."""
rag_service: RAGService
retrieval_service: RetrievalService
context_service: ContextService
graph_rag_service: Optional[GraphRAGService]
conversation_manager: ConversationManager
pg_repo: PostgresRepository
key_pool: Optional[APIKeyPool] = None
entity_linker: Optional[LegalEntityLinker] = None
external_detector: Optional[Callable[[str], Any]] = None
def _require_search_config(value: Any, field_name: str) -> Any:
"""Require an explicit search config value (no silent fallback)."""
if value is None:
logger.error(
"Missing required search config: search.%s. Set it in app_config.yaml.",
field_name,
)
raise ValueError(f"Missing required search config: search.{field_name}")
return value
[docs]
def ensure_graph_rag_service(components: "RagComponents") -> Optional[GraphRAGService]:
"""
Initialize GraphRAGService lazily if it is not available yet.
This allows graph mode to recover when Neo4j becomes available after startup.
"""
graph_rag_service = components.graph_rag_service
if graph_rag_service is not None:
return graph_rag_service
retrieval_service = components.retrieval_service
with _graph_init_lock:
# Double-check inside lock to avoid duplicate initialization.
graph_rag_service = components.graph_rag_service
if graph_rag_service is not None:
return graph_rag_service
qdrant_repo_chunks = retrieval_service.qdrant_repos.get("chunks")
if qdrant_repo_chunks is None:
logger.warning("Graph RAG unavailable: chunks repository missing")
return None
try:
config = get_config()
neo4j_repo = Neo4jRepository(config.graph)
graph_rag_service = GraphRAGService(
neo4j_repo=neo4j_repo,
qdrant_repo=qdrant_repo_chunks,
embedding_service=retrieval_service.embedding_service,
key_pool=components.key_pool,
)
components.graph_rag_service = graph_rag_service
logger.info("Graph RAG initialized successfully")
return graph_rag_service
except Exception as exc:
logger.warning("Graph RAG remains unavailable: %s", exc)
return None
[docs]
def init_components():
"""Initialize RAG service components"""
logger.info("RAG bootstrap: loading config")
config = get_config()
# Initialize repositories
logger.info("RAG bootstrap: initializing PostgreSQL repository")
pg_repo = PostgresRepository(config.database.connection_string)
logger.info("RAG bootstrap: initializing API key pool for RAG (keys 1-%d)", config.generation.key_pool_max)
try:
key_pool_max = config.generation.key_pool_max
key_pool: Optional[APIKeyPool] = APIKeyPool.from_env("MISTRAL_API_KEY", max_keys=key_pool_max)
logger.info(
"RAG bootstrap: key pool ready with %d key(s) (max=%d, remaining for workers)",
len(key_pool),
key_pool_max,
)
except ValueError:
key_pool = None
logger.warning("RAG bootstrap: no key pool available, using single key")
preset_embedding_services: Dict[str, EmbeddingService] = {}
qdrant_repos: Dict[str, QdrantRepository] = {}
logger.info("RAG bootstrap: initializing preset embedding services")
for preset in list_embedding_presets(enabled_only=True):
try:
svc = EmbeddingService(
provider=preset.provider,
model_name=preset.model_name,
device=preset.device,
key_pool=key_pool,
)
preset_embedding_services[preset.preset_id] = svc
qdrant_repos[f"chunks__{preset.preset_id}"] = QdrantRepository.from_embedding_service_with_auto_collection(
embedding_service=svc,
base_collection_name=config.vector.collection_chunks,
)
qdrant_repos[f"acts__{preset.preset_id}"] = QdrantRepository.from_embedding_service_with_auto_collection(
embedding_service=svc,
base_collection_name=config.vector.collection_acts,
)
logger.info("RAG bootstrap: preset '%s' ready (provider=%s)", preset.preset_id, preset.provider)
except Exception as exc:
logger.warning("RAG bootstrap: failed to init preset '%s': %s", preset.preset_id, exc)
default_preset = get_default_embedding_preset()
embedding_service = preset_embedding_services.get(default_preset.preset_id)
if embedding_service is None:
raise ValueError(f"Default embedding preset '{default_preset.preset_id}' is not available")
qdrant_repo_chunks = qdrant_repos.get(f"chunks__{default_preset.preset_id}")
qdrant_repo_acts = qdrant_repos.get(f"acts__{default_preset.preset_id}")
if qdrant_repo_chunks is None or qdrant_repo_acts is None:
raise ValueError(f"Default preset repositories missing for '{default_preset.preset_id}'")
# chunks: token-bounded retrieval (1 vector/chunk)
# acts: document-level routing (1 mean-pooled vector/act)
qdrant_repos["chunks"] = qdrant_repo_chunks
qdrant_repos["acts"] = qdrant_repo_acts
logger.info("RAG bootstrap: initializing reranker")
rerank_service_url = config.search.rerank_service_url
reranker = RerankService(
RerankConfig(
model_name=config.search.rerank_model,
device=config.search.rerank_device,
batch_size=config.search.rerank_batch_size,
max_candidates=config.search.rerank_max_candidates,
max_chars=config.search.rerank_max_chars,
enabled=config.search.rerank_enabled,
cache_dir=config.search.rerank_cache_dir or config.models_cache_dir,
rerank_service_url=rerank_service_url,
service_timeout_seconds=config.search.rerank_service_timeout_seconds,
fallback_to_skip=config.search.rerank_fallback_to_skip,
circuit_failure_threshold=config.search.rerank_circuit_failure_threshold,
circuit_cooldown_seconds=config.search.rerank_circuit_cooldown_seconds,
)
)
if config.search.rerank_enabled and not rerank_service_url:
# Only load the local model when no HTTP service is configured
reranker.load()
payload_builder = PayloadBuilder()
lexical_weight = _require_search_config(
config.search.lexical_weight,
"lexical_weight",
)
semantic_weight = _require_search_config(
config.search.semantic_weight,
"semantic_weight",
)
fusion_method = _require_search_config(
config.search.fusion_method,
"fusion_method",
)
logger.info("RAG bootstrap: initializing retrieval service")
retrieval_service = RetrievalService(
qdrant_repos=qdrant_repos,
pg_repo=pg_repo,
embedding_service=embedding_service,
reranker=reranker,
payload_builder=payload_builder,
search_language=config.search.fulltext_language,
lexical_weight=lexical_weight,
semantic_weight=semantic_weight,
fusion_method=fusion_method,
preset_embedding_services=preset_embedding_services,
)
logger.info("RAG bootstrap: initializing context service")
context_service = ContextService(pg_repo=pg_repo)
logger.info("RAG bootstrap: initializing graph service lazily")
# Resolve graph service before building RAGService so it can be injected into HybridMode.
# We create a temporary minimal RagComponents to reuse ensure_graph_rag_service.
_pre = RagComponents(
rag_service=None, # type: ignore[arg-type]
retrieval_service=retrieval_service,
context_service=None, # type: ignore[arg-type]
graph_rag_service=None,
conversation_manager=None, # type: ignore[arg-type]
pg_repo=pg_repo,
key_pool=key_pool,
)
ensure_graph_rag_service(_pre)
graph_rag_service = _pre.graph_rag_service
logger.info("RAG bootstrap: initializing legal entity linker")
entity_linker: Optional[LegalEntityLinker]
try:
extraction_cfg = config.extraction
entity_linker = build_entity_linker(
pg_repo,
fuzzy_threshold=extraction_cfg.entity_linker_fuzzy_threshold,
fuzzy_min_gap=extraction_cfg.entity_linker_fuzzy_min_gap,
fuzzy_limit=extraction_cfg.entity_linker_fuzzy_limit,
min_alias_chars=extraction_cfg.entity_linker_min_alias_chars,
)
except Exception as exc:
entity_linker = None
logger.warning(
"RAG bootstrap: entity linker unavailable (%s) — prose linking disabled",
exc,
)
external_detector: Optional[Callable[[str], Any]] = None
if entity_linker is not None:
try:
external_detector = build_external_detector(entity_linker)
except Exception:
logger.warning(
"RAG bootstrap: external NER detector init failed — keeping regex-only linking",
exc_info=True,
)
logger.info("RAG bootstrap: initializing RAG service")
act_summary_service = ActSummaryService(
pg_repo=pg_repo,
model_id=ActSummaryService.build_runtime_model_id(),
)
rag_service = RAGService(
retrieval_service=retrieval_service,
context_service=context_service,
graph_rag_service=graph_rag_service,
key_pool=key_pool,
act_summary_service=act_summary_service,
entity_linker=entity_linker,
external_detector=external_detector,
)
logger.info("RAG bootstrap: initializing conversation manager")
conversation_manager = ConversationManager(pg_repo=pg_repo, llm=rag_service.llm)
components = RagComponents(
rag_service=rag_service,
retrieval_service=retrieval_service,
context_service=context_service,
graph_rag_service=graph_rag_service,
conversation_manager=conversation_manager,
pg_repo=pg_repo,
key_pool=key_pool,
entity_linker=entity_linker,
external_detector=external_detector,
)
logger.info("RAG bootstrap: startup completed")
return components