Skip to content

Examples

This page provides practical examples for common use cases with StainX.

Basic Usage

The simplest workflow: fit on a reference image, then transform source images.

import torch
from stainx import Reinhard

# Prepare images
reference = torch.randn(1, 3, 512, 512)  # Reference/template image
images = torch.randn(10, 3, 512, 512)    # Images to normalize

# Create normalizer and fit
normalizer = Reinhard(device="cuda")
normalizer.fit(reference)

# Transform images
normalized = normalizer.transform(images)

All Normalizers

StainX provides three normalization algorithms:

import torch
from stainx import Reinhard, Macenko, HistogramMatching

reference = torch.randn(1, 3, 512, 512)
images = torch.randn(10, 3, 512, 512)

# Reinhard normalization
reinhard = Reinhard(device="cuda")
reinhard.fit(reference)
normalized_reinhard = reinhard.transform(images)

# Macenko normalization
macenko = Macenko(device="cuda")
macenko.fit(reference)
normalized_macenko = macenko.transform(images)

# Histogram Matching
histogram = HistogramMatching(device="cuda", channel_axis=1)
histogram.fit(reference)
normalized_histogram = histogram.transform(images)

Fit and Transform in One Step

Use fit_transform() for convenience:

normalizer = Reinhard(device="cuda")
normalized = normalizer.fit_transform(images)  # Fits and transforms in one call

Batch Processing

Process multiple images efficiently in a single batch:

import torch
from stainx import Reinhard

# Small batch
small_batch = torch.randn(8, 3, 512, 512)

# Large batch (more efficient)
large_batch = torch.randn(128, 3, 512, 512)

normalizer = Reinhard(device="cuda")
normalizer.fit(torch.randn(1, 3, 512, 512))

# Process entire batch at once
normalized = normalizer.transform(large_batch)
print(f"Processed {large_batch.shape[0]} images")

Channels-Last Format

Support for channels-last format (useful when working with certain image loaders):

import torch
from stainx import HistogramMatching

# Images in (N, H, W, C) format
images = torch.randn(10, 512, 512, 3)

# Use channel_axis=-1 for channels-last
normalizer = HistogramMatching(device="cuda", channel_axis=-1)
normalizer.fit(images[:1])  # Fit on first image
normalized = normalizer.transform(images)

Working with Real Images

Example of loading and processing real images:

import torch
from PIL import Image
import torchvision.transforms as transforms
from stainx import Reinhard

# Load reference image
reference_img = Image.open("reference.png")
reference_tensor = transforms.ToTensor()(reference_img).unsqueeze(0)  # (1, 3, H, W)

# Load source images
source_images = []
for path in ["img1.png", "img2.png", "img3.png"]:
    img = Image.open(path)
    tensor = transforms.ToTensor()(img)
    source_images.append(tensor)

# Stack into batch
source_batch = torch.stack(source_images)  # (3, 3, H, W)

# Normalize
normalizer = Reinhard(device="cuda")
normalizer.fit(reference_tensor)
normalized_batch = normalizer.transform(source_batch)

# Convert back to images if needed
for i, normalized in enumerate(normalized_batch):
    img = transforms.ToPILImage()(normalized)
    img.save(f"normalized_{i}.png")

Device Selection

Examples for different devices:

import torch
from stainx import Reinhard

reference = torch.randn(1, 3, 512, 512)
images = torch.randn(10, 3, 512, 512)

# CPU
normalizer_cpu = Reinhard(device="cpu")
normalizer_cpu.fit(reference)
normalized_cpu = normalizer_cpu.transform(images)

# CUDA (NVIDIA GPU)
if torch.cuda.is_available():
    normalizer_cuda = Reinhard(device="cuda")
    normalizer_cuda.fit(reference.to("cuda"))
    normalized_cuda = normalizer_cuda.transform(images.to("cuda"))

# MPS (Apple Silicon)
if torch.backends.mps.is_available():
    normalizer_mps = Reinhard(device="mps")
    normalizer_mps.fit(reference.to("mps"))
    normalized_mps = normalizer_mps.transform(images.to("mps"))

Backend Selection

Force a specific backend:

from stainx import Reinhard

reference = torch.randn(1, 3, 512, 512, device="cuda")
images = torch.randn(10, 3, 512, 512, device="cuda")

# Use optimized CUDA backends (if available)
normalizer_torch_cuda = Reinhard(device="cuda", backend="torch_cuda")
normalizer_torch_cuda.fit(reference)
normalized_torch_cuda = normalizer_torch_cuda.transform(images)

# Force torch backend (works everywhere)
normalizer_torch = Reinhard(device="cuda", backend="torch")
normalizer_torch.fit(reference)
normalized_torch = normalizer_torch.transform(images)

Processing Different Image Sizes

Handle images of different sizes in a batch:

import torch
from stainx import Reinhard

# Images must be the same size in a batch
# If you have different sizes, process them separately or resize first

reference = torch.randn(1, 3, 512, 512)
small_images = torch.randn(5, 3, 256, 256)
large_images = torch.randn(5, 3, 1024, 1024)

normalizer = Reinhard(device="cuda")

# Process small images
normalizer.fit(reference)
normalized_small = normalizer.transform(small_images)

# Process large images
normalized_large = normalizer.transform(large_images)

Preserving Data Types

StainX preserves input data types:

import torch
from stainx import Reinhard

# uint8 input
reference_uint8 = (torch.rand(1, 3, 512, 512) * 255).round().to(torch.uint8)
images_uint8 = (torch.rand(10, 3, 512, 512) * 255).round().to(torch.uint8)

normalizer = Reinhard(device="cuda")
normalizer.fit(reference_uint8)
normalized = normalizer.transform(images_uint8)

print(f"Input dtype: {images_uint8.dtype}")
print(f"Output dtype: {normalized.dtype}")  # Should match input