Source code for xnatctl.core.auth

"""Authentication management for xnatctl.

Handles credential storage and session token caching.
"""

from __future__ import annotations

import json
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path

from xnatctl.core.config import CONFIG_DIR, ENV_PASS, ENV_TOKEN, ENV_USER

# =============================================================================
# Constants
# =============================================================================

SESSION_CACHE_FILE = CONFIG_DIR / ".session"
SESSION_EXPIRY_MINUTES = 15  # XNAT JSESSION expires after 15 minutes of inactivity by default


# =============================================================================
# Session Cache
# =============================================================================


[docs] @dataclass class CachedSession: """Cached session token with metadata.""" token: str url: str username: str created_at: datetime expires_at: datetime | None = None
[docs] def is_expired(self) -> bool: """Check if session has expired.""" if self.expires_at: return datetime.now() >= self.expires_at return False
[docs] def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return { "token": self.token, "url": self.url, "username": self.username, "created_at": self.created_at.isoformat(), "expires_at": self.expires_at.isoformat() if self.expires_at else None, }
[docs] @classmethod def from_dict(cls, data: dict) -> CachedSession: """Create from dictionary.""" return cls( token=data["token"], url=data["url"], username=data["username"], created_at=datetime.fromisoformat(data["created_at"]), expires_at=( datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None ), )
# ============================================================================= # AuthManager # =============================================================================
[docs] class AuthManager: """Manages authentication credentials and session tokens.""" def __init__(self, cache_file: Path | None = None): """Initialize auth manager. Args: cache_file: Path to session cache file. """ self.cache_file = cache_file or SESSION_CACHE_FILE # ========================================================================= # Credential Access # =========================================================================
[docs] def get_credentials(self) -> tuple[str | None, str | None]: """Get credentials from environment variables. Returns: Tuple of (username, password). """ username = os.getenv(ENV_USER) password = os.getenv(ENV_PASS) return username, password
[docs] def get_token_from_env(self) -> str | None: """Get session token from environment variable. Returns: Token if set. """ return os.getenv(ENV_TOKEN)
# ========================================================================= # Session Cache # =========================================================================
[docs] def save_session( self, token: str, url: str, username: str, expiry_minutes: int = SESSION_EXPIRY_MINUTES, ) -> CachedSession: """Save session token to cache. Args: token: Session token (JSESSIONID). url: XNAT server URL. username: Username used for authentication. expiry_minutes: Minutes until session is considered expired. Returns: Cached session object. """ now = datetime.now() session = CachedSession( token=token, url=url, username=username, created_at=now, expires_at=now + timedelta(minutes=expiry_minutes), ) # Ensure directory exists self.cache_file.parent.mkdir(parents=True, exist_ok=True) # Write cache file with restricted permissions with open(self.cache_file, "w") as f: json.dump(session.to_dict(), f) # Set restrictive permissions (owner read/write only) try: os.chmod(self.cache_file, 0o600) except OSError: pass # May fail on some systems return session
[docs] def load_session(self, url: str | None = None) -> CachedSession | None: """Load cached session token. Args: url: Optional URL to match. If provided, only returns session for that URL. Returns: Cached session if valid, None otherwise. """ if not self.cache_file.exists(): return None try: with open(self.cache_file) as f: data = json.load(f) session = CachedSession.from_dict(data) # Check URL match if url and session.url != url: return None # Check expiry if session.is_expired(): self.clear_session() return None return session except (json.JSONDecodeError, KeyError, ValueError): # Invalid cache file self.clear_session() return None
[docs] def clear_session(self) -> bool: """Clear cached session. Returns: True if cache was cleared. """ if self.cache_file.exists(): try: self.cache_file.unlink() return True except OSError: pass return False
[docs] def has_valid_session(self, url: str | None = None) -> bool: """Check if there's a valid cached session. Args: url: Optional URL to match. Returns: True if valid session exists. """ session = self.load_session(url) return session is not None and not session.is_expired()
# ========================================================================= # Convenience Methods # =========================================================================
[docs] def get_session_token(self, url: str | None = None) -> str | None: """Get session token from cache or environment. Priority: 1. Environment variable (XNAT_TOKEN) 2. Cached session Args: url: Optional URL to match for cached session. Returns: Session token if available. """ # Check environment first if token := self.get_token_from_env(): return token # Check cache if session := self.load_session(url): return session.token return None
[docs] def get_session_info(self, url: str | None = None) -> dict | None: """Get session information for display. Args: url: Optional URL to match. Returns: Dict with session info or None. """ session = self.load_session(url) if not session: return None return { "url": session.url, "username": session.username, "created_at": session.created_at.isoformat(), "expires_at": session.expires_at.isoformat() if session.expires_at else None, "is_expired": session.is_expired(), }