baby.training.cnn_trainer.CNNTrainer¶
- class baby.training.cnn_trainer.CNNTrainer(save_dir: Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)¶
Bases:
object
Methods for optimising the weights of the CNNs using gradient descent.
- Attributes
cnn
The keras Model for the active CNN.
cnn_dir
The subdirectory into which to save the current CNN output.
cnn_fn
The current CNN architecture function.
cnn_name
The name of the currently active CNN architecture.
histories
A CNN name to training history dictionary.
opt_cnn
The keras model for the CNN with the lowest loss.
opt_dir
The directory in which the weights of the best CNN are saved.
Methods
fit
([epochs, schedule, replace, extend])Fit the active CNN to minimise loss on the (augmented) generator.
plot_histories
([key, log, window, ax, save, ...])Plot the loss/metric histories of all of the trained CNNs.
- __init__(save_dir: Path, cnn_set: tuple, gen: ImageLabel, aug: Augmenter, flattener: SegmentationFlattening, max_cnns: int = 3, cnn_fn: Optional[str] = None)¶
Methods for optimising the weights of the CNNs using gradient descent.
- Parameters
save_dir – base directory in which to save weights and outputs
cnn_set – the names of CNN architectures to be trained
gen – the data generator
aug – the data augmentor
flattener – the data flattener
max_cnns – the maximum number of CNNs to train/keep, default is 3
cnn_fn – the CNN architecture to start with, defaults to None
Methods
__init__
(save_dir, cnn_set, gen, aug, flattener)Methods for optimising the weights of the CNNs using gradient descent.
fit
([epochs, schedule, replace, extend])Fit the active CNN to minimise loss on the (augmented) generator.
plot_histories
([key, log, window, ax, save, ...])Plot the loss/metric histories of all of the trained CNNs.
Attributes
The keras Model for the active CNN.
The subdirectory into which to save the current CNN output.
The current CNN architecture function.
The name of the currently active CNN architecture.
A CNN name to training history dictionary.
The keras model for the CNN with the lowest loss.
The directory in which the weights of the best CNN are saved.
- property cnn¶
The keras Model for the active CNN.
- property cnn_dir¶
The subdirectory into which to save the current CNN output.
- property cnn_fn¶
The current CNN architecture function.
- property cnn_name¶
The name of the currently active CNN architecture.
- fit(epochs: int = 400, schedule: Optional[List[Tuple[int, float]]] = None, replace: bool = False, extend: bool = False)¶
Fit the active CNN to minimise loss on the (augmented) generator.
- Parameters
epochs – number of epochs to train, defaults to 400
schedule – learning rate schedule, defaults to fixed rate 0.001
replace – force training if CNN already has a weights file
extend – train using existing CNN weights as initial weights
- property histories¶
A CNN name to training history dictionary.
- property opt_cnn¶
The keras model for the CNN with the lowest loss.
- property opt_dir¶
The directory in which the weights of the best CNN are saved.
- plot_histories(key: str = 'loss', log: bool = True, window: int = 21, ax: Optional[axis] = None, save: bool = True, legend: bool = True)¶
Plot the loss/metric histories of all of the trained CNNs.
# TODO add an image as an example
- Parameters
key – which metric to plot, defaults to loss
log – set y-axis to log scale, defaults to True
window – filter window for Savitsky-Golay filter applied to
metric before plotting, defaults to 21 which assumes at least 22 epochs of training, 50 suggested. :param ax: axis on which to plot the losses, defaults to None :param save: save the plot to save_dir / f”histories_{key}.png”, defaults to True :param legend: adds a legend to the plot, defaults to True :return: