README

Start here

Library

Reference

Tutorials

Other

public Learnerstruct

Learner(model, data, optimizer, lossfn, [callbacks...; kwargs...])

Holds and coordinates all state of the training. model is trained by optimizing lossfn with optimizer on data.

Arguments

  • model
  • data: Tuple of data iterators in the order (traindata, valdata, [testdata]). Must be iterable and return tuples of (xs, ys)
  • lossfn: Function with signature lossfn(model(x), y) -> Number
  • optimizer
  • callbacks...: Any other unnamed arguments are callbacks

Keyword arguments

Fields

(Use this as a reference when implementing callbacks)

  • model, optimizer, and lossfn are stored as passed in

  • data is a NamedTuple of (training = ..., validation = ..., test = ...). Some values might be nothing if you didn’t pass in multiple data iterators.

  • params: an instance of model’s parameters of type Flux.Params

  • batch: State of the current batch, including:

    • batch.xs: model inputs
    • batch.ys: target outputs
    • batch.ŷs: model outputs, i.e. model(xs)
    • batch.loss: batch loss, i.e. lossfn(ŷs, ys)
    • batch.gs: batch gradients, instance of Zygote.Grads

    (!) Note: Depending on the progress of the step, some fields may be nothing, e.g. the gs before the backward pass.

  • cbstate::PropDict: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the custom callbacks guide for more info.