"""HTTP client for XNAT REST API.
Provides retry logic, pagination, and session-based authentication.
"""
from __future__ import annotations
import re
import time
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Any
from urllib.parse import quote
import httpx
from xnatctl.core.exceptions import (
AuthenticationError,
NetworkError,
PermissionDeniedError,
ResourceNotFoundError,
RetryExhaustedError,
ServerUnreachableError,
SessionExpiredError,
)
from xnatctl.core.timeouts import DEFAULT_HTTP_TIMEOUT_SECONDS
from xnatctl.core.validation import validate_server_url
# =============================================================================
# Constants
# =============================================================================
DEFAULT_TIMEOUT = DEFAULT_HTTP_TIMEOUT_SECONDS
DEFAULT_MAX_RETRIES = 3
RETRY_BACKOFF_BASE = 2
RETRYABLE_STATUS_CODES = {502, 503, 504}
_AUTH_LOGGED_IN_RE = re.compile(r"User '([^']+)' is logged in")
# =============================================================================
# XNATClient
# =============================================================================
[docs]
@dataclass
class XNATClient:
"""HTTP client for XNAT REST API with retry and pagination."""
base_url: str
username: str | None = None
password: str | None = None
session_token: str | None = None
timeout: int = DEFAULT_TIMEOUT
max_retries: int = DEFAULT_MAX_RETRIES
verify_ssl: bool = True
auto_reauth: bool = False
_client: httpx.Client | None = field(init=False, default=None, repr=False)
def __post_init__(self) -> None:
"""Validate and normalize URL."""
self.base_url = validate_server_url(self.base_url)
# =========================================================================
# Client Management
# =========================================================================
def _get_client(self) -> httpx.Client:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.Client(
base_url=self.base_url,
timeout=self.timeout,
verify=self.verify_ssl,
follow_redirects=True,
)
return self._client
[docs]
def close(self) -> None:
"""Close the HTTP client."""
if self._client is not None:
self._client.close()
self._client = None
[docs]
def __enter__(self) -> XNATClient:
return self
[docs]
def __exit__(self, *args: Any) -> None:
self.close()
# =========================================================================
# Authentication
# =========================================================================
@property
def is_authenticated(self) -> bool:
"""Check if client has a session token."""
return self.session_token is not None
[docs]
def authenticate(self) -> str:
"""Authenticate with username/password and get JSESSIONID.
Returns:
Session token (JSESSIONID).
Raises:
AuthenticationError: If authentication fails.
"""
if not self.username or not self.password:
raise AuthenticationError(self.base_url, "Username and password required")
client = self._get_client()
try:
resp = client.post(
"/data/JSESSION",
auth=(self.username, self.password),
)
except httpx.ConnectError as e:
raise ServerUnreachableError(self.base_url) from e
except httpx.TimeoutException as e:
raise NetworkError(self.base_url, f"Timeout: {e}") from e
if resp.status_code != 200:
raise AuthenticationError(self.base_url, f"HTTP {resp.status_code}")
# XNAT returns HTML on auth failure
if "<html" in resp.text.lower():
raise AuthenticationError(self.base_url, "Invalid credentials or password expired")
self.session_token = resp.text.strip()
return self.session_token
[docs]
def invalidate_session(self) -> None:
"""Logout and clear session token."""
if self.session_token:
try:
client = self._get_client()
client.delete(
"/data/JSESSION",
cookies={"JSESSIONID": self.session_token},
)
except Exception:
pass # Best effort
finally:
self.session_token = None
# =========================================================================
# HTTP Methods
# =========================================================================
def _get_cookies(self) -> dict[str, str]:
"""Get cookies for request."""
if self.session_token:
return {"JSESSIONID": self.session_token}
return {}
def _get_auth(self) -> tuple[str, str] | None:
"""Get basic auth tuple if no session token."""
if not self.session_token and self.username and self.password:
return (self.username, self.password)
return None
def _request(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: Any | None = None,
data: Any | None = None,
files: Any | None = None,
headers: dict[str, str] | None = None,
timeout: int | None = None,
stream: bool = False,
) -> httpx.Response:
"""Execute HTTP request with retry logic.
Args:
method: HTTP method.
path: API path.
params: Query parameters.
json: JSON body.
data: Form data or raw body.
files: Files to upload.
headers: Additional headers.
timeout: Request timeout override.
stream: Whether to stream response.
Returns:
HTTP response.
Raises:
AuthenticationError: If authentication fails.
NetworkError: If network error occurs.
RetryExhaustedError: If all retries fail.
"""
client = self._get_client()
request_timeout = timeout or self.timeout
last_error: Exception | None = None
did_reauth = False
attempt = 0
while attempt <= self.max_retries:
cookies = self._get_cookies()
auth = self._get_auth()
try:
resp = client.request(
method,
path,
params=params,
json=json,
data=data if data is not None else None,
files=files,
headers=headers,
cookies=cookies,
auth=auth,
timeout=request_timeout,
)
# Handle auth errors
if resp.status_code == 401:
if self.auto_reauth and not did_reauth and self.username and self.password:
self.authenticate()
did_reauth = True
continue
expired_err = SessionExpiredError(self.base_url)
expired_err.details.update(
{"status_code": resp.status_code, "method": method, "path": path}
)
raise expired_err
if resp.status_code == 403:
denied_err = PermissionDeniedError(
resource=path,
operation=method.lower(),
url=self.base_url,
)
denied_err.details.update(
{"status_code": resp.status_code, "method": method, "path": path}
)
raise denied_err
# Handle 404
if resp.status_code == 404:
raise ResourceNotFoundError("resource", path)
# Retry on server errors
if resp.status_code in RETRYABLE_STATUS_CODES:
last_error = NetworkError(
self.base_url,
f"HTTP {resp.status_code}",
)
if attempt < self.max_retries:
time.sleep(RETRY_BACKOFF_BASE ** (attempt + 1))
attempt += 1
continue
# Success or non-retryable error
resp.raise_for_status()
return resp
except httpx.ConnectError:
last_error = ServerUnreachableError(self.base_url)
except httpx.TimeoutException:
last_error = NetworkError(self.base_url, f"Timeout after {request_timeout}s")
except (AuthenticationError, ResourceNotFoundError):
raise
# Retry with backoff
if attempt < self.max_retries:
time.sleep(RETRY_BACKOFF_BASE ** (attempt + 1))
attempt += 1
raise RetryExhaustedError("request", self.max_retries + 1, last_error)
[docs]
def get(
self,
path: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout: int | None = None,
stream: bool = False,
) -> httpx.Response:
"""GET request."""
return self._request(
"GET",
path,
params=params,
headers=headers,
timeout=timeout,
stream=stream,
)
[docs]
def post(
self,
path: str,
*,
params: dict[str, Any] | None = None,
json: Any | None = None,
data: Any | None = None,
files: Any | None = None,
headers: dict[str, str] | None = None,
timeout: int | None = None,
) -> httpx.Response:
"""POST request."""
return self._request(
"POST",
path,
params=params,
json=json,
data=data,
files=files,
headers=headers,
timeout=timeout,
)
[docs]
def put(
self,
path: str,
*,
params: dict[str, Any] | None = None,
json: Any | None = None,
data: Any | None = None,
files: Any | None = None,
headers: dict[str, str] | None = None,
timeout: int | None = None,
) -> httpx.Response:
"""PUT request."""
return self._request(
"PUT",
path,
params=params,
json=json,
data=data,
files=files,
headers=headers,
timeout=timeout,
)
[docs]
def delete(
self,
path: str,
*,
params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
timeout: int | None = None,
) -> httpx.Response:
"""DELETE request."""
return self._request(
"DELETE",
path,
params=params,
headers=headers,
timeout=timeout,
)
# =========================================================================
# Pagination
# =========================================================================
[docs]
def paginate(
self,
path: str,
*,
params: dict[str, Any] | None = None,
page_size: int = 100,
result_key: str = "ResultSet.Result",
) -> Iterator[dict[str, Any]]:
"""Paginated GET returning items one by one.
Args:
path: API path.
params: Additional query parameters.
page_size: Number of items per page.
result_key: Dot-separated path to results in response.
Yields:
Individual result items.
"""
offset = 0
base_params = params.copy() if params else {}
base_params["format"] = "json"
while True:
page_params = {
**base_params,
"offset": offset,
"limit": page_size,
}
resp = self.get(path, params=page_params)
data = resp.json()
# Navigate to results using dot notation
results = data
for key in result_key.split("."):
results = results.get(key, []) if isinstance(results, dict) else []
if not results:
break
yield from results
offset += page_size
if len(results) < page_size:
break
# =========================================================================
# Convenience Methods
# =========================================================================
[docs]
def get_json(
self,
path: str,
*,
params: dict[str, Any] | None = None,
) -> Any:
"""GET request returning JSON."""
if params is None:
params = {}
params["format"] = "json"
resp = self.get(path, params=params)
return resp.json()
[docs]
def ping(self) -> dict[str, Any]:
"""Check server connectivity and get version info.
Returns:
Dict with server info.
Raises:
NetworkError: If server is unreachable.
"""
start = time.time()
resp = self.get("/xapi/siteConfig/buildInfo/version")
latency = int((time.time() - start) * 1000)
return {
"url": self.base_url,
"status": "ok",
"version": resp.text.strip(),
"latency_ms": latency,
}
[docs]
def whoami(self) -> dict[str, Any]:
"""Get current user information.
Returns:
Dict with user info.
Raises:
AuthenticationError: If not authenticated.
"""
current_username = self._get_current_username()
if current_username:
display_username = self._apply_username_hint(current_username)
details = self._get_user_details(current_username)
if details is not None:
details["username"] = display_username
return details
return {
"username": display_username,
"firstname": "",
"lastname": "",
"email": "",
"enabled": True,
}
if self.username:
return {
"username": self.username,
"firstname": "",
"lastname": "",
"email": "",
"enabled": True,
}
return {
"username": "unknown",
"firstname": "",
"lastname": "",
"email": "",
"enabled": False,
}
def _get_current_username(self) -> str | None:
"""Resolve the authenticated username from server endpoints.
Some XNAT deployments return a full user listing from `/data/user`,
which is not a reliable whoami source. Prefer dedicated current-user
endpoints when available.
"""
try:
resp = self.get("/xapi/users/username")
username = resp.text.strip()
if username and "<html" not in username.lower():
return username
except (AuthenticationError, SessionExpiredError, PermissionDeniedError):
raise
except Exception:
pass
try:
resp = self.get("/data/auth")
match = _AUTH_LOGGED_IN_RE.search(resp.text)
if match:
return match.group(1).strip()
except (AuthenticationError, SessionExpiredError, PermissionDeniedError):
raise
except Exception:
pass
return None
def _get_user_details(self, username: str) -> dict[str, Any] | None:
"""Fetch user details for a resolved username, if available."""
try:
data = self.get_json(f"/xapi/users/{quote(username, safe='')}")
except Exception:
return None
if not isinstance(data, dict):
return None
return {
"username": data.get("username", username),
"firstname": data.get("firstName", ""),
"lastname": data.get("lastName", ""),
"email": data.get("email", ""),
"enabled": data.get("enabled", False),
}
def _apply_username_hint(self, username: str) -> str:
"""Preserve configured/cached username casing when it matches."""
if self.username and self.username.casefold() == username.casefold():
return self.username
return username