"""
Tiler: Tiles and tracks traps.
The tasks of the Tiler are selecting regions of interest, or tiles, of an image - with one tile per trap, tracking and correcting for the drift of the microscope stage over time, and handling errors and bridging between the image data and ALIBY’s image-processing steps.
Tiler subclasses deal with either network connections or local files.
To find traps, we use a two-step process: we analyse the bright-field image to produce the template of a trap, and we fit this template to the image to find the traps' centres.
We use texture-based segmentation (entropy) to split the image into foreground -- cells and traps -- and background, which we then identify with an Otsu filter. Two methods are used to produce a template trap from these regions: pick the trap with the smallest minor axis length and average over all validated traps.
A peak-identifying algorithm recovers the x and y-axis location of traps in the original image, and we choose the templating approach that identifies the most traps
One key method is Tiler.run.
The image-processing is performed by traps/segment_traps.
The experiment is stored as an array with a standard indexing order of (Time, Channels, Z-stack, X, Y).
"""
import re
import typing as t
import warnings
from functools import lru_cache
from pathlib import PosixPath
import dask.array as da
import h5py
import numpy as np
from skimage.registration import phase_cross_correlation
from agora.abc import ParametersABC, StepABC
from agora.io.writer import BridgeH5
from aliby.io.image import ImageLocalOME, ImageDir, ImageDummy
from aliby.tile.traps import segment_traps
[docs]class Trap:
"""
Stores a trap's location and size.
Allows checks to see if the trap should be padded.
Can export the trap either in OMERO or numpy formats.
"""
[docs] def __init__(self, centre, parent, size, max_size):
self.centre = centre
self.parent = parent # used to access drifts
self.size = size
self.half_size = size // 2
self.max_size = max_size
[docs] def at_time(self, tp: int) -> t.List[int]:
"""
Return trap centre at time tp by applying drifts
Parameters
----------
tp: integer
Index for a time point
Returns
-------
trap_centre:
"""
drifts = self.parent.drifts
trap_centre = self.centre - np.sum(drifts[: tp + 1], axis=0)
return list(trap_centre.astype(int))
[docs] def as_tile(self, tp):
"""
Return trap in the OMERO tile format of x, y, w, h
where x, y are at the bottom left corner of the tile
and w and h are the tile width and height.
Parameters
----------
tp: integer
Index for a time point
Returns
-------
x: int
x-coordinate of bottom left corner of tile
y: int
y-coordinate of bottom left corner of tile
w: int
Width of tile
h: int
Height of tile
"""
x, y = self.at_time(tp)
# tile bottom corner
x = int(x - self.half_size)
y = int(y - self.half_size)
return x, y, self.size, self.size
[docs] def as_range(self, tp):
"""
Return trap in a range format: two slice objects that can
be used in arrays
Parameters
----------
tp: integer
Index for a time point
Returns
-------
A slice of x coordinates from left to right
A slice of y coordinates from top to bottom
"""
x, y, w, h = self.as_tile(tp)
return slice(x, x + w), slice(y, y + h)
[docs]class TrapLocations:
"""
Stores each trap as an instance of Trap.
Traps can be iterated.
"""
[docs] def __init__(
self,
initial_location: np.array,
tile_size: int = None,
max_size: int = 1200,
drifts: np.array = None,
):
if drifts is None:
drifts = []
self.tile_size = tile_size
self.max_size = max_size
self.initial_location = initial_location
self.traps = [
Trap(centre, self, tile_size or max_size, max_size)
for centre in initial_location
]
self.drifts = drifts
def __len__(self):
return len(self.traps)
def __iter__(self):
yield from self.traps
@property
def shape(self):
"""
Returns no of traps and no of drifts
"""
return len(self.traps), len(self.drifts)
[docs] def to_dict(self, tp):
"""
Export inital locations, tile_size, max_size, and drifts
as a dictionary
Parameters
----------
tp: integer
An index for a time point
"""
res = dict()
if tp == 0:
res["trap_locations"] = self.initial_location
res["attrs/tile_size"] = self.tile_size
res["attrs/max_size"] = self.max_size
res["drifts"] = np.expand_dims(self.drifts[tp], axis=0)
return res
def at_time(self, tp: int) -> np.ndarray:
# Returns ( ntraps, 2 ) ndarray with the trap centres as individual rows
return np.array([trap.at_time(tp) for trap in self.traps])
[docs] @classmethod
def from_tiler_init(
cls, initial_location, tile_size: int = None, max_size: int = 1200
):
"""
Instantiate class from an instance of the Tiler class
"""
return cls(initial_location, tile_size, max_size, drifts=[])
[docs] @classmethod
def read_hdf5(cls, file):
"""
Instantiate class from a hdf5 file
"""
with h5py.File(file, "r") as hfile:
trap_info = hfile["trap_info"]
initial_locations = trap_info["trap_locations"][()]
drifts = trap_info["drifts"][()].tolist()
max_size = trap_info.attrs["max_size"]
tile_size = trap_info.attrs["tile_size"]
trap_locs = cls(initial_locations, tile_size, max_size=max_size)
trap_locs.drifts = drifts
return trap_locs
[docs]class TilerParameters(ParametersABC):
_defaults = {"tile_size": 117, "ref_channel": "Brightfield", "ref_z": 0}
[docs]class Tiler(StepABC):
"""
Remote Timelapse Tiler.
Finds traps and re-registers images if there is any drifting.
Fetches images from a server.
Uses an Image instance, which lazily provides the data on pixels, and, as
an independent argument, metadata.
"""
[docs] def __init__(
self,
image: da.core.Array,
metadata: dict,
parameters: TilerParameters,
trap_locs=None,
):
"""
Initialise Tiler
Parameters
----------
image: an instance of Image
metadata: dictionary
parameters: an instance of TilerPameters
trap_locs: (optional)
"""
super().__init__(parameters)
self.image = image
self._metadata = metadata
self.channels = metadata.get(
"channels", list(range(metadata["size_c"]))
)
self.ref_channel = self.get_channel_index(parameters.ref_channel)
self.trap_locs = trap_locs
try:
self.z_perchannel = {
ch: zsect
for ch, zsect in zip(self.channels, metadata["zsections"])
}
except Exception as e:
self._log(f"No z_perchannel data: {e}")
self.tile_size = self.tile_size or min(self.image.shape[-2:])
[docs] @classmethod
def dummy(cls, parameters: dict):
"""
Instantiate dummy Tiler from dummy image
If image.dimorder exists dimensions are saved in that order.
Otherwise default to "tczyx".
Parameters
----------
parameters: dictionary output of an instance of TilerParameters
"""
imgdmy_obj = ImageDummy(parameters)
dummy_image = imgdmy_obj.get_data_lazy()
# Default to "tczyx" if image.dimorder is None
dummy_omero_metadata = {
f"size_{dim}": dim_size
for dim, dim_size in zip(
imgdmy_obj.dimorder or "tczyx", dummy_image.shape
)
}
dummy_omero_metadata.update(
{
"channels": [
parameters["ref_channel"],
*(["nil"] * (dummy_omero_metadata["size_c"] - 1)),
],
"name": "",
}
)
return cls(
imgdmy_obj.data,
dummy_omero_metadata,
TilerParameters.from_dict(parameters),
)
[docs] @classmethod
def from_image(cls, image, parameters: TilerParameters):
"""
Instantiate Tiler from an Image instance
Parameters
----------
image: an instance of Image
parameters: an instance of TilerPameters
"""
return cls(image.data, image.metadata, parameters)
[docs] @classmethod
def from_h5(
cls,
image: t.Union[
ImageLocalOME, ImageDir
], # TODO provide baseclass instead
filepath: t.Union[str, PosixPath],
parameters: TilerParameters = None,
):
"""
Instantiate Tiler from hdf5 files
Parameters
----------
image: an instance of Image
filepath: Path instance
Path to a directory of h5 files
parameters: an instance of TileParameters (optional)
"""
trap_locs = TrapLocations.read_hdf5(filepath)
metadata = BridgeH5(filepath).meta_h5
metadata["channels"] = image.metadata["channels"]
if parameters is None:
parameters = TilerParameters.default()
tiler = cls(
image.data,
metadata,
parameters,
trap_locs=trap_locs,
)
if hasattr(trap_locs, "drifts"):
tiler.n_processed = len(trap_locs.drifts)
return tiler
[docs] @lru_cache(maxsize=2)
def get_tc(self, t, c):
"""
Load image using dask.
Assumes the image is arranged as
no of time points
no of channels
no of z stacks
no of pixels in y direction
no of pixels in x direction
Parameters
----------
t: integer
An index for a time point
c: integer
An index for a channel
Retruns
-------
full: an array of images
"""
full = self.image[t, c].compute(scheduler="synchronous")
return full
@property
def shape(self):
"""
Returns properties of the time-lapse as shown by self.image.shape
"""
return self.image.shape
@property
def n_processed(self):
"""
Returns the number of images that have been processed
"""
if not hasattr(self, "_n_processed"):
self._n_processed = 0
return self._n_processed
@n_processed.setter
def n_processed(self, value):
self._n_processed = value
@property
def n_traps(self):
"""
Returns number of traps
"""
return len(self.trap_locs)
[docs] def initialise_traps(self, tile_size: int = None):
"""
Find initial trap positions if they have not been initialised.
Removes all those that are too close to the edge so no padding
is necessary.
Parameters
----------
tile_size: integer
The size of a tile
"""
initial_image = self.image[0, self.ref_channel, self.ref_z]
if tile_size:
half_tile = tile_size // 2
# max_size is the minimal number of x or y pixels
max_size = min(self.image.shape[-2:])
# first time point, reference channel, reference z-position
# find the traps
trap_locs = segment_traps(initial_image, tile_size)
# keep only traps that are not near an edge
trap_locs = [
[x, y]
for x, y in trap_locs
if half_tile < x < max_size - half_tile
and half_tile < y < max_size - half_tile
]
# store traps in an instance of TrapLocations
self.trap_locs = TrapLocations.from_tiler_init(
trap_locs, tile_size
)
else:
yx_shape = self.image.shape[-2:]
trap_locs = [[x // 2 for x in yx_shape]]
self.trap_locs = TrapLocations.from_tiler_init(
trap_locs, max_size=min(yx_shape)
)
[docs] def find_drift(self, tp):
"""
Find any translational drift between two images at consecutive
time points using cross correlation.
Arguments
---------
tp: integer
Index for a time point
"""
prev_tp = max(0, tp - 1)
# cross-correlate
drift, _, _ = phase_cross_correlation(
self.image[prev_tp, self.ref_channel, self.ref_z],
self.image[tp, self.ref_channel, self.ref_z],
)
# store drift
if 0 < tp < len(self.trap_locs.drifts):
self.trap_locs.drifts[tp] = drift.tolist()
else:
self.trap_locs.drifts.append(drift.tolist())
[docs] def get_tp_data(self, tp, c):
"""
Returns all traps corrected for drift.
Parameters
----------
tp: integer
An index for a time point
c: integer
An index for a channel
"""
traps = []
# get OMERO image
full = self.get_tc(tp, c)
for trap in self.trap_locs:
# pad trap if necessary
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
traps.append(ndtrap)
return np.stack(traps)
[docs] def get_trap_data(self, trap_id, tp, c):
"""
Returns a particular trap corrected for drift and padding
Parameters
----------
trap_id: integer
Number of trap
tp: integer
Index of time points
c: integer
Index of channel
Returns
-------
ndtrap: array
An array of (x, y) arrays, one for each z stack
"""
full = self.get_tc(tp, c)
trap = self.trap_locs.traps[trap_id]
ndtrap = self.ifoob_pad(full, trap.as_range(tp))
return ndtrap
def _run_tp(self, tp):
"""
Find traps if they have not yet been found.
Determine any translational drift of the current image from the
previous one.
Arguments
---------
tp: integer
The time point to tile.
"""
# assert tp >= self.n_processed, "Time point already processed"
# TODO check contiguity?
if self.n_processed == 0 or not hasattr(self.trap_locs, "drifts"):
self.initialise_traps(self.tile_size)
if hasattr(self.trap_locs, "drifts"):
drift_len = len(self.trap_locs.drifts)
if self.n_processed != drift_len:
warnings.warn("Tiler:n_processed and ndrifts don't match")
self.n_processed = drift_len
# determine drift
self.find_drift(tp)
# update n_processed
self.n_processed = tp + 1
# return result for writer
return self.trap_locs.to_dict(tp)
[docs] def run(self, time_dim=None):
"""
Tile all time points in an experiment at once.
"""
if time_dim is None:
time_dim = 0
for frame in range(self.image.shape[time_dim]):
self.run_tp(frame)
return None
def get_traps_timepoint(self, *args, **kwargs):
self._log(
"get_trap_timepoints is deprecated; get_tiles_timepoint instead."
)
return self.get_tiles_timepoint(*args, **kwargs)
# The next set of functions are necessary for the extraction object
[docs] def get_tiles_timepoint(
self, tp, tile_shape=None, channels=None, z: int = 0
) -> np.ndarray:
"""
Get a multidimensional array with all tiles for a set of channels
and z-stacks.
Used by extractor.
Parameters
---------
tp: int
Index of time point
tile_shape: int or tuple of two ints
Size of tile in x and y dimensions
channels: string or list of strings
Names of channels of interest
z: int
Index of z-channel of interest
Returns
-------
res: array
Data arranged as (traps, channels, timepoints, X, Y, Z)
"""
# FIXME add support for subtiling trap
# FIXME can we ignore z(always give)
if channels is None:
channels = [0]
elif isinstance(channels, str):
channels = [channels]
# get the data
res = []
for c in channels:
# only return requested z
val = self.get_tp_data(tp, c)[:, z]
# starts with the order: traps, z, y, x
# returns the order: trap, C, T, X, Y, Z
val = val.swapaxes(1, 3).swapaxes(1, 2)
val = np.expand_dims(val, axis=1)
res.append(val)
if tile_shape is not None:
if isinstance(tile_shape, int):
tile_shape = (tile_shape, tile_shape)
assert np.all(
[
(tile_size - ax) > -1
for tile_size, ax in zip(tile_shape, res[0].shape[-3:-2])
]
)
return np.stack(res, axis=1)
@property
def ref_channel_index(self):
return self.get_channel_index(self.parameters.ref_channel)
[docs] def get_channel_index(self, channel: str or int):
"""
Find index for channel using regex. Returns the first matched string.
Parameters
----------
channel: string or int
The channel or index to be used
"""
if isinstance(channel, str):
channel = find_channel_index(self.channels, channel)
if channel is None:
raise Warning(
f"Reference channel {channel} not in the available channels: {self.channels}"
)
return channel
[docs] @staticmethod
def ifoob_pad(full, slices):
"""
Returns the slices padded if it is out of bounds.
Parameters
----------
full: array
Slice of OMERO image (zstacks, x, y) - the entire position
with zstacks as first axis
slices: tuple of two slices
Delineates indiceds for the x- and y- ranges of the tile.
Returns
-------
trap: array
A tile with all z stacks for the given slices.
If some padding is needed, the median of the image is used.
If much padding is needed, a tile of NaN is returned.
"""
# number of pixels in the y direction
max_size = full.shape[-1]
# ignore parts of the tile outside of the image
y, x = [slice(max(0, s.start), min(max_size, s.stop)) for s in slices]
# get the tile including all z stacks
trap = full[:, y, x]
# find extent of padding needed in x and y
padding = np.array(
[(-min(0, s.start), -min(0, max_size - s.stop)) for s in slices]
)
if padding.any():
tile_size = slices[0].stop - slices[0].start
if (padding > tile_size / 4).any():
# too much of the tile is outside of the image
# fill with NaN
trap = np.full((full.shape[0], tile_size, tile_size), np.nan)
else:
# pad tile with median value of trap image
trap = np.pad(trap, [[0, 0]] + padding.tolist(), "median")
return trap
[docs]def find_channel_index(image_channels: t.List[str], channel: str):
"""
Access
"""
for i, ch in enumerate(image_channels):
found = re.match(channel, ch, re.IGNORECASE)
if found:
if len(found.string) - (found.endpos - found.start()):
self._log(f"Channel {channel} matched {ch} using regex")
return i
[docs]def find_channel_name(image_channels: t.List[str], channel: str):
"""
Find the name of the channel according to a given channel regex.
"""
index = find_channel_index(image_channels, channel)
if index is not None:
return image_channels[index]