Skip to content

Advanced Usage

echo edited this page Mar 1, 2025 · 13 revisions

This guide dives deeper into Brain4J, focusing on how to work efficiently with datasets and advanced training techniques.

📊 Working with DataSet

The DataSet class is designed to manage your training data efficiently. You can split data into batches using partition and partitionWithSize.

🔹 Creating Mini-Batches

partition(int batches)

Divides the dataset into a fixed number of batches.

DataSet<DataRow> fullData = new DataSet<>();

// Populate dataset with some sample data
for (int i = 0; i < 100; i++) {
    Vector input = Vector.random(5); // 5 input features
    Vector output = Vector.of(Math.random());

    fullData.add(new DataRow(input, output));
}

// Split data into 10 batches
fullData.partition(10);

// Access batches
List<List<DataRow>> batches = fullData.getPartitions();

partitionWithSize(int batchSize)

Splits the dataset into batches of a given size.

// Create batches with 16 samples each
fullData.partitionWithSize(16);

// Access batches
List<List<DataRow>> batches = fullData.getPartitions();

🔹 Shuffling Data

Before training, it's good practice to shuffle the dataset to improve learning.

fullData.shuffle();

🤖 Using SmartTrainer

SmartTrainer automates training by handling batch updates, stopping conditions, and evaluation.

🔹 Basic Usage

SmartTrainer trainer = new SmartTrainer(0.95, 5); // Learning rate decay 0.95, evaluate every 5 epochs
trainer.start(model, fullData, 0.01, 0.001); // Train until loss < 0.01 or tolerance exceeded

🔹 Training for a Fixed Number of Epochs

trainer.startFor(model, fullData, 1000); // Train for 1000 epochs

🔹 Monitoring Training Progress

You can add listeners to track the training process in real time.

private static class ExampleListener extends TrainListener {

    @Override
    public void onEvaluated(DataSet<DataRow> dataSet, int epoch, double loss, long took) {
         System.out.print("\rEpoch " + epoch + " loss: " + loss + " took " + (took / 1e6) + " ms");
    }
}
trainer.addListener(new ExampleListener());

📚 Next Steps

Clone this wiki locally