import os
import shutil
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from tqdm import tqdm
from esrf_data_compressor.compressors.jp2k import JP2KCompressorWrapper
from esrf_data_compressor.utils.paths import (
get_available_cpus,
resolve_compressed_path,
resolve_mirror_path,
)
[docs]
class Compressor:
"""
Abstract base class. Subclasses must implement compress_file().
"""
[docs]
def compress_file(self, input_path: str, output_path: str, **kwargs):
raise NotImplementedError
[docs]
class CompressorManager:
"""
Manages parallel compression and overwrite.
Each worker process is given up to 2 Blosc2 threads (or fewer if the machine
has fewer than 4 cores). The number of worker processes is then
total_cores // threads_per_worker (at least 1). If the user explicitly
passes `workers`, we cap it to `total_cores`, then recompute threads_per_worker
= min(2, total_cores // workers).
Usage:
mgr = CompressorManager(cratio=10, method='jp2k')
mgr.compress_files([...])
mgr.overwrite_files([...])
"""
def __init__(
self,
workers: int | None = None,
cratio: int = 10,
method: str = "jp2k",
layout: str = "sibling",
):
total_cores = get_available_cpus()
default_nthreads = 2 if total_cores >= 2 else 1
default_workers = max(1, total_cores // default_nthreads)
if workers is None:
w = default_workers
nthreads = default_nthreads
else:
w = min(workers, total_cores)
possible = total_cores // w
nthreads = min(possible, 2) if possible >= 1 else 1
self.workers = max(1, w)
self.nthreads = max(1, nthreads)
self.cratio = cratio
self.method = method
self.layout = layout
if self.method == "jp2k":
self.compressor = JP2KCompressorWrapper(
cratio=cratio, nthreads=self.nthreads
)
else:
raise ValueError(f"Unsupported compression method: {self.method}")
print(f"Compression method: {self.method}")
print(f"Output layout: {self.layout}")
print(f"Total CPU cores: {total_cores}")
print(f"Worker processes: {self.workers}")
print(f"Threads per worker: {self.nthreads}")
print(f"Total threads: {self.workers * self.nthreads}")
@staticmethod
def _find_raw_root(path: str) -> str | None:
p = Path(os.path.abspath(path))
parts = p.parts
if "RAW_DATA" not in parts:
return None
return str(Path(*parts[: parts.index("RAW_DATA") + 1]))
def _compress_worker(self, ipath: str) -> tuple[str, str]:
"""
Worker function for ProcessPoolExecutor: compress a single HDF5:
- sibling layout: <same_dir>/<basename>_<method>.h5
- mirror layout: mirror RAW_DATA tree under RAW_DATA_COMPRESSED
"""
outp = resolve_compressed_path(ipath, self.method, layout=self.layout)
os.makedirs(os.path.dirname(outp), exist_ok=True)
self.compressor.compress_file(
ipath, outp, cratio=self.cratio, nthreads=self.nthreads
)
return ipath, "success"
def _mirror_non_compressed_dataset_content(self, file_list: list[str]) -> None:
source_targets = {os.path.realpath(p) for p in file_list}
raw_roots: set[str] = set()
for ipath in file_list:
raw_root = self._find_raw_root(ipath)
if raw_root:
raw_roots.add(raw_root)
copy_tasks: list[tuple[str, str]] = []
for src_dir in sorted(raw_roots):
try:
dst_dir = resolve_mirror_path(src_dir)
except ValueError:
print(f"WARNING: Cannot mirror folder outside RAW_DATA: '{src_dir}'")
continue
for cur, dirs, files in os.walk(src_dir):
rel_cur = os.path.relpath(cur, src_dir)
target_cur = (
dst_dir if rel_cur == "." else os.path.join(dst_dir, rel_cur)
)
os.makedirs(target_cur, exist_ok=True)
for dname in dirs:
os.makedirs(os.path.join(target_cur, dname), exist_ok=True)
for fname in files:
src_file = os.path.join(cur, fname)
if os.path.realpath(src_file) in source_targets:
# Do not copy raw files that will be produced by compression.
continue
dst_file = os.path.join(target_cur, fname)
copy_tasks.append((src_file, dst_file))
if not copy_tasks:
return
max_workers = min(len(copy_tasks), max(1, get_available_cpus()), 8)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(shutil.copy2, s, d): (s, d) for s, d in copy_tasks
}
for fut in as_completed(futures):
src_file, dst_file = futures[fut]
try:
fut.result()
except Exception as e:
print(f"WARNING: Failed to copy '{src_file}' → '{dst_file}': {e}")
[docs]
def compress_files(self, file_list: list[str]) -> None:
"""
Compress each .h5 in file_list in parallel.
- sibling layout: produce <basename>_<method>.h5 next to each source.
- mirror layout: write compressed files to RAW_DATA_COMPRESSED with same file names.
Does not overwrite originals. At the end, prints total elapsed time and data rate in MB/s.
"""
valid = [p for p in file_list if p.lower().endswith(".h5")]
if not valid:
print("No valid .h5 files to compress.")
return
if self.layout == "mirror":
print(
"Preparing RAW_DATA_COMPRESSED with non-compressed dataset content..."
)
self._mirror_non_compressed_dataset_content(valid)
total_bytes = 0
for f in valid:
try:
total_bytes += os.path.getsize(f)
except OSError:
pass
import time
t0 = time.time()
with ProcessPoolExecutor(max_workers=self.workers) as executor:
futures = {executor.submit(self._compress_worker, p): p for p in valid}
for fut in tqdm(
as_completed(futures),
total=len(futures),
desc=f"Compressing HDF5 files ({self.method})",
unit="file",
):
pth = futures[fut]
try:
fut.result()
except Exception as e:
print(f"Failed to compress '{pth}': {e}")
elapsed = time.time() - t0
total_mb = total_bytes / (1024 * 1024)
rate_mb_s = total_mb / elapsed if elapsed > 0 else float("inf")
print(f"\nTotal elapsed time: {elapsed:.3f}s")
print(f"Data processed: {total_mb:.2f} MB ({rate_mb_s:.2f} MB/s)\n")
[docs]
def overwrite_files(self, file_list: list[str]) -> None:
"""
Overwrites files only if they have a compressed sibling:
1) Rename <file>.h5 → <file>.h5.bak
2) Rename <file>_<method>.h5 → <file>.h5
After processing all files, removes the backup .h5.bak files.
"""
for ipath in file_list:
if not ipath.lower().endswith(".h5"):
continue
compressed_path = resolve_compressed_path(
ipath, self.method, layout=self.layout
)
if os.path.exists(compressed_path):
backup = ipath + ".bak"
try:
os.replace(ipath, backup)
os.replace(compressed_path, ipath)
print(f"Overwritten '{ipath}' (backup at '{backup}').")
except Exception as e:
print(f"ERROR overwriting '{ipath}': {e}")
else:
print(f"SKIP (no compressed file): {ipath}")
[docs]
def remove_backups(self, file_list: list[str]) -> None:
candidates = {p + ".bak" for p in file_list if p.lower().endswith(".h5")}
backups = [b for b in candidates if os.path.exists(b)]
if not backups:
print("No backup files to remove.")
return
total_bytes = 0
for b in backups:
try:
total_bytes += os.path.getsize(b)
except OSError:
pass
total_mb = total_bytes / (1024 * 1024)
print(
f"About to remove {len(backups)} backup file(s), ~{total_mb:.2f} MB total."
)
ans = input("Proceed? [y/N]: ").strip().lower()
if ans not in ("y", "yes"):
print("Backups kept.")
return
removed = 0
for b in backups:
try:
os.remove(b)
removed += 1
except Exception as e:
print(f"ERROR deleting backup '{b}': {e}")
print(f"Deleted {removed} backup file(s).")
[docs]
def restore_backups(self, file_list: list[str]) -> None:
restored = 0
preserved = 0
for ipath in file_list:
if not ipath.lower().endswith(".h5"):
continue
backup = ipath + ".bak"
method_path = resolve_compressed_path(
ipath, self.method, layout=self.layout
)
if not os.path.exists(backup):
print(f"SKIP (no backup): {ipath}")
continue
if os.path.exists(ipath) and not os.path.exists(method_path):
try:
os.replace(ipath, method_path)
preserved += 1
print(f"Preserved current file to '{method_path}'.")
except Exception as e:
print(f"ERROR preserving current '{ipath}' to '{method_path}': {e}")
continue
try:
os.replace(backup, ipath)
restored += 1
print(f"Restored '{ipath}' from backup.")
except Exception as e:
print(f"ERROR restoring '{ipath}' from '{backup}': {e}")
print(
f"Restore complete. Restored: {restored}, preserved compressed copies: {preserved}."
)