Skip to content

Getting started

1. Load a pre-trained model

Pre-trained weights can be found on the histolytics model hub or downloaded automatically when calling from_pretrained. Make sure you have an internet connection for the first use.

Available segmentation model architectures are:

  • CellposePanoptic
  • HoverNetPanoptic
  • StarDistPanoptic
  • CellVitPanoptic
  • CPPNetPanoptic
from histolytics.models.cellpose_panoptic import CellposePanoptic
# from histolytics.models.hovernet_panoptic import HoverNetPanoptic
# from histolytics.models.stardist_panoptic import StarDistPanoptic


model = CellposePanoptic.from_pretrained("hgsc_v1_efficientnet_b5")
# model = HoverNetPanoptic.from_pretrained("hgsc_v1_efficientnet_b5")
# model = StarDistPanoptic.from_pretrained("hgsc_v1_efficientnet_b5")

2. Run inference for one image

from albumentations import Resize, Compose
from histolytics.utils import FileHandler
from histolytics.transforms import MinMaxNormalization

model.set_inference_mode()

# Resize to multiple of 32 of your own choosing
transform = Compose([Resize(1024, 1024), MinMaxNormalization()])

im = FileHandler.read_img(IMG_PATH)
im = transform(image=im)["image"]

prob = model.predict(im)
out = model.post_process(prob)
# out = {"nuc": [(nuc instances (H, W), nuc types (H, W))], "cyto": None, "tissue": None}

2.1 Run inference for image batch

import torch
from histolytics.utils import FileHandler

model.set_inference_mode()

# dont use random matrices IRL
batch = torch.rand(8, 3, 1024, 1024)

prob = model.predict(im)
out = model.post_process(prob)
# out = {
#  "nuc": [
#    (nuc instances (H, W), nuc types (H, W)),
#    (nuc instances (H, W), nuc types (H, W)),
#    .
#    .
#    .
#    (nuc instances (H, W), nuc types (H, W))
#  ],
#  "tissue": [
#    (nuc instances (H, W), nuc types (H, W)),
#    (nuc instances (H, W), nuc types (H, W)),
#    .
#    .
#    .
#    (nuc instances (H, W), nuc types (H, W))
#  ],
#  "cyto": None,
#}

3. Visualize output

from matplotlib import pyplot as plt
from skimage.color import label2rgb

fig, ax = plt.subplots(1, 4, figsize=(24, 6))
ax[0].imshow(im)
ax[1].imshow(label2rgb(out["nuc"][0][0], bg_label=0)) # inst_map
ax[2].imshow(label2rgb(out["nuc"][0][1], bg_label=0)) # type_map
ax[3].imshow(label2rgb(out["tissue"][0], bg_label=0)) # tissue_map
out