"""
Load JSON payload schemas and render payloads.
"""
import ast
import json
from collections.abc import Callable
from os import PathLike
from pathlib import Path
from typing import Any, TypeAlias, cast
Transformer = Callable[[Any], Any]
Transformers = dict[str, Transformer]
SchemaObject: TypeAlias = dict[str, object]
[docs]
class PayloadSchemaLoader:
"""Loads and applies payload schemas."""
def __init__(self, schema_file: str | PathLike[str] | None = None):
"""Initialize loader (defaults to payload_schemas.json)."""
if schema_file is None:
resolved_schema_file = Path(__file__).parent / "payload_schemas.json"
else:
resolved_schema_file = Path(schema_file)
self.schema_file: Path = resolved_schema_file
self._schemas: dict[str, SchemaObject] = self._load_schemas()
def _load_schemas(self) -> dict[str, SchemaObject]:
"""Load schemas from JSON."""
try:
with open(self.schema_file, "r", encoding="utf-8") as f:
loaded = cast(object, json.load(f))
if not isinstance(loaded, dict):
raise ValueError(f"Schema file {self.schema_file} must contain a JSON object at top level")
loaded_mapping = cast(dict[object, object], loaded)
schemas: dict[str, SchemaObject] = {}
for key, value in loaded_mapping.items():
if not isinstance(key, str):
continue
if not isinstance(value, dict):
raise ValueError(f"Schema '{key}' must be a JSON object")
schemas[key] = cast(SchemaObject, value)
return schemas
except FileNotFoundError:
raise FileNotFoundError(f"Payload schema file not found: {self.schema_file}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in schema file {self.schema_file}: {e}")
[docs]
def get_schema(self, schema_name: str) -> SchemaObject:
"""Fetch a schema by name."""
if schema_name not in self._schemas:
raise KeyError(f"Schema '{schema_name}' not found. Available: {list(self._schemas.keys())}")
schema = self._schemas[schema_name]
return schema
[docs]
def build_payload_from_schema(
self, schema_name: str, context: dict[str, Any], transformers: Transformers | None = None
) -> dict[str, Any]:
"""Render a payload from schema + context."""
schema: SchemaObject = self.get_schema(schema_name)
active_transformers: Transformers = transformers if transformers is not None else {}
payload: dict[str, Any] = {}
# Process all field groups
groups_obj = schema.get("fields", [])
if not isinstance(groups_obj, list):
raise ValueError(f"Schema '{schema_name}' field 'fields' must be a list")
groups = cast(list[object], groups_obj)
for group_obj in groups:
if not isinstance(group_obj, dict):
continue
group = cast(SchemaObject, group_obj)
group_fields_obj = group.get("fields", [])
if not isinstance(group_fields_obj, list):
continue
group_fields = cast(list[object], group_fields_obj)
# Skip optional groups if source data not available
if group.get("optional_group", False):
if not group_fields:
continue
first_field_obj = group_fields[0]
if not isinstance(first_field_obj, dict):
continue
first_field = cast(SchemaObject, first_field_obj)
first_source = first_field.get("source")
if not isinstance(first_source, str):
continue
source_key = self._get_source_key(first_source)
if source_key not in context or context.get(source_key) is None:
continue
for field_obj in group_fields:
if not isinstance(field_obj, dict):
continue
field_def = cast(SchemaObject, field_obj)
field_name_obj = field_def.get("name")
if not isinstance(field_name_obj, str):
continue
field_name = field_name_obj
# Skip if required and source not available
if bool(field_def.get("required", False)):
source_value = field_def.get("source")
if not isinstance(source_value, str):
raise ValueError(f"Field '{field_name}' has invalid source")
source_key = self._get_source_key(source_value)
if source_key not in context:
raise ValueError(f"Required field '{field_name}' source '{source_key}' not found in context")
# Build field value
value = self._build_field_value(field_def, context, active_transformers)
# Skip None values for optional fields
if value is None and not bool(field_def.get("required", False)):
# Use default if specified
if "default" in field_def:
payload[field_name] = field_def["default"]
continue
payload[field_name] = value
# Merge metadata if enabled
metadata_merge_obj = schema.get("metadata_merge", {})
metadata_merge: SchemaObject
if isinstance(metadata_merge_obj, dict):
metadata_merge = cast(SchemaObject, metadata_merge_obj)
else:
metadata_merge = {}
if bool(metadata_merge.get("enabled", False)):
metadata_obj = context.get("metadata", {})
metadata: dict[str, Any]
if isinstance(metadata_obj, dict):
metadata = cast(dict[str, Any], metadata_obj)
else:
metadata = {}
overwrite = bool(metadata_merge.get("overwrite_existing", False))
for key, value in metadata.items():
if overwrite or key not in payload:
payload[key] = value
return payload
def _get_source_key(self, source: str) -> str:
"""Get the top-level key for a source path."""
# 'act_data.celex' -> 'act_data'
# 'full_text' -> 'full_text'
return source.split(".")[0]
def _build_field_value(self, field_def: SchemaObject, context: dict[str, Any], transformers: Transformers) -> Any:
"""Resolve a field value from schema + context."""
source_obj = field_def.get("source")
if not isinstance(source_obj, str):
return None
source = source_obj
# Handle computed expressions
if bool(field_def.get("computed", False)) and "expression" in field_def:
expression = field_def.get("expression")
if isinstance(expression, str):
return self._eval_expression(expression, context)
return []
# Handle special computed fields with len() in source
if bool(field_def.get("computed", False)) and source.startswith("len(") and source.endswith(")"):
# Extract path from len(subdivision_data.content)
path = source[4:-1] # Remove 'len(' and ')'
value = self._get_nested_value(path, context)
if value is None:
value = ""
return len(value)
# Get source value
value = self._get_nested_value(source, context)
# Apply transformer if specified
transformer_name_obj = field_def.get("transformer")
if isinstance(transformer_name_obj, str):
transformer = transformers.get(transformer_name_obj)
if transformer is not None:
value = transformer(value)
return value
def _get_nested_value(self, source: str, context: dict[str, Any]) -> Any:
"""Get a value from a dotted path."""
parts = source.split(".")
if len(parts) == 1:
# Direct access
return context.get(parts[0])
# Nested access
obj: Any = context.get(parts[0])
for part in parts[1:]:
if isinstance(obj, dict):
mapping = cast(dict[str, Any], obj)
obj = mapping.get(part)
else:
obj = getattr(obj, part, None)
return obj
def _eval_expression(self, expression: str, context: dict[str, Any]) -> list[Any]:
"""Evaluate a restricted list comprehension on subjects."""
subjects_obj = context.get("subjects", [])
if not isinstance(subjects_obj, list):
return []
subjects: list[object] = cast(list[object], subjects_obj)
try:
tree = ast.parse(expression, mode="eval")
except SyntaxError:
return []
if not isinstance(tree.body, ast.ListComp):
return []
comp = tree.body
if len(comp.generators) != 1:
return []
generator = comp.generators[0]
if generator.ifs:
return []
if not (isinstance(generator.target, ast.Name) and generator.target.id == "s"):
return []
if not (isinstance(generator.iter, ast.Name) and generator.iter.id == "subjects"):
return []
results: list[Any] = []
try:
for subject in subjects:
results.append(self._eval_subject_expr(comp.elt, subject))
return results
except Exception:
return []
def _eval_subject_expr(self, node: ast.AST, subject: object) -> Any:
"""Evaluate a restricted expression on one subject."""
if isinstance(node, ast.BoolOp) and isinstance(node.op, ast.Or):
result = None
for value in node.values:
result = self._eval_subject_expr(value, subject)
if result:
return result
return result
if isinstance(node, ast.Call):
if not (isinstance(node.func, ast.Attribute) and node.func.attr == "get"):
raise ValueError("Only s.get(...) calls are allowed")
if not (isinstance(node.func.value, ast.Name) and node.func.value.id == "s"):
raise ValueError("Only s.get(...) calls are allowed")
if node.keywords:
raise ValueError("Keyword args are not allowed")
args = [self._eval_literal(arg) for arg in node.args]
if not isinstance(subject, dict):
return None
subject_dict = cast(dict[str, Any], subject)
if len(args) == 1:
key0 = args[0]
if not isinstance(key0, str):
return None
return subject_dict.get(key0)
if len(args) == 2:
key0 = args[0]
if not isinstance(key0, str):
return args[1]
return subject_dict.get(key0, args[1])
raise ValueError("Invalid number of arguments to get()")
if isinstance(node, ast.Subscript):
if not (isinstance(node.value, ast.Name) and node.value.id == "s"):
raise ValueError("Only s[...] access is allowed")
key = self._eval_subscript_index(node.slice)
if not isinstance(subject, dict):
return None
if not isinstance(key, str):
return None
subject_dict = cast(dict[str, Any], subject)
return subject_dict.get(key)
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.Name) and node.id == "s":
return subject
raise ValueError("Unsupported expression node")
def _eval_subscript_index(self, node: ast.AST) -> object:
return self._eval_literal(node)
def _eval_literal(self, node: ast.AST) -> object:
if isinstance(node, ast.Constant):
return node.value
raise ValueError("Only literal constants are allowed")