Source code for esrf_data_compressor.compressors.base

import os
import shutil
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from tqdm import tqdm

from esrf_data_compressor.compressors.jp2k import JP2KCompressorWrapper
from esrf_data_compressor.utils.paths import (
    get_available_cpus,
    resolve_compressed_path,
    resolve_mirror_path,
)


[docs] class Compressor: """ Abstract base class. Subclasses must implement compress_file(). """
[docs] def compress_file(self, input_path: str, output_path: str, **kwargs): raise NotImplementedError
[docs] class CompressorManager: """ Manages parallel compression and overwrite. Each worker process is given up to 2 Blosc2 threads (or fewer if the machine has fewer than 4 cores). The number of worker processes is then total_cores // threads_per_worker (at least 1). If the user explicitly passes `workers`, we cap it to `total_cores`, then recompute threads_per_worker = min(2, total_cores // workers). Usage: mgr = CompressorManager(cratio=10, method='jp2k') mgr.compress_files([...]) mgr.overwrite_files([...]) """ def __init__( self, workers: int | None = None, cratio: int = 10, method: str = "jp2k", layout: str = "sibling", ): total_cores = get_available_cpus() default_nthreads = 2 if total_cores >= 2 else 1 default_workers = max(1, total_cores // default_nthreads) if workers is None: w = default_workers nthreads = default_nthreads else: w = min(workers, total_cores) possible = total_cores // w nthreads = min(possible, 2) if possible >= 1 else 1 self.workers = max(1, w) self.nthreads = max(1, nthreads) self.cratio = cratio self.method = method self.layout = layout if self.method == "jp2k": self.compressor = JP2KCompressorWrapper( cratio=cratio, nthreads=self.nthreads ) else: raise ValueError(f"Unsupported compression method: {self.method}") print(f"Compression method: {self.method}") print(f"Output layout: {self.layout}") print(f"Total CPU cores: {total_cores}") print(f"Worker processes: {self.workers}") print(f"Threads per worker: {self.nthreads}") print(f"Total threads: {self.workers * self.nthreads}") @staticmethod def _find_raw_root(path: str) -> str | None: p = Path(os.path.abspath(path)) parts = p.parts if "RAW_DATA" not in parts: return None return str(Path(*parts[: parts.index("RAW_DATA") + 1])) def _compress_worker(self, ipath: str) -> tuple[str, str]: """ Worker function for ProcessPoolExecutor: compress a single HDF5: - sibling layout: <same_dir>/<basename>_<method>.h5 - mirror layout: mirror RAW_DATA tree under RAW_DATA_COMPRESSED """ outp = resolve_compressed_path(ipath, self.method, layout=self.layout) os.makedirs(os.path.dirname(outp), exist_ok=True) self.compressor.compress_file( ipath, outp, cratio=self.cratio, nthreads=self.nthreads ) return ipath, "success" def _mirror_non_compressed_dataset_content(self, file_list: list[str]) -> None: source_targets = {os.path.realpath(p) for p in file_list} raw_roots: set[str] = set() for ipath in file_list: raw_root = self._find_raw_root(ipath) if raw_root: raw_roots.add(raw_root) copy_tasks: list[tuple[str, str]] = [] for src_dir in sorted(raw_roots): try: dst_dir = resolve_mirror_path(src_dir) except ValueError: print(f"WARNING: Cannot mirror folder outside RAW_DATA: '{src_dir}'") continue for cur, dirs, files in os.walk(src_dir): rel_cur = os.path.relpath(cur, src_dir) target_cur = ( dst_dir if rel_cur == "." else os.path.join(dst_dir, rel_cur) ) os.makedirs(target_cur, exist_ok=True) for dname in dirs: os.makedirs(os.path.join(target_cur, dname), exist_ok=True) for fname in files: src_file = os.path.join(cur, fname) if os.path.realpath(src_file) in source_targets: # Do not copy raw files that will be produced by compression. continue dst_file = os.path.join(target_cur, fname) copy_tasks.append((src_file, dst_file)) if not copy_tasks: return max_workers = min(len(copy_tasks), max(1, get_available_cpus()), 8) with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit(shutil.copy2, s, d): (s, d) for s, d in copy_tasks } for fut in as_completed(futures): src_file, dst_file = futures[fut] try: fut.result() except Exception as e: print(f"WARNING: Failed to copy '{src_file}' → '{dst_file}': {e}")
[docs] def compress_files(self, file_list: list[str]) -> None: """ Compress each .h5 in file_list in parallel. - sibling layout: produce <basename>_<method>.h5 next to each source. - mirror layout: write compressed files to RAW_DATA_COMPRESSED with same file names. Does not overwrite originals. At the end, prints total elapsed time and data rate in MB/s. """ valid = [p for p in file_list if p.lower().endswith(".h5")] if not valid: print("No valid .h5 files to compress.") return if self.layout == "mirror": print( "Preparing RAW_DATA_COMPRESSED with non-compressed dataset content..." ) self._mirror_non_compressed_dataset_content(valid) total_bytes = 0 for f in valid: try: total_bytes += os.path.getsize(f) except OSError: pass import time t0 = time.time() with ProcessPoolExecutor(max_workers=self.workers) as executor: futures = {executor.submit(self._compress_worker, p): p for p in valid} for fut in tqdm( as_completed(futures), total=len(futures), desc=f"Compressing HDF5 files ({self.method})", unit="file", ): pth = futures[fut] try: fut.result() except Exception as e: print(f"Failed to compress '{pth}': {e}") elapsed = time.time() - t0 total_mb = total_bytes / (1024 * 1024) rate_mb_s = total_mb / elapsed if elapsed > 0 else float("inf") print(f"\nTotal elapsed time: {elapsed:.3f}s") print(f"Data processed: {total_mb:.2f} MB ({rate_mb_s:.2f} MB/s)\n")
[docs] def overwrite_files(self, file_list: list[str]) -> None: """ Overwrites files only if they have a compressed sibling: 1) Rename <file>.h5 → <file>.h5.bak 2) Rename <file>_<method>.h5 → <file>.h5 After processing all files, removes the backup .h5.bak files. """ for ipath in file_list: if not ipath.lower().endswith(".h5"): continue compressed_path = resolve_compressed_path( ipath, self.method, layout=self.layout ) if os.path.exists(compressed_path): backup = ipath + ".bak" try: os.replace(ipath, backup) os.replace(compressed_path, ipath) print(f"Overwritten '{ipath}' (backup at '{backup}').") except Exception as e: print(f"ERROR overwriting '{ipath}': {e}") else: print(f"SKIP (no compressed file): {ipath}")
[docs] def remove_backups(self, file_list: list[str]) -> None: candidates = {p + ".bak" for p in file_list if p.lower().endswith(".h5")} backups = [b for b in candidates if os.path.exists(b)] if not backups: print("No backup files to remove.") return total_bytes = 0 for b in backups: try: total_bytes += os.path.getsize(b) except OSError: pass total_mb = total_bytes / (1024 * 1024) print( f"About to remove {len(backups)} backup file(s), ~{total_mb:.2f} MB total." ) ans = input("Proceed? [y/N]: ").strip().lower() if ans not in ("y", "yes"): print("Backups kept.") return removed = 0 for b in backups: try: os.remove(b) removed += 1 except Exception as e: print(f"ERROR deleting backup '{b}': {e}") print(f"Deleted {removed} backup file(s).")
[docs] def restore_backups(self, file_list: list[str]) -> None: restored = 0 preserved = 0 for ipath in file_list: if not ipath.lower().endswith(".h5"): continue backup = ipath + ".bak" method_path = resolve_compressed_path( ipath, self.method, layout=self.layout ) if not os.path.exists(backup): print(f"SKIP (no backup): {ipath}") continue if os.path.exists(ipath) and not os.path.exists(method_path): try: os.replace(ipath, method_path) preserved += 1 print(f"Preserved current file to '{method_path}'.") except Exception as e: print(f"ERROR preserving current '{ipath}' to '{method_path}': {e}") continue try: os.replace(backup, ipath) restored += 1 print(f"Restored '{ipath}' from backup.") except Exception as e: print(f"ERROR restoring '{ipath}' from '{backup}': {e}") print( f"Restore complete. Restored: {restored}, preserved compressed copies: {preserved}." )