class Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)

Bases: object

Methods for optimising the weights of the CNNs using gradient descent.


The keras Model for the active CNN.


The subdirectory into which to save the current CNN output.


The current CNN architecture function.


The name of the currently active CNN architecture.


A CNN name to training history dictionary.


The keras model for the CNN with the lowest loss.


The directory in which the weights of the best CNN are saved.


fit([epochs, schedule, replace, extend])

Fit the active CNN to minimise loss on the (augmented) generator.

plot_histories([key, log, window, ax, save, ...])

Plot the loss/metric histories of all of the trained CNNs.

__init__(save_dir: Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)

Methods for optimising the weights of the CNNs using gradient descent.

  • save_dir – base directory in which to save weights and outputs

  • cnn_set – the names of CNN architectures to be trained

  • gen – the data generator

  • aug – the data augmentor

  • flattener – the data flattener

  • max_cnns – the maximum number of CNNs to train/keep, default is 3

  • cnn_fn – the CNN architecture to start with, defaults to None


__init__(save_dir, cnn_set, gen, aug, flattener)

Methods for optimising the weights of the CNNs using gradient descent.

fit([epochs, schedule, replace, extend])

Fit the active CNN to minimise loss on the (augmented) generator.

plot_histories([key, log, window, ax, save, ...])

Plot the loss/metric histories of all of the trained CNNs.



The keras Model for the active CNN.


The subdirectory into which to save the current CNN output.


The current CNN architecture function.


The name of the currently active CNN architecture.


A CNN name to training history dictionary.


The keras model for the CNN with the lowest loss.


The directory in which the weights of the best CNN are saved.

property cnn

The keras Model for the active CNN.

property cnn_dir

The subdirectory into which to save the current CNN output.

property cnn_fn

The current CNN architecture function.

property cnn_name

The name of the currently active CNN architecture.

fit(epochs: int = 400, schedule: Optional[List[Tuple[int, float]]] = None, replace: bool = False, extend: bool = False)

Fit the active CNN to minimise loss on the (augmented) generator.

  • epochs – number of epochs to train, defaults to 400

  • schedule – learning rate schedule, defaults to fixed rate 0.001

  • replace – force training if CNN already has a weights file

  • extend – train using existing CNN weights as initial weights

property histories

A CNN name to training history dictionary.

property opt_cnn

The keras model for the CNN with the lowest loss.

property opt_dir

The directory in which the weights of the best CNN are saved.

plot_histories(key: str = 'loss', log: bool = True, window: int = 21, ax: Optional[axis] = None, save: bool = True, legend: bool = True)

Plot the loss/metric histories of all of the trained CNNs.

# TODO add an image as an example

  • key – which metric to plot, defaults to loss

  • log – set y-axis to log scale, defaults to True

  • window – filter window for Savitsky-Golay filter applied to

metric before plotting, defaults to 21 which assumes at least 22 epochs of training, 50 suggested. :param ax: axis on which to plot the losses, defaults to None :param save: save the plot to save_dir / f”histories_{key}.png”, defaults to True :param legend: adds a legend to the plot, defaults to True :return: