"""Functions for identifying and dealing with ALCATRAS traps."""
import numpy as np
from skimage import feature, transform
from skimage.filters import threshold_otsu
from skimage.filters.rank import entropy
from skimage.measure import label, regionprops
from skimage.morphology import closing, disk, square
from skimage.segmentation import clear_border
from skimage.util import img_as_ubyte
[docs]def half_floor(x, tile_size):
return x - tile_size // 2
[docs]def half_ceil(x, tile_size):
return x + -(tile_size // -2)
[docs]def segment_traps(
image,
tile_size,
downscale=0.4,
disk_radius_frac=0.01,
square_size=3,
min_frac_tilesize=0.3,
**identify_traps_kwargs,
):
"""
Use an entropy filter and Otsu thresholding to find a trap template,
which is then passed to identify_trap_locations.
To obtain candidate traps the major axis length of a tile must be smaller than tilesize.
The hyperparameters have not been optimised.
Parameters
----------
image: 2D array
tile_size: integer
Size of the tile
downscale: float (optional)
Fraction by which to shrink image
disk_radius_frac: float (optional)
Radius of disk using in the entropy filter
square_size: integer (optional)
Parameter for a morphological closing applied to thresholded
image
min_frac_tilesize: float (optional)
max_frac_tilesize: float (optional)
Used to determine bounds on the major axis length of regions
suspected of containing traps.
identify_traps_kwargs:
Passed to identify_trap_locations
Returns
-------
traps: an array of pairs of integers
The coordinates of the centroids of the traps.
"""
# keep a memory of image in case need to re-run
img = image
# bounds on major axis length of traps
min_mal = min_frac_tilesize * tile_size
# shrink image
if downscale != 1:
img = transform.rescale(image, downscale)
# generate an entropy image using a disk footprint
disk_radius = int(min([disk_radius_frac * x for x in img.shape]))
entropy_image = entropy(img_as_ubyte(img), disk(disk_radius))
if downscale != 1:
entropy_image = transform.rescale(entropy_image, 1 / downscale)
# find Otsu threshold for entropy image
thresh = threshold_otsu(entropy_image)
# apply morphological closing to thresholded, and so binary, image
bw = closing(entropy_image > thresh, square(square_size))
# remove artifacts connected to image border
cleared = clear_border(bw)
# label distinct regions of the image
label_image = label(cleared)
# find regions likely to contain traps:
# with a major axis length within a certain range
# and a centroid at least tile_size // 2 away from the image edge
idx_valid_region = [
(i, region)
for i, region in enumerate(regionprops(label_image))
if min_mal < region.major_axis_length < tile_size
and tile_size // 2
< region.centroid[0]
< half_floor(image.shape[0], tile_size) - 1
and tile_size // 2
< region.centroid[1]
< half_floor(image.shape[1], tile_size) - 1
]
assert idx_valid_region, "No valid tiling regions found"
_, valid_region = zip(*idx_valid_region)
# find centroids and minor axes lengths of valid regions
centroids = (
np.array([x.centroid for x in valid_region]).round().astype(int)
)
minals = [region.minor_axis_length for region in valid_region]
# coords for best trap
x, y = np.round(centroids[np.argmin(minals)]).astype(int)
# make candidate templates from the other traps found
candidate_templates = [
image[
half_floor(x, tile_size) : half_ceil(x, tile_size),
half_floor(y, tile_size) : half_ceil(y, tile_size),
]
for x, y in centroids
]
# make a mean template from all the found traps
mean_template = np.stack(candidate_templates).astype(int).mean(axis=0)
# find traps using the mean trap template
traps = identify_trap_locations(
image, mean_template, **identify_traps_kwargs
)
# if there are too few traps, try again
traps_retry = []
if len(traps) < 30 and downscale != 1:
print("Tiler:TrapIdentification: Trying again.")
traps_retry = segment_traps(image, tile_size, downscale=1)
# return results with the most number of traps
if len(traps_retry) < len(traps):
return traps
else:
return traps_retry
[docs]def identify_trap_locations(
image, trap_template, optimize_scale=True, downscale=0.35, trap_size=None
):
"""
Identify the traps in a single image based on a trap template.
Requires the trap template to be similar to the image
(same camera, same magnification - ideally the same experiment).
Use normalised correlation in scikit-image's to match_template.
The search is sped up by down-scaling both the image and
the trap template before running the template matching.
The trap template is rotated and re-scaled to improve matching.
The parameters of the rotation and re-scaling are optimised, although
over restricted ranges.
Parameters
----------
image: 2D array
trap_template: 2D array
optimize_scale : boolean (optional)
downscale: float (optional)
Fraction by which to downscale to increase speed
trap_size: integer (optional)
If unspecified, the size is determined from the trap_template
Returns
-------
coordinates: an array of pairs of integers
"""
if trap_size is None:
trap_size = trap_template.shape[0]
# careful: the image is float16!
img = transform.rescale(image.astype(float), downscale)
template = transform.rescale(trap_template, downscale)
# try multiple rotations of template to determine
# which best matches the image
# result is squared because the sign of the correlation is unimportant
matches = {
rotation: feature.match_template(
img,
transform.rotate(template, rotation, cval=np.median(img)),
pad_input=True,
mode="median",
)
** 2
for rotation in [0, 90, 180, 270]
}
# find best rotation
best_rotation = max(matches, key=lambda x: np.percentile(matches[x], 99.9))
# rotate template by best rotation
template = transform.rotate(template, best_rotation, cval=np.median(img))
if optimize_scale:
# try multiple scales appled to template to determine which
# best matches the image
scales = np.linspace(0.5, 2, 10)
matches = {
scale: feature.match_template(
img,
transform.rescale(template, scale),
mode="median",
pad_input=True,
)
** 2
for scale in scales
}
# find best scale
best_scale = max(
matches, key=lambda x: np.percentile(matches[x], 99.9)
)
# choose the best result - an image of normalised correlations
# with the template
matched = matches[best_scale]
else:
# find the image of normalised correlations with the template
matched = feature.match_template(
img, template, pad_input=True, mode="median"
)
# re-scale back the image of normalised correlations
# find the coordinates of local maxima
coordinates = feature.peak_local_max(
transform.rescale(matched, 1 / downscale),
min_distance=int(trap_size * 0.70),
exclude_border=(trap_size // 3),
)
return coordinates