Package me.yixqiao.jlearn.models
Class Model
- java.lang.Object
-
- me.yixqiao.jlearn.models.Model
-
- All Implemented Interfaces:
java.io.Serializable
public class Model extends java.lang.Object implements java.io.Serializable
Neural network model.- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
Model.FitBuilder
Builder class for fit operation.protected static class
Model.FitPrint
Thread to print progress when fitting.
-
Field Summary
Fields Modifier and Type Field Description protected int
layerCount
Number of layers (including input).protected java.util.ArrayList<Layer>
layers
Layers.protected Loss
loss
Loss function for model.
-
Constructor Summary
Constructors Constructor Description Model()
Create a new model.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description Model
addLayer(Layer layer)
Add a layer to the model.java.util.ArrayList<Matrix>
backPropagate(Matrix y)
Backpropagate model after forward propagating input.void
buildModel(Loss loss)
Build the model.void
comparePredictions(Matrix input, Matrix output, int printNum)
Print predictions and compare to correct.void
evaluate(Matrix x, Matrix y, java.util.ArrayList<Metric> metrics)
Evaluate the performance of the model.void
fit(Model.FitBuilder fb)
Train the model on data.Matrix
forwardPropagate(Matrix x)
Forward propagate a batch of input.double
getLoss(Loss loss, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get the loss.double
getLoss(Loss loss, Matrix out, Matrix y)
Get the loss.double
getMetric(Metric metric, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get a metric of the model.double
getMetric(Metric metric, Matrix out, Matrix y)
Get a metric of the model.protected void
onEpochEnd(int epoch)
Method that is called at the end of every epoch.Matrix
predict(Matrix x)
Predict on a batch of input.void
printSummary()
Print a summary of the model.static Model
readFromFile(java.lang.String filePath)
Read a model from a file.void
saveToFile(java.lang.String filePath)
Save the current model to a file.void
trainOnBatch(Matrix x, Matrix y, double learningRate)
Train on a single batch of input and output.protected void
update(java.util.ArrayList<Matrix> errors)
Update the model after backpropagation.
-
-
-
Method Detail
-
readFromFile
public static Model readFromFile(java.lang.String filePath)
Read a model from a file.- Parameters:
filePath
- path to the file- Returns:
- the model read from the file
-
addLayer
public Model addLayer(Layer layer)
Add a layer to the model.- Parameters:
layer
- the layer instance to add- Returns:
- the model itself to allow for daisy chaining
-
buildModel
public void buildModel(Loss loss)
Build the model. Run this after adding layers and before training.- Parameters:
loss
- an instance of the loss function to use in the model
-
printSummary
public void printSummary()
Print a summary of the model.
-
fit
public void fit(Model.FitBuilder fb)
Train the model on data.- Parameters:
fb
- FitBuilder instance
-
trainOnBatch
public void trainOnBatch(Matrix x, Matrix y, double learningRate)
Train on a single batch of input and output.- Parameters:
x
- input data to train ony
- expected outputslearningRate
- learning rate of training
-
predict
public Matrix predict(Matrix x)
Predict on a batch of input.- Parameters:
x
- input data to feed forward- Returns:
- prediction of model for input
-
comparePredictions
public void comparePredictions(Matrix input, Matrix output, int printNum)
Print predictions and compare to correct.- Parameters:
input
- inputoutput
- correct outputprintNum
- number of items to print
-
evaluate
public void evaluate(Matrix x, Matrix y, java.util.ArrayList<Metric> metrics)
Evaluate the performance of the model.- Parameters:
x
- input datay
- expected outputsmetrics
- metrics to display
-
forwardPropagate
public Matrix forwardPropagate(Matrix x)
Forward propagate a batch of input.- Parameters:
x
- input data to feed forward- Returns:
- result of model for input
-
backPropagate
public java.util.ArrayList<Matrix> backPropagate(Matrix y)
Backpropagate model after forward propagating input.- Parameters:
y
- expected output for model- Returns:
- errors for each layer from backpropagation
-
getLoss
public double getLoss(Loss loss, Matrix out, Matrix y)
Get the loss.- Parameters:
loss
- loss function to useout
- actual output of modely
- expected output- Returns:
- loss of model
-
getLoss
public double getLoss(Loss loss, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get the loss.- Parameters:
loss
- loss function to useout
- actual output of modely
- expected output- Returns:
- loss of model
-
getMetric
public double getMetric(Metric metric, Matrix out, Matrix y)
Get a metric of the model.- Parameters:
metric
- metric function to useout
- actual output of modely
- expected output- Returns:
- calculated metric
-
getMetric
public double getMetric(Metric metric, java.util.ArrayList<Matrix> out, java.util.ArrayList<Matrix> y)
Get a metric of the model.- Parameters:
metric
- metric function to useout
- actual output of modely
- expected output- Returns:
- calculated metric
-
update
protected void update(java.util.ArrayList<Matrix> errors)
Update the model after backpropagation.- Parameters:
errors
- errors obtained from backpropagation
-
saveToFile
public void saveToFile(java.lang.String filePath)
Save the current model to a file.Note: the model can continue to be trained and used to predict after saving.
- Parameters:
filePath
- path to file to save to
-
onEpochEnd
protected void onEpochEnd(int epoch)
Method that is called at the end of every epoch.Subclasses of Model can easily override this method to do something once each epoch ends.
- Parameters:
epoch
- epoch number (0-indexed)
-
-