"""Input validation module for xnatctl.
Provides comprehensive validation for URLs, ports, identifiers, paths,
and DICOM-specific values.
"""
from __future__ import annotations
import os
import re
from pathlib import Path
from urllib.parse import urlparse
from xnatctl.core.exceptions import (
ConfigurationError,
InvalidIdentifierError,
InvalidPortError,
InvalidURLError,
PathValidationError,
)
from xnatctl.core.timeouts import DEFAULT_HTTP_TIMEOUT_SECONDS
# =============================================================================
# Constants
# =============================================================================
MIN_PORT = 1
MAX_PORT = 65535
# XNAT identifier: alphanumeric, underscore, hyphen
XNAT_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
XNAT_ID_MAX_LENGTH = 64
# DICOM AE Title: 1-16 printable ASCII chars, no backslash
AE_TITLE_PATTERN = re.compile(r"^[\x20-\x5B\x5D-\x7E]{1,16}$")
AE_TITLE_MAX_LENGTH = 16
ALLOWED_URL_SCHEMES = {"http", "https"}
ALLOWED_ARCHIVE_EXTENSIONS = {".zip", ".tar", ".tar.gz", ".tgz"}
# =============================================================================
# URL Validation
# =============================================================================
[docs]
def validate_server_url(url: str) -> str:
"""Validate XNAT server URL and return normalized form.
Args:
url: Server URL to validate.
Returns:
Normalized URL (trailing slash removed).
Raises:
InvalidURLError: If URL is malformed or uses unsupported scheme.
"""
if not url or not isinstance(url, str):
raise InvalidURLError(str(url), "URL cannot be empty")
url = url.strip()
if not url:
raise InvalidURLError(url, "URL cannot be empty")
try:
parsed = urlparse(url)
except Exception as e:
raise InvalidURLError(url, f"Failed to parse URL: {e}") from e
if not parsed.scheme:
raise InvalidURLError(url, "URL must include scheme (http:// or https://)")
if parsed.scheme.lower() not in ALLOWED_URL_SCHEMES:
raise InvalidURLError(
url,
f"Unsupported scheme '{parsed.scheme}'. Use http or https.",
)
if not parsed.netloc:
raise InvalidURLError(url, "URL must include hostname")
return url.rstrip("/")
[docs]
def validate_url_or_none(url: str | None) -> str | None:
"""Validate URL if provided, or return None."""
if url is None or (isinstance(url, str) and not url.strip()):
return None
return validate_server_url(url)
# =============================================================================
# Port Validation
# =============================================================================
[docs]
def validate_port(port: int | str | None, allow_none: bool = False) -> int | None:
"""Validate port number.
Args:
port: Port number to validate.
allow_none: If True, None is a valid value.
Returns:
Validated port number or None.
Raises:
InvalidPortError: If port is invalid.
"""
if port is None:
if allow_none:
return None
raise InvalidPortError(port)
try:
port_int = int(port)
except (ValueError, TypeError) as e:
raise InvalidPortError(port) from e
if port_int < MIN_PORT or port_int > MAX_PORT:
raise InvalidPortError(port)
return port_int
# =============================================================================
# XNAT Identifier Validation
# =============================================================================
[docs]
def validate_xnat_identifier(
value: str,
identifier_type: str = "identifier",
*,
allow_empty: bool = False,
max_length: int = XNAT_ID_MAX_LENGTH,
) -> str:
"""Validate an XNAT identifier (project, subject, session, scan ID).
Args:
value: Identifier value to validate.
identifier_type: Type name for error messages.
allow_empty: If True, empty string is valid.
max_length: Maximum allowed length.
Returns:
Validated and stripped identifier.
Raises:
InvalidIdentifierError: If identifier is invalid.
"""
if not isinstance(value, str):
raise InvalidIdentifierError(identifier_type, str(value), "must be a string")
value = value.strip()
if not value:
if allow_empty:
return value
raise InvalidIdentifierError(identifier_type, value, "cannot be empty")
if len(value) > max_length:
raise InvalidIdentifierError(
identifier_type,
value,
f"exceeds maximum length of {max_length} characters",
)
if not XNAT_ID_PATTERN.match(value):
raise InvalidIdentifierError(
identifier_type,
value,
"must contain only alphanumeric characters, underscores, and hyphens",
)
return value
[docs]
def validate_project_id(project: str) -> str:
"""Validate XNAT project ID."""
return validate_xnat_identifier(project, "project")
[docs]
def validate_subject_id(subject: str) -> str:
"""Validate XNAT subject ID."""
return validate_xnat_identifier(subject, "subject")
[docs]
def validate_session_id(session: str) -> str:
"""Validate XNAT session/experiment ID."""
return validate_xnat_identifier(session, "session")
[docs]
def validate_scan_id(scan_id: str) -> str:
"""Validate XNAT scan ID (typically numeric but XNAT allows strings)."""
return validate_xnat_identifier(scan_id, "scan", max_length=32)
[docs]
def validate_resource_label(label: str) -> str:
"""Validate XNAT resource label (more flexible than other identifiers)."""
if not isinstance(label, str):
raise InvalidIdentifierError("resource_label", str(label), "must be a string")
label = label.strip()
if not label:
raise InvalidIdentifierError("resource_label", label, "cannot be empty")
if "/" in label or "\\" in label:
raise InvalidIdentifierError(
"resource_label",
label,
"cannot contain path separators",
)
if len(label) > 64:
raise InvalidIdentifierError(
"resource_label",
label,
"exceeds maximum length of 64 characters",
)
return label
# =============================================================================
# DICOM Validation
# =============================================================================
[docs]
def validate_ae_title(ae_title: str, field_name: str = "AE Title") -> str:
"""Validate DICOM Application Entity Title.
Per DICOM standard: 1-16 printable ASCII characters, no backslash.
"""
if not isinstance(ae_title, str):
raise InvalidIdentifierError(field_name, str(ae_title), "must be a string")
ae_title = ae_title.strip()
if not ae_title:
raise InvalidIdentifierError(field_name, ae_title, "cannot be empty")
if len(ae_title) > AE_TITLE_MAX_LENGTH:
raise InvalidIdentifierError(
field_name,
ae_title,
f"exceeds maximum length of {AE_TITLE_MAX_LENGTH} characters",
)
if not AE_TITLE_PATTERN.match(ae_title):
raise InvalidIdentifierError(
field_name,
ae_title,
"must contain only printable ASCII characters (no backslash)",
)
return ae_title
# =============================================================================
# Path Validation
# =============================================================================
[docs]
def validate_path_exists(
path: str | Path,
*,
must_be_file: bool = False,
must_be_dir: bool = False,
description: str = "path",
) -> Path:
"""Validate that a path exists and optionally check its type.
Args:
path: Path to validate.
must_be_file: If True, path must be a file.
must_be_dir: If True, path must be a directory.
description: Description for error messages.
Returns:
Resolved Path object.
Raises:
PathValidationError: If path is invalid or doesn't meet requirements.
"""
if isinstance(path, str):
path = Path(path)
if not path:
raise PathValidationError(str(path), f"{description} cannot be empty")
path = path.expanduser()
if not path.exists():
raise PathValidationError(str(path), f"{description} does not exist")
if must_be_file and not path.is_file():
raise PathValidationError(str(path), f"{description} must be a file")
if must_be_dir and not path.is_dir():
raise PathValidationError(str(path), f"{description} must be a directory")
return path.resolve()
[docs]
def validate_path_writable(
path: str | Path,
description: str = "path",
) -> Path:
"""Validate that a path is writable (parent directory exists and is writable).
Args:
path: Path to validate.
description: Description for error messages.
Returns:
Resolved Path object.
Raises:
PathValidationError: If path is not writable.
"""
if isinstance(path, str):
path = Path(path)
path = path.expanduser()
parent = path.parent
if not parent.exists():
raise PathValidationError(
str(path),
f"parent directory does not exist: {parent}",
)
if not os.access(parent, os.W_OK):
raise PathValidationError(
str(path),
f"parent directory is not writable: {parent}",
)
return path.resolve()
[docs]
def validate_archive_path(path: str | Path) -> Path:
"""Validate that path is a supported archive file.
Args:
path: Path to archive file.
Returns:
Resolved Path object.
Raises:
PathValidationError: If path is not a valid archive.
"""
resolved = validate_path_exists(path, must_be_file=True, description="archive")
suffix = resolved.suffix.lower()
if suffix == ".gz" and resolved.stem.endswith(".tar"):
suffix = ".tar.gz"
elif suffix == ".tgz":
suffix = ".tgz"
if suffix not in ALLOWED_ARCHIVE_EXTENSIONS:
raise PathValidationError(
str(resolved),
f"unsupported archive format. Allowed: {', '.join(sorted(ALLOWED_ARCHIVE_EXTENSIONS))}",
)
return resolved
[docs]
def validate_dicom_directory(path: str | Path) -> Path:
"""Validate that path is a directory suitable for DICOM files."""
resolved = validate_path_exists(path, must_be_dir=True, description="DICOM directory")
if not os.access(resolved, os.R_OK):
raise PathValidationError(str(resolved), "directory is not readable")
return resolved
# =============================================================================
# Configuration Validation
# =============================================================================
[docs]
def validate_timeout(
value: int | float | str | None,
field_name: str = "timeout",
*,
min_value: int = 1,
max_value: int = 86400 * 30, # 30 days
default: int = DEFAULT_HTTP_TIMEOUT_SECONDS,
) -> int:
"""Validate timeout value in seconds.
Args:
value: Timeout value to validate.
field_name: Field name for error messages.
min_value: Minimum allowed value.
max_value: Maximum allowed value.
default: Default value if None.
Returns:
Validated timeout in seconds.
Raises:
ConfigurationError: If timeout is invalid.
"""
if value is None:
return default
try:
timeout = int(value)
except (ValueError, TypeError) as e:
raise ConfigurationError(
f"{field_name} must be a valid integer",
field_name,
value,
) from e
if timeout < min_value:
raise ConfigurationError(
f"{field_name} must be at least {min_value} seconds",
field_name,
timeout,
)
if timeout > max_value:
raise ConfigurationError(
f"{field_name} cannot exceed {max_value} seconds",
field_name,
timeout,
)
return timeout
[docs]
def validate_workers(
value: int | str | None,
field_name: str = "workers",
*,
min_value: int = 1,
max_value: int = 100,
default: int = 4,
) -> int:
"""Validate worker count for parallel operations.
Args:
value: Worker count to validate.
field_name: Field name for error messages.
min_value: Minimum allowed value.
max_value: Maximum allowed value.
default: Default value if None.
Returns:
Validated worker count.
Raises:
ConfigurationError: If value is invalid.
"""
if value is None:
return default
try:
workers = int(value)
except (ValueError, TypeError) as e:
raise ConfigurationError(
f"{field_name} must be a valid integer",
field_name,
value,
) from e
if workers < min_value:
raise ConfigurationError(
f"{field_name} must be at least {min_value}",
field_name,
workers,
)
if workers > max_value:
raise ConfigurationError(
f"{field_name} cannot exceed {max_value}",
field_name,
workers,
)
return workers
[docs]
def validate_regex_pattern(pattern: str, field_name: str = "pattern") -> re.Pattern[str]:
"""Validate and compile a regex pattern.
Args:
pattern: Regex pattern string.
field_name: Field name for error messages.
Returns:
Compiled regex pattern.
Raises:
ConfigurationError: If pattern is invalid.
"""
if not pattern or not isinstance(pattern, str):
raise ConfigurationError(f"{field_name} cannot be empty", field_name, pattern)
try:
return re.compile(pattern)
except re.error as e:
raise ConfigurationError(
f"Invalid regex pattern: {e}",
field_name,
pattern,
) from e
# =============================================================================
# Batch Input Validation
# =============================================================================
[docs]
def validate_project_list(projects_input: str) -> list[str]:
"""Validate and parse comma-separated project IDs.
Args:
projects_input: Comma-separated project IDs.
Returns:
List of validated project IDs.
Raises:
InvalidIdentifierError: If any project ID is invalid.
"""
project_ids = []
for part in projects_input.split(","):
part = part.strip()
if part:
validated = validate_project_id(part)
project_ids.append(validated)
if not project_ids:
raise InvalidIdentifierError("project", projects_input, "no valid project IDs provided")
return project_ids