Source code for xnatctl.core.client

"""HTTP client for XNAT REST API.

Provides retry logic, pagination, and session-based authentication.
"""

from __future__ import annotations

import re
import time
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
from urllib.parse import quote

import httpx

from xnatctl.core.exceptions import (
    AuthenticationError,
    NetworkError,
    PermissionDeniedError,
    ResourceNotFoundError,
    RetryExhaustedError,
    ServerUnreachableError,
    SessionExpiredError,
)
from xnatctl.core.timeouts import DEFAULT_HTTP_TIMEOUT_SECONDS
from xnatctl.core.validation import validate_server_url

# =============================================================================
# Constants
# =============================================================================

DEFAULT_TIMEOUT = DEFAULT_HTTP_TIMEOUT_SECONDS
DEFAULT_MAX_RETRIES = 3
RETRY_BACKOFF_BASE = 2
RETRYABLE_STATUS_CODES = {502, 503, 504}
_AUTH_LOGGED_IN_RE = re.compile(r"User '([^']+)' is logged in")


# =============================================================================
# XNATClient
# =============================================================================


[docs] @dataclass class XNATClient: """HTTP client for XNAT REST API with retry and pagination.""" base_url: str username: str | None = None password: str | None = None session_token: str | None = None timeout: int = DEFAULT_TIMEOUT max_retries: int = DEFAULT_MAX_RETRIES verify_ssl: bool = True auto_reauth: bool = False _client: httpx.Client | None = field(init=False, default=None, repr=False) def __post_init__(self) -> None: """Validate and normalize URL.""" self.base_url = validate_server_url(self.base_url) # ========================================================================= # Client Management # ========================================================================= def _get_client(self) -> httpx.Client: """Get or create HTTP client.""" if self._client is None: self._client = httpx.Client( base_url=self.base_url, timeout=self.timeout, verify=self.verify_ssl, follow_redirects=True, ) return self._client
[docs] def close(self) -> None: """Close the HTTP client.""" if self._client is not None: self._client.close() self._client = None
[docs] def __enter__(self) -> XNATClient: return self
[docs] def __exit__(self, *args: Any) -> None: self.close()
# ========================================================================= # Authentication # ========================================================================= @property def is_authenticated(self) -> bool: """Check if client has a session token.""" return self.session_token is not None
[docs] def authenticate(self) -> str: """Authenticate with username/password and get JSESSIONID. Returns: Session token (JSESSIONID). Raises: AuthenticationError: If authentication fails. """ if not self.username or not self.password: raise AuthenticationError(self.base_url, "Username and password required") client = self._get_client() try: resp = client.post( "/data/JSESSION", auth=(self.username, self.password), ) except httpx.ConnectError as e: raise ServerUnreachableError(self.base_url) from e except httpx.TimeoutException as e: raise NetworkError(self.base_url, f"Timeout: {e}") from e if resp.status_code != 200: raise AuthenticationError(self.base_url, f"HTTP {resp.status_code}") # XNAT returns HTML on auth failure if "<html" in resp.text.lower(): raise AuthenticationError(self.base_url, "Invalid credentials or password expired") self.session_token = resp.text.strip() return self.session_token
[docs] def invalidate_session(self) -> None: """Logout and clear session token.""" if self.session_token: try: client = self._get_client() client.delete( "/data/JSESSION", cookies={"JSESSIONID": self.session_token}, ) except Exception: pass # Best effort finally: self.session_token = None
# ========================================================================= # HTTP Methods # ========================================================================= def _get_cookies(self) -> dict[str, str]: """Get cookies for request.""" if self.session_token: return {"JSESSIONID": self.session_token} return {} def _get_auth(self) -> tuple[str, str] | None: """Get basic auth tuple if no session token.""" if not self.session_token and self.username and self.password: return (self.username, self.password) return None def _request( self, method: str, path: str, *, params: dict[str, Any] | None = None, json: Any | None = None, data: Any | None = None, files: Any | None = None, headers: dict[str, str] | None = None, timeout: int | None = None, stream: bool = False, ) -> httpx.Response: """Execute HTTP request with retry logic. Args: method: HTTP method. path: API path. params: Query parameters. json: JSON body. data: Form data or raw body. files: Files to upload. headers: Additional headers. timeout: Request timeout override. stream: Whether to stream response. Returns: HTTP response. Raises: AuthenticationError: If authentication fails. NetworkError: If network error occurs. RetryExhaustedError: If all retries fail. """ client = self._get_client() request_timeout = timeout or self.timeout last_error: Exception | None = None did_reauth = False attempt = 0 while attempt <= self.max_retries: cookies = self._get_cookies() auth = self._get_auth() try: resp = client.request( method, path, params=params, json=json, data=data if data is not None else None, files=files, headers=headers, cookies=cookies, auth=auth, timeout=request_timeout, ) # Handle auth errors if resp.status_code == 401: if self.auto_reauth and not did_reauth and self.username and self.password: self.authenticate() did_reauth = True continue expired_err = SessionExpiredError(self.base_url) expired_err.details.update( {"status_code": resp.status_code, "method": method, "path": path} ) raise expired_err if resp.status_code == 403: denied_err = PermissionDeniedError( resource=path, operation=method.lower(), url=self.base_url, ) denied_err.details.update( {"status_code": resp.status_code, "method": method, "path": path} ) raise denied_err # Handle 404 if resp.status_code == 404: raise ResourceNotFoundError("resource", path) # Retry on server errors if resp.status_code in RETRYABLE_STATUS_CODES: last_error = NetworkError( self.base_url, f"HTTP {resp.status_code}", ) if attempt < self.max_retries: time.sleep(RETRY_BACKOFF_BASE ** (attempt + 1)) attempt += 1 continue # Success or non-retryable error resp.raise_for_status() return resp except httpx.ConnectError: last_error = ServerUnreachableError(self.base_url) except httpx.TimeoutException: last_error = NetworkError(self.base_url, f"Timeout after {request_timeout}s") except (AuthenticationError, ResourceNotFoundError): raise # Retry with backoff if attempt < self.max_retries: time.sleep(RETRY_BACKOFF_BASE ** (attempt + 1)) attempt += 1 raise RetryExhaustedError("request", self.max_retries + 1, last_error)
[docs] def get( self, path: str, *, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None, timeout: int | None = None, stream: bool = False, ) -> httpx.Response: """GET request.""" return self._request( "GET", path, params=params, headers=headers, timeout=timeout, stream=stream, )
[docs] def post( self, path: str, *, params: dict[str, Any] | None = None, json: Any | None = None, data: Any | None = None, files: Any | None = None, headers: dict[str, str] | None = None, timeout: int | None = None, ) -> httpx.Response: """POST request.""" return self._request( "POST", path, params=params, json=json, data=data, files=files, headers=headers, timeout=timeout, )
[docs] def put( self, path: str, *, params: dict[str, Any] | None = None, json: Any | None = None, data: Any | None = None, files: Any | None = None, headers: dict[str, str] | None = None, timeout: int | None = None, ) -> httpx.Response: """PUT request.""" return self._request( "PUT", path, params=params, json=json, data=data, files=files, headers=headers, timeout=timeout, )
[docs] def delete( self, path: str, *, params: dict[str, Any] | None = None, headers: dict[str, str] | None = None, timeout: int | None = None, ) -> httpx.Response: """DELETE request.""" return self._request( "DELETE", path, params=params, headers=headers, timeout=timeout, )
# ========================================================================= # Pagination # =========================================================================
[docs] def paginate( self, path: str, *, params: dict[str, Any] | None = None, page_size: int = 100, result_key: str = "ResultSet.Result", ) -> Iterator[dict[str, Any]]: """Paginated GET returning items one by one. Args: path: API path. params: Additional query parameters. page_size: Number of items per page. result_key: Dot-separated path to results in response. Yields: Individual result items. """ offset = 0 base_params = params.copy() if params else {} base_params["format"] = "json" while True: page_params = { **base_params, "offset": offset, "limit": page_size, } resp = self.get(path, params=page_params) data = resp.json() # Navigate to results using dot notation results = data for key in result_key.split("."): results = results.get(key, []) if isinstance(results, dict) else [] if not results: break yield from results offset += page_size if len(results) < page_size: break
# ========================================================================= # Convenience Methods # =========================================================================
[docs] def get_json( self, path: str, *, params: dict[str, Any] | None = None, ) -> Any: """GET request returning JSON.""" if params is None: params = {} params["format"] = "json" resp = self.get(path, params=params) return resp.json()
[docs] def ping(self) -> dict[str, Any]: """Check server connectivity and get version info. Returns: Dict with server info. Raises: NetworkError: If server is unreachable. """ start = time.time() resp = self.get("/xapi/siteConfig/buildInfo/version") latency = int((time.time() - start) * 1000) return { "url": self.base_url, "status": "ok", "version": resp.text.strip(), "latency_ms": latency, }
[docs] def whoami(self) -> dict[str, Any]: """Get current user information. Returns: Dict with user info. Raises: AuthenticationError: If not authenticated. """ current_username = self._get_current_username() if current_username: display_username = self._apply_username_hint(current_username) details = self._get_user_details(current_username) if details is not None: details["username"] = display_username return details return { "username": display_username, "firstname": "", "lastname": "", "email": "", "enabled": True, } if self.username: return { "username": self.username, "firstname": "", "lastname": "", "email": "", "enabled": True, } return { "username": "unknown", "firstname": "", "lastname": "", "email": "", "enabled": False, }
def _get_current_username(self) -> str | None: """Resolve the authenticated username from server endpoints. Some XNAT deployments return a full user listing from `/data/user`, which is not a reliable whoami source. Prefer dedicated current-user endpoints when available. """ try: resp = self.get("/xapi/users/username") username = resp.text.strip() if username and "<html" not in username.lower(): return username except (AuthenticationError, SessionExpiredError, PermissionDeniedError): raise except Exception: pass try: resp = self.get("/data/auth") match = _AUTH_LOGGED_IN_RE.search(resp.text) if match: return match.group(1).strip() except (AuthenticationError, SessionExpiredError, PermissionDeniedError): raise except Exception: pass return None def _get_user_details(self, username: str) -> dict[str, Any] | None: """Fetch user details for a resolved username, if available.""" try: data = self.get_json(f"/xapi/users/{quote(username, safe='')}") except Exception: return None if not isinstance(data, dict): return None return { "username": data.get("username", username), "firstname": data.get("firstName", ""), "lastname": data.get("lastName", ""), "email": data.get("email", ""), "enabled": data.get("enabled", False), } def _apply_username_hint(self, username: str) -> str: """Preserve configured/cached username casing when it matches.""" if self.username and self.username.casefold() == username.casefold(): return self.username return username