Skip to content

WSIGridProcessor

histolytics.wsi.wsi_processor.WSIGridProcessor

Source code in src/histolytics/wsi/wsi_processor.py
class WSIGridProcessor:
    def __init__(
        self,
        slide_reader: SlideReader,
        grid: gpd.GeoDataFrame,
        nuclei: gpd.GeoDataFrame,
        pipeline_func: Callable,
        tissue: gpd.GeoDataFrame = None,
        nuclei_classes: Dict[str, int] = None,
        tissue_classes: Dict[str, int] = None,
        batch_size: int = 8,
        num_workers: int = 8,
        pin_memory: bool = True,
        shuffle: bool = False,
        drop_last: bool = False,
    ):
        """Context manager for processing WSI grid cells.

        Parameters:
            slide_reader (SlideReader):
                SlideReader instance.
            grid (GeoDataFrame):
                A grid GeoDataFrame containing rectangular grid cells.
            nuclei (GeoDataFrame):
                A GeoDataFrame containing nuclei data.
            tissue (GeoDataFrame):
                A GeoDataFrame containing tissue data.
            nuclei_classes (Dict[str, int]):
                A dictionary mapping nuclei class names to integers.
            tissue_classes (Dict[str, int]):
                A dictionary mapping tissue class names to integers.
            batch_size (int):
                The batch size for processing.
            num_workers (int):
                The number of worker processes.
            pin_memory (bool):
                Whether to pin memory for faster GPU transfer.
            shuffle (bool):
                Whether to shuffle the data.
            drop_last (bool):
                Whether to drop the last incomplete batch.

        Examples:
            >>> from tqdm import tqdm
            >>> from histolystics.wsi.wsi_processor import WSIGridProcessor
            >>>
            >>> # ...  initialize reader, grid_gdf etc.
            >>> crop_loader = WSIGridProcessor(
            ...     slide_reader=reader, # SlideReader object
            ...     grid=grid_gdf, # GeoDataFrame containing grid cells
            ...     nuclei=nuc_gdf, # GeoDataFrame containing nuclei data
            ...     nuclei_classes=nuclei_classes, # Mapping of nuclei class names to integers
            ...     pipeline_func=partial(chromatin_feats, metrics=("chrom_area", "chrom_nuc_prop")),
            ...     batch_size=8,
            ...     num_workers=8,
            ...     pin_memory=False,
            ...     shuffle=False,
            ...     drop_last=False,
            ... )
            >>>
            >>> crop_feats = []
            >>> with crop_loader as loader:
            >>>     with tqdm(loader, unit="batch", total=len(loader)) as pbar:
            >>>         for batch_idx, batch in enumerate(pbar):
            >>>             crop_feats.append(batch)
        """
        self.slide_reader = slide_reader
        self.grid = grid
        self.nuclei = nuclei
        self.tissue = tissue
        self.nuclei_classes = nuclei_classes or {}
        self.tissue_classes = tissue_classes or {}
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.pipeline_func = pipeline_func

        # Internal state
        self._dataset = None
        self._loader = None
        self._iterator = None

    def __enter__(self):
        """Enter the context manager and initialize the dataset and loader."""
        # Create the dataset
        self._dataset = WSIGridDataset(
            slider_reader=self.slide_reader,
            grid=self.grid,
            nuclei=self.nuclei,
            pipeline_func=self.pipeline_func,
            tissue=self.tissue,
            nuclei_classes=self.nuclei_classes,
            tissue_classes=self.tissue_classes,
        )

        # Create the loader
        self._loader = NodesDataLoader(
            dataset=self._dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
            collate=self._dataset.collate,
        )

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Exit the context manager and clean up resources."""
        # Clean up iterator
        if self._iterator is not None:
            del self._iterator
            self._iterator = None

        # Clean up loader
        if self._loader is not None:
            del self._loader
            self._loader = None

        # Clean up dataset
        if self._dataset is not None:
            del self._dataset
            self._dataset = None

        # Force garbage collection
        gc.collect()

        # Return False to propagate any exceptions
        return False

    def __iter__(self):
        """Make the class iterable."""
        if self._loader is None:
            raise RuntimeError("Context manager not entered. Use 'with' statement.")

        self._iterator = iter(self._loader)
        return self

    def __next__(self):
        """Get the next batch."""
        if self._iterator is None:
            raise RuntimeError(
                "Iterator not initialized. Use 'with' statement and iterate."
            )

        return next(self._iterator)

    def __len__(self):
        """Get the total number of batches."""
        if self._dataset is None:
            raise RuntimeError("Context manager not entered. Use 'with' statement.")

        return int(np.ceil(len(self._dataset) / self.batch_size))

    @property
    def total_samples(self):
        """Get the total number of samples (grid cells)."""
        if self._dataset is None:
            raise RuntimeError("Context manager not entered. Use 'with' statement.")

        return len(self._dataset)

    def get_single_item(self, index: int):
        """Get a single item by index without batching."""
        if self._dataset is None:
            raise RuntimeError("Context manager not entered. Use 'with' statement.")

        return self._dataset[index]

total_samples property

total_samples

Get the total number of samples (grid cells).

__init__

__init__(slide_reader: SlideReader, grid: GeoDataFrame, nuclei: GeoDataFrame, pipeline_func: Callable, tissue: GeoDataFrame = None, nuclei_classes: Dict[str, int] = None, tissue_classes: Dict[str, int] = None, batch_size: int = 8, num_workers: int = 8, pin_memory: bool = True, shuffle: bool = False, drop_last: bool = False)

Context manager for processing WSI grid cells.

Parameters:

Name Type Description Default
slide_reader SlideReader

SlideReader instance.

required
grid GeoDataFrame

A grid GeoDataFrame containing rectangular grid cells.

required
nuclei GeoDataFrame

A GeoDataFrame containing nuclei data.

required
tissue GeoDataFrame

A GeoDataFrame containing tissue data.

None
nuclei_classes Dict[str, int]

A dictionary mapping nuclei class names to integers.

None
tissue_classes Dict[str, int]

A dictionary mapping tissue class names to integers.

None
batch_size int

The batch size for processing.

8
num_workers int

The number of worker processes.

8
pin_memory bool

Whether to pin memory for faster GPU transfer.

True
shuffle bool

Whether to shuffle the data.

False
drop_last bool

Whether to drop the last incomplete batch.

False

Examples:

>>> from tqdm import tqdm
>>> from histolystics.wsi.wsi_processor import WSIGridProcessor
>>>
>>> # ...  initialize reader, grid_gdf etc.
>>> crop_loader = WSIGridProcessor(
...     slide_reader=reader, # SlideReader object
...     grid=grid_gdf, # GeoDataFrame containing grid cells
...     nuclei=nuc_gdf, # GeoDataFrame containing nuclei data
...     nuclei_classes=nuclei_classes, # Mapping of nuclei class names to integers
...     pipeline_func=partial(chromatin_feats, metrics=("chrom_area", "chrom_nuc_prop")),
...     batch_size=8,
...     num_workers=8,
...     pin_memory=False,
...     shuffle=False,
...     drop_last=False,
... )
>>>
>>> crop_feats = []
>>> with crop_loader as loader:
>>>     with tqdm(loader, unit="batch", total=len(loader)) as pbar:
>>>         for batch_idx, batch in enumerate(pbar):
>>>             crop_feats.append(batch)
Source code in src/histolytics/wsi/wsi_processor.py
def __init__(
    self,
    slide_reader: SlideReader,
    grid: gpd.GeoDataFrame,
    nuclei: gpd.GeoDataFrame,
    pipeline_func: Callable,
    tissue: gpd.GeoDataFrame = None,
    nuclei_classes: Dict[str, int] = None,
    tissue_classes: Dict[str, int] = None,
    batch_size: int = 8,
    num_workers: int = 8,
    pin_memory: bool = True,
    shuffle: bool = False,
    drop_last: bool = False,
):
    """Context manager for processing WSI grid cells.

    Parameters:
        slide_reader (SlideReader):
            SlideReader instance.
        grid (GeoDataFrame):
            A grid GeoDataFrame containing rectangular grid cells.
        nuclei (GeoDataFrame):
            A GeoDataFrame containing nuclei data.
        tissue (GeoDataFrame):
            A GeoDataFrame containing tissue data.
        nuclei_classes (Dict[str, int]):
            A dictionary mapping nuclei class names to integers.
        tissue_classes (Dict[str, int]):
            A dictionary mapping tissue class names to integers.
        batch_size (int):
            The batch size for processing.
        num_workers (int):
            The number of worker processes.
        pin_memory (bool):
            Whether to pin memory for faster GPU transfer.
        shuffle (bool):
            Whether to shuffle the data.
        drop_last (bool):
            Whether to drop the last incomplete batch.

    Examples:
        >>> from tqdm import tqdm
        >>> from histolystics.wsi.wsi_processor import WSIGridProcessor
        >>>
        >>> # ...  initialize reader, grid_gdf etc.
        >>> crop_loader = WSIGridProcessor(
        ...     slide_reader=reader, # SlideReader object
        ...     grid=grid_gdf, # GeoDataFrame containing grid cells
        ...     nuclei=nuc_gdf, # GeoDataFrame containing nuclei data
        ...     nuclei_classes=nuclei_classes, # Mapping of nuclei class names to integers
        ...     pipeline_func=partial(chromatin_feats, metrics=("chrom_area", "chrom_nuc_prop")),
        ...     batch_size=8,
        ...     num_workers=8,
        ...     pin_memory=False,
        ...     shuffle=False,
        ...     drop_last=False,
        ... )
        >>>
        >>> crop_feats = []
        >>> with crop_loader as loader:
        >>>     with tqdm(loader, unit="batch", total=len(loader)) as pbar:
        >>>         for batch_idx, batch in enumerate(pbar):
        >>>             crop_feats.append(batch)
    """
    self.slide_reader = slide_reader
    self.grid = grid
    self.nuclei = nuclei
    self.tissue = tissue
    self.nuclei_classes = nuclei_classes or {}
    self.tissue_classes = tissue_classes or {}
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.pin_memory = pin_memory
    self.shuffle = shuffle
    self.drop_last = drop_last
    self.pipeline_func = pipeline_func

    # Internal state
    self._dataset = None
    self._loader = None
    self._iterator = None

__enter__

__enter__()

Enter the context manager and initialize the dataset and loader.

Source code in src/histolytics/wsi/wsi_processor.py
def __enter__(self):
    """Enter the context manager and initialize the dataset and loader."""
    # Create the dataset
    self._dataset = WSIGridDataset(
        slider_reader=self.slide_reader,
        grid=self.grid,
        nuclei=self.nuclei,
        pipeline_func=self.pipeline_func,
        tissue=self.tissue,
        nuclei_classes=self.nuclei_classes,
        tissue_classes=self.tissue_classes,
    )

    # Create the loader
    self._loader = NodesDataLoader(
        dataset=self._dataset,
        batch_size=self.batch_size,
        shuffle=self.shuffle,
        num_workers=self.num_workers,
        pin_memory=self.pin_memory,
        drop_last=self.drop_last,
        collate=self._dataset.collate,
    )

    return self

__exit__

__exit__(exc_type, exc_val, exc_tb)

Exit the context manager and clean up resources.

Source code in src/histolytics/wsi/wsi_processor.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Exit the context manager and clean up resources."""
    # Clean up iterator
    if self._iterator is not None:
        del self._iterator
        self._iterator = None

    # Clean up loader
    if self._loader is not None:
        del self._loader
        self._loader = None

    # Clean up dataset
    if self._dataset is not None:
        del self._dataset
        self._dataset = None

    # Force garbage collection
    gc.collect()

    # Return False to propagate any exceptions
    return False

__iter__

__iter__()

Make the class iterable.

Source code in src/histolytics/wsi/wsi_processor.py
def __iter__(self):
    """Make the class iterable."""
    if self._loader is None:
        raise RuntimeError("Context manager not entered. Use 'with' statement.")

    self._iterator = iter(self._loader)
    return self

__next__

__next__()

Get the next batch.

Source code in src/histolytics/wsi/wsi_processor.py
def __next__(self):
    """Get the next batch."""
    if self._iterator is None:
        raise RuntimeError(
            "Iterator not initialized. Use 'with' statement and iterate."
        )

    return next(self._iterator)

__len__

__len__()

Get the total number of batches.

Source code in src/histolytics/wsi/wsi_processor.py
def __len__(self):
    """Get the total number of batches."""
    if self._dataset is None:
        raise RuntimeError("Context manager not entered. Use 'with' statement.")

    return int(np.ceil(len(self._dataset) / self.batch_size))

get_single_item

get_single_item(index: int)

Get a single item by index without batching.

Source code in src/histolytics/wsi/wsi_processor.py
def get_single_item(self, index: int):
    """Get a single item by index without batching."""
    if self._dataset is None:
        raise RuntimeError("Context manager not entered. Use 'with' statement.")

    return self._dataset[index]