"""
Functions to process, filter and merge tracks.
"""
# from collections import Counter
import typing as t
from copy import copy
from typing import List, Union
import more_itertools as mit
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from utils_find_1st import cmp_larger, find_1st
from postprocessor.core.processes.savgol import non_uniform_savgol
[docs]def load_test_dset():
# Load development dataset to test functions
return pd.DataFrame(
{
("a", 1, 1): [2, 5, np.nan, 6, 8] + [np.nan] * 5,
("a", 1, 2): list(range(2, 12)),
("a", 1, 3): [np.nan] * 8 + [6, 7],
("a", 1, 4): [np.nan] * 5 + [9, 12, 10, 14, 18],
},
index=range(1, 11),
).T
[docs]def max_ntps(track: pd.Series) -> int:
# Get number of timepoints
indices = np.where(track.notna())
return np.max(indices) - np.min(indices)
[docs]def max_nonstop_ntps(track: pd.Series) -> int:
nona_tracks = track.notna()
consecutive_nonas_grouped = [
len(list(x))
for x in mit.consecutive_groups(np.flatnonzero(nona_tracks))
]
return max(consecutive_nonas_grouped)
[docs]def get_tracks_ntps(tracks: pd.DataFrame) -> pd.Series:
return tracks.apply(max_ntps, axis=1)
[docs]def get_avg_gr(track: pd.Series) -> int:
"""
Get average growth rate for a track.
:param tracks: Series with volume and timepoints as indices
"""
ntps = max_ntps(track)
vals = track.dropna().values
gr = (vals[-1] - vals[0]) / ntps
return gr
[docs]def get_avg_grs(tracks: pd.DataFrame) -> pd.DataFrame:
"""
Get average growth rate for a group of tracks
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
"""
return tracks.apply(get_avg_gr, axis=1)
[docs]def clean_tracks(
tracks, min_len: int = 15, min_gr: float = 1.0
) -> pd.DataFrame:
"""
Clean small non-growing tracks and return the reduced dataframe
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_len: int number of timepoints cells must have not to be removed
:param min_gr: float Minimum mean growth rate to assume an outline is growing
"""
ntps = get_tracks_ntps(tracks)
grs = get_avg_grs(tracks)
growing_long_tracks = tracks.loc[(ntps >= min_len) & (grs > min_gr)]
return growing_long_tracks
[docs]def merge_tracks(
tracks, drop=False, **kwargs
) -> t.Tuple[pd.DataFrame, t.Collection]:
"""
Join tracks that are contiguous and within a volume threshold of each other
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param kwargs: args passed to get_joinable
returns
:joint_tracks: (m x n) Dataframe where rows are cell tracks and
columns are timepoints. Merged tracks are still present but filled
with np.nans.
"""
# calculate tracks that can be merged until no more traps can be merged
joinable_pairs = get_joinable(tracks, **kwargs)
if joinable_pairs:
tracks = join_tracks(tracks, joinable_pairs, drop=drop)
return (tracks, joinable_pairs)
[docs]def get_joint_ids(merging_seqs) -> dict:
"""
Convert a series of merges into a dictionary where
the key is the cell_id of destination and the value a list
of the other track ids that were merged into the key
:param merging_seqs: list of tuples of indices indicating the
sequence of merging events. It is important for this to be in sequential order
How it works:
The order of merging matters for naming, always the leftmost track will keep the id
For example, having tracks (a, b, c, d) and the iterations of merge events:
0 a b c d
1 a b cd
2 ab cd
3 abcd
We shold get:
output {a:a, b:a, c:a, d:a}
"""
if not merging_seqs:
return {}
targets, origins = list(zip(*merging_seqs))
static_tracks = set(targets).difference(origins)
joint = {track_id: track_id for track_id in static_tracks}
for target, origin in merging_seqs:
joint[origin] = target
moved_target = [
k for k, v in joint.items() if joint[v] != v and v in joint.values()
]
for orig in moved_target:
joint[orig] = rec_bottom(joint, orig)
return {
k: v for k, v in joint.items() if k != v
} # remove ids that point to themselves
[docs]def rec_bottom(d, k):
if d[k] == k:
return k
else:
return rec_bottom(d, d[k])
[docs]def join_tracks(tracks, joinable_pairs, drop=True) -> pd.DataFrame:
"""
Join pairs of tracks from later tps towards the start.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
returns (copy)
:param joint_tracks: (m x n) Dataframe where rows are cell tracks and
columns are timepoints. Merged tracks are still present but filled
with np.nans.
:param drop: bool indicating whether or not to drop moved rows
"""
tmp = copy(tracks)
for target, source in joinable_pairs:
tmp.loc[target] = join_track_pair(tmp.loc[target], tmp.loc[source])
if drop:
tmp = tmp.drop(source)
return tmp
[docs]def join_track_pair(target, source):
tgt_copy = copy(target)
end = find_1st(target.values[::-1], 0, cmp_larger)
tgt_copy.iloc[-end:] = source.iloc[-end:].values
return tgt_copy
[docs]def get_joinable(tracks, smooth=False, tol=0.1, window=5, degree=3) -> dict:
"""
Get the pair of track (without repeats) that have a smaller error than the
tolerance. If there is a track that can be assigned to two or more other
ones, it chooses the one with a lowest error.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param tol: float or int threshold of average (prediction error/std) necessary
to consider two tracks the same. If float is fraction of first track,
if int it is absolute units.
:param window: int value of window used for savgol_filter
:param degree: int value of polynomial degree passed to savgol_filter
"""
tracks = tracks.loc[tracks.notna().sum(axis=1) > 2]
# Commented because we are not smoothing in this step yet
# candict = {k:v for d in contig.values for k,v in d.items()}
# smooth all relevant tracks
if smooth: # Apply savgol filter TODO fix nans affecting edge placing
clean = clean_tracks(
tracks, min_len=window + 1, min_gr=0.9
) # get useful tracks
def savgol_on_srs(x):
return non_uniform_savgol(x.index, x.values, window, degree)
contig = clean.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = clean.loc[flat].apply(savgol_on_srs, 1)
else:
contig = tracks.groupby(["trap"]).apply(get_contiguous_pairs)
contig = contig.loc[contig.apply(len) > 0]
flat = set([k for v in contig.values for i in v for j in i for k in j])
smoothed_tracks = tracks.loc[flat].apply(
lambda x: np.array(x.values), axis=1
)
# fetch edges from ids TODO (IF necessary, here we can compare growth rates)
def idx_to_edge(preposts):
return [
(
[get_val(smoothed_tracks.loc[pre], -1) for pre in pres],
[get_val(smoothed_tracks.loc[post], 0) for post in posts],
)
for pres, posts in preposts
]
# idx_to_means = lambda preposts: [
# (
# [get_means(smoothed_tracks.loc[pre], -window) for pre in pres],
# [get_means(smoothed_tracks.loc[post], window) for post in posts],
# )
# for pres, posts in preposts
# ]
def idx_to_pred(preposts):
result = []
for pres, posts in preposts:
pre_res = []
for pre in pres:
y = get_last_i(smoothed_tracks.loc[pre], -window)
pre_res.append(
np.poly1d(np.polyfit(range(len(y)), y, 1))(len(y) + 1),
)
pos_res = [
get_means(smoothed_tracks.loc[post], window) for post in posts
]
result.append([pre_res, pos_res])
return result
edges = contig.apply(idx_to_edge) # Raw edges
# edges_mean = contig.apply(idx_to_means) # Mean of both
pre_pred = contig.apply(idx_to_pred) # Prediction of pre and mean of post
# edges_dMetric = edges.apply(get_dMetric_wrap, tol=tol)
# edges_dMetric_mean = edges_mean.apply(get_dMetric_wrap, tol=tol)
edges_dMetric_pred = pre_pred.apply(get_dMetric_wrap, tol=tol)
# combined_dMetric = pd.Series(
# [
# [np.nanmin((a, b), axis=0) for a, b in zip(x, y)]
# for x, y in zip(edges_dMetric, edges_dMetric_mean)
# ],
# index=edges_dMetric.index,
# )
# closest_pairs = combined_dMetric.apply(get_vec_closest_pairs, tol=tol)
solutions = []
# for (i, dMetrics), edgeset in zip(combined_dMetric.items(), edges):
for (i, dMetrics), edgeset in zip(edges_dMetric_pred.items(), edges):
solutions.append(solve_matrices_wrap(dMetrics, edgeset, tol=tol))
closest_pairs = pd.Series(
solutions,
index=edges_dMetric_pred.index,
)
# match local with global ids
joinable_ids = [
localid_to_idx(closest_pairs.loc[i], contig.loc[i])
for i in closest_pairs.index
]
return [pair for pairset in joinable_ids for pair in pairset]
[docs]def get_val(x, n):
return x[~np.isnan(x)][n] if len(x[~np.isnan(x)]) else np.nan
[docs]def get_means(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return np.nanmean(v)
[docs]def get_last_i(x, i):
if not len(x[~np.isnan(x)]):
return np.nan
if i > 0:
v = x[~np.isnan(x)][:i]
else:
v = x[~np.isnan(x)][i:]
return v
[docs]def localid_to_idx(local_ids, contig_trap):
"""Fetch then original ids from a nested list with joinable local_ids
input
:param local_ids: list of list of pairs with cell ids to be joint
:param local_ids: list of list of pairs with corresponding cell ids
return
list of pairs with (experiment-level) ids to be joint
"""
lin_pairs = []
for i, pairs in enumerate(local_ids):
if len(pairs):
for left, right in pairs:
lin_pairs.append(
(contig_trap[i][0][left], contig_trap[i][1][right])
)
return lin_pairs
[docs]def get_vec_closest_pairs(lst: List, **kwargs):
return [get_closest_pairs(*sublist, **kwargs) for sublist in lst]
[docs]def get_dMetric_wrap(lst: List, **kwargs):
return [get_dMetric(*sublist, **kwargs) for sublist in lst]
[docs]def solve_matrices_wrap(dMetric: List, edges: List, **kwargs):
return [
solve_matrices(mat, edgeset, **kwargs)
for mat, edgeset in zip(dMetric, edges)
]
[docs]def get_dMetric(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
"""
if len(pre) > len(post):
dMetric = np.abs(np.subtract.outer(post, pre))
else:
dMetric = np.abs(np.subtract.outer(pre, post))
dMetric[np.isnan(dMetric)] = (
tol + 1 + np.nanmax(dMetric)
) # nans will be filtered
return dMetric
[docs]def solve_matrices(
dMetric: np.ndarray, prepost: List, tol: Union[float, int] = 1
):
"""
Solve the distance matrices obtained in get_dMetric and/or merged from
independent dMetric matrices
"""
ids = solve_matrix(dMetric)
if not len(ids[0]):
return []
pre, post = prepost
norm = (
np.array(pre)[ids[len(pre) > len(post)]] if tol < 1 else 1
) # relative or absolute tol
result = dMetric[ids] / norm
ids = ids if len(pre) < len(post) else ids[::-1]
return [idx for idx, res in zip(zip(*ids), result) if res <= tol]
[docs]def get_closest_pairs(
pre: List[float], post: List[float], tol: Union[float, int] = 1
):
"""Calculate a cost matrix the Hungarian algorithm to pick the best set of
options
input
:param pre: list of floats with edges on left
:param post: list of floats with edges on right
:param tol: int or float if int metrics of tolerance, if float fraction
returns
:: list of indices corresponding to the best solutions for matrices
"""
dMetric = get_dMetric(pre, post, tol)
return solve_matrices(dMetric, pre, post, tol)
[docs]def solve_matrix(dMetric):
"""
Solve cost matrix focusing on getting the smallest cost at each iteration.
input
:param dMetric: np.array cost matrix
returns
tuple of np.arrays indicating picks with lowest individual value
"""
glob_is = []
glob_js = []
if (~np.isnan(dMetric)).any():
tmp = copy(dMetric)
std = sorted(tmp[~np.isnan(tmp)])
while (~np.isnan(std)).any():
v = std[0]
i_s, j_s = np.where(tmp == v)
i = i_s[0]
j = j_s[0]
tmp[i, :] += np.nan
tmp[:, j] += np.nan
glob_is.append(i)
glob_js.append(j)
std = sorted(tmp[~np.isnan(tmp)])
return (np.array(glob_is), np.array(glob_js))
[docs]def plot_joinable(tracks, joinable_pairs):
"""
Convenience plotting function for debugging and data vis
"""
nx = 8
ny = 8
_, axes = plt.subplots(nx, ny)
for i in range(nx):
for j in range(ny):
if i * ny + j < len(joinable_pairs):
ax = axes[i, j]
pre, post = joinable_pairs[i * ny + j]
pre_srs = tracks.loc[pre].dropna()
post_srs = tracks.loc[post].dropna()
ax.plot(pre_srs.index, pre_srs.values, "b")
# try:
# totrange = np.arange(pre_srs.index[0],post_srs.index[-1])
# ax.plot(totrange, interpolate(pre_srs, totrange), 'r-')
# except:
# pass
ax.plot(post_srs.index, post_srs.values, "g")
plt.show()
[docs]def get_contiguous_pairs(tracks: pd.DataFrame) -> list:
"""
Get all pair of contiguous track ids from a tracks dataframe.
:param tracks: (m x n) dataframe where rows are cell tracks and
columns are timepoints
:param min_dgr: float minimum difference in growth rate from
the interpolation
"""
mins, maxes = [
tracks.notna().apply(np.where, axis=1).apply(fn)
for fn in (np.min, np.max)
]
mins_d = mins.groupby(mins).apply(lambda x: x.index.tolist())
mins_d.index = mins_d.index - 1 # make indices equal
# TODO add support for skipping time points
maxes_d = maxes.groupby(maxes).apply(lambda x: x.index.tolist())
common = sorted(
set(mins_d.index).intersection(maxes_d.index), reverse=True
)
return [(maxes_d[t], mins_d[t]) for t in common]
# def fit_track(track: pd.Series, obj=None):
# if obj is None:
# obj = objective
# x = track.dropna().index
# y = track.dropna().values
# popt, _ = curve_fit(obj, x, y)
# return popt
# def interpolate(track, xs) -> list:
# '''
# Interpolate next timepoint from a track
# :param track: pd.Series of volume growth over a time period
# :param t: int timepoint to interpolate
# '''
# popt = fit_track(track)
# # perr = np.sqrt(np.diag(pcov))
# return objective(np.array(xs), *popt)
# def objective(x,a,b,c,d) -> float:
# # return (a)/(1+b*np.exp(c*x))+d
# return (((x+d)*a)/((x+d)+b))+c
# def cand_pairs_to_dict(candidates):
# d={x:[] for x,_ in candidates}
# for x,y in candidates:
# d[x].append(y)
# return d