Source code for xnatctl.core.config

"""Configuration management for xnatctl.

Supports YAML profiles and environment variable overrides.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import yaml

from xnatctl.core.exceptions import ConfigurationError, ProfileNotFoundError
from xnatctl.core.timeouts import DEFAULT_HTTP_TIMEOUT_SECONDS

# =============================================================================
# Constants
# =============================================================================

CONFIG_DIR = Path.home() / ".config" / "xnatctl"
CONFIG_FILE = CONFIG_DIR / "config.yaml"
SESSION_CACHE_FILE = CONFIG_DIR / ".session"

# Environment variable names
ENV_URL = "XNAT_URL"
ENV_USER = "XNAT_USER"
ENV_PASS = "XNAT_PASS"
ENV_TOKEN = "XNAT_TOKEN"
ENV_PROFILE = "XNAT_PROFILE"
ENV_VERIFY_SSL = "XNAT_VERIFY_SSL"
ENV_TIMEOUT = "XNAT_TIMEOUT"


# =============================================================================
# Profile
# =============================================================================


_OPERATIONAL_FIELDS = ("workers", "overwrite", "direct_archive", "archive_mode", "extract")


[docs] @dataclass class Profile: """Configuration profile for an XNAT server.""" url: str verify_ssl: bool = True timeout: int = DEFAULT_HTTP_TIMEOUT_SECONDS default_project: str | None = None username: str | None = None password: str | None = None # Operational defaults (None = not configured, use command default) workers: int | None = None overwrite: str | None = None direct_archive: bool | None = None archive_mode: str | None = None extract: bool | None = None
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" result: dict[str, Any] = { "url": self.url, "verify_ssl": self.verify_ssl, "timeout": self.timeout, "default_project": self.default_project, } if self.username: result["username"] = self.username if self.password: result["password"] = self.password for field_name in _OPERATIONAL_FIELDS: val = getattr(self, field_name) if val is not None: result[field_name] = val return result
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Profile: """Create from dictionary.""" return cls( url=data.get("url", ""), verify_ssl=data.get("verify_ssl", True), timeout=data.get("timeout", DEFAULT_HTTP_TIMEOUT_SECONDS), default_project=data.get("default_project"), username=data.get("username"), password=data.get("password"), workers=data.get("workers"), overwrite=data.get("overwrite"), direct_archive=data.get("direct_archive"), archive_mode=data.get("archive_mode"), extract=data.get("extract"), )
# ============================================================================= # Config # =============================================================================
[docs] @dataclass class Config: """Application configuration.""" default_profile: str = "default" output_format: str = "table" profiles: dict[str, Profile] = field(default_factory=dict)
[docs] @classmethod def load(cls, config_path: Path | None = None) -> Config: """Load config from file with environment variable overrides. Priority (highest to lowest): 1. Environment variables 2. Config file 3. Defaults Args: config_path: Optional path to config file. Returns: Loaded configuration. """ path = config_path or CONFIG_FILE config = cls() # Load from file if exists if path.exists(): try: with open(path) as f: data = yaml.safe_load(f) or {} config.default_profile = data.get("default_profile", "default") config.output_format = data.get("output_format", "table") for name, pdata in data.get("profiles", {}).items(): config.profiles[name] = Profile.from_dict(pdata) except Exception as e: raise ConfigurationError(f"Failed to load config: {e}") from e # Environment variable overrides if url := os.getenv(ENV_URL): verify_ssl = os.getenv(ENV_VERIFY_SSL, "true").lower() in ("true", "1", "yes") timeout = int(os.getenv(ENV_TIMEOUT, str(DEFAULT_HTTP_TIMEOUT_SECONDS))) config.profiles["default"] = Profile( url=url, verify_ssl=verify_ssl, timeout=timeout, ) if profile := os.getenv(ENV_PROFILE): config.default_profile = profile return config
[docs] def save(self, config_path: Path | None = None) -> None: """Save config to file (excludes secrets). Args: config_path: Optional path to config file. """ path = config_path or CONFIG_FILE path.parent.mkdir(parents=True, exist_ok=True) data = { "default_profile": self.default_profile, "output_format": self.output_format, "profiles": {name: p.to_dict() for name, p in self.profiles.items()}, } with open(path, "w") as f: yaml.dump(data, f, default_flow_style=False, sort_keys=False)
[docs] def get_profile(self, name: str | None = None) -> Profile: """Get profile by name or default. Args: name: Profile name. If None, uses default_profile. Returns: Profile configuration. Raises: ProfileNotFoundError: If profile doesn't exist. """ name = name or self.default_profile if name not in self.profiles: raise ProfileNotFoundError(name) return self.profiles[name]
[docs] def has_profile(self, name: str) -> bool: """Check if profile exists.""" return name in self.profiles
[docs] def add_profile( self, name: str, url: str, verify_ssl: bool = True, timeout: int = DEFAULT_HTTP_TIMEOUT_SECONDS, default_project: str | None = None, ) -> Profile: """Add or update a profile. Args: name: Profile name. url: XNAT server URL. verify_ssl: Whether to verify SSL certificates. timeout: Request timeout in seconds. default_project: Default project ID. Returns: Created profile. """ profile = Profile( url=url, verify_ssl=verify_ssl, timeout=timeout, default_project=default_project, ) self.profiles[name] = profile return profile
[docs] def remove_profile(self, name: str) -> bool: """Remove a profile. Args: name: Profile name. Returns: True if removed, False if didn't exist. """ if name in self.profiles: del self.profiles[name] return True return False
[docs] def set_default_profile(self, name: str) -> None: """Set the default profile. Args: name: Profile name to set as default. Raises: ProfileNotFoundError: If profile doesn't exist. """ if name not in self.profiles: raise ProfileNotFoundError(name) self.default_profile = name
def get_credentials(profile: Profile | None = None) -> tuple[str | None, str | None]: """Get credentials with priority: env vars > profile config. Args: profile: Optional profile to read credentials from. Returns: Tuple of (username, password). """ username = os.getenv(ENV_USER) password = os.getenv(ENV_PASS) if profile: if not username and profile.username: username = profile.username if not password and profile.password: password = profile.password return username, password def get_token() -> str | None: """Get session token from environment variable. Returns: Token if set, None otherwise. """ return os.getenv(ENV_TOKEN)