"""
Pipeline and chaining elements.
"""
import logging
import os
import re
import traceback
import typing as t
from copy import copy
from importlib.metadata import version
from pathlib import Path, PosixPath
import h5py
import numpy as np
import pandas as pd
from pathos.multiprocessing import Pool
from tqdm import tqdm
from agora.abc import ParametersABC, ProcessABC
from agora.io.metadata import MetaData, parse_logfiles
from agora.io.reader import StateReader
from agora.io.signal import Signal
from agora.io.writer import (
LinearBabyWriter,
StateWriter,
TilerWriter,
)
from aliby.baby_client import BabyParameters, BabyRunner
from aliby.haystack import initialise_tf
from aliby.io.dataset import dispatch_dataset
from aliby.io.image import get_image_class
from aliby.tile.tiler import Tiler, TilerParameters
from extraction.core.extractor import Extractor, ExtractorParameters
from extraction.core.functions.defaults import exparams_from_meta
from postprocessor.core.processor import PostProcessor, PostProcessorParameters
[docs]class PipelineParameters(ParametersABC):
"""
Parameters that host what is run and how. It takes a list of dictionaries, one for
general in collection:
pass dictionary for each step
--------------------
expt_id: int or str Experiment id (if integer) or local path (if string).
directory: str Directory into which results are dumped. Default is "../data"
Provides default parameters for the entire pipeline. This downloads the logfiles and sets the default
timepoints and extraction parameters from there.
"""
_pool_index = None
[docs] def __init__(
self, general, tiler, baby, extraction, postprocessing, reporting
):
self.general = general
self.tiler = tiler
self.baby = baby
self.extraction = extraction
self.postprocessing = postprocessing
self.reporting = reporting
@classmethod
def default(
cls,
general={},
tiler={},
baby={},
extraction={},
postprocessing={},
):
expt_id = general.get("expt_id", 19993)
if isinstance(expt_id, PosixPath):
expt_id = str(expt_id)
general["expt_id"] = expt_id
directory = Path(general.get("directory", "../data"))
with dispatch_dataset(
expt_id,
**{k: general.get(k) for k in ("host", "username", "password")},
) as conn:
directory = directory / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# Download logs to use for metadata
conn.cache_logs(directory)
try:
meta_d = MetaData(directory, None).load_logs()
except Exception as e:
logging.getLogger("aliby").warn(
f"WARNING:Metadata: error when loading: {e}"
)
minimal_default_meta = {
"channels": ["Brightfield"],
"ntps": [2000],
}
# Set minimal metadata
meta_d = minimal_default_meta
tps = meta_d.get("ntps", 2000)
defaults = {
"general": dict(
id=expt_id,
distributed=0,
tps=tps,
directory=str(directory.parent),
filter="",
earlystop=dict(
min_tp=100,
thresh_pos_clogged=0.4,
thresh_trap_ncells=8,
thresh_trap_area=0.9,
ntps_to_eval=5,
),
logfile_level="INFO",
use_explog=True,
)
}
for k, v in general.items(): # Overwrite general parameters
if k not in defaults["general"]:
defaults["general"][k] = v
elif isinstance(v, dict):
for k2, v2 in v.items():
defaults["general"][k][k2] = v2
else:
defaults["general"][k] = v
defaults["tiler"] = TilerParameters.default(**tiler).to_dict()
defaults["baby"] = BabyParameters.default(**baby).to_dict()
defaults["extraction"] = (
exparams_from_meta(meta_d)
or BabyParameters.default(**extraction).to_dict()
)
defaults["postprocessing"] = {}
defaults["reporting"] = {}
defaults["postprocessing"] = PostProcessorParameters.default(
**postprocessing
).to_dict()
defaults["reporting"] = {}
return cls(**{k: v for k, v in defaults.items()})
def load_logs(self):
parsed_flattened = parse_logfiles(self.log_dir)
return parsed_flattened
[docs]class Pipeline(ProcessABC):
"""
A chained set of Pipeline elements connected through pipes.
Tiling, Segmentation,Extraction and Postprocessing should use their own default parameters.
These can be overriden passing the key:value of parameters to override to a PipelineParameters class
"""
iterative_steps = ["tiler", "baby", "extraction"]
step_sequence = [
"tiler",
"baby",
"extraction",
"postprocessing",
]
# Indicate step-writer groupings to perform special operations during step iteration
writer_groups = {
"tiler": ["trap_info"],
"baby": ["cell_info"],
"extraction": ["extraction"],
"postprocessing": ["postprocessing", "modifiers"],
}
writers = { # TODO integrate Extractor and PostProcessing in here
"tiler": [("tiler", TilerWriter)],
"baby": [("baby", LinearBabyWriter), ("state", StateWriter)],
}
[docs] def __init__(self, parameters: PipelineParameters, store=None):
super().__init__(parameters)
if store is not None:
store = Path(store)
self.store = store
@staticmethod
def setLogger(
folder, file_level: str = "INFO", stream_level: str = "WARNING"
):
logger = logging.getLogger("aliby")
logger.setLevel(getattr(logging, file_level))
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s:%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S%z",
)
ch = logging.StreamHandler()
ch.setLevel(getattr(logging, stream_level))
ch.setFormatter(formatter)
logger.addHandler(ch)
# create file handler which logs even debug messages
fh = logging.FileHandler(Path(folder) / "aliby.log", "w+")
fh.setLevel(getattr(logging, file_level))
fh.setFormatter(formatter)
logger.addHandler(fh)
@classmethod
def from_yaml(cls, fpath):
# This is just a convenience function, think before implementing
# for other processes
return cls(parameters=PipelineParameters.from_yaml(fpath))
[docs] @classmethod
def from_folder(cls, dir_path):
"""
Constructor to re-process all files in a given folder.
Assumes all files share the same parameters (even if they don't share
the same channel set).
Parameters
---------
dir_path : str or Pathlib indicating the folder containing the files to process
"""
dir_path = Path(dir_path)
files = list(dir_path.rglob("*.h5"))
assert len(files), "No valid files found in folder"
fpath = files[0]
# TODO add support for non-standard unique folder names
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
pipeline_parameters.general["directory"] = dir_path.parent
pipeline_parameters.general["filter"] = [fpath.stem for fpath in files]
# Fix legacy postprocessing parameters
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters)
[docs] @classmethod
def from_existing_h5(cls, fpath):
"""
Constructor to process an existing hdf5 file.
Notice that it forces a single file, not suitable for multiprocessing of certain positions.
It i s also used as a base for a folder-wide reprocessing.
"""
with h5py.File(fpath, "r") as f:
pipeline_parameters = PipelineParameters.from_yaml(
f.attrs["parameters"]
)
directory = Path(fpath).parent
pipeline_parameters.general["directory"] = directory
pipeline_parameters.general["filter"] = Path(fpath).stem
post_process_params = pipeline_parameters.postprocessing.get(
"parameters", None
)
if post_process_params:
pipeline_parameters.postprocessing["param_sets"] = copy(
post_process_params
)
del pipeline_parameters.postprocessing["parameters"]
return cls(pipeline_parameters, store=directory)
@property
def _logger(self):
return logging.getLogger("aliby")
[docs] def run(self):
"""
Config holds the general information, use in main
Steps: all holds general tasks
steps: strain_name holds task for a given strain
"""
config = self.parameters.to_dict()
expt_id = config["general"]["id"]
distributed = config["general"]["distributed"]
pos_filter = config["general"]["filter"]
root_dir = Path(config["general"]["directory"])
self.server_info = {
k: config["general"].get(k)
for k in ("host", "username", "password")
}
dispatcher = dispatch_dataset(expt_id, **self.server_info)
logging.getLogger("aliby").info(
f"Fetching data using {dispatcher.__class__.__name__}"
)
# Do all all initialisations
with dispatcher as conn:
image_ids = conn.get_images()
directory = self.store or root_dir / conn.unique_name
if not directory.exists():
directory.mkdir(parents=True)
# Download logs to use for metadata
conn.cache_logs(directory)
# Modify to the configuration
self.parameters.general["directory"] = str(directory)
config["general"]["directory"] = directory
self.setLogger(directory)
# Filter TODO integrate filter onto class and add regex
def filt_int(d: dict, filt: int):
return {k: v for i, (k, v) in enumerate(d.items()) if i == filt}
def filt_str(image_ids: dict, filt: str):
return {k: v for k, v in image_ids.items() if re.search(filt, k)}
def pick_filter(image_ids: dict, filt: int or str):
if isinstance(filt, str):
image_ids = filt_str(image_ids, filt)
elif isinstance(filt, int):
image_ids = filt_int(image_ids, filt)
return image_ids
if isinstance(pos_filter, list):
image_ids = {
k: v
for filt in pos_filter
for k, v in pick_filter(image_ids, filt).items()
}
else:
image_ids = pick_filter(image_ids, pos_filter)
assert len(image_ids), "No images to segment"
if distributed != 0: # Gives the number of simultaneous processes
with Pool(distributed) as p:
results = p.map(
lambda x: self.create_pipeline(*x),
[(k, i) for i, k in enumerate(image_ids.items())],
# num_cpus=distributed,
# position=0,
)
else: # Sequential
results = []
for k, v in tqdm(image_ids.items()):
r = self.create_pipeline((k, v), 1)
results.append(r)
return results
def create_pipeline(
self,
image_id: t.Tuple[str, str or PosixPath or int],
index: t.Optional[int] = None,
):
""" """
self._pool_index = index
name, image_id = image_id
session = None
filename = None
run_kwargs = {"extraction": {"labels": None, "masks": None}}
try:
(
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
) = self._setup_pipeline(image_id)
loaded_writers = {
name: writer(filename)
for k in self.step_sequence
if k in self.writers
for name, writer in self.writers[k]
}
writer_ow_kwargs = {
"state": loaded_writers["state"].datatypes.keys(),
"baby": ["mother_assign"],
}
# START PIPELINE
frac_clogged_traps = 0
min_process_from = min(process_from.values())
with get_image_class(image_id)(
image_id, **self.server_info
) as image:
# Initialise Steps
if "tiler" not in steps:
steps["tiler"] = Tiler.from_image(
image, TilerParameters.from_dict(config["tiler"])
)
if process_from["baby"] < tps:
session = initialise_tf(2)
steps["baby"] = BabyRunner.from_tiler(
BabyParameters.from_dict(config["baby"]),
steps["tiler"],
)
if trackers_state:
steps["baby"].crawler.tracker_states = trackers_state
# Limit extraction parameters during run using the available channels in tiler
if process_from["extraction"] < tps:
# TODO Move this parameter validation into Extractor
av_channels = set((*steps["tiler"].channels, "general"))
config["extraction"]["tree"] = {
k: v
for k, v in config["extraction"]["tree"].items()
if k in av_channels
}
config["extraction"]["sub_bg"] = av_channels.intersection(
config["extraction"]["sub_bg"]
)
av_channels_wsub = av_channels.union(
[c + "_bgsub" for c in config["extraction"]["sub_bg"]]
)
tmp = copy(config["extraction"]["multichannel_ops"])
for op, (input_ch, _, _) in tmp.items():
if not set(input_ch).issubset(av_channels_wsub):
del config["extraction"]["multichannel_ops"][op]
exparams = ExtractorParameters.from_dict(
config["extraction"]
)
steps["extraction"] = Extractor.from_tiler(
exparams, store=filename, tiler=steps["tiler"]
)
pbar = tqdm(
range(min_process_from, tps),
desc=image.name,
initial=min_process_from,
total=tps,
# position=index + 1,
)
for i in pbar:
if (
frac_clogged_traps
< earlystop["thresh_pos_clogged"]
or i < earlystop["min_tp"]
):
for step in self.iterative_steps:
if i >= process_from[step]:
result = steps[step].run_tp(
i, **run_kwargs.get(step, {})
)
if step in loaded_writers:
loaded_writers[step].write(
data=result,
overwrite=writer_ow_kwargs.get(
step, []
),
tp=i,
meta={"last_processed": i},
)
# Step-specific actions
if (
step == "tiler"
and i == min_process_from
):
logging.getLogger("aliby").info(
f"Found {steps['tiler'].n_traps} traps in {image.name}"
)
elif (
step == "baby"
): # Write state and pass info to ext
loaded_writers["state"].write(
data=steps[
step
].crawler.tracker_states,
overwrite=loaded_writers[
"state"
].datatypes.keys(),
tp=i,
)
elif (
step == "extraction"
): # Remove mask/label after ext
for k in ["masks", "labels"]:
run_kwargs[step][k] = None
frac_clogged_traps = self.check_earlystop(
filename, earlystop, steps["tiler"].tile_size
)
self._log(
f"{name}:Clogged_traps:{frac_clogged_traps}"
)
frac = np.round(frac_clogged_traps * 100)
pbar.set_postfix_str(f"{frac} Clogged")
else: # Stop if more than X% traps are clogged
self._log(
f"{name}:Analysis stopped early at time {i} with {frac_clogged_traps} clogged traps"
)
meta.add_fields({"end_status": "Clogged"})
break
meta.add_fields({"last_processed": i})
# Run post-processing
meta.add_fields({"end_status": "Success"})
post_proc_params = PostProcessorParameters.from_dict(
config["postprocessing"]
)
PostProcessor(filename, post_proc_params).run()
self._log("Analysis finished successfully.", "info")
return 1
except Exception as e: # Catch bugs during setup or runtime
logging.exception(
f"{name}: Exception caught.",
exc_info=True,
)
# This prints the type, value, and stack trace of the
# current exception being handled.
traceback.print_exc()
raise e
finally:
_close_session(session)
@staticmethod
def check_earlystop(filename: str, es_parameters: dict, tile_size: int):
s = Signal(filename)
df = s["/extraction/general/None/area"]
cells_used = df[
df.columns[-1 - es_parameters["ntps_to_eval"] : -1]
].dropna(how="all")
traps_above_nthresh = (
cells_used.groupby("trap").count().apply(np.mean, axis=1)
> es_parameters["thresh_trap_ncells"]
)
traps_above_athresh = (
cells_used.groupby("trap").sum().apply(np.mean, axis=1)
/ tile_size**2
> es_parameters["thresh_trap_area"]
)
return (traps_above_nthresh & traps_above_athresh).mean()
def _load_config_from_file(
self,
filename: PosixPath,
process_from: t.Dict[str, int],
trackers_state: t.List,
overwrite: t.Dict[str, bool],
):
with h5py.File(filename, "r") as f:
for k in process_from.keys():
if not overwrite[k]:
process_from[k] = self.legacy_get_last_tp[k](f)
process_from[k] += 1
return process_from, trackers_state, overwrite
[docs] @staticmethod
def legacy_get_last_tp(step: str) -> t.Callable:
"""Get last time-point in different ways depending
on which step we are using
To support segmentation in aliby < v0.24
TODO Deprecate and replace with State method
"""
switch_case = {
"tiler": lambda f: f["trap_info/drifts"].shape[0] - 1,
"baby": lambda f: f["cell_info/timepoint"][-1],
"extraction": lambda f: f[
"extraction/general/None/area/timepoint"
][-1],
}
return switch_case[step]
def _setup_pipeline(
self, image_id: int
) -> t.Tuple[
PosixPath,
MetaData,
t.Dict,
int,
t.Dict,
t.Dict,
t.Optional[int],
t.List[np.ndarray],
]:
"""
Initialise pipeline components and if necessary use
exising file to continue existing experiments.
Parameters
----------
image_id : int
identifier of image in OMERO server, or filename
Returns
---------
filename: str
meta:
config:
process_from:
tps:
steps:
earlystop:
session:
trackers_state:
Examples
--------
FIXME: Add docs.
"""
config = self.parameters.to_dict()
pparams = config
image_id = image_id
general_config = config["general"]
session = None
earlystop = general_config.get("earlystop", None)
process_from = {k: 0 for k in self.iterative_steps}
steps = {}
ow = {k: 0 for k in self.step_sequence}
# check overwriting
ow_id = general_config.get("overwrite", 0)
ow = {step: True for step in self.step_sequence}
if ow_id and ow_id is not True:
ow = {
step: self.step_sequence.index(ow_id) < i
for i, step in enumerate(self.step_sequence, 1)
}
# Set up
directory = general_config["directory"]
trackers_state: t.List[np.ndarray] = []
with get_image_class(image_id)(image_id, **self.server_info) as image:
filename = Path(f"{directory}/{image.name}.h5")
meta = MetaData(directory, filename)
from_start = True if np.any(ow.values()) else False
# New experiment or overwriting
if (
from_start
and (
config.get("overwrite", False) == True
or np.all(list(ow.values()))
)
and filename.exists()
):
os.remove(filename)
# If no previous segmentation and keep tiler
if filename.exists():
self._log("Result file exists.", "info")
if not ow["tiler"]:
steps["tiler"] = Tiler.from_hdf5(image, filename)
try:
(
process_from,
trackers_state,
ow,
) = self._load_config_from_file(
filename, process_from, trackers_state, ow
)
# get state array
trackers_state = (
[]
if ow["baby"]
else StateReader(filename).get_formatted_states()
)
config["tiler"] = steps["tiler"].parameters.to_dict()
except Exception:
pass
if config["general"]["use_explog"]:
meta.run()
meta.add_fields( # Add non-logfile metadata
{
"aliby_version": version("aliby"),
"baby_version": version("aliby-baby"),
"omero_id": config["general"]["id"],
"image_id": image_id
if isinstance(image_id, int)
else str(image_id),
"parameters": PipelineParameters.from_dict(
pparams
).to_yaml(),
}
)
tps = min(general_config["tps"], image.data.shape[0])
return (
filename,
meta,
config,
process_from,
tps,
steps,
earlystop,
session,
trackers_state,
)
def _close_session(session):
if session:
session.close()