baby.training.cnn_trainer.CNNTrainer

class baby.training.cnn_trainer.CNNTrainer(save_dir: Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)

Bases: object

Methods for optimising the weights of the CNNs using gradient descent.

Attributes
cnn

The keras Model for the active CNN.

cnn_dir

The subdirectory into which to save the current CNN output.

cnn_fn

The current CNN architecture function.

cnn_name

The name of the currently active CNN architecture.

histories

A CNN name to training history dictionary.

opt_cnn

The keras model for the CNN with the lowest loss.

opt_dir

The directory in which the weights of the best CNN are saved.

Methods

fit([epochs, schedule, replace, extend])

Fit the active CNN to minimise loss on the (augmented) generator.

plot_histories([key, log, window, ax, save, ...])

Plot the loss/metric histories of all of the trained CNNs.

__init__(save_dir: Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)

Methods for optimising the weights of the CNNs using gradient descent.

Parameters
  • save_dir – base directory in which to save weights and outputs

  • cnn_set – the names of CNN architectures to be trained

  • gen – the data generator

  • aug – the data augmentor

  • flattener – the data flattener

  • max_cnns – the maximum number of CNNs to train/keep, default is 3

  • cnn_fn – the CNN architecture to start with, defaults to None

Methods

__init__(save_dir, cnn_set, gen, aug, flattener)

Methods for optimising the weights of the CNNs using gradient descent.

fit([epochs, schedule, replace, extend])

Fit the active CNN to minimise loss on the (augmented) generator.

plot_histories([key, log, window, ax, save, ...])

Plot the loss/metric histories of all of the trained CNNs.

Attributes

cnn

The keras Model for the active CNN.

cnn_dir

The subdirectory into which to save the current CNN output.

cnn_fn

The current CNN architecture function.

cnn_name

The name of the currently active CNN architecture.

histories

A CNN name to training history dictionary.

opt_cnn

The keras model for the CNN with the lowest loss.

opt_dir

The directory in which the weights of the best CNN are saved.

property cnn

The keras Model for the active CNN.

property cnn_dir

The subdirectory into which to save the current CNN output.

property cnn_fn

The current CNN architecture function.

property cnn_name

The name of the currently active CNN architecture.

fit(epochs: int = 400, schedule: Optional[List[Tuple[int, float]]] = None, replace: bool = False, extend: bool = False)

Fit the active CNN to minimise loss on the (augmented) generator.

Parameters
  • epochs – number of epochs to train, defaults to 400

  • schedule – learning rate schedule, defaults to fixed rate 0.001

  • replace – force training if CNN already has a weights file

  • extend – train using existing CNN weights as initial weights

property histories

A CNN name to training history dictionary.

property opt_cnn

The keras model for the CNN with the lowest loss.

property opt_dir

The directory in which the weights of the best CNN are saved.

plot_histories(key: str = 'loss', log: bool = True, window: int = 21, ax: Optional[axis] = None, save: bool = True, legend: bool = True)

Plot the loss/metric histories of all of the trained CNNs.

# TODO add an image as an example

Parameters
  • key – which metric to plot, defaults to loss

  • log – set y-axis to log scale, defaults to True

  • window – filter window for Savitsky-Golay filter applied to

metric before plotting, defaults to 21 which assumes at least 22 epochs of training, 50 suggested. :param ax: axis on which to plot the losses, defaults to None :param save: save the plot to save_dir / f”histories_{key}.png”, defaults to True :param legend: adds a legend to the plot, defaults to True :return: