Source code for lalandre_core.repositories.common.schema_loader

"""
Load JSON payload schemas and render payloads.
"""

import ast
import json
from collections.abc import Callable
from os import PathLike
from pathlib import Path
from typing import Any, TypeAlias, cast

Transformer = Callable[[Any], Any]
Transformers = dict[str, Transformer]
SchemaObject: TypeAlias = dict[str, object]


[docs] class PayloadSchemaLoader: """Loads and applies payload schemas.""" def __init__(self, schema_file: str | PathLike[str] | None = None): """Initialize loader (defaults to payload_schemas.json).""" if schema_file is None: resolved_schema_file = Path(__file__).parent / "payload_schemas.json" else: resolved_schema_file = Path(schema_file) self.schema_file: Path = resolved_schema_file self._schemas: dict[str, SchemaObject] = self._load_schemas() def _load_schemas(self) -> dict[str, SchemaObject]: """Load schemas from JSON.""" try: with open(self.schema_file, "r", encoding="utf-8") as f: loaded = cast(object, json.load(f)) if not isinstance(loaded, dict): raise ValueError(f"Schema file {self.schema_file} must contain a JSON object at top level") loaded_mapping = cast(dict[object, object], loaded) schemas: dict[str, SchemaObject] = {} for key, value in loaded_mapping.items(): if not isinstance(key, str): continue if not isinstance(value, dict): raise ValueError(f"Schema '{key}' must be a JSON object") schemas[key] = cast(SchemaObject, value) return schemas except FileNotFoundError: raise FileNotFoundError(f"Payload schema file not found: {self.schema_file}") except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in schema file {self.schema_file}: {e}")
[docs] def get_schema(self, schema_name: str) -> SchemaObject: """Fetch a schema by name.""" if schema_name not in self._schemas: raise KeyError(f"Schema '{schema_name}' not found. Available: {list(self._schemas.keys())}") schema = self._schemas[schema_name] return schema
[docs] def build_payload_from_schema( self, schema_name: str, context: dict[str, Any], transformers: Transformers | None = None ) -> dict[str, Any]: """Render a payload from schema + context.""" schema: SchemaObject = self.get_schema(schema_name) active_transformers: Transformers = transformers if transformers is not None else {} payload: dict[str, Any] = {} # Process all field groups groups_obj = schema.get("fields", []) if not isinstance(groups_obj, list): raise ValueError(f"Schema '{schema_name}' field 'fields' must be a list") groups = cast(list[object], groups_obj) for group_obj in groups: if not isinstance(group_obj, dict): continue group = cast(SchemaObject, group_obj) group_fields_obj = group.get("fields", []) if not isinstance(group_fields_obj, list): continue group_fields = cast(list[object], group_fields_obj) # Skip optional groups if source data not available if group.get("optional_group", False): if not group_fields: continue first_field_obj = group_fields[0] if not isinstance(first_field_obj, dict): continue first_field = cast(SchemaObject, first_field_obj) first_source = first_field.get("source") if not isinstance(first_source, str): continue source_key = self._get_source_key(first_source) if source_key not in context or context.get(source_key) is None: continue for field_obj in group_fields: if not isinstance(field_obj, dict): continue field_def = cast(SchemaObject, field_obj) field_name_obj = field_def.get("name") if not isinstance(field_name_obj, str): continue field_name = field_name_obj # Skip if required and source not available if bool(field_def.get("required", False)): source_value = field_def.get("source") if not isinstance(source_value, str): raise ValueError(f"Field '{field_name}' has invalid source") source_key = self._get_source_key(source_value) if source_key not in context: raise ValueError(f"Required field '{field_name}' source '{source_key}' not found in context") # Build field value value = self._build_field_value(field_def, context, active_transformers) # Skip None values for optional fields if value is None and not bool(field_def.get("required", False)): # Use default if specified if "default" in field_def: payload[field_name] = field_def["default"] continue payload[field_name] = value # Merge metadata if enabled metadata_merge_obj = schema.get("metadata_merge", {}) metadata_merge: SchemaObject if isinstance(metadata_merge_obj, dict): metadata_merge = cast(SchemaObject, metadata_merge_obj) else: metadata_merge = {} if bool(metadata_merge.get("enabled", False)): metadata_obj = context.get("metadata", {}) metadata: dict[str, Any] if isinstance(metadata_obj, dict): metadata = cast(dict[str, Any], metadata_obj) else: metadata = {} overwrite = bool(metadata_merge.get("overwrite_existing", False)) for key, value in metadata.items(): if overwrite or key not in payload: payload[key] = value return payload
def _get_source_key(self, source: str) -> str: """Get the top-level key for a source path.""" # 'act_data.celex' -> 'act_data' # 'full_text' -> 'full_text' return source.split(".")[0] def _build_field_value(self, field_def: SchemaObject, context: dict[str, Any], transformers: Transformers) -> Any: """Resolve a field value from schema + context.""" source_obj = field_def.get("source") if not isinstance(source_obj, str): return None source = source_obj # Handle computed expressions if bool(field_def.get("computed", False)) and "expression" in field_def: expression = field_def.get("expression") if isinstance(expression, str): return self._eval_expression(expression, context) return [] # Handle special computed fields with len() in source if bool(field_def.get("computed", False)) and source.startswith("len(") and source.endswith(")"): # Extract path from len(subdivision_data.content) path = source[4:-1] # Remove 'len(' and ')' value = self._get_nested_value(path, context) if value is None: value = "" return len(value) # Get source value value = self._get_nested_value(source, context) # Apply transformer if specified transformer_name_obj = field_def.get("transformer") if isinstance(transformer_name_obj, str): transformer = transformers.get(transformer_name_obj) if transformer is not None: value = transformer(value) return value def _get_nested_value(self, source: str, context: dict[str, Any]) -> Any: """Get a value from a dotted path.""" parts = source.split(".") if len(parts) == 1: # Direct access return context.get(parts[0]) # Nested access obj: Any = context.get(parts[0]) for part in parts[1:]: if isinstance(obj, dict): mapping = cast(dict[str, Any], obj) obj = mapping.get(part) else: obj = getattr(obj, part, None) return obj def _eval_expression(self, expression: str, context: dict[str, Any]) -> list[Any]: """Evaluate a restricted list comprehension on subjects.""" subjects_obj = context.get("subjects", []) if not isinstance(subjects_obj, list): return [] subjects: list[object] = cast(list[object], subjects_obj) try: tree = ast.parse(expression, mode="eval") except SyntaxError: return [] if not isinstance(tree.body, ast.ListComp): return [] comp = tree.body if len(comp.generators) != 1: return [] generator = comp.generators[0] if generator.ifs: return [] if not (isinstance(generator.target, ast.Name) and generator.target.id == "s"): return [] if not (isinstance(generator.iter, ast.Name) and generator.iter.id == "subjects"): return [] results: list[Any] = [] try: for subject in subjects: results.append(self._eval_subject_expr(comp.elt, subject)) return results except Exception: return [] def _eval_subject_expr(self, node: ast.AST, subject: object) -> Any: """Evaluate a restricted expression on one subject.""" if isinstance(node, ast.BoolOp) and isinstance(node.op, ast.Or): result = None for value in node.values: result = self._eval_subject_expr(value, subject) if result: return result return result if isinstance(node, ast.Call): if not (isinstance(node.func, ast.Attribute) and node.func.attr == "get"): raise ValueError("Only s.get(...) calls are allowed") if not (isinstance(node.func.value, ast.Name) and node.func.value.id == "s"): raise ValueError("Only s.get(...) calls are allowed") if node.keywords: raise ValueError("Keyword args are not allowed") args = [self._eval_literal(arg) for arg in node.args] if not isinstance(subject, dict): return None subject_dict = cast(dict[str, Any], subject) if len(args) == 1: key0 = args[0] if not isinstance(key0, str): return None return subject_dict.get(key0) if len(args) == 2: key0 = args[0] if not isinstance(key0, str): return args[1] return subject_dict.get(key0, args[1]) raise ValueError("Invalid number of arguments to get()") if isinstance(node, ast.Subscript): if not (isinstance(node.value, ast.Name) and node.value.id == "s"): raise ValueError("Only s[...] access is allowed") key = self._eval_subscript_index(node.slice) if not isinstance(subject, dict): return None if not isinstance(key, str): return None subject_dict = cast(dict[str, Any], subject) return subject_dict.get(key) if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Name) and node.id == "s": return subject raise ValueError("Unsupported expression node") def _eval_subscript_index(self, node: ast.AST) -> object: return self._eval_literal(node) def _eval_literal(self, node: ast.AST) -> object: if isinstance(node, ast.Constant): return node.value raise ValueError("Only literal constants are allowed")