Package me.yixqiao.jlearn.models
Class Model
- java.lang.Object
-
- me.yixqiao.jlearn.models.Model
-
- All Implemented Interfaces:
java.io.Serializable
public class Model extends java.lang.Object implements java.io.SerializableNeural network model.- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classModel.FitBuilderBuilder class for fit operation.protected static classModel.FitPrintThread to print progress when fitting.
-
Field Summary
Fields Modifier and Type Field Description protected intlayerCountNumber of layers (including input).protected java.util.ArrayList<Layer>layersLayers.protected LosslossLoss function for model.
-
Constructor Summary
Constructors Constructor Description Model()Create a new model.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description ModeladdLayer(Layer layer)Add a layer to the model.java.util.ArrayList<Matrix>backPropagate(Matrix y)Backpropagate model after forward propagating input.voidbuildModel(Loss loss)Build the model.voidcomparePredictions(Matrix input, Matrix output, int printNum)Print predictions and compare to correct.voidevaluate(Matrix x, Matrix y, java.util.ArrayList<Metric> metrics)Evaluate the performance of the model.voidfit(Model.FitBuilder fb)Train the model on data.MatrixforwardPropagate(Matrix x)Forward propagate a batch of input.doublegetLoss(Loss loss, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)Get the loss.doublegetLoss(Loss loss, Matrix out, Matrix y)Get the loss.doublegetMetric(Metric metric, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)Get a metric of the model.doublegetMetric(Metric metric, Matrix out, Matrix y)Get a metric of the model.protected voidonEpochEnd(int epoch)Method that is called at the end of every epoch.Matrixpredict(Matrix x)Predict on a batch of input.voidprintSummary()Print a summary of the model.static ModelreadFromFile(java.lang.String filePath)Read a model from a file.voidsaveToFile(java.lang.String filePath)Save the current model to a file.voidtrainOnBatch(Matrix x, Matrix y, double learningRate)Train on a single batch of input and output.protected voidupdate(java.util.ArrayList<Matrix> errors)Update the model after backpropagation.
-
-
-
Method Detail
-
readFromFile
public static Model readFromFile(java.lang.String filePath)
Read a model from a file.- Parameters:
filePath- path to the file- Returns:
- the model read from the file
-
addLayer
public Model addLayer(Layer layer)
Add a layer to the model.- Parameters:
layer- the layer instance to add- Returns:
- the model itself to allow for daisy chaining
-
buildModel
public void buildModel(Loss loss)
Build the model. Run this after adding layers and before training.- Parameters:
loss- an instance of the loss function to use in the model
-
printSummary
public void printSummary()
Print a summary of the model.
-
fit
public void fit(Model.FitBuilder fb)
Train the model on data.- Parameters:
fb- FitBuilder instance
-
trainOnBatch
public void trainOnBatch(Matrix x, Matrix y, double learningRate)
Train on a single batch of input and output.- Parameters:
x- input data to train ony- expected outputslearningRate- learning rate of training
-
predict
public Matrix predict(Matrix x)
Predict on a batch of input.- Parameters:
x- input data to feed forward- Returns:
- prediction of model for input
-
comparePredictions
public void comparePredictions(Matrix input, Matrix output, int printNum)
Print predictions and compare to correct.- Parameters:
input- inputoutput- correct outputprintNum- number of items to print
-
evaluate
public void evaluate(Matrix x, Matrix y, java.util.ArrayList<Metric> metrics)
Evaluate the performance of the model.- Parameters:
x- input datay- expected outputsmetrics- metrics to display
-
forwardPropagate
public Matrix forwardPropagate(Matrix x)
Forward propagate a batch of input.- Parameters:
x- input data to feed forward- Returns:
- result of model for input
-
backPropagate
public java.util.ArrayList<Matrix> backPropagate(Matrix y)
Backpropagate model after forward propagating input.- Parameters:
y- expected output for model- Returns:
- errors for each layer from backpropagation
-
getLoss
public double getLoss(Loss loss, Matrix out, Matrix y)
Get the loss.- Parameters:
loss- loss function to useout- actual output of modely- expected output- Returns:
- loss of model
-
getLoss
public double getLoss(Loss loss, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get the loss.- Parameters:
loss- loss function to useout- actual output of modely- expected output- Returns:
- loss of model
-
getMetric
public double getMetric(Metric metric, Matrix out, Matrix y)
Get a metric of the model.- Parameters:
metric- metric function to useout- actual output of modely- expected output- Returns:
- calculated metric
-
getMetric
public double getMetric(Metric metric, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get a metric of the model.- Parameters:
metric- metric function to useout- actual output of modely- expected output- Returns:
- calculated metric
-
update
protected void update(java.util.ArrayList<Matrix> errors)
Update the model after backpropagation.- Parameters:
errors- errors obtained from backpropagation
-
saveToFile
public void saveToFile(java.lang.String filePath)
Save the current model to a file.Note: the model can continue to be trained and used to predict after saving.
- Parameters:
filePath- path to file to save to
-
onEpochEnd
protected void onEpochEnd(int epoch)
Method that is called at the end of every epoch.Subclasses of Model can easily override this method to do something once each epoch ends.
- Parameters:
epoch- epoch number (0-indexed)
-
-