import numpy as np
import pandas as pd
from scipy.ndimage import binary_fill_holes
from baby.io import load_tiled_image
from baby.tracker.core import CellTracker
[docs]class CellBenchmarker: # TODO Simplify this by inheritance
"""
Takes a metadata dataframe and a model and estimates the prediction in a trap-wise manner.
This class can also produce confusion matrices for a given Tracker and validation dataset.
"""
[docs] def __init__(self, meta, model, bak_model, nstepsback=None):
self.indices = ["experimentID", "position", "trap", "tp"]
self.cindices = self.indices + ["cellLabels"]
self.meta = meta.copy()
self.meta["cont_list_index"] = (*range(len(self.meta)),)
self.tracker = CellTracker(model=model, bak_model=bak_model)
if nstepsback is None:
self.nstepsback = self.tracker.nstepsback
self.traps_loc
@property
def traps_loc(self):
"""
Generates a list of trap locations using the metadata.
"""
if not hasattr(self, "_traps_loc"):
traps = np.unique(
[
ind[: self.indices.index("trap") + 1]
for ind in self.meta.index
],
axis=0,
)
# str->int conversion
traps = [(ind[0], *map(int, ind[1:])) for ind in traps]
self._traps_loc = (*map(tuple, traps),)
return self._traps_loc
@property
def masks(self):
if not hasattr(self, "_masks"):
self._masks = [
load_tiled_image(fname)[0] for fname in self.meta["filename"]
]
for i, mask in enumerate(self._masks):
for j in range(mask.shape[2]):
self._masks[i][..., j] = binary_fill_holes(
self._masks[i][..., j]
)
self._masks = [np.moveaxis(mask, 2, 0) for mask in self._masks]
return self._masks
def predict_lbls_from_tpimgs(self, tp_img_tuple):
max_lbl = 0
prev_feats = []
cell_lbls = []
for tp, masks in tp_img_tuple:
lastn_lbls = cell_lbls[-self.nstepsback :]
lastn_feats = prev_feats[-self.nstepsback :]
new_lbl, feats, max_lbl = self.tracker.get_new_lbls(
masks, lastn_lbls, lastn_feats, max_lbl
)
cell_lbls = cell_lbls + [new_lbl]
prev_feats = prev_feats + [feats]
return (tp, cell_lbls)
def df_get_imglist(self, exp, pos, trap, tp=None):
df = self.meta.loc[(exp, pos, trap), ["cont_list_index", "cellLabels"]]
return zip(df.index, [self.masks[i] for i in df["cont_list_index"]])
[docs] def predict_set(self, exp, pos, trap, tp=None):
"""
Predict labels using tp1-tp2 accuracy of prediction
"""
# print("Processing trap {}".format(exp, pos, trap))
tp_img_tuple = (*self.df_get_imglist(exp, pos, trap),)
tp, lbl_list = self.predict_lbls_from_tpimgs(tp_img_tuple)
# print("loc {}, {}, {}, labels: {}".format(exp, pos, trap, lbl_list))
return lbl_list
[docs] def compare_traps(self, exp, pos, trap):
"""
Error calculator for testing model and assignment heuristics.
Uses the trap id to compare the amount of cells correctly predicted.
This uses local indices, not whole timepoints. It returns the
fraction of cells correctly predicted, and the timepoints of mistakes
Returns:
float: Fraction of cells correctly predicted
list of 2-sized tuples: list of tp id of errors and the mistaken cell
"""
print("Processing trap {}, {}, {}".format(exp, pos, trap))
new_cids = self.predict_set(exp, pos, trap)
test_df = self.meta.loc(axis=0)[(exp, pos, trap)].copy()
test_df["pred_cellLabels"] = new_cids
orig = test_df["cellLabels"].values
new = test_df["pred_cellLabels"].values
local_indices = [[], []]
# Case just defines if it is the test or new set
# print("Making tp-wise comparison")
for i, case in enumerate(
(zip(orig[:-1], orig[1:]), zip(new[:-1], new[1:]))
):
for prev_cells, pos_cells in case:
local_assignment = [
prev_cells.index(cell) if cell in prev_cells else -1
for cell in pos_cells
]
local_indices[i] += local_assignment
# Flatten
if len(local_indices) > 2:
flt_test, flt_new = [
np.array([j for i in case for j in i])
for case in local_indices
]
tp_list = np.array(
[i for i, vals in enumerate(local_indices[0]) for j in vals]
)
else:
flt_test, flt_new = [
np.array([i for i in case]) for case in local_indices
]
# tp_list = np.array(
# [i for i, vals in enumerate(local_indices[0]) for j in vals])
correct = flt_test == flt_new
if len(local_indices) > 2:
error_list = tp_list[~correct]
error_cid = (
test_df.iloc[1:]["cellLabels"].explode().dropna()[~correct].values
)
frac_correct = np.mean(correct)
print("Fraction of correct predictions", frac_correct)
if len(local_indices) > 2:
return (frac_correct, list(zip(error_list, error_cid)))
else:
# print("Warning: Single set of tps for this position")
return (frac_correct, error_cid)
[docs] def predict_all(self):
"""
Predict all datasets defined in self.traps_loc
"""
stepsback = [2]
threshs = [0.9]
self.predictions = {}
for nstepsback in stepsback:
for thresh in threshs:
self.nstepsback = nstepsback
self.tracker.nstepsback = nstepsback
self.low_thresh = 1 - thresh
self.high_thresh = thresh
self.thresh = thresh * 5 / 8
for address in self.traps_loc:
self.predictions[
(nstepsback, thresh, address)
] = self.predict_set(*address)
[docs] def calculate_errsum(self):
"""
Calculate all errors, addresses of images with errors and error fractions.
"""
frac_errs = {}
all_errs = {}
nerrs = {}
stepsback = list(range(1, 3))
threshs = [0.95]
for nstepsback in stepsback:
for thresh in threshs:
self.nstepsback = nstepsback
self.tracker.nstepsback = nstepsback
self.low_thresh = 1 - thresh
self.high_thresh = thresh
self.thresh = thresh * 5 / 8
all_errs[(thresh, nstepsback)] = {}
frac_errs[(thresh, nstepsback)] = []
nerrs[(thresh, nstepsback)] = []
for address in self.traps_loc:
fraction, errors = self.compare_traps(*address)
if len(errors):
all_errs[(thresh, nstepsback)][address] = errors
frac_errs[(thresh, nstepsback)].append(fraction)
nerrs[(thresh, nstepsback)].append(len(errors))
else:
nerrs[(thresh, nstepsback)].append(0)
frac_errs[(thresh, nstepsback)].append(1.0)
return (frac_errs, all_errs, nerrs)
[docs] def get_truth_matrix_from_pair(self, pair):
"""
Requires self.meta
args:
:pair: tuple of size 4 (experimentID, position, trap (tp1, tp2))
returns
:truth_mat: boolean ndarray of shape (ncells(tp1) x ncells(tp2)
links cells in tp1 to cells in tp2
"""
clabs1 = self.meta.loc[pair[:3] + (pair[3][0],), "cellLabels"]
clabs2 = self.meta.loc[pair[:3] + (pair[3][1],), "cellLabels"]
truth_mat = gen_boolmat_from_clabs(clabs1, clabs2)
return truth_mat
def get_mota_stats(self, pair):
true_mat = self.get_truth_matrix_from_pair(pair)
prob_mat = self.tracker.predict_proba_from_ndarray(
ndarray, *args, **kwargs
)
pred_mat = prob_mat > thresh
true_flat = true_mat.flatten()
pred_flat = pred_mat.flatten()
true_pos = np.sum(true_flat & pred_flat)
false_pos = np.sum(true_flat & ~pred_flat)
# TODO add identity switch
[docs] def gen_cm_stats(self, pair, thresh=0.7, *args, **kwargs):
"""
Calculate confusion matrix for a pair of pos-timepoints
"""
masks = [self.masks[i] for i in self.meta.loc[pair, "cont_list_index"]]
feats = [self.tracker.calc_feats_from_mask(mask) for mask in masks]
ndarray = self.tracker.calc_feat_ndarray(*feats)
self.tracker.low_thresh = 1 - thresh
self.tracker.high_thresh = thresh
prob_mat = self.tracker.predict_proba_from_ndarray(
ndarray, *args, **kwargs
)
pred_mat = prob_mat > thresh
true_mat = self.get_truth_matrix_from_pair(pair)
if not len(true_mat) and not len(pred_mat):
return (0, 0, 0, 0)
true_flat = true_mat.flatten()
pred_flat = pred_mat.flatten()
true_pos = np.sum(true_flat & pred_flat)
false_pos = np.sum(true_flat & ~pred_flat)
false_neg = np.sum(~true_flat & pred_flat)
true_neg = np.sum(~true_flat & ~pred_flat)
return (true_pos, false_pos, false_neg, true_neg)
def extract_pairs_from_trap(self, trap_loc):
subdf = self.meta[["list_index", "cellLabels"]].loc(axis=0)[trap_loc]
pairs = [
trap_loc + tuple((pair,))
for pair in zip(subdf.index[:-1], subdf.index[1:])
]
return pairs
def gen_pairlist(self):
self.pairs = [
self.extract_pairs_from_trap(trap) for trap in self.traps_loc
]
def gen_cm_from_pairs(self, thresh=0.5, *args, **kwargs):
con_mat = {}
con_mat["tp"] = 0
con_mat["fp"] = 0
con_mat["fn"] = 0
con_mat["tn"] = 0
for pairset in self.pairs:
for pair in pairset:
res = self.gen_cm_stats(pair, thresh=thresh, *args, **kwargs)
con_mat["tp"] += res[0]
con_mat["fp"] += res[1]
con_mat["fn"] += res[2]
con_mat["tn"] += res[3]
self._con_mat = con_mat
return self._con_mat
[docs] def get_frac_error_df(self):
"""
Calculates the trap-wise error and averages across a position.
"""
self.frac_errs, self.all_errs, self.nerrs = self.calculate_errsum()
# nerrs_df = pd.DataFrame(self.nerrs).melt()
frac_df = pd.DataFrame(self.frac_errs).melt()
return frac_df
[docs] def gen_errorplots(self):
"""
Calculates the trap-wise error and averages across a position.
"""
frac_df = self.get_frac_error_df()
import seaborn as sns
from matplotlib import pyplot as plt
# ax = sns.barplot(x='variable_0', y='value', data=frac_df)
ax = sns.barplot(
x="variable_1", y="value", hue="variable_0", data=frac_df
)
ax.set(
xlabel="Backtrace depth",
ylabel="Fraction of correct assignments",
ylim=(0.9, 1),
)
plt.legend(title="Threshold")
plt.savefig("tracker_benchmark_btdepth.png")
plt.show()
# def plot_pair(self, address)
[docs]def gen_boolmat_from_clabs(clabs1, clabs2):
if not np.any(clabs1) and not np.any(clabs2):
return np.array([])
boolmat = np.zeros((len(clabs1), len(clabs2))).astype(bool)
for i, lab1 in enumerate(clabs1):
for j, lab2 in enumerate(clabs2):
if lab1 == lab2:
boolmat[i, j] = True
return boolmat
[docs]def gen_stats_dict(results):
"""
Generates a dictionary using results from different binary classification tasks,
for example, using different thresholds
output
dictionary containing the name of statistic as a key and a list
of that statistic for the data subsets.
"""
funs = (get_precision, get_recall, get_tnr, get_balanced_acc)
names = ("precision", "recall", "TNR", "balanced_acc")
stats_dict = {
name: [fun(res) for res in results] for fun, name in zip(funs, names)
}
return stats_dict
[docs]def get_precision(res_dict):
return (res_dict["tp"]) / (res_dict["tp"] + res_dict["fp"])
[docs]def get_recall(res_dict):
return res_dict["tp"] / (res_dict["tp"] + res_dict["fn"])
[docs]def get_tnr(res_dict):
return res_dict["tn"] / (res_dict["tn"] + res_dict["fp"])
[docs]def get_balanced_acc(res_dict):
return (get_recall(res_dict) + get_tnr(res_dict)) / 2