Source code for lalandre_db_qdrant.repository

"""
Qdrant repository implementation
similarity search and vector retrieval
"""

import logging
import time
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

from lalandre_core.config import get_config
from lalandre_core.repositories.base import BaseRepository
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Distance, VectorParams

from .models import SearchResult, VectorPoint

logger = logging.getLogger(__name__)


[docs] class QdrantRepository(BaseRepository): """ Repository for Qdrant vector database operations Handles all low-level operations including collection management and data ingestion Collection naming: {base}_{model}_{dimension} Examples: - chunk_embeddings_mistral_1024 - chunk_embeddings_gte_large_1024 - chunk_embeddings_bge_m3_1024 - chunk_embeddings_e5_large_1024 This allows multiple embedding models to coexist without confusion. """
[docs] @staticmethod def make_collection_name(base_name: str, model_name: str, dimension: int) -> str: """Generate collection name: {base}_{model}_{dimension} Args: base_name: Base collection name model_name: Model name (e.g., 'thenlper/gte-large') dimension: Vector dimension (e.g., 1024) Returns: Collection name like 'chunk_embeddings_gte_large_1024' Example: >>> QdrantRepository.make_collection_name( ... 'chunk_embeddings', ... 'thenlper/gte-large', ... 1024 ... ) 'chunk_embeddings_gte_large_1024' """ # Simplify model name for collection naming model_suffix = model_name.lower() model_suffix = model_suffix.replace("/", "_").replace("-", "_") # Keep only the meaningful part (last 2 components) parts = model_suffix.split("_") if len(parts) > 2: model_suffix = "_".join(parts[-2:]) return f"{base_name}_{model_suffix}_{dimension}"
def __init__( self, host: Optional[str] = None, port: Optional[int] = None, collection_name: Optional[str] = None, vector_size: Optional[int] = None, api_key: Optional[str] = None, use_https: Optional[bool] = None, ): config = get_config() # Determine host and port final_host = host or config.vector.host final_port = port or config.vector.port final_api_key = api_key or config.vector.api_key # HTTPS toggle: # explicit constructor arg > config.vector.use_https > False (local/dev default). use_ssl = use_https if use_https is not None else config.vector.use_https # Keep API keys off insecure HTTP to avoid noisy/inconsistent behavior. api_key_for_client = final_api_key if use_ssl else None if not use_ssl and final_api_key: logger.info("Qdrant running over HTTP: API key is ignored.") self._client_kwargs: dict[str, Any] = { "host": final_host, "port": final_port, "https": use_ssl, "api_key": api_key_for_client, "timeout": config.vector.timeout, } # Single initialization path. self.client = QdrantClient(**self._client_kwargs) if not collection_name: raise ValueError("Qdrant collection name is required. Pass collection_name explicitly.") self.collection_name: str = collection_name # Vector size: explicit param > config > auto-detect from collection self._vector_size: int | None = vector_size if vector_size is not None else config.vector.vector_size @staticmethod def _extract_vector_size_from_collection_info(collection_info: Any) -> int | None: """Extract vector size from Qdrant collection info across supported vector configs.""" config = getattr(collection_info, "config", None) params = getattr(config, "params", None) vectors = getattr(params, "vectors", None) if isinstance(vectors, VectorParams): return int(vectors.size) if isinstance(vectors, dict): named_vectors = cast(dict[object, object], vectors) for vector_params in named_vectors.values(): if isinstance(vector_params, VectorParams): return int(vector_params.size) return None @property def vector_size(self) -> int: """Get vector size - auto-detects from existing collection if not set""" if self._vector_size is None: # Try to detect from existing collection try: if self.collection_exists(): collection_info = self.client.get_collection(self.collection_name) detected_size = self._extract_vector_size_from_collection_info(collection_info) if detected_size is not None: self._vector_size = detected_size msg = ( f"[INFO] Auto-detected vector size: {self._vector_size} " f"from collection '{self.collection_name}'" ) logger.info(msg) except Exception as e: logger.warning("Could not auto-detect vector size: %s", e) if self._vector_size is None: self._vector_size = get_config().vector.vector_size return int(self._vector_size) @vector_size.setter def vector_size(self, value: int): """Set vector size explicitly""" self._vector_size = value
[docs] def close(self): """Close Qdrant client connection""" if hasattr(self.client, "close"): self.client.close()
[docs] def health_check(self) -> bool: """Verify Qdrant connectivity""" try: self.client.get_collections() return True except Exception: return False
# === COLLECTION MANAGEMENT ===
[docs] def collection_exists(self) -> bool: """Check if collection exists""" try: collections = self.client.get_collections().collections return any(c.name == self.collection_name for c in collections) except Exception: return False
[docs] def create_collection(self, recreate: bool = False, distance: Distance = Distance.COSINE) -> bool: """ Args: recreate: If True, delete existing collection and recreate distance: Distance metric (COSINE, EUCLID, DOT) Returns: True if collection was created, False if already exists """ try: # Check if collection exists exists = self.collection_exists() if exists and recreate: logger.info("Deleting existing collection: %s", self.collection_name) self.client.delete_collection(collection_name=self.collection_name) exists = False if not exists: logger.info( "Creating collection: %s (vector size: %d)", self.collection_name, self.vector_size, ) self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=self.vector_size, distance=distance), ) logger.info( "Collection created successfully with %d-dimensional vectors", self.vector_size, ) self._reconcile_standard_indexes() return True else: # Verify dimension matches collection_info = self.client.get_collection(self.collection_name) existing_size = self._extract_vector_size_from_collection_info(collection_info) if existing_size is None: logger.info(f"Collection exists: {self.collection_name} (vector size: unknown)") logger.warning("WARNING: Could not determine vector size from existing collection metadata.") return False if existing_size != self.vector_size: logger.warning("Collection dimension mismatch!") logger.warning(f" Expected: {self.vector_size}, Found: {existing_size}") logger.warning(" To fix: Set recreate=True or use a different collection name") else: logger.info(f"Collection exists: {self.collection_name} (vector size: {existing_size})") self._reconcile_standard_indexes() return False except Exception as e: logger.exception("Error creating collection: %s", e) raise
[docs] @classmethod def from_embedding_service_with_auto_collection( cls, embedding_service: Any, base_collection_name: Optional[str] = None ) -> "QdrantRepository": """Create repository with automatic collection naming based on embedding model Format: {base}_{model}_{dimension} Args: embedding_service: EmbeddingService instance base_collection_name: Base collection name Returns: QdrantRepository with auto-generated collection name Example: >>> embedding_service = EmbeddingService(provider="local", model_name="thenlper/gte-large") >>> repo = QdrantRepository.from_embedding_service_with_auto_collection(embedding_service) >>> # Collection: 'chunk_embeddings_gte_large_1024' """ if not base_collection_name: raise ValueError("Base collection name is required. Pass base_collection_name explicitly.") base_name = base_collection_name vector_size = embedding_service.get_vector_size() model_name = embedding_service.model_name collection_name = cls.make_collection_name(base_name, model_name, vector_size) logger.info("Auto-generated collection: %s", collection_name) return cls(collection_name=collection_name, vector_size=vector_size)
[docs] def create_payload_index(self, field_name: str, field_schema: models.PayloadSchemaType) -> bool: """ for efficient filtering Args: field_name: Name of the payload field to index field_schema: Schema type (KEYWORD, INTEGER, FLOAT, etc.) Returns: True if index created successfully """ try: self.client.create_payload_index( collection_name=self.collection_name, field_name=field_name, field_schema=field_schema ) logger.info("Created payload index for field: %s", field_name) return True except Exception as e: logger.warning("Error creating payload index for %s: %s", field_name, e) return False
[docs] def setup_standard_indexes(self) -> None: """Create standard payload indexes used for legal document filtering.""" for field_name, field_schema in self._standard_indexes(): self.create_payload_index(field_name, field_schema)
def _reconcile_standard_indexes(self) -> None: try: existing = self.client.get_collection(self.collection_name).payload_schema or {} except Exception as e: logger.warning("Could not read payload_schema for %s: %s", self.collection_name, e) existing = {} missing = [name for name, _ in self._standard_indexes() if name not in existing] if not missing: return logger.info("Creating missing payload indexes on %s: %s", self.collection_name, missing) schema_by_name = dict(self._standard_indexes()) for name in missing: self.create_payload_index(name, schema_by_name[name]) @staticmethod def _standard_indexes() -> List[Tuple[str, models.PayloadSchemaType]]: return [ ("act_id", models.PayloadSchemaType.INTEGER), ("version_id", models.PayloadSchemaType.INTEGER), ("chunk_id", models.PayloadSchemaType.INTEGER), ("celex", models.PayloadSchemaType.KEYWORD), ("eli", models.PayloadSchemaType.KEYWORD), ("act_type", models.PayloadSchemaType.KEYWORD), ("subdivision_type", models.PayloadSchemaType.KEYWORD), ("language", models.PayloadSchemaType.KEYWORD), ("is_current_version", models.PayloadSchemaType.BOOL), ("has_relations", models.PayloadSchemaType.BOOL), ("has_subjects", models.PayloadSchemaType.BOOL), ("retrieval_enabled", models.PayloadSchemaType.BOOL), ] # === UTILITY METHODS === def _build_filter(self, filter_conditions: Dict[str, Any]) -> models.Filter: """ Build Qdrant filter from conditions Supports: - Exact match: {"key": "value"} - List of values: {"key": ["val1", "val2"]} -> OR condition - Range queries: {"key": {"gte": 10, "lte": 20}} - Null checks: {"key": None} -> field must not exist Args: filter_conditions: Dictionary of filter conditions Returns: Qdrant Filter object """ must_conditions: List[models.Condition] = [] for key, value in filter_conditions.items(): if value is None: # Null check - field must not exist must_conditions.append(models.IsNullCondition(is_null=models.PayloadField(key=key))) elif isinstance(value, list): # List of values - OR condition list_values = cast(list[object], value) should_conditions: List[models.Condition] = [] for item in list_values: if isinstance(item, (str, int, bool)): should_conditions.append(models.FieldCondition(key=key, match=models.MatchValue(value=item))) else: raise TypeError(f"List filter values must be str/int/bool, got {type(item).__name__}") must_conditions.append(models.Filter(should=should_conditions)) elif isinstance(value, dict) and any(k in value for k in ["gte", "gt", "lte", "lt"]): # Range query range_condition = models.Range() if "gte" in value: range_condition.gte = value["gte"] if "gt" in value: range_condition.gt = value["gt"] if "lte" in value: range_condition.lte = value["lte"] if "lt" in value: range_condition.lt = value["lt"] must_conditions.append(models.FieldCondition(key=key, range=range_condition)) else: # Exact match if not isinstance(value, (str, int, bool)): raise TypeError(f"Exact match filter value for '{key}' must be str/int/bool") must_conditions.append(models.FieldCondition(key=key, match=models.MatchValue(value=value))) return models.Filter(must=must_conditions) # === DATA MANAGEMENT === _RETRY_MAX = 4 _RETRY_BASE_DELAY = 2.0 # seconds, doubles each attempt _DELETE_BY_IDS_BATCH_SIZE = 64 _DELETE_BY_FILTER_SCROLL_LIMIT = 64 @staticmethod def _is_transient_qdrant_error(exc: Exception) -> bool: err_text = str(exc).lower() transient_markers = ( "deadline", "timeout", "timed out", "cancelled", "unavailable", "temporarily unavailable", "connection reset", "connection refused", "broken pipe", "transport error", "grpc", "resource exhausted", "too many requests", ) return any(marker in err_text for marker in transient_markers) def _retry_qdrant_operation(self, operation: str, fn: Any) -> Any: """Retry a Qdrant operation with exponential backoff on transient transport errors.""" for attempt in range(1, self._RETRY_MAX + 1): try: return fn() except Exception as exc: is_transient = self._is_transient_qdrant_error(exc) if not is_transient or attempt == self._RETRY_MAX: raise delay = self._RETRY_BASE_DELAY * (2 ** (attempt - 1)) logger.warning( "Qdrant %s failed (attempt %d/%d): %s — retrying in %.1fs", operation, attempt, self._RETRY_MAX, exc, delay, ) time.sleep(delay) raise RuntimeError("unreachable") # pragma: no cover def _upsert_batch_with_retries( self, batch: Sequence[models.PointStruct], *, operation: str, ) -> int: try: self._retry_qdrant_operation( operation, lambda: self.client.upsert(collection_name=self.collection_name, points=list(batch)), ) return len(batch) except Exception as exc: if len(batch) <= 1 or not self._is_transient_qdrant_error(exc): raise mid = len(batch) // 2 left = list(batch[:mid]) right = list(batch[mid:]) logger.warning( "Qdrant %s failed for %d points after retries: %s — splitting into %d + %d points", operation, len(batch), exc, len(left), len(right), ) left_total = self._upsert_batch_with_retries( left, operation=f"{operation}/left", ) right_total = self._upsert_batch_with_retries( right, operation=f"{operation}/right", ) return left_total + right_total def _delete_batch_with_retries( self, point_ids: Sequence[models.ExtendedPointId], *, operation: str, ) -> int: batch = list(point_ids) try: self._retry_qdrant_operation( operation, lambda: self.client.delete( collection_name=self.collection_name, points_selector=models.PointIdsList(points=batch), ), ) return len(batch) except Exception as exc: if len(batch) <= 1 or not self._is_transient_qdrant_error(exc): raise mid = len(batch) // 2 left = batch[:mid] right = batch[mid:] logger.warning( "Qdrant %s failed for %d points after retries: %s — splitting into %d + %d points", operation, len(batch), exc, len(left), len(right), ) left_total = self._delete_batch_with_retries( left, operation=f"{operation}/left", ) right_total = self._delete_batch_with_retries( right, operation=f"{operation}/right", ) return left_total + right_total def _delete_point_ids_with_retries( self, point_ids: Sequence[models.ExtendedPointId], *, operation: str, ) -> int: total = 0 for i in range(0, len(point_ids), self._DELETE_BY_IDS_BATCH_SIZE): batch = list(point_ids[i : i + self._DELETE_BY_IDS_BATCH_SIZE]) total += self._delete_batch_with_retries( batch, operation=f"{operation} batch {i // self._DELETE_BY_IDS_BATCH_SIZE + 1}", ) return total def _scroll_point_ids_by_filter( self, qdrant_filter: models.Filter, ) -> List[models.ExtendedPointId]: point_ids: List[models.ExtendedPointId] = [] offset: Any = None while True: records, next_offset = self._retry_qdrant_operation( "scroll_delete_candidates", lambda offset=offset: self.client.scroll( collection_name=self.collection_name, scroll_filter=qdrant_filter, limit=self._DELETE_BY_FILTER_SCROLL_LIMIT, offset=offset, with_payload=False, with_vectors=False, ), ) for record in records: record_id = getattr(record, "id", None) if record_id is not None: point_ids.append(cast(models.ExtendedPointId, record_id)) if next_offset is None: break offset = next_offset return point_ids
[docs] def upsert_points( self, points: Sequence[Union[VectorPoint, models.PointStruct]], batch_size: Optional[int] = None ) -> int: """ Insert or update points in the collection with retry on transient transport errors. Args: points: List of VectorPoint or PointStruct objects to upsert batch_size: Optional batch size for large inserts Returns: Number of points upserted """ try: # Convert VectorPoint to PointStruct if needed qdrant_points: List[models.PointStruct] = [] for point in points: if isinstance(point, VectorPoint): qdrant_points.append(point.to_qdrant_point()) else: qdrant_points.append(point) if batch_size and len(qdrant_points) > batch_size: # Batch processing for large datasets total = 0 for i in range(0, len(qdrant_points), batch_size): batch = qdrant_points[i : i + batch_size] total += self._upsert_batch_with_retries( batch, operation=f"upsert batch {i // batch_size + 1}", ) logger.info("Upserted batch %d: %d points", i // batch_size + 1, len(batch)) return total else: # Single batch return self._upsert_batch_with_retries(qdrant_points, operation="upsert") except Exception as e: logger.exception("Error upserting points: %s", e) raise
[docs] def delete_points(self, point_ids: List[Union[str, int]]) -> int: """Delete points by IDs from the current collection.""" if not point_ids: return 0 try: typed_point_ids = cast(List[models.ExtendedPointId], point_ids) return self._delete_point_ids_with_retries(typed_point_ids, operation="delete") except Exception as e: logger.exception("Error deleting points: %s", e) raise
[docs] def delete_points_by_filter( self, query_filter: Union[Dict[str, Any], models.Filter], ) -> int: """Delete points matching a payload filter from the current collection.""" try: if isinstance(query_filter, models.Filter): qdrant_filter = query_filter else: qdrant_filter = self._build_filter(query_filter) try: point_ids = self._scroll_point_ids_by_filter(qdrant_filter) except Exception as exc: logger.warning( "Qdrant delete_by_filter prefetch failed for collection %s: %s — " "falling back to server-side delete", self.collection_name, exc, ) else: if not point_ids: return 0 return self._delete_point_ids_with_retries( point_ids, operation="delete_by_filter ids", ) self._retry_qdrant_operation( "delete_by_filter", lambda: self.client.delete( collection_name=self.collection_name, points_selector=models.FilterSelector(filter=qdrant_filter), ), ) return 0 except Exception as e: logger.exception("Error deleting points by filter: %s", e) raise
# === SEARCH AND RETRIEVAL ===
[docs] def search( self, query_vector: List[float], limit: Optional[int] = None, score_threshold: Optional[float] = None, query_filter: Optional[Union[Dict[str, Any], models.Filter]] = None, hnsw_ef: Optional[int] = None, exact: Optional[bool] = None, ) -> List[SearchResult]: """ Search for similar vectors Args: query_vector: The query embedding limit: Maximum number of results (None -> config.search.default_limit) score_threshold: Minimum similarity score (0-1) query_filter: Metadata filters (e.g., {"document_type": "directive"}) hnsw_ef: ANN breadth parameter for HNSW search (None -> config.search.hnsw_ef) exact: Force exact vector search (None -> config.search.exact_search) """ try: search_cfg = get_config().search resolved_limit = max(int(limit), 1) if limit is not None else max(int(search_cfg.default_limit), 1) resolved_hnsw_ef = hnsw_ef if hnsw_ef is not None else search_cfg.hnsw_ef resolved_exact = bool(exact) if exact is not None else bool(search_cfg.exact_search) # Build filter if conditions provided if query_filter is None: qdrant_filter = None elif isinstance(query_filter, models.Filter): qdrant_filter = query_filter else: qdrant_filter = self._build_filter(query_filter) search_params: Optional[models.SearchParams] = None search_params_kwargs: Dict[str, Any] = {} if resolved_hnsw_ef is not None: search_params_kwargs["hnsw_ef"] = max(int(resolved_hnsw_ef), 1) search_params_kwargs["exact"] = resolved_exact if search_params_kwargs: search_params = models.SearchParams(**search_params_kwargs) # Perform search search_result = self.client.query_points( collection_name=self.collection_name, query=query_vector, limit=resolved_limit, score_threshold=score_threshold, query_filter=qdrant_filter, search_params=search_params, with_payload=True, with_vectors=False, ).points except UnexpectedResponse as e: if e.status_code == 404: logger.warning("Qdrant collection '%s' not found, returning empty results", self.collection_name) return [] raise except Exception: raise # Convert to our SearchResult format # Qdrant returns a relevance score where higher is better. results: List[SearchResult] = [] for hit in search_result: payload_obj = hit.payload payload: Dict[str, Any] if isinstance(payload_obj, dict): payload = dict(payload_obj) else: payload = {} results.append(SearchResult(id=str(hit.id), score=float(hit.score), payload=payload)) return results
[docs] def retrieve_vectors_by_ids( self, point_ids: List[int], ) -> Dict[int, List[float]]: """Retrieve stored embedding vectors by point IDs. Returns a dict mapping point_id -> vector. Points not found in the collection are silently omitted. """ if not point_ids: return {} try: records = self.client.retrieve( collection_name=self.collection_name, ids=point_ids, with_vectors=True, with_payload=False, ) result: Dict[int, List[float]] = {} for record in records: raw_vec = record.vector if raw_vec is not None and isinstance(raw_vec, list): result[int(record.id)] = [float(v) for v in raw_vec] # type: ignore[arg-type] return result except Exception as e: logger.exception("Error retrieving vectors by IDs: %s", e) raise
# === COLLECTION INFO AND STATISTICS ===
[docs] def get_collection_info(self) -> Dict[str, Any]: """Return collection metadata and point counts for the active collection.""" try: info = self.client.get_collection(collection_name=self.collection_name) vector_size = self.vector_size # Extract vector size from config if available extracted_size = self._extract_vector_size_from_collection_info(info) if extracted_size is not None: vector_size = extracted_size return { "name": self.collection_name, "points_count": getattr(info, "points_count", 0), "status": getattr(info, "status", "unknown"), "vector_size": vector_size, } except Exception as e: return {"error": str(e)}
[docs] def get_statistics(self) -> Dict[str, Any]: """Get vector database statistics""" try: info = self.get_collection_info() if "error" in info: return {"error": info["error"], "database_type": "qdrant"} return { "total_points": info.get("points_count", 0), "collection_name": self.collection_name, "vector_size": info.get("vector_size", self.vector_size), "status": info.get("status", "unknown"), "database_type": "qdrant", } except Exception as e: return {"error": str(e), "database_type": "qdrant"}