Source code for agora.io.bridge

"""
Tools to interact with h5 files and handle data consistently.
"""
import collections
import logging
import typing as t
from itertools import chain, groupby, product
from typing import Union

import h5py
import numpy as np
import yaml


[docs]class BridgeH5: """ Base class to interact with h5 files. It includes functions that predict how long segmentation will take. """
[docs] def __init__(self, filename, flag="r"): """Initialise with the name of the h5 file.""" self.filename = filename if flag is not None: self._hdf = h5py.File(filename, flag) self._filecheck
def _log(self, message: str, level: str = "warn"): # Log messages in the corresponding level logger = logging.getLogger("aliby") getattr(logger, level)(f"{self.__class__.__name__}: {message}") def _filecheck(self): assert "cell_info" in self._hdf, "Invalid file. No 'cell_info' found."
[docs] def close(self): """Close the h5 file.""" self._hdf.close()
@property def meta_h5(self) -> t.Dict[str, t.Any]: """Return metadata, defining it if necessary.""" if not hasattr(self, "_meta_h5"): with h5py.File(self.filename, "r") as f: self._meta_h5 = dict(f.attrs) return self._meta_h5 @property def cell_tree(self): return self.get_info_tree()
[docs] @staticmethod def get_consecutives(tree, nstepsback): """Receives a sorted tree and returns the keys of consecutive elements.""" # get tp level vals = {k: np.array(list(v)) for k, v in tree.items()} # get indices of consecutive elements where_consec = [ { k: np.where(np.subtract(v[n + 1 :], v[: -n - 1]) == n + 1)[0] for k, v in vals.items() } for n in range(nstepsback) ] return where_consec
def get_npairs(self, nstepsback=2, tree=None): if tree is None: tree = self.cell_tree consecutive = self.get_consecutives(tree, nstepsback=nstepsback) flat_tree = flatten(tree) n_predictions = 0 for i, d in enumerate(consecutive, 1): flat = list(chain(*[product([k], list(v)) for k, v in d.items()])) pairs = [(f, (f[0], f[1] + i)) for f in flat] for p in pairs: n_predictions += len(flat_tree.get(p[0], [])) * len( flat_tree.get(p[1], []) ) return n_predictions def get_npairs_over_time(self, nstepsback=2): tree = self.cell_tree npairs = [] for tp in self._hdf["cell_info"]["processed_timepoints"][()]: tmp_tree = { k: {k2: v2 for k2, v2 in v.items() if k2 <= tp} for k, v in tree.items() } npairs.append(self.get_npairs(tree=tmp_tree)) return np.diff(npairs)
[docs] def get_info_tree( self, fields: Union[tuple, list] = ("trap", "timepoint", "cell_label") ): """ Return traps, time points and labels for this position in the form of a tree in the hierarchy determined by the argument fields. Note that it is compressed to non-empty elements and timepoints. Default hierarchy is: - trap - time point - cell label This function currently produces trees of depth 3, but it can easily be extended for deeper trees if needed (e.g. considering groups, chambers and/or positions). Parameters ---------- fields: list of strs Fields to fetch from 'cell_info' inside the h5 file. Returns ---------- Nested dictionary where keys (or branches) are the upper levels and the leaves are the last element of :fields:. """ zipped_info = (*zip(*[self._hdf["cell_info"][f][()] for f in fields]),) return recursive_groupsort(zipped_info)
[docs]def groupsort(iterable: Union[tuple, list]): """Sorts iterable and returns a dictionary where the values are grouped by the first element.""" iterable = sorted(iterable, key=lambda x: x[0]) grouped = { k: [x[1:] for x in v] for k, v in groupby(iterable, lambda x: x[0]) } return grouped
[docs]def recursive_groupsort(iterable): """Recursive extension of groupsort.""" if len(iterable[0]) > 1: return { k: recursive_groupsort(v) for k, v in groupsort(iterable).items() } else: # only two elements in list return [x[0] for x in iterable]
[docs]def flatten(d, parent_key="", sep="_"): """Flatten nested dict. Adapted from https://stackoverflow.com/a/6027615.""" items = [] for k, v in d.items(): new_key = parent_key + (k,) if parent_key else (k,) if isinstance(v, collections.MutableMapping): items.extend(flatten(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items)
[docs]def attrs_from_h5(fpath: str): """Return attributes as dict from an h5 file.""" with h5py.File(fpath, "r") as f: return dict(f.attrs)
[docs]def image_creds_from_h5(fpath: str): """Return image id and server credentials from an h5.""" attrs = attrs_from_h5(fpath) return ( attrs["image_id"], yaml.safe_load(attrs["parameters"])["general"]["server_info"], )