"""Upload service for XNAT upload operations.
Provides UploadService with methods for all upload transports:
- REST batch upload (simple ZIP batches via import service)
- Parallel REST upload (batched archives with parallel workers)
- DICOM C-STORE upload (pynetdicom-based network transfer)
- Resource upload (file/directory upload to session resources)
Public utility functions (collect_dicom_files, split_into_batches, etc.)
are available for direct import and testing.
"""
from __future__ import annotations
import contextlib
import logging
import os
import shutil
import tarfile
import tempfile
import threading
import time
import zipfile
from collections.abc import Callable, Iterator, Sequence
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, as_completed, wait
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from zipfile import ZIP_DEFLATED, ZipFile
import httpx
from xnatctl.core.timeouts import DEFAULT_HTTP_TIMEOUT_SECONDS
from xnatctl.models.progress import (
OperationPhase,
UploadProgress,
UploadSummary,
)
from .base import BaseService
logger = logging.getLogger(__name__)
# =============================================================================
# Constants
# =============================================================================
DEFAULT_BATCH_SIZE = 500
DEFAULT_UPLOAD_WORKERS = 4
DEFAULT_ARCHIVE_WORKERS = 4
DEFAULT_ARCHIVE_FORMAT = "tar"
DEFAULT_TIMEOUT = DEFAULT_HTTP_TIMEOUT_SECONDS
DEFAULT_IMPORT_HANDLER = "DICOM-zip"
DEFAULT_OVERWRITE = "delete"
DEFAULT_DICOM_STORE_WORKERS = 4
DEFAULT_DICOM_CALLING_AET = "XNATCTL"
DEFAULT_DICOM_PORT = 104
DICOM_EXTENSIONS = {".dcm", ".ima", ".img", ".dicom"}
UPLOAD_MAX_RETRIES = 5
UPLOAD_RETRY_BACKOFF_BASE = 2 # seconds: 2, 4, 8, 16, 32
RETRYABLE_STATUS_CODES = {400, 429, 500, 502, 503, 504}
# When running with --verbose, we can log small snippets of retryable HTTP 400
# response bodies to help diagnose transient XNAT import races. Keep this capped
# to avoid flooding logs on very large uploads.
_RETRY_DEBUG_MAX_SNIPPETS = 20
_retry_debug_snippets_emitted = 0
_retry_debug_lock = threading.Lock()
# Gradual-DICOM uses one HTTP request per DICOM file; creating a new httpx.Client
# per file is expensive and can trigger transient ConnectError bursts under high
# concurrency. Reuse a persistent client per worker thread (keep-alive).
_GRADUAL_HTTP_TIMEOUT_SECONDS = 120.0
_gradual_client_local = threading.local()
_gradual_client_registry_lock = threading.Lock()
_gradual_client_registry: list[httpx.Client] = []
_gradual_scope_lock = threading.Lock()
_gradual_scope_refcount = 0
class _SessionRefresher:
"""Thread-safe XNAT session token manager.
When any worker thread encounters a 401 (expired session), it calls
:meth:`refresh`. Only the first thread to detect a stale token actually
re-authenticates; concurrent callers wait on the lock and receive the
already-refreshed token.
Args:
base_url: XNAT server URL.
verify_ssl: Whether to verify SSL certificates.
token: Initial JSESSIONID token.
username: Credentials for re-authentication.
password: Credentials for re-authentication.
"""
def __init__(
self,
base_url: str,
verify_ssl: bool,
token: str | None,
username: str | None,
password: str | None,
) -> None:
self._base_url = base_url
self._verify_ssl = verify_ssl
self._token = token
self._username = username
self._password = password
self._lock = threading.Lock()
@property
def token(self) -> str | None:
"""Current session token (may be updated by any thread)."""
with self._lock:
return self._token
def refresh(self, stale_token: str | None) -> str | None:
"""Re-authenticate and return a fresh token.
Thread-safe: if another thread already refreshed past *stale_token*,
the cached fresh token is returned without hitting the server again.
Args:
stale_token: The token that triggered the 401.
Returns:
Fresh session token, or the unchanged token if credentials are
unavailable.
"""
with self._lock:
if self._token != stale_token:
return self._token
if not self._username or not self._password:
logger.warning("Session expired but no credentials available for reauth")
return self._token
try:
with httpx.Client(
base_url=self._base_url,
verify=self._verify_ssl,
timeout=30.0,
) as client:
resp = client.post(
"/data/JSESSION",
auth=(self._username, self._password),
)
if resp.status_code == 200 and "<html" not in resp.text.lower():
self._token = resp.text.strip()
logger.info("Session refreshed successfully")
else:
logger.error("Session refresh failed: HTTP %d", resp.status_code)
except Exception:
logger.exception("Session refresh failed")
return self._token
def _get_gradual_http_client(*, base_url: str, verify_ssl: bool) -> httpx.Client:
"""Get a thread-local httpx.Client for gradual-DICOM uploads."""
key = (base_url, verify_ssl)
client: httpx.Client | None = getattr(_gradual_client_local, "client", None)
client_key: tuple[str, bool] | None = getattr(_gradual_client_local, "key", None)
if client is None or client_key != key or client.is_closed:
if client is not None:
try:
client.close()
except Exception:
pass
client = httpx.Client(
base_url=base_url,
timeout=_GRADUAL_HTTP_TIMEOUT_SECONDS,
verify=verify_ssl,
limits=httpx.Limits(max_connections=1, max_keepalive_connections=1),
)
_gradual_client_local.client = client
_gradual_client_local.key = key
with _gradual_client_registry_lock:
_gradual_client_registry.append(client)
return client
def _close_gradual_http_clients() -> None:
"""Close any thread-local clients created for gradual uploads."""
# Best-effort clear for the current thread so sequential operations don't
# accidentally reuse a closed client.
try:
_gradual_client_local.client = None
_gradual_client_local.key = None
except Exception:
pass
with _gradual_client_registry_lock:
clients = list(_gradual_client_registry)
_gradual_client_registry.clear()
for c in clients:
try:
c.close()
except Exception:
pass
@contextlib.contextmanager
def _gradual_http_clients_scope() -> Iterator[None]:
"""Scope gradual httpx client lifecycle to an upload operation.
Gradual uploads use a per-thread httpx.Client. Since the registry is global,
concurrent gradual upload operations must not close each other's clients.
This context manager refcounts active gradual operations and only performs a
global close when the last active operation completes.
"""
global _gradual_scope_refcount
with _gradual_scope_lock:
_gradual_scope_refcount += 1
try:
yield
finally:
with _gradual_scope_lock:
_gradual_scope_refcount -= 1
if _gradual_scope_refcount <= 0:
_gradual_scope_refcount = 0
_close_gradual_http_clients()
# =============================================================================
# DICOM C-STORE Result (separate from REST models)
# =============================================================================
[docs]
@dataclass
class DICOMStoreSummary:
"""Summary of a DICOM C-STORE operation."""
total_files: int
sent: int
failed: int
log_dir: Path
workspace: Path
success: bool
# =============================================================================
# Internal Batch Result
# =============================================================================
@dataclass
class _UploadResult:
"""Result of a single batch upload (internal)."""
batch_id: int
success: bool
duration: float
file_count: int
archive_size: int
error: str = ""
# =============================================================================
# Public Utility Functions
# =============================================================================
[docs]
def collect_dicom_files(
root: Path,
*,
include_extensionless: bool = True,
) -> list[Path]:
"""Recursively collect DICOM-like files under a root directory.
Args:
root: Root directory to search.
include_extensionless: If True, include files without extensions
(common for raw DICOM from scanners).
Returns:
Sorted list of file paths.
Raises:
ValueError: If root is not a directory.
"""
if not root.exists() or not root.is_dir():
raise ValueError(f"Not a directory: {root}")
files: list[Path] = []
for path in root.rglob("*"):
if not path.is_file():
continue
if path.is_symlink():
try:
resolved = path.resolve()
if not resolved.exists():
continue
except (OSError, ValueError):
continue
if _is_dicom_like_path(path, include_extensionless=include_extensionless):
files.append(path)
return sorted(files)
def _has_dicom_magic(path: Path) -> bool:
"""Return True if the file has the DICOM preamble magic bytes (DICM at offset 128)."""
try:
with open(path, "rb") as f:
f.seek(128)
return f.read(4) == b"DICM"
except OSError:
return False
def _is_dicom_like_path(path: Path, *, include_extensionless: bool = True) -> bool:
"""Return True when a path looks like a DICOM file we should ingest."""
if path.name.startswith("."):
return False
suffix = path.suffix.lower()
if suffix in DICOM_EXTENSIONS:
return True
if include_extensionless and suffix == "":
return _has_dicom_magic(path)
return False
[docs]
def archive_destination_params(project: str, direct_archive: bool) -> dict[str, str]:
"""Return the querystring keys that route a POST /data/services/import.
* Direct-archive path: ``Direct-Archive=true`` — handled by the
``DICOM-zip`` and ``gradual-DICOM`` import handlers; bypasses the
prearchive and writes straight to the project archive.
* Prearchive path: ``dest=/prearchive/projects/{project}`` — the
documented destination form. ``Direct-Archive=false`` alone is
equivalent to "use standard upload mechanism"; we prefer the
explicit ``dest`` because it is self-describing and matches the
``PrearchiveService`` pattern used elsewhere in this repo.
Caveat: neither form can prevent a *project-configured* auto-archive.
XNAT's ``prearchive_code`` on the project (0=manual, 4/5=auto) is the
authoritative switch. When a project has auto-archive enabled, a
session uploaded via either of these paths will land in prearchive
momentarily then be auto-archived by the server. To force
prearchive-only behaviour, the project's prearchive setting must be
changed to "Leave in prearchive" (prearchive_code=0). There is no
per-upload import-service override for this on XNAT 1.8+.
"""
if direct_archive:
return {"Direct-Archive": "true"}
return {"dest": f"/prearchive/projects/{project}"}
[docs]
def split_into_batches(
files: Sequence[Path],
batch_size: int,
) -> list[list[Path]]:
"""Split files into batches of specified size.
Args:
files: Sequence of file paths to split.
batch_size: Maximum files per batch.
Returns:
List of batches, each batch being a list of paths.
"""
if not files:
return []
if batch_size <= 0:
return [list(files)]
batches: list[list[Path]] = []
current_batch: list[Path] = []
for file_path in files:
current_batch.append(file_path)
if len(current_batch) >= batch_size:
batches.append(current_batch)
current_batch = []
if current_batch:
batches.append(current_batch)
return batches
[docs]
def split_into_n_batches(
files: Sequence[Path],
num_batches: int,
) -> list[list[Path]]:
"""Split files into N roughly equal batches using round-robin.
Args:
files: Sequence of file paths to split.
num_batches: Number of batches to create.
Returns:
List of batches, each batch being a list of paths.
"""
if not files:
return []
if num_batches <= 0:
return [list(files)]
actual_batches = min(num_batches, len(files))
batches: list[list[Path]] = [[] for _ in range(actual_batches)]
for idx, file_path in enumerate(files):
batches[idx % actual_batches].append(file_path)
return batches
[docs]
def is_retryable_status(status_code: int) -> bool:
"""Check if an HTTP status code warrants a retry.
Retryable: 400 (XNAT transient), 429 (rate limit), 5xx (server errors).
Non-retryable: 2xx (success), 401/403 (auth), other 4xx (client error).
"""
return status_code in RETRYABLE_STATUS_CODES
[docs]
def upload_with_retry(
upload_fn: Callable[[], httpx.Response],
*,
max_retries: int = UPLOAD_MAX_RETRIES,
backoff_base: int = UPLOAD_RETRY_BACKOFF_BASE,
label: str = "upload",
) -> httpx.Response:
"""Execute an upload function with retry on transient HTTP errors.
Args:
upload_fn: Callable that performs the upload and returns an httpx.Response.
Will be called multiple times on retry -- must be idempotent.
max_retries: Maximum number of retries (default: 5).
backoff_base: Base for exponential backoff in seconds (default: 2).
label: Label for log messages.
Returns:
The httpx.Response from a successful attempt.
Raises:
The last exception if all retries are exhausted and no response was obtained.
"""
last_resp = None
last_exc: Exception | None = None
for attempt in range(max_retries + 1):
try:
resp = upload_fn()
if not is_retryable_status(resp.status_code):
return resp
last_resp = resp
last_exc = None
if attempt < max_retries:
delay = backoff_base ** (attempt + 1)
# Optional debug detail for transient XNAT 400s
if resp.status_code == 400 and logger.isEnabledFor(logging.DEBUG):
global _retry_debug_snippets_emitted
with _retry_debug_lock:
should_log = _retry_debug_snippets_emitted < _RETRY_DEBUG_MAX_SNIPPETS
if should_log:
_retry_debug_snippets_emitted += 1
if should_log:
try:
snippet = resp.text.strip().replace("\n", " ")
if snippet:
logger.debug(
"%s: retryable HTTP 400 body: %s",
label,
snippet[:200],
)
except Exception:
pass
logger.warning(
"%s: HTTP %d on attempt %d/%d, retrying in %ds",
label,
resp.status_code,
attempt + 1,
max_retries + 1,
delay,
)
time.sleep(delay)
except (httpx.TimeoutException, httpx.ConnectError) as e:
last_exc = e
last_resp = None
if attempt < max_retries:
delay = backoff_base ** (attempt + 1)
detail = f"{type(e).__name__}: {str(e).strip().replace(chr(10), ' ')}"
logger.warning(
"%s: %s on attempt %d/%d, retrying in %ds",
label,
detail,
attempt + 1,
max_retries + 1,
delay,
)
time.sleep(delay)
if last_resp is not None:
return last_resp
if last_exc is not None:
raise last_exc
raise RuntimeError(f"{label}: all retries exhausted with no response")
# =============================================================================
# Archive Helpers (private)
# =============================================================================
def _create_tar_archive(files: list[Path], output_path: Path, base_dir: Path) -> int:
"""Create a TAR archive from files, returning size in bytes."""
with tarfile.open(output_path, "w") as tf:
for file_path in files:
arcname = os.path.relpath(file_path, base_dir)
tf.add(file_path, arcname=arcname)
return output_path.stat().st_size
def _create_zip_archive(files: list[Path], output_path: Path, base_dir: Path) -> int:
"""Create a ZIP archive from files, returning size in bytes."""
with ZipFile(output_path, "w", compression=ZIP_DEFLATED, allowZip64=True) as zf:
for file_path in files:
arcname = os.path.relpath(file_path, base_dir)
zf.write(file_path, arcname)
return output_path.stat().st_size
def _create_archive(
files: list[Path],
output_path: Path,
base_dir: Path,
archive_format: str,
) -> int:
"""Create an archive from files.
Args:
files: List of file paths to include.
output_path: Path for the output archive.
base_dir: Base directory for relative paths in archive.
archive_format: Format ("tar" or "zip").
Returns:
Size of created archive in bytes.
Raises:
ValueError: If archive format is unsupported.
"""
if archive_format == "tar":
return _create_tar_archive(files, output_path, base_dir)
if archive_format == "zip":
return _create_zip_archive(files, output_path, base_dir)
raise ValueError(f"Unsupported archive format: {archive_format}")
# =============================================================================
# Parallel Upload Helpers (private, thread-safe standalone functions)
# =============================================================================
def _upload_single_archive(
*,
base_url: str,
username: str | None,
password: str | None,
session_token: str | None,
verify_ssl: bool,
timeout: int,
archive_path: Path,
project: str,
subject: str,
session: str,
import_handler: str,
ignore_unparsable: bool,
overwrite: str,
direct_archive: bool,
) -> tuple[bool, str]:
"""Upload a single archive file to XNAT.
Creates a fresh httpx client for thread-safety in parallel execution.
Returns:
Tuple of (success, error_message).
"""
name = archive_path.name.lower()
content_type = (
"application/x-tar" if name.endswith((".tar", ".tar.gz", ".tgz")) else "application/zip"
)
params = {
"import-handler": import_handler,
"Ignore-Unparsable": "true" if ignore_unparsable else "false",
"project": project,
"subject": subject,
"session": session,
"overwrite": overwrite,
"overwrite_files": "true",
"quarantine": "false",
"triggerPipelines": "true",
"rename": "false",
"inbody": "true",
**archive_destination_params(project, direct_archive),
}
with httpx.Client(
base_url=base_url,
timeout=timeout,
verify=verify_ssl,
) as client:
try:
cookies: dict[str, str] = {}
created_session = False
if session_token:
cookies = {"JSESSIONID": session_token}
else:
if not username or not password:
return False, "Authentication failed: missing credentials"
auth_resp = client.post(
"/data/JSESSION",
auth=(str(username), str(password)),
)
if auth_resp.status_code != 200:
return False, f"Authentication failed: HTTP {auth_resp.status_code}"
if "<html" in auth_resp.text.lower():
return False, "Authentication failed: invalid credentials"
session_token = auth_resp.text.strip()
cookies = {"JSESSIONID": session_token}
created_session = True
def _attempt() -> httpx.Response:
with archive_path.open("rb") as data:
return client.post(
"/data/services/import",
params=params,
headers={"Content-Type": content_type},
content=data,
cookies=cookies,
)
try:
resp = upload_with_retry(_attempt, label=f"batch {archive_path.name}")
finally:
if created_session:
try:
client.delete("/data/JSESSION", cookies=cookies)
except Exception:
pass
if resp.status_code == 200:
return True, ""
if resp.status_code in (401, 403):
return False, "Authentication failed: invalid or expired session"
return False, f"HTTP {resp.status_code}: {resp.text[:200]}"
except httpx.TimeoutException:
return False, "Upload timed out (after retries)"
except httpx.ConnectError as e:
return False, f"Connection failed (after retries): {e}"
except Exception as e:
return False, str(e)
def _upload_batch(
*,
base_url: str,
username: str | None,
password: str | None,
session_token: str | None,
verify_ssl: bool,
timeout: int,
batch_id: int,
archive_path: Path,
file_count: int,
project: str,
subject: str,
session: str,
import_handler: str,
ignore_unparsable: bool,
overwrite: str,
direct_archive: bool,
) -> _UploadResult:
"""Upload a single batch archive, returning an _UploadResult."""
archive_size = archive_path.stat().st_size
start_time = time.time()
try:
success, error = _upload_single_archive(
base_url=base_url,
username=username,
password=password,
session_token=session_token,
verify_ssl=verify_ssl,
timeout=timeout,
archive_path=archive_path,
project=project,
subject=subject,
session=session,
import_handler=import_handler,
ignore_unparsable=ignore_unparsable,
overwrite=overwrite,
direct_archive=direct_archive,
)
return _UploadResult(
batch_id=batch_id,
success=success,
duration=time.time() - start_time,
file_count=file_count,
archive_size=archive_size,
error=error,
)
except Exception as e:
return _UploadResult(
batch_id=batch_id,
success=False,
duration=time.time() - start_time,
file_count=file_count,
archive_size=archive_size,
error=str(e),
)
def _create_and_upload_batch(
*,
batch: list[Path],
archive_path: Path,
source_path: Path,
archive_format: str,
base_url: str,
username: str | None,
password: str | None,
session_token: str | None,
verify_ssl: bool,
timeout: int,
batch_id: int,
project: str,
subject: str,
session: str,
import_handler: str,
ignore_unparsable: bool,
overwrite: str,
direct_archive: bool,
) -> _UploadResult:
"""Create archive, upload it, then delete the archive immediately.
Combines archive creation and upload into a single task to reduce peak
disk and memory usage. The archive is deleted as soon as the upload
completes (or fails), preventing all archives from existing on disk
simultaneously.
"""
start_time = time.time()
archive_size = 0
try:
archive_size = _create_archive(batch, archive_path, source_path, archive_format)
success, error = _upload_single_archive(
base_url=base_url,
username=username,
password=password,
session_token=session_token,
verify_ssl=verify_ssl,
timeout=timeout,
archive_path=archive_path,
project=project,
subject=subject,
session=session,
import_handler=import_handler,
ignore_unparsable=ignore_unparsable,
overwrite=overwrite,
direct_archive=direct_archive,
)
return _UploadResult(
batch_id=batch_id,
success=success,
duration=time.time() - start_time,
file_count=len(batch),
archive_size=archive_size,
error=error,
)
except Exception as e:
return _UploadResult(
batch_id=batch_id,
success=False,
duration=time.time() - start_time,
file_count=len(batch),
archive_size=archive_size,
error=str(e),
)
finally:
try:
archive_path.unlink(missing_ok=True)
except Exception:
pass
# =============================================================================
# DICOM C-STORE Helpers (private, lazy imports)
# =============================================================================
def _check_dicom_deps() -> None:
"""Check if DICOM dependencies are available.
Raises:
ImportError: If pydicom or pynetdicom are not installed.
"""
try:
import pydicom # noqa: F401
import pynetdicom # noqa: F401
except ImportError as e:
raise ImportError(
"DICOM C-STORE requires pydicom and pynetdicom. "
"Install with: pip install xnatctl[dicom]"
) from e
def _get_verification_sop_class():
"""Get VerificationSOPClass with compatibility for pynetdicom versions."""
from pynetdicom import sop_class as _sop_class
verification_uid = "1.2.840.10008.1.1"
return getattr(
_sop_class,
"VerificationSOPClass",
getattr(_sop_class, "Verification", verification_uid),
)
def _get_storage_contexts():
"""Get storage presentation contexts with version compatibility."""
try:
from pynetdicom import StoragePresentationContexts
return list(StoragePresentationContexts)
except ImportError:
from pynetdicom import sop_class as _sc
from pynetdicom.presentation import build_context
uids = [getattr(_sc, name) for name in dir(_sc) if name.endswith("Storage")]
return [build_context(uid) for uid in uids]
def _ensure_sop_uids(ds) -> None:
"""Populate missing SOP UID attributes from file-meta.
Args:
ds: pydicom Dataset object.
"""
if not getattr(ds, "SOPClassUID", None):
uid = getattr(ds.file_meta, "MediaStorageSOPClassUID", None)
if uid:
ds.SOPClassUID = uid
if not getattr(ds, "SOPInstanceUID", None):
uid = getattr(ds.file_meta, "MediaStorageSOPInstanceUID", None)
if uid:
ds.SOPInstanceUID = uid
def _c_echo(host: str, port: int, calling_aet: str, called_aet: str) -> bool:
"""Send a C-ECHO to verify connectivity and AE titles.
Args:
host: DICOM SCP host.
port: DICOM SCP port.
calling_aet: Our AE title.
called_aet: Remote AE title.
Returns:
True if C-ECHO succeeded.
"""
_check_dicom_deps()
from pynetdicom import AE
ae = AE(ae_title=calling_aet)
ae.add_requested_context(_get_verification_sop_class())
assoc = ae.associate(host, port, ae_title=called_aet)
if not assoc.is_established:
return False
status = assoc.send_c_echo()
assoc.release()
return bool(status and status.Status == 0x0000)
def _send_dicom_batch(
batch_id: str,
files: list[Path],
host: str,
port: int,
calling_aet: str,
called_aet: str,
log_dir: Path,
) -> tuple[int, int]:
"""Send a batch of DICOM files over a single association.
Args:
batch_id: Identifier for this batch (for logging).
files: List of DICOM file paths.
host: DICOM SCP host.
port: DICOM SCP port.
calling_aet: Our AE title.
called_aet: Remote AE title.
log_dir: Directory for batch log files.
Returns:
Tuple of (sent_count, failed_count).
"""
_check_dicom_deps()
import pydicom
from pydicom.errors import InvalidDicomError
from pynetdicom import AE
sent = failed = 0
log_path = log_dir / f"{batch_id}.log"
with log_path.open("w") as log:
ae = AE(ae_title=calling_aet)
ae.requested_contexts = _get_storage_contexts()
ae.add_requested_context("1.3.12.2.1107.5.9.1")
assoc = ae.associate(host, port, ae_title=called_aet)
if not assoc.is_established:
log.write("Association rejected/aborted\n")
return sent, len(files)
for file_path in files:
try:
ds = pydicom.dcmread(file_path, force=True)
except InvalidDicomError:
failed += 1
log.write(f"Skip non-DICOM {file_path}\n")
continue
_ensure_sop_uids(ds)
try:
status = assoc.send_c_store(ds)
except Exception as e:
failed += 1
log.write(f"Store error {file_path}: {type(e).__name__}: {e}\n")
continue
if status and status.Status == 0x0000:
sent += 1
else:
failed += 1
status_hex = hex(status.Status) if status else "0x0000"
log.write(f"Failed {file_path} status {status_hex}\n")
assoc.release()
return sent, failed
# =============================================================================
# Gradual-DICOM Helpers (private, thread-safe standalone functions)
# =============================================================================
def _upload_single_file_gradual(
*,
base_url: str,
session_refresher: _SessionRefresher,
verify_ssl: bool,
file_path: Path,
display_path: str | None = None,
project: str,
subject: str,
session: str,
direct_archive: bool = True,
) -> tuple[str, bool, str]:
"""Upload a single file via the gradual-DICOM import handler.
Uses a thread-local httpx client to reuse keep-alive connections per worker thread.
On HTTP 401, refreshes the session token via *session_refresher* and retries once.
Args:
base_url: XNAT server base URL.
session_refresher: Thread-safe token manager for reauth on 401.
verify_ssl: Whether to verify SSL certificates.
file_path: Path to the DICOM file.
project: Target project ID.
subject: Target subject label.
session: Target session label.
direct_archive: Use direct archive vs prearchive (default: True).
Returns:
Tuple of (filename, success, error_message).
"""
name = display_path or file_path.name
try:
client = _get_gradual_http_client(base_url=base_url, verify_ssl=verify_ssl)
def _do_upload(token: str | None) -> httpx.Response:
cookies = {"JSESSIONID": token} if token else {}
def _attempt() -> httpx.Response:
with open(file_path, "rb") as f:
return client.post(
"/data/services/import",
params={
"inbody": "true",
"import-handler": "gradual-DICOM",
"PROJECT_ID": project,
"SUBJECT_ID": subject,
"EXPT_LABEL": session,
**archive_destination_params(project, direct_archive),
},
content=f,
headers={"Content-Type": "application/dicom"},
cookies=cookies,
)
return upload_with_retry(_attempt, label=f"gradual-DICOM {name}")
token = session_refresher.token
resp = _do_upload(token)
if resp.status_code == 401:
new_token = session_refresher.refresh(token)
if new_token != token:
resp = _do_upload(new_token)
if resp.status_code == 401:
logger.warning("Still 401 after session refresh for %s", name)
else:
logger.debug("Session refresh returned same token for %s", name)
if 200 <= resp.status_code < 300:
return name, True, ""
# Include a small snippet of server response for debugging (XNAT often returns
# useful details for 4xx/5xx in plain text or HTML).
snippet = ""
try:
snippet = resp.text.strip().replace("\n", " ")
except Exception:
snippet = ""
if snippet:
snippet = snippet[:200]
detail = f"HTTP {resp.status_code}"
if snippet:
detail = f"{detail}: {snippet}"
return name, False, detail
except Exception as e:
return name, False, str(e)
# =============================================================================
# Upload Service
# =============================================================================
[docs]
class UploadService(BaseService):
"""Service for XNAT upload operations.
Provides methods for all upload transports: REST batch, parallel REST,
DICOM C-STORE, and resource uploads.
"""
[docs]
def upload_dicom(
self,
project: str,
subject: str,
session: str,
source_path: Path,
overwrite: bool = False,
quarantine: bool = False,
batch_size: int = DEFAULT_BATCH_SIZE,
parallel: bool = True,
workers: int = DEFAULT_UPLOAD_WORKERS,
progress_callback: Callable[[UploadProgress], None] | None = None,
) -> UploadSummary:
"""Upload DICOM files via simple REST batch (ZIP per batch).
Args:
project: Project ID.
subject: Subject label.
session: Session label.
source_path: Path to DICOM files (directory or ZIP).
overwrite: Overwrite existing scans.
quarantine: Send to prearchive instead.
batch_size: Files per upload batch.
parallel: Use parallel uploads.
workers: Number of parallel workers.
progress_callback: Progress callback function.
Returns:
UploadSummary with results.
"""
start_time = time.time()
source_path = Path(source_path)
if not source_path.exists():
raise FileNotFoundError(f"Source not found: {source_path}")
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.PREPARING,
message="Preparing upload",
)
)
# Collect DICOM files
temp_dir: str | None = None
try:
dicom_files: list[Path] = []
if source_path.is_file():
if source_path.suffix.lower() == ".zip":
temp_dir = tempfile.mkdtemp()
temp_root = Path(temp_dir)
with zipfile.ZipFile(source_path, "r") as zf:
for member in zf.infolist():
if member.is_dir():
continue
target = (temp_root / member.filename).resolve()
if not target.is_relative_to(temp_root.resolve()):
continue
target.parent.mkdir(parents=True, exist_ok=True)
with zf.open(member) as src, open(target, "wb") as dst:
shutil.copyfileobj(src, dst)
source_path = Path(temp_dir)
dicom_files = collect_dicom_files(source_path)
else:
dicom_files = [source_path] if _is_dicom_like_path(source_path) else []
else:
dicom_files = collect_dicom_files(source_path)
total_files = len(dicom_files)
if total_files == 0:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=0,
errors=["No DICOM files found"],
)
total_size = sum(f.stat().st_size for f in dicom_files)
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.ARCHIVING,
total=total_files,
message=f"Found {total_files} files",
)
)
batches = list(self._split_into_batches(dicom_files, batch_size))
total_batches = len(batches)
results: dict[str, Any] = {"succeeded": 0, "failed": 0, "errors": []}
dest = f"/archive/projects/{project}/subjects/{subject}/experiments/{session}"
base_url = self.client.base_url
session_token = self.client.session_token
verify_ssl = self.client.verify_ssl
timeout = self.client.timeout
def _upload_batch_fn(batch_id: int, files: list[Path]) -> tuple[bool, str]:
try:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp:
zip_path = Path(tmp.name)
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for file_path in files:
zf.write(file_path, file_path.name)
params: dict[str, Any] = {
"dest": dest,
"overwrite": "delete" if overwrite else "none",
"import-handler": "SI",
"PROJECT_ID": project,
"SUBJECT_ID": subject,
"EXPT_LABEL": session,
}
if quarantine:
params["dest"] = f"/prearchive/projects/{project}"
cookies = {"JSESSIONID": session_token} if session_token else {}
with httpx.Client(
base_url=base_url,
timeout=timeout,
verify=verify_ssl,
) as http:
with open(zip_path, "rb") as zip_file:
http.post(
"/data/services/import",
params=params,
content=zip_file,
headers={"Content-Type": "application/zip"},
cookies=cookies,
)
zip_path.unlink()
return (True, "")
except Exception as e:
return (False, str(e))
if parallel and total_batches > 1:
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = {
executor.submit(_upload_batch_fn, i, batch): i
for i, batch in enumerate(batches)
}
for future in as_completed(futures):
batch_id = futures[future]
success, error = future.result()
if success:
results["succeeded"] += 1
else:
results["failed"] += 1
results["errors"].append(f"Batch {batch_id}: {error}")
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.UPLOADING,
current=results["succeeded"] + results["failed"],
total=total_batches,
batch_id=batch_id,
message=f"Uploading batch {batch_id + 1}/{total_batches}",
success=success,
)
)
else:
for batch_id, batch in enumerate(batches):
success, error = _upload_batch_fn(batch_id, batch)
if success:
results["succeeded"] += 1
else:
results["failed"] += 1
results["errors"].append(f"Batch {batch_id}: {error}")
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.UPLOADING,
current=batch_id + 1,
total=total_batches,
batch_id=batch_id,
message=f"Uploading batch {batch_id + 1}/{total_batches}",
success=success,
)
)
duration = time.time() - start_time
overall_success = results["failed"] == 0
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.COMPLETE if overall_success else OperationPhase.ERROR,
current=total_batches,
total=total_batches,
message="Upload complete"
if overall_success
else "Upload completed with errors",
success=overall_success,
errors=results["errors"],
)
)
return UploadSummary(
success=overall_success,
total=total_batches,
succeeded=results["succeeded"],
failed=results["failed"],
duration=duration,
errors=results["errors"],
total_files=total_files,
total_size_mb=total_size / (1024 * 1024),
batches_total=total_batches,
batches_succeeded=results["succeeded"],
batches_failed=results["failed"],
session_id=session,
)
finally:
if temp_dir:
shutil.rmtree(temp_dir, ignore_errors=True)
[docs]
def upload_dicom_parallel(
self,
source_dir: Path,
project: str,
subject: str,
session: str,
*,
username: str | None = None,
password: str | None = None,
upload_workers: int = DEFAULT_UPLOAD_WORKERS,
archive_workers: int = DEFAULT_ARCHIVE_WORKERS,
archive_format: str = DEFAULT_ARCHIVE_FORMAT,
import_handler: str = DEFAULT_IMPORT_HANDLER,
ignore_unparsable: bool = True,
overwrite: str = DEFAULT_OVERWRITE,
direct_archive: bool = True,
timeout: int = DEFAULT_TIMEOUT,
progress_callback: Callable[[UploadProgress], None] | None = None,
) -> UploadSummary:
"""Upload DICOM files using parallel batched archives via REST import.
High-throughput upload that:
1. Collects DICOM files from the source directory
2. Splits files into N batches (N = upload_workers)
3. Creates archives in parallel
4. Uploads archives in parallel with per-thread HTTP clients
Args:
source_dir: Directory containing DICOM files.
project: Target project ID.
subject: Target subject label.
session: Target session label.
username: XNAT username (override for per-thread auth).
password: XNAT password (override for per-thread auth).
upload_workers: Parallel upload workers (default: 4).
archive_workers: Parallel archive workers (default: 4).
archive_format: Archive format, "tar" or "zip" (default: tar).
import_handler: XNAT import handler (default: DICOM-zip).
ignore_unparsable: Skip unparsable DICOM files (default: True).
overwrite: Overwrite mode: none, append, delete (default: delete).
direct_archive: Use direct archive vs prearchive (default: True).
timeout: HTTP timeout in seconds.
progress_callback: Optional callback for progress updates.
Returns:
UploadSummary with results.
"""
total_start = time.time()
errors: list[str] = []
base_url = self.client.base_url
session_token = self.client.session_token
verify_ssl = self.client.verify_ssl
effective_username = username or self.client.username
effective_password = password or self.client.password
def report(phase: OperationPhase, **kwargs: Any) -> None:
if progress_callback:
progress_callback(UploadProgress(phase=phase, **kwargs))
# Phase 1: Collect files
report(OperationPhase.PREPARING, message="Scanning for DICOM files...")
try:
files = collect_dicom_files(source_dir)
except Exception as e:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=time.time() - total_start,
errors=[f"Failed to scan directory: {e}"],
)
if not files:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=time.time() - total_start,
errors=["No DICOM files found"],
)
# Phase 2: Split into batches
batch_count = max(1, min(upload_workers, len(files)))
batches = split_into_n_batches(files, batch_count)
report(
OperationPhase.PREPARING,
message=f"Split {len(files)} files into {len(batches)} batches",
)
# Phase 3+4: Create archives and upload (merged to reduce peak memory)
#
# Each worker creates its archive, uploads it, then deletes it
# immediately. This avoids having all archives on disk at once,
# which previously doubled the disk/page-cache footprint.
ext = ".tar" if archive_format == "tar" else ".zip"
temp_dir = Path(tempfile.mkdtemp(prefix="xnatctl_upload_"))
archive_paths: list[Path] = []
total_archive_size = 0
try:
for i in range(len(batches)):
archive_paths.append(temp_dir / f"batch_{i + 1}{ext}")
source_path = source_dir.expanduser().resolve()
effective_workers = max(1, min(upload_workers, len(batches)))
report(
OperationPhase.UPLOADING,
total=len(batches),
message="Starting batch processing...",
)
results: list[_UploadResult] = []
with ThreadPoolExecutor(max_workers=effective_workers) as executor:
futures: dict[Future[_UploadResult], int] = {}
for i, batch in enumerate(batches):
fut: Future[_UploadResult] = executor.submit( # type: ignore[arg-type]
_create_and_upload_batch,
batch=batch,
archive_path=archive_paths[i],
source_path=source_path,
archive_format=archive_format,
base_url=base_url,
username=effective_username,
password=effective_password,
session_token=session_token,
verify_ssl=verify_ssl,
timeout=timeout,
batch_id=i + 1,
project=project,
subject=subject,
session=session,
import_handler=import_handler,
ignore_unparsable=ignore_unparsable,
overwrite=overwrite,
direct_archive=direct_archive,
)
futures[fut] = i + 1
for done in as_completed(futures): # type: ignore[arg-type]
result: _UploadResult = done.result() # type: ignore[assignment]
results.append(result)
total_archive_size += result.archive_size
if not result.success:
errors.append(f"Batch {result.batch_id}: {result.error}")
succeeded = sum(1 for r in results if r.success)
report(
OperationPhase.UPLOADING,
current=len(results),
total=len(batches),
batch_id=result.batch_id,
success=result.success,
message=f"Completed {len(results)}/{len(batches)} ({succeeded} succeeded)",
)
# Phase 5: Complete
total_duration = time.time() - total_start
batches_succeeded = sum(1 for r in results if r.success)
batches_failed = len(results) - batches_succeeded
success = batches_failed == 0
report(
OperationPhase.COMPLETE if success else OperationPhase.ERROR,
current=len(results),
total=len(batches),
message=(
"Upload complete!"
if success
else f"Upload completed with {batches_failed} failures"
),
success=success,
errors=errors,
)
if not success:
logger.warning("Upload completed with %s failures", batches_failed)
return UploadSummary(
success=success,
total=len(batches),
succeeded=batches_succeeded,
failed=batches_failed,
duration=total_duration,
errors=errors,
total_files=len(files),
total_size_mb=total_archive_size / 1024 / 1024,
batches_total=len(batches),
batches_succeeded=batches_succeeded,
batches_failed=batches_failed,
session_id=session,
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
[docs]
def upload_dicom_store(
self,
dicom_root: Path,
host: str,
called_aet: str,
*,
port: int = DEFAULT_DICOM_PORT,
calling_aet: str = DEFAULT_DICOM_CALLING_AET,
workers: int = DEFAULT_DICOM_STORE_WORKERS,
cleanup: bool = True,
) -> DICOMStoreSummary:
"""Send DICOM files to an SCP using C-STORE.
This method:
1. Verifies connectivity with C-ECHO
2. Collects DICOM files from the root directory
3. Splits files into batches for parallel associations
4. Sends files using multiple concurrent C-STORE associations
Args:
dicom_root: Directory containing DICOM files.
host: DICOM SCP host.
called_aet: Remote AE title.
port: DICOM SCP port (default: 104).
calling_aet: Our AE title (default: XNATCTL).
workers: Number of parallel associations (default: 4).
cleanup: Remove temporary workspace on completion (default: True).
Returns:
DICOMStoreSummary with results.
Raises:
ImportError: If pydicom/pynetdicom are not installed.
ValueError: If dicom_root is not a directory.
RuntimeError: If C-ECHO fails or no DICOM files found.
"""
_check_dicom_deps()
if not dicom_root.exists() or not dicom_root.is_dir():
raise ValueError(f"dicom_root is not a directory: {dicom_root}")
workspace = Path(tempfile.mkdtemp(prefix="xnatctl_dicom_store_"))
log_dir = workspace / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
failed_total = 0
try:
logger.info(
"Pre-flight C-ECHO %s -> %s @ %s:%s",
calling_aet,
called_aet,
host,
port,
)
if not _c_echo(host, port, calling_aet, called_aet):
raise RuntimeError(
f"C-ECHO failed - check host/port/AET settings "
f"(host={host}, port={port}, called_aet={called_aet})"
)
files = collect_dicom_files(dicom_root)
if not files:
raise RuntimeError(f"No DICOM files found in {dicom_root}")
batches = split_into_n_batches(files, workers)
logger.info(
"Discovered %d files, using %d parallel associations",
len(files),
len(batches),
)
sent_total = 0
with ThreadPoolExecutor(max_workers=len(batches)) as pool:
futures = {
pool.submit(
_send_dicom_batch,
f"{i:03d}",
batch,
host,
port,
calling_aet,
called_aet,
log_dir,
): i
for i, batch in enumerate(batches)
}
for future in as_completed(futures):
batch_idx = futures[future]
sent, failed = future.result()
sent_total += sent
failed_total += failed
logger.info(
"Batch %03d complete: %d sent, %d failed",
batch_idx,
sent,
failed,
)
return DICOMStoreSummary(
total_files=len(files),
sent=sent_total,
failed=failed_total,
log_dir=log_dir,
workspace=workspace,
success=failed_total == 0,
)
finally:
if cleanup and failed_total == 0:
shutil.rmtree(workspace, ignore_errors=True)
[docs]
def upload_dicom_gradual(
self,
source_path: Path,
project: str,
subject: str,
session: str,
*,
workers: int = DEFAULT_UPLOAD_WORKERS,
direct_archive: bool = True,
progress_callback: Callable[[UploadProgress], None] | None = None,
) -> UploadSummary:
"""Upload DICOM files using the gradual-DICOM handler (parallel per-file).
Each file is uploaded individually to the XNAT import service using
the gradual-DICOM handler, which lets XNAT parse each file on ingest.
Files are uploaded in parallel using per-thread HTTP clients.
Accepts directories or ZIP archives. ZIP archives are extracted to a
temporary directory before upload. Only DICOM-like files are sent:
known DICOM extensions plus extensionless files commonly produced by
scanners.
Args:
source_path: Directory or ZIP file containing DICOM files.
project: Target project ID.
subject: Target subject label.
session: Target session label.
workers: Number of parallel upload workers (default: 4).
direct_archive: Use direct archive vs prearchive (default: True).
progress_callback: Optional callback for progress updates.
Returns:
UploadSummary with results.
Raises:
ValueError: If source_path is not a directory or ZIP file.
FileNotFoundError: If source_path does not exist.
"""
with _gradual_http_clients_scope():
start_time = time.time()
source_path = Path(source_path)
if not source_path.exists():
raise FileNotFoundError(f"Source not found: {source_path}")
temp_dir: str | None = None
files: list[Path] = []
try:
if source_path.is_file() and source_path.suffix.lower() == ".zip":
temp_dir = tempfile.mkdtemp(prefix="xnatctl_gradual_")
temp_path = Path(temp_dir)
with zipfile.ZipFile(source_path, "r") as zf:
for member in zf.infolist():
if member.is_dir():
continue
target = (temp_path / member.filename).resolve()
if not target.is_relative_to(temp_path.resolve()):
continue
target.parent.mkdir(parents=True, exist_ok=True)
with zf.open(member) as src, open(target, "wb") as dst:
shutil.copyfileobj(src, dst)
files = collect_dicom_files(temp_path)
elif source_path.is_dir():
files = collect_dicom_files(source_path)
else:
raise ValueError("gradual-DICOM requires a directory or ZIP file")
if not files:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=time.time() - start_time,
errors=["No DICOM files found"],
)
# Prefer stable relative paths in logs/errors (especially for ZIP
# extractions into a temp directory).
display_root = Path(temp_dir) if temp_dir else source_path
return self._upload_dicom_gradual_from_files(
files=files,
display_root=display_root,
project=project,
subject=subject,
session=session,
workers=workers,
direct_archive=direct_archive,
progress_callback=progress_callback,
start_time=start_time,
)
finally:
if temp_dir:
shutil.rmtree(temp_dir, ignore_errors=True)
[docs]
def upload_dicom_gradual_files(
self,
*,
files: Sequence[Path],
project: str,
subject: str,
session: str,
workers: int = DEFAULT_UPLOAD_WORKERS,
direct_archive: bool = True,
progress_callback: Callable[[UploadProgress], None] | None = None,
) -> UploadSummary:
"""Upload a specific list of DICOM files via the gradual-DICOM handler.
Unlike :meth:`upload_dicom_gradual`, this method uploads only the files
explicitly provided and does not scan any directories.
Args:
files: Explicit list of files to upload.
project: Target project ID.
subject: Target subject label.
session: Target session label.
direct_archive: Use direct archive vs prearchive (default: True).
workers: Number of parallel upload workers.
progress_callback: Optional callback for progress updates.
Returns:
UploadSummary with results.
Raises:
FileNotFoundError: If any provided path does not exist.
ValueError: If any provided path is not a file.
"""
with _gradual_http_clients_scope():
start_time = time.time()
file_list = [Path(p) for p in files]
if not file_list:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=0.0,
errors=["No files provided"],
)
for p in file_list:
if not p.exists():
raise FileNotFoundError(f"File not found: {p}")
if not p.is_file():
raise ValueError(f"Not a file: {p}")
dicom_file_list = [p for p in file_list if _is_dicom_like_path(p)]
if not dicom_file_list:
return UploadSummary(
success=False,
total=0,
succeeded=0,
failed=0,
duration=0.0,
errors=["No DICOM files found"],
)
resolved_to_original: dict[Path, Path] = {}
duplicate_resolved: set[Path] = set()
for p in dicom_file_list:
resolved = p.expanduser().resolve(strict=False)
if resolved in resolved_to_original:
duplicate_resolved.add(resolved)
else:
resolved_to_original[resolved] = p
if duplicate_resolved:
dup_str = ", ".join(sorted(str(p) for p in duplicate_resolved))
raise ValueError(f"Duplicate file paths provided: {dup_str}")
# Use a stable common root for relative display paths.
try:
common = Path(os.path.commonpath([str(p.resolve()) for p in dicom_file_list]))
display_root = common if common.is_dir() else common.parent
except Exception:
display_root = dicom_file_list[0].parent
return self._upload_dicom_gradual_from_files(
files=dicom_file_list,
display_root=display_root,
project=project,
subject=subject,
session=session,
workers=workers,
direct_archive=direct_archive,
progress_callback=progress_callback,
start_time=start_time,
)
def _upload_dicom_gradual_from_files(
self,
*,
files: Sequence[Path],
display_root: Path,
project: str,
subject: str,
session: str,
workers: int,
direct_archive: bool = True,
progress_callback: Callable[[UploadProgress], None] | None,
start_time: float,
) -> UploadSummary:
"""Upload a precomputed list of files using the gradual-DICOM handler.
Args:
files: Files to upload.
display_root: Root used for stable relative display paths.
project: Target project ID.
subject: Target subject label.
session: Target session label.
workers: Number of parallel upload workers.
direct_archive: Use direct archive vs prearchive (default: True).
progress_callback: Optional callback for progress updates.
start_time: Start timestamp for duration calculation.
Returns:
UploadSummary with results.
"""
base_url = self.client.base_url
verify_ssl = self.client.verify_ssl
session_refresher = _SessionRefresher(
base_url=base_url,
verify_ssl=verify_ssl,
token=self.client.session_token,
username=self.client.username,
password=self.client.password,
)
file_list = list(files)
def report(phase: OperationPhase, **kwargs: Any) -> None:
if progress_callback:
progress_callback(UploadProgress(phase=phase, **kwargs))
report(
OperationPhase.PREPARING,
total=len(file_list),
message=f"Found {len(file_list)} files for gradual-DICOM upload",
)
total_files = len(file_list)
def display(path: Path) -> str:
try:
return str(path.relative_to(display_root))
except Exception:
return path.name
failed_paths: set[Path] = set()
error_by_path: dict[Path, str] = {}
completed = 0
# Warm-up: upload a small set of files sequentially before going wide-parallel.
#
# XNAT can return transient HTTP 400s when a session/scan is being created in
# prearchive. With high concurrency, multiple workers can hit that "cold start"
# race at the same time.
def scan_id_for(path: Path) -> str | None:
"""Extract scan ID from standard session layout paths, if present."""
try:
rel = path.relative_to(display_root)
except Exception:
return None
parts = rel.parts
# Expected layout: scans/<scan_id>/resources/DICOM/files/<...>
if (
len(parts) >= 6
and parts[0] == "scans"
and parts[2] == "resources"
and parts[3] == "DICOM"
and parts[4] == "files"
):
return parts[1]
return None
def _scan_sort_key(scan_id: str) -> tuple[int, int, str]:
try:
return (0, int(scan_id), scan_id)
except ValueError:
return (1, 0, scan_id)
scan_groups: dict[str, list[Path]] = {}
other_files: list[Path] = []
for p in file_list:
sid = scan_id_for(p)
if sid:
scan_groups.setdefault(sid, []).append(p)
else:
other_files.append(p)
warmup_files: list[Path] = []
remaining_files: list[Path] = []
if scan_groups:
# Warm up one file per scan (capped) and interleave remaining uploads
# across scans to reduce per-scan contention under high worker counts.
from collections import deque
queues: dict[str, deque[Path]] = {
sid: deque(paths) for sid, paths in scan_groups.items()
}
if other_files:
queues["_other"] = deque(other_files)
scan_ids = sorted(queues.keys(), key=_scan_sort_key)
max_warmup_scans = min(50, len(scan_ids))
warmup_scan_ids = [sid for sid in scan_ids if sid != "_other"][:max_warmup_scans]
for sid in warmup_scan_ids:
q = queues.get(sid)
if q:
warmup_files.append(q.popleft())
# Round-robin remaining files across scan queues
scan_order = deque(scan_ids)
while scan_order:
sid = scan_order.popleft()
q = queues.get(sid)
if not q:
queues.pop(sid, None)
continue
remaining_files.append(q.popleft())
if q:
scan_order.append(sid)
else:
queues.pop(sid, None)
else:
# Fallback: warm up a few files in provided order
warmup_n = min(5, total_files)
warmup_files = file_list[:warmup_n]
remaining_files = file_list[warmup_n:]
if warmup_files:
report(
OperationPhase.PREPARING,
message=f"Warming up gradual-DICOM upload with {len(warmup_files)} file(s)...",
)
for p in warmup_files:
_name, ok, err = _upload_single_file_gradual(
base_url=base_url,
session_refresher=session_refresher,
verify_ssl=verify_ssl,
file_path=p,
display_path=display(p),
project=project,
subject=subject,
session=session,
direct_archive=direct_archive,
)
completed += 1
if not ok:
failed_paths.add(p)
error_by_path[p] = err
succeeded_so_far = completed - len(failed_paths)
report(
OperationPhase.UPLOADING,
current=completed,
total=total_files,
success=ok,
message=(
f"Uploaded {completed}/{total_files} "
f"({succeeded_so_far} ok, {len(failed_paths)} failed)"
),
)
# Main pass: parallel per-file upload (bounded in-flight window)
with ThreadPoolExecutor(max_workers=workers) as executor:
prefetch = max(1, workers * 2)
file_iter = iter(remaining_files)
in_flight: set[Future[tuple[str, bool, str]]] = set()
future_to_path: dict[Future[tuple[str, bool, str]], Path] = {}
def _submit_one(path: Path) -> None:
fut: Future[tuple[str, bool, str]] = executor.submit( # type: ignore[arg-type]
_upload_single_file_gradual,
base_url=base_url,
session_refresher=session_refresher,
verify_ssl=verify_ssl,
file_path=path,
display_path=display(path),
project=project,
subject=subject,
session=session,
direct_archive=direct_archive,
)
in_flight.add(fut)
future_to_path[fut] = path
for _ in range(min(prefetch, len(remaining_files))):
try:
_submit_one(next(file_iter))
except StopIteration:
break
while in_flight:
done, _pending = wait(in_flight, return_when=FIRST_COMPLETED)
in_flight = _pending
for future in done:
completed += 1
p = future_to_path.pop(future)
try:
_name, ok, err = future.result()
except Exception as e:
ok = False
err = str(e)
if not ok:
failed_paths.add(p)
error_by_path[p] = err
succeeded_so_far = completed - len(failed_paths)
report(
OperationPhase.UPLOADING,
current=completed,
total=total_files,
success=ok,
message=(
f"Uploaded {completed}/{total_files} "
f"({succeeded_so_far} ok, {len(failed_paths)} failed)"
),
)
while len(in_flight) < prefetch:
try:
_submit_one(next(file_iter))
except StopIteration:
break
# Salvage pass: retry a small number of failed files at lower concurrency.
# This helps when XNAT returns transient 400s under high parallel load.
max_salvage = min(5000, max(500, int(total_files * 0.01)))
if failed_paths and len(failed_paths) <= max_salvage:
retry_workers = max(1, min(4, workers))
report(
OperationPhase.PREPARING,
message=(
f"Retrying {len(failed_paths)} failed file(s) "
f"at lower concurrency ({retry_workers} workers)..."
),
)
to_retry = sorted(failed_paths, key=display)
remaining_failed: set[Path] = set(failed_paths)
with ThreadPoolExecutor(max_workers=retry_workers) as retry_executor:
prefetch = max(1, retry_workers * 2)
retry_iter = iter(to_retry)
retry_in_flight: set[Future[tuple[str, bool, str]]] = set()
retry_future_to_path: dict[Future[tuple[str, bool, str]], Path] = {}
def _submit_retry(path: Path) -> None:
fut: Future[tuple[str, bool, str]] = retry_executor.submit( # type: ignore[arg-type]
_upload_single_file_gradual,
base_url=base_url,
session_refresher=session_refresher,
verify_ssl=verify_ssl,
file_path=path,
display_path=display(path),
project=project,
subject=subject,
session=session,
direct_archive=direct_archive,
)
retry_in_flight.add(fut)
retry_future_to_path[fut] = path
for _ in range(min(prefetch, len(to_retry))):
try:
_submit_retry(next(retry_iter))
except StopIteration:
break
while retry_in_flight:
done, _pending = wait(retry_in_flight, return_when=FIRST_COMPLETED)
retry_in_flight = _pending
for future in done:
p = retry_future_to_path.pop(future)
try:
_name, ok, err = future.result()
except Exception as e:
ok = False
err = str(e)
if ok:
remaining_failed.discard(p)
error_by_path.pop(p, None)
else:
error_by_path[p] = err
while len(retry_in_flight) < prefetch:
try:
_submit_retry(next(retry_iter))
except StopIteration:
break
failed_paths = remaining_failed
# Final safety net: if only a handful of files are still failing, retry them
# sequentially.
if failed_paths and len(failed_paths) <= 50:
report(
OperationPhase.PREPARING,
message=f"Final sequential retry for {len(failed_paths)} file(s)...",
)
remaining_failed = set[Path]()
for p in sorted(failed_paths, key=display):
_name, ok, err = _upload_single_file_gradual(
base_url=base_url,
session_refresher=session_refresher,
verify_ssl=verify_ssl,
file_path=p,
display_path=display(p),
project=project,
subject=subject,
session=session,
direct_archive=direct_archive,
)
if ok:
error_by_path.pop(p, None)
else:
remaining_failed.add(p)
error_by_path[p] = err
failed_paths = remaining_failed
duration = time.time() - start_time
failed = len(failed_paths)
succeeded = total_files - failed
overall_success = failed == 0
errors = [
f"{display(p)}: {error_by_path.get(p, '')}".rstrip(": ")
for p in sorted(failed_paths, key=display)
]
report(
OperationPhase.COMPLETE if overall_success else OperationPhase.ERROR,
current=total_files,
total=total_files,
message=(
f"Uploaded {succeeded} files via gradual-DICOM"
if overall_success
else f"Uploaded {succeeded}/{total_files} files ({failed} failed)"
),
success=overall_success,
errors=errors,
)
return UploadSummary(
success=overall_success,
total=total_files,
succeeded=succeeded,
failed=failed,
duration=duration,
errors=errors,
total_files=total_files,
session_id=session,
)
[docs]
def upload_resource(
self,
session_id: str,
resource_label: str,
source_path: Path,
scan_id: str | None = None,
project: str | None = None,
extract: bool = False,
overwrite: bool = False,
progress_callback: Callable[[UploadProgress], None] | None = None,
) -> UploadSummary:
"""Upload files to a resource.
Args:
session_id: Session ID.
resource_label: Resource label.
source_path: File or directory to upload.
scan_id: Scan ID (for scan-level resources).
project: Project ID.
extract: Extract ZIP/TAR after upload.
overwrite: Overwrite existing files.
progress_callback: Progress callback.
Returns:
UploadSummary with results.
"""
start_time = time.time()
source_path = Path(source_path)
if not source_path.exists():
raise FileNotFoundError(f"Source not found: {source_path}")
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.PREPARING,
message="Preparing upload",
)
)
if scan_id:
if project:
base_path = f"/data/projects/{project}/experiments/{session_id}/scans/{scan_id}/resources/{resource_label}/files"
else:
base_path = f"/data/experiments/{session_id}/scans/{scan_id}/resources/{resource_label}/files"
else:
if project:
base_path = f"/data/projects/{project}/experiments/{session_id}/resources/{resource_label}/files"
else:
base_path = f"/data/experiments/{session_id}/resources/{resource_label}/files"
if source_path.is_dir():
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp:
zip_path = Path(tmp.name)
shutil.make_archive(str(zip_path.with_suffix("")), "zip", source_path)
source_path = zip_path
extract = True
file_size = source_path.stat().st_size
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.UPLOADING,
total_bytes=file_size,
message=f"Uploading {source_path.name}",
)
)
params: dict[str, Any] = {}
if extract:
params["extract"] = "true"
if overwrite:
params["overwrite"] = "true"
path = f"{base_path}/{source_path.name}"
try:
base_url = self.client.base_url
session_token = self.client.session_token
verify_ssl = self.client.verify_ssl
res_timeout = self.client.timeout
cookies = {"JSESSIONID": session_token} if session_token else {}
with httpx.Client(
base_url=base_url,
timeout=res_timeout,
verify=verify_ssl,
) as http:
def _attempt() -> httpx.Response:
with open(source_path, "rb") as f:
return http.put(
path,
params=params,
content=f,
headers={"Content-Type": "application/octet-stream"},
cookies=cookies,
)
resp = upload_with_retry(_attempt, label=f"resource {source_path.name}")
if resp.status_code not in (200, 201):
raise RuntimeError(f"HTTP {resp.status_code}: {resp.text[:200]}")
duration = time.time() - start_time
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.COMPLETE,
bytes_sent=file_size,
total_bytes=file_size,
message="Upload complete",
success=True,
)
)
return UploadSummary(
success=True,
total=1,
succeeded=1,
failed=0,
duration=duration,
total_files=1,
total_size_mb=file_size / (1024 * 1024),
session_id=session_id,
)
except Exception as e:
duration = time.time() - start_time
if progress_callback:
progress_callback(
UploadProgress(
phase=OperationPhase.ERROR,
message=str(e),
success=False,
errors=[str(e)],
)
)
return UploadSummary(
success=False,
total=1,
succeeded=0,
failed=1,
duration=duration,
errors=[str(e)],
session_id=session_id,
)
def _split_into_batches(
self,
files: list[Path],
batch_size: int,
) -> Iterator[list[Path]]:
"""Split files into batches.
Args:
files: List of file paths.
batch_size: Maximum files per batch.
Yields:
Lists of files for each batch.
"""
for i in range(0, len(files), batch_size):
yield files[i : i + batch_size]