Start here





using Pkg  # src
Pkg.activate("../FluxTraining/docs")  # src

Let’s put FluxTraining.jl to train a model on the MNIST dataset.

MNIST is simple enough that we can focus on the part where FluxTraining.jl comes in, the training.


if you want to run this tutorial yourself, you can find the notebook file here. To make data loading and batching a bit easier, we’ll install some additional dependencies:

using Pkg

Now we can import everything we’ll need.

using DataLoaders: DataLoader
using MLDataPattern: splitobs
using Flux
using FluxTraining


There are 4 pieces that you always need to construct and train a Learner:

Building a Learner

Let’s look at the data first.

FluxTraining.jl is agnostic of the data source. The only requirements are:

Glossing over the details as it’s not the focus of this tutorial, here’s the code for getting a data iterator of the MNIST dataset. We use DataLoaders.DataLoader to create an iterator of batches from our dataset.

xs, ys = (
    # convert each image into h*w*1 array of floats 
    [Float32.(reshape(img, 28, 28, 1)) for img in Flux.Data.MNIST.images()],
    # one-hot encode the labels
    [Flux.onehot(y, 0:9) for y in Flux.Data.MNIST.labels()],

# split into training and validation sets
traindata, valdata = splitobs((xs, ys))

# create iterators
trainiter, valiter = DataLoader(traindata, 128), DataLoader(valdata, 256);

Next, let’s create a simple Flux.jl model that we’ll train to classify the MNIST digits.

model = Chain(
    Conv((3, 3), 1 => 16, relu, pad = 1, stride = 2),
    Conv((3, 3), 16 => 32, relu, pad = 1),
    Dense(32, 10),

We’ll use categorical cross entropy as a loss function and ADAM as an optimizer.

lossfn = Flux.Losses.logitcrossentropy
optim = Flux.ADAM();

Now we’re ready to create a Learner. At this point you can also add any callbacks, like ToGPU to run the training on your GPU if you have one available. Some callbacks are also included by default.

Since we’re classifying digits, we also use the Metrics callback to track the accuracy of the model’s predictions:

learner = Learner(model, (trainiter, valiter), optim, lossfn, ToGPU(), Metrics(accuracy))


With a Learner inplace, training is as simple as calling fit!(learner, nepochs).!(learner, 10)