baby.training.training.BabyTrainer

class baby.training.training.BabyTrainer(save_dir, base_dir=None, params=None, max_cnns=3)

Bases: object

Manager to set up and train BABY models

Parameters
  • save_dir – directory in which to save parameters and logs (and from which to auto-load parameters and logs)

  • train_val_images – either a dict with keys ‘training’ and ‘validation’ and values specifying lists of file name pairs, or the name of a json file containing such a dict. The file name pairs should correspond to image-label pairs suitable for input to baby.generator.ImageLabel.

  • flattener – either a baby.preprocessing.SegmentationFlattening object, or the name of a json file that is a saved SegmentationFlattening object.

Attributes
aug
bud_trainer
cnn
cnn_dir
cnn_fn
cnn_name
cnn_opt
cnn_opt_dir
cnn_trainer
data
flattener
flattener_stats
flattener_trainer
gen
histories
hyperparameter_trainer
in_memory
parameters
seg_examples
seg_param_stats
seg_params
smoothing_sigma_model
smoothing_sigma_stats
smoothing_sigma_trainer
track_trainer
tracker_data

Methods

fit_bud_model(**kwargs)

fit_seg_params([njobs, scoring])

generate_bud_stats()

refit_filter_seg_params([lazy, bootstrap, ...])

validate_seg_params([iou_thresh, save])

fit_cnn

fit_flattener

fit_smoothing_model

generate_flattener_stats

generate_smoothing_sigma_stats

plot_fitted_smoothing_sigma_model

plot_flattener_stats

plot_gen_sample

plot_histories

__init__(save_dir, base_dir=None, params=None, max_cnns=3)

Methods

__init__(save_dir[, base_dir, params, max_cnns])

fit_bud_model(**kwargs)

fit_cnn(**kwargs)

fit_flattener(**kwargs)

fit_seg_params([njobs, scoring])

fit_smoothing_model([filt])

generate_bud_stats()

generate_flattener_stats([max_erode])

generate_smoothing_sigma_stats()

plot_fitted_smoothing_sigma_model()

plot_flattener_stats(**kwargs)

plot_gen_sample([validation])

plot_histories(**kwargs)

refit_filter_seg_params([lazy, bootstrap, ...])

validate_seg_params([iou_thresh, save])

Attributes

aug

bud_trainer

cnn

cnn_dir

cnn_fn

cnn_name

cnn_opt

cnn_opt_dir

cnn_trainer

data

flattener

flattener_stats

flattener_trainer

gen

histories

hyperparameter_trainer

in_memory

parameters

seg_examples

seg_param_stats

seg_params

smoothing_sigma_model

smoothing_sigma_stats

smoothing_sigma_trainer

track_trainer

tracker_data