"""
Qdrant repository implementation
similarity search and vector retrieval
"""
import logging
import time
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from lalandre_core.config import get_config
from lalandre_core.repositories.base import BaseRepository
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Distance, VectorParams
from .models import SearchResult, VectorPoint
logger = logging.getLogger(__name__)
[docs]
class QdrantRepository(BaseRepository):
"""
Repository for Qdrant vector database operations
Handles all low-level operations including collection management
and data ingestion
Collection naming: {base}_{model}_{dimension}
Examples:
- chunk_embeddings_mistral_1024
- chunk_embeddings_gte_large_1024
- chunk_embeddings_bge_m3_1024
- chunk_embeddings_e5_large_1024
This allows multiple embedding models to coexist without confusion.
"""
[docs]
@staticmethod
def make_collection_name(base_name: str, model_name: str, dimension: int) -> str:
"""Generate collection name: {base}_{model}_{dimension}
Args:
base_name: Base collection name
model_name: Model name (e.g., 'thenlper/gte-large')
dimension: Vector dimension (e.g., 1024)
Returns:
Collection name like 'chunk_embeddings_gte_large_1024'
Example:
>>> QdrantRepository.make_collection_name(
... 'chunk_embeddings',
... 'thenlper/gte-large',
... 1024
... )
'chunk_embeddings_gte_large_1024'
"""
# Simplify model name for collection naming
model_suffix = model_name.lower()
model_suffix = model_suffix.replace("/", "_").replace("-", "_")
# Keep only the meaningful part (last 2 components)
parts = model_suffix.split("_")
if len(parts) > 2:
model_suffix = "_".join(parts[-2:])
return f"{base_name}_{model_suffix}_{dimension}"
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = None,
collection_name: Optional[str] = None,
vector_size: Optional[int] = None,
api_key: Optional[str] = None,
use_https: Optional[bool] = None,
):
config = get_config()
# Determine host and port
final_host = host or config.vector.host
final_port = port or config.vector.port
final_api_key = api_key or config.vector.api_key
# HTTPS toggle:
# explicit constructor arg > config.vector.use_https > False (local/dev default).
use_ssl = use_https if use_https is not None else config.vector.use_https
# Keep API keys off insecure HTTP to avoid noisy/inconsistent behavior.
api_key_for_client = final_api_key if use_ssl else None
if not use_ssl and final_api_key:
logger.info("Qdrant running over HTTP: API key is ignored.")
self._client_kwargs: dict[str, Any] = {
"host": final_host,
"port": final_port,
"https": use_ssl,
"api_key": api_key_for_client,
"timeout": config.vector.timeout,
}
# Single initialization path.
self.client = QdrantClient(**self._client_kwargs)
if not collection_name:
raise ValueError("Qdrant collection name is required. Pass collection_name explicitly.")
self.collection_name: str = collection_name
# Vector size: explicit param > config > auto-detect from collection
self._vector_size: int | None = vector_size if vector_size is not None else config.vector.vector_size
@staticmethod
def _extract_vector_size_from_collection_info(collection_info: Any) -> int | None:
"""Extract vector size from Qdrant collection info across supported vector configs."""
config = getattr(collection_info, "config", None)
params = getattr(config, "params", None)
vectors = getattr(params, "vectors", None)
if isinstance(vectors, VectorParams):
return int(vectors.size)
if isinstance(vectors, dict):
named_vectors = cast(dict[object, object], vectors)
for vector_params in named_vectors.values():
if isinstance(vector_params, VectorParams):
return int(vector_params.size)
return None
@property
def vector_size(self) -> int:
"""Get vector size - auto-detects from existing collection if not set"""
if self._vector_size is None:
# Try to detect from existing collection
try:
if self.collection_exists():
collection_info = self.client.get_collection(self.collection_name)
detected_size = self._extract_vector_size_from_collection_info(collection_info)
if detected_size is not None:
self._vector_size = detected_size
msg = (
f"[INFO] Auto-detected vector size: {self._vector_size} "
f"from collection '{self.collection_name}'"
)
logger.info(msg)
except Exception as e:
logger.warning("Could not auto-detect vector size: %s", e)
if self._vector_size is None:
self._vector_size = get_config().vector.vector_size
return int(self._vector_size)
@vector_size.setter
def vector_size(self, value: int):
"""Set vector size explicitly"""
self._vector_size = value
[docs]
def close(self):
"""Close Qdrant client connection"""
if hasattr(self.client, "close"):
self.client.close()
[docs]
def health_check(self) -> bool:
"""Verify Qdrant connectivity"""
try:
self.client.get_collections()
return True
except Exception:
return False
# === COLLECTION MANAGEMENT ===
[docs]
def collection_exists(self) -> bool:
"""Check if collection exists"""
try:
collections = self.client.get_collections().collections
return any(c.name == self.collection_name for c in collections)
except Exception:
return False
[docs]
def create_collection(self, recreate: bool = False, distance: Distance = Distance.COSINE) -> bool:
"""
Args:
recreate: If True, delete existing collection and recreate
distance: Distance metric (COSINE, EUCLID, DOT)
Returns:
True if collection was created, False if already exists
"""
try:
# Check if collection exists
exists = self.collection_exists()
if exists and recreate:
logger.info("Deleting existing collection: %s", self.collection_name)
self.client.delete_collection(collection_name=self.collection_name)
exists = False
if not exists:
logger.info(
"Creating collection: %s (vector size: %d)",
self.collection_name,
self.vector_size,
)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.vector_size, distance=distance),
)
logger.info(
"Collection created successfully with %d-dimensional vectors",
self.vector_size,
)
self._reconcile_standard_indexes()
return True
else:
# Verify dimension matches
collection_info = self.client.get_collection(self.collection_name)
existing_size = self._extract_vector_size_from_collection_info(collection_info)
if existing_size is None:
logger.info(f"Collection exists: {self.collection_name} (vector size: unknown)")
logger.warning("WARNING: Could not determine vector size from existing collection metadata.")
return False
if existing_size != self.vector_size:
logger.warning("Collection dimension mismatch!")
logger.warning(f" Expected: {self.vector_size}, Found: {existing_size}")
logger.warning(" To fix: Set recreate=True or use a different collection name")
else:
logger.info(f"Collection exists: {self.collection_name} (vector size: {existing_size})")
self._reconcile_standard_indexes()
return False
except Exception as e:
logger.exception("Error creating collection: %s", e)
raise
[docs]
@classmethod
def from_embedding_service_with_auto_collection(
cls, embedding_service: Any, base_collection_name: Optional[str] = None
) -> "QdrantRepository":
"""Create repository with automatic collection naming based on embedding model
Format: {base}_{model}_{dimension}
Args:
embedding_service: EmbeddingService instance
base_collection_name: Base collection name
Returns:
QdrantRepository with auto-generated collection name
Example:
>>> embedding_service = EmbeddingService(provider="local", model_name="thenlper/gte-large")
>>> repo = QdrantRepository.from_embedding_service_with_auto_collection(embedding_service)
>>> # Collection: 'chunk_embeddings_gte_large_1024'
"""
if not base_collection_name:
raise ValueError("Base collection name is required. Pass base_collection_name explicitly.")
base_name = base_collection_name
vector_size = embedding_service.get_vector_size()
model_name = embedding_service.model_name
collection_name = cls.make_collection_name(base_name, model_name, vector_size)
logger.info("Auto-generated collection: %s", collection_name)
return cls(collection_name=collection_name, vector_size=vector_size)
[docs]
def create_payload_index(self, field_name: str, field_schema: models.PayloadSchemaType) -> bool:
"""
for efficient filtering
Args:
field_name: Name of the payload field to index
field_schema: Schema type (KEYWORD, INTEGER, FLOAT, etc.)
Returns:
True if index created successfully
"""
try:
self.client.create_payload_index(
collection_name=self.collection_name, field_name=field_name, field_schema=field_schema
)
logger.info("Created payload index for field: %s", field_name)
return True
except Exception as e:
logger.warning("Error creating payload index for %s: %s", field_name, e)
return False
[docs]
def setup_standard_indexes(self) -> None:
"""Create standard payload indexes used for legal document filtering."""
for field_name, field_schema in self._standard_indexes():
self.create_payload_index(field_name, field_schema)
def _reconcile_standard_indexes(self) -> None:
try:
existing = self.client.get_collection(self.collection_name).payload_schema or {}
except Exception as e:
logger.warning("Could not read payload_schema for %s: %s", self.collection_name, e)
existing = {}
missing = [name for name, _ in self._standard_indexes() if name not in existing]
if not missing:
return
logger.info("Creating missing payload indexes on %s: %s", self.collection_name, missing)
schema_by_name = dict(self._standard_indexes())
for name in missing:
self.create_payload_index(name, schema_by_name[name])
@staticmethod
def _standard_indexes() -> List[Tuple[str, models.PayloadSchemaType]]:
return [
("act_id", models.PayloadSchemaType.INTEGER),
("version_id", models.PayloadSchemaType.INTEGER),
("chunk_id", models.PayloadSchemaType.INTEGER),
("celex", models.PayloadSchemaType.KEYWORD),
("eli", models.PayloadSchemaType.KEYWORD),
("act_type", models.PayloadSchemaType.KEYWORD),
("subdivision_type", models.PayloadSchemaType.KEYWORD),
("language", models.PayloadSchemaType.KEYWORD),
("is_current_version", models.PayloadSchemaType.BOOL),
("has_relations", models.PayloadSchemaType.BOOL),
("has_subjects", models.PayloadSchemaType.BOOL),
("retrieval_enabled", models.PayloadSchemaType.BOOL),
]
# === UTILITY METHODS ===
def _build_filter(self, filter_conditions: Dict[str, Any]) -> models.Filter:
"""
Build Qdrant filter from conditions
Supports:
- Exact match: {"key": "value"}
- List of values: {"key": ["val1", "val2"]} -> OR condition
- Range queries: {"key": {"gte": 10, "lte": 20}}
- Null checks: {"key": None} -> field must not exist
Args:
filter_conditions: Dictionary of filter conditions
Returns:
Qdrant Filter object
"""
must_conditions: List[models.Condition] = []
for key, value in filter_conditions.items():
if value is None:
# Null check - field must not exist
must_conditions.append(models.IsNullCondition(is_null=models.PayloadField(key=key)))
elif isinstance(value, list):
# List of values - OR condition
list_values = cast(list[object], value)
should_conditions: List[models.Condition] = []
for item in list_values:
if isinstance(item, (str, int, bool)):
should_conditions.append(models.FieldCondition(key=key, match=models.MatchValue(value=item)))
else:
raise TypeError(f"List filter values must be str/int/bool, got {type(item).__name__}")
must_conditions.append(models.Filter(should=should_conditions))
elif isinstance(value, dict) and any(k in value for k in ["gte", "gt", "lte", "lt"]):
# Range query
range_condition = models.Range()
if "gte" in value:
range_condition.gte = value["gte"]
if "gt" in value:
range_condition.gt = value["gt"]
if "lte" in value:
range_condition.lte = value["lte"]
if "lt" in value:
range_condition.lt = value["lt"]
must_conditions.append(models.FieldCondition(key=key, range=range_condition))
else:
# Exact match
if not isinstance(value, (str, int, bool)):
raise TypeError(f"Exact match filter value for '{key}' must be str/int/bool")
must_conditions.append(models.FieldCondition(key=key, match=models.MatchValue(value=value)))
return models.Filter(must=must_conditions)
# === DATA MANAGEMENT ===
_RETRY_MAX = 4
_RETRY_BASE_DELAY = 2.0 # seconds, doubles each attempt
_DELETE_BY_IDS_BATCH_SIZE = 64
_DELETE_BY_FILTER_SCROLL_LIMIT = 64
@staticmethod
def _is_transient_qdrant_error(exc: Exception) -> bool:
err_text = str(exc).lower()
transient_markers = (
"deadline",
"timeout",
"timed out",
"cancelled",
"unavailable",
"temporarily unavailable",
"connection reset",
"connection refused",
"broken pipe",
"transport error",
"grpc",
"resource exhausted",
"too many requests",
)
return any(marker in err_text for marker in transient_markers)
def _retry_qdrant_operation(self, operation: str, fn: Any) -> Any:
"""Retry a Qdrant operation with exponential backoff on transient transport errors."""
for attempt in range(1, self._RETRY_MAX + 1):
try:
return fn()
except Exception as exc:
is_transient = self._is_transient_qdrant_error(exc)
if not is_transient or attempt == self._RETRY_MAX:
raise
delay = self._RETRY_BASE_DELAY * (2 ** (attempt - 1))
logger.warning(
"Qdrant %s failed (attempt %d/%d): %s — retrying in %.1fs",
operation,
attempt,
self._RETRY_MAX,
exc,
delay,
)
time.sleep(delay)
raise RuntimeError("unreachable") # pragma: no cover
def _upsert_batch_with_retries(
self,
batch: Sequence[models.PointStruct],
*,
operation: str,
) -> int:
try:
self._retry_qdrant_operation(
operation,
lambda: self.client.upsert(collection_name=self.collection_name, points=list(batch)),
)
return len(batch)
except Exception as exc:
if len(batch) <= 1 or not self._is_transient_qdrant_error(exc):
raise
mid = len(batch) // 2
left = list(batch[:mid])
right = list(batch[mid:])
logger.warning(
"Qdrant %s failed for %d points after retries: %s — splitting into %d + %d points",
operation,
len(batch),
exc,
len(left),
len(right),
)
left_total = self._upsert_batch_with_retries(
left,
operation=f"{operation}/left",
)
right_total = self._upsert_batch_with_retries(
right,
operation=f"{operation}/right",
)
return left_total + right_total
def _delete_batch_with_retries(
self,
point_ids: Sequence[models.ExtendedPointId],
*,
operation: str,
) -> int:
batch = list(point_ids)
try:
self._retry_qdrant_operation(
operation,
lambda: self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(points=batch),
),
)
return len(batch)
except Exception as exc:
if len(batch) <= 1 or not self._is_transient_qdrant_error(exc):
raise
mid = len(batch) // 2
left = batch[:mid]
right = batch[mid:]
logger.warning(
"Qdrant %s failed for %d points after retries: %s — splitting into %d + %d points",
operation,
len(batch),
exc,
len(left),
len(right),
)
left_total = self._delete_batch_with_retries(
left,
operation=f"{operation}/left",
)
right_total = self._delete_batch_with_retries(
right,
operation=f"{operation}/right",
)
return left_total + right_total
def _delete_point_ids_with_retries(
self,
point_ids: Sequence[models.ExtendedPointId],
*,
operation: str,
) -> int:
total = 0
for i in range(0, len(point_ids), self._DELETE_BY_IDS_BATCH_SIZE):
batch = list(point_ids[i : i + self._DELETE_BY_IDS_BATCH_SIZE])
total += self._delete_batch_with_retries(
batch,
operation=f"{operation} batch {i // self._DELETE_BY_IDS_BATCH_SIZE + 1}",
)
return total
def _scroll_point_ids_by_filter(
self,
qdrant_filter: models.Filter,
) -> List[models.ExtendedPointId]:
point_ids: List[models.ExtendedPointId] = []
offset: Any = None
while True:
records, next_offset = self._retry_qdrant_operation(
"scroll_delete_candidates",
lambda offset=offset: self.client.scroll(
collection_name=self.collection_name,
scroll_filter=qdrant_filter,
limit=self._DELETE_BY_FILTER_SCROLL_LIMIT,
offset=offset,
with_payload=False,
with_vectors=False,
),
)
for record in records:
record_id = getattr(record, "id", None)
if record_id is not None:
point_ids.append(cast(models.ExtendedPointId, record_id))
if next_offset is None:
break
offset = next_offset
return point_ids
[docs]
def upsert_points(
self, points: Sequence[Union[VectorPoint, models.PointStruct]], batch_size: Optional[int] = None
) -> int:
"""
Insert or update points in the collection with retry on transient transport errors.
Args:
points: List of VectorPoint or PointStruct objects to upsert
batch_size: Optional batch size for large inserts
Returns:
Number of points upserted
"""
try:
# Convert VectorPoint to PointStruct if needed
qdrant_points: List[models.PointStruct] = []
for point in points:
if isinstance(point, VectorPoint):
qdrant_points.append(point.to_qdrant_point())
else:
qdrant_points.append(point)
if batch_size and len(qdrant_points) > batch_size:
# Batch processing for large datasets
total = 0
for i in range(0, len(qdrant_points), batch_size):
batch = qdrant_points[i : i + batch_size]
total += self._upsert_batch_with_retries(
batch,
operation=f"upsert batch {i // batch_size + 1}",
)
logger.info("Upserted batch %d: %d points", i // batch_size + 1, len(batch))
return total
else:
# Single batch
return self._upsert_batch_with_retries(qdrant_points, operation="upsert")
except Exception as e:
logger.exception("Error upserting points: %s", e)
raise
[docs]
def delete_points(self, point_ids: List[Union[str, int]]) -> int:
"""Delete points by IDs from the current collection."""
if not point_ids:
return 0
try:
typed_point_ids = cast(List[models.ExtendedPointId], point_ids)
return self._delete_point_ids_with_retries(typed_point_ids, operation="delete")
except Exception as e:
logger.exception("Error deleting points: %s", e)
raise
[docs]
def delete_points_by_filter(
self,
query_filter: Union[Dict[str, Any], models.Filter],
) -> int:
"""Delete points matching a payload filter from the current collection."""
try:
if isinstance(query_filter, models.Filter):
qdrant_filter = query_filter
else:
qdrant_filter = self._build_filter(query_filter)
try:
point_ids = self._scroll_point_ids_by_filter(qdrant_filter)
except Exception as exc:
logger.warning(
"Qdrant delete_by_filter prefetch failed for collection %s: %s — "
"falling back to server-side delete",
self.collection_name,
exc,
)
else:
if not point_ids:
return 0
return self._delete_point_ids_with_retries(
point_ids,
operation="delete_by_filter ids",
)
self._retry_qdrant_operation(
"delete_by_filter",
lambda: self.client.delete(
collection_name=self.collection_name,
points_selector=models.FilterSelector(filter=qdrant_filter),
),
)
return 0
except Exception as e:
logger.exception("Error deleting points by filter: %s", e)
raise
# === SEARCH AND RETRIEVAL ===
[docs]
def search(
self,
query_vector: List[float],
limit: Optional[int] = None,
score_threshold: Optional[float] = None,
query_filter: Optional[Union[Dict[str, Any], models.Filter]] = None,
hnsw_ef: Optional[int] = None,
exact: Optional[bool] = None,
) -> List[SearchResult]:
"""
Search for similar vectors
Args:
query_vector: The query embedding
limit: Maximum number of results (None -> config.search.default_limit)
score_threshold: Minimum similarity score (0-1)
query_filter: Metadata filters (e.g., {"document_type": "directive"})
hnsw_ef: ANN breadth parameter for HNSW search (None -> config.search.hnsw_ef)
exact: Force exact vector search (None -> config.search.exact_search)
"""
try:
search_cfg = get_config().search
resolved_limit = max(int(limit), 1) if limit is not None else max(int(search_cfg.default_limit), 1)
resolved_hnsw_ef = hnsw_ef if hnsw_ef is not None else search_cfg.hnsw_ef
resolved_exact = bool(exact) if exact is not None else bool(search_cfg.exact_search)
# Build filter if conditions provided
if query_filter is None:
qdrant_filter = None
elif isinstance(query_filter, models.Filter):
qdrant_filter = query_filter
else:
qdrant_filter = self._build_filter(query_filter)
search_params: Optional[models.SearchParams] = None
search_params_kwargs: Dict[str, Any] = {}
if resolved_hnsw_ef is not None:
search_params_kwargs["hnsw_ef"] = max(int(resolved_hnsw_ef), 1)
search_params_kwargs["exact"] = resolved_exact
if search_params_kwargs:
search_params = models.SearchParams(**search_params_kwargs)
# Perform search
search_result = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=resolved_limit,
score_threshold=score_threshold,
query_filter=qdrant_filter,
search_params=search_params,
with_payload=True,
with_vectors=False,
).points
except UnexpectedResponse as e:
if e.status_code == 404:
logger.warning("Qdrant collection '%s' not found, returning empty results", self.collection_name)
return []
raise
except Exception:
raise
# Convert to our SearchResult format
# Qdrant returns a relevance score where higher is better.
results: List[SearchResult] = []
for hit in search_result:
payload_obj = hit.payload
payload: Dict[str, Any]
if isinstance(payload_obj, dict):
payload = dict(payload_obj)
else:
payload = {}
results.append(SearchResult(id=str(hit.id), score=float(hit.score), payload=payload))
return results
[docs]
def retrieve_vectors_by_ids(
self,
point_ids: List[int],
) -> Dict[int, List[float]]:
"""Retrieve stored embedding vectors by point IDs.
Returns a dict mapping point_id -> vector.
Points not found in the collection are silently omitted.
"""
if not point_ids:
return {}
try:
records = self.client.retrieve(
collection_name=self.collection_name,
ids=point_ids,
with_vectors=True,
with_payload=False,
)
result: Dict[int, List[float]] = {}
for record in records:
raw_vec = record.vector
if raw_vec is not None and isinstance(raw_vec, list):
result[int(record.id)] = [float(v) for v in raw_vec] # type: ignore[arg-type]
return result
except Exception as e:
logger.exception("Error retrieving vectors by IDs: %s", e)
raise
# === COLLECTION INFO AND STATISTICS ===
[docs]
def get_collection_info(self) -> Dict[str, Any]:
"""Return collection metadata and point counts for the active collection."""
try:
info = self.client.get_collection(collection_name=self.collection_name)
vector_size = self.vector_size
# Extract vector size from config if available
extracted_size = self._extract_vector_size_from_collection_info(info)
if extracted_size is not None:
vector_size = extracted_size
return {
"name": self.collection_name,
"points_count": getattr(info, "points_count", 0),
"status": getattr(info, "status", "unknown"),
"vector_size": vector_size,
}
except Exception as e:
return {"error": str(e)}
[docs]
def get_statistics(self) -> Dict[str, Any]:
"""Get vector database statistics"""
try:
info = self.get_collection_info()
if "error" in info:
return {"error": info["error"], "database_type": "qdrant"}
return {
"total_points": info.get("points_count", 0),
"collection_name": self.collection_name,
"vector_size": info.get("vector_size", self.vector_size),
"status": info.get("status", "unknown"),
"database_type": "qdrant",
}
except Exception as e:
return {"error": str(e), "database_type": "qdrant"}