Class Model

  • All Implemented Interfaces:
    java.io.Serializable

    public class Model
    extends java.lang.Object
    implements java.io.Serializable
    Neural network model.
    See Also:
    Serialized Form
    • Field Detail

      • layers

        protected final java.util.ArrayList<Layer> layers
        Layers.
      • layerCount

        protected int layerCount
        Number of layers (including input).
      • loss

        protected Loss loss
        Loss function for model.
    • Constructor Detail

      • Model

        public Model()
        Create a new model.
    • Method Detail

      • readFromFile

        public static Model readFromFile​(java.lang.String filePath)
        Read a model from a file.
        Parameters:
        filePath - path to the file
        Returns:
        the model read from the file
      • addLayer

        public Model addLayer​(Layer layer)
        Add a layer to the model.
        Parameters:
        layer - the layer instance to add
        Returns:
        the model itself to allow for daisy chaining
      • buildModel

        public void buildModel​(Loss loss)
        Build the model. Run this after adding layers and before training.
        Parameters:
        loss - an instance of the loss function to use in the model
      • printSummary

        public void printSummary()
        Print a summary of the model.
      • fit

        public void fit​(Model.FitBuilder fb)
        Train the model on data.
        Parameters:
        fb - FitBuilder instance
      • trainOnBatch

        public void trainOnBatch​(Matrix x,
                                 Matrix y,
                                 double learningRate)
        Train on a single batch of input and output.
        Parameters:
        x - input data to train on
        y - expected outputs
        learningRate - learning rate of training
      • predict

        public Matrix predict​(Matrix x)
        Predict on a batch of input.
        Parameters:
        x - input data to feed forward
        Returns:
        prediction of model for input
      • comparePredictions

        public void comparePredictions​(Matrix input,
                                       Matrix output,
                                       int printNum)
        Print predictions and compare to correct.
        Parameters:
        input - input
        output - correct output
        printNum - number of items to print
      • evaluate

        public void evaluate​(Matrix x,
                             Matrix y,
                             java.util.ArrayList<Metric> metrics)
        Evaluate the performance of the model.
        Parameters:
        x - input data
        y - expected outputs
        metrics - metrics to display
      • forwardPropagate

        public Matrix forwardPropagate​(Matrix x)
        Forward propagate a batch of input.
        Parameters:
        x - input data to feed forward
        Returns:
        result of model for input
      • backPropagate

        public java.util.ArrayList<Matrix> backPropagate​(Matrix y)
        Backpropagate model after forward propagating input.
        Parameters:
        y - expected output for model
        Returns:
        errors for each layer from backpropagation
      • getLoss

        public double getLoss​(Loss loss,
                              Matrix out,
                              Matrix y)
        Get the loss.
        Parameters:
        loss - loss function to use
        out - actual output of model
        y - expected output
        Returns:
        loss of model
      • getLoss

        public double getLoss​(Loss loss,
                              java.util.ArrayList<Matrix> out,
                              java.util.ArrayList<Matrix> y)
        Get the loss.
        Parameters:
        loss - loss function to use
        out - actual output of model
        y - expected output
        Returns:
        loss of model
      • getMetric

        public double getMetric​(Metric metric,
                                Matrix out,
                                Matrix y)
        Get a metric of the model.
        Parameters:
        metric - metric function to use
        out - actual output of model
        y - expected output
        Returns:
        calculated metric
      • getMetric

        public double getMetric​(Metric metric,
                                java.util.ArrayList<Matrix> out,
                                java.util.ArrayList<Matrix> y)
        Get a metric of the model.
        Parameters:
        metric - metric function to use
        out - actual output of model
        y - expected output
        Returns:
        calculated metric
      • update

        protected void update​(java.util.ArrayList<Matrix> errors)
        Update the model after backpropagation.
        Parameters:
        errors - errors obtained from backpropagation
      • saveToFile

        public void saveToFile​(java.lang.String filePath)
        Save the current model to a file.

        Note: the model can continue to be trained and used to predict after saving.

        Parameters:
        filePath - path to file to save to
      • onEpochEnd

        protected void onEpochEnd​(int epoch)
        Method that is called at the end of every epoch.

        Subclasses of Model can easily override this method to do something once each epoch ends.

        Parameters:
        epoch - epoch number (0-indexed)