DeepLearning4j (DL4J) is a powerful deep learning library for Java and Scala, designed to bring AI capabilities to the JVM ecosystem. Whether you're a Java developer exploring machine learning or an AI enthusiast looking for a production-ready framework, DL4J provides an efficient way to build, train, and deploy neural networks.
In this guide, we’ll walk through the process of creating your first neural network using DeepLearning4j. By the end, you’ll have a working model trained on a simple dataset.
Before starting, ensure you have:
First, add DL4J and its dependencies to your project.
Add the following to your pom.xml:
<dependencies>
<!-- Core DL4J library -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0</version>
</dependency>
<!-- ND4J for numerical computations -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0</version>
</dependency>
<!-- DataVec for dataset loading -->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>1.0.0</version>
</dependency>
</dependencies>
Add these to your build.gradle:
dependencies {
implementation 'org.deeplearning4j:deeplearning4j-core:1.0.0'
implementation 'org.nd4j:nd4j-native-platform:1.0.0'
implementation 'org.datavec:datavec-api:1.0.0'
}
For this tutorial, we’ll use the Iris dataset, a classic dataset for classification tasks.
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
public class IrisDatasetLoader {
public static DataSetIterator loadIrisDataset(String filePath, int batchSize) throws Exception {
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new FileSplit(new File(filePath)));
// 4 input features, 3 output classes
return new RecordReaderDataSetIterator(recordReader, batchSize, 4, 3);
}
}
Now, let’s define a simple Multi-Layer Perceptron (MLP) for classification.
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class SimpleNeuralNetwork {
public static MultiLayerNetwork buildModel() {
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.updater(new Adam(0.01))
.list()
.layer(new DenseLayer.Builder()
.nIn(4) // Input features (sepal length, sepal width, petal length, petal width)
.nOut(10) // Hidden layer neurons
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(10)
.nOut(3) // 3 output classes (Iris species)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
return model;
}
}
Now, let’s train the model using the Iris dataset.
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
public class ModelTraining {
public static void main(String[] args) throws Exception {
// Load dataset
String irisCsvPath = "path/to/iris_dataset.csv";
DataSetIterator trainData = IrisDatasetLoader.loadIrisDataset(irisCsvPath, 10);
// Build model
MultiLayerNetwork model = SimpleNeuralNetwork.buildModel();
// Train for 100 epochs
int epochs = 100;
for (int i = 0; i < epochs; i++) {
model.fit(trainData);
trainData.reset();
System.out.println("Epoch " + i + " completed.");
}
System.out.println("Training complete!");
}
}
After training, we can evaluate the model’s performance.
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.evaluation.classification.Evaluation;
public class ModelEvaluation {
public static void main(String[] args) throws Exception {
// Load test dataset
String testCsvPath = "path/to/iris_test_dataset.csv";
DataSetIterator testData = IrisDatasetLoader.loadIrisDataset(testCsvPath, 10);
// Load trained model
MultiLayerNetwork model = SimpleNeuralNetwork.buildModel();
// Evaluate
Evaluation eval = model.evaluate(testData);
System.out.println(eval.stats());
}
}
Congratulations! 🎉 You’ve just built and trained your first neural network using DeepLearning4j. Here’s a quick recap:
Would you like a deeper dive into any specific topic? Let me know in the comments! 🚀
References:
This guide provides a hands-on introduction to DL4J. Happy coding! 🤖💻