Define a Modeler

The modeler defines the model pipeline. This example defines the computation that is needed for a ResNet32 network:

# Create a ResNet32 network
network_name = ""
net = getattr(importlib.import_module(network_name), "net")

# Create basic modeler configration
modeler_config = ModelerConfig(

# Add additional configuration for image classification
modeler_config = ImageClassificationModelerConfig(

# Create modeler
modeler_name = "source.modeler.image_classification_modeler"
modeler = importlib.import_module(modeler_name).build(modeler_config, net)
  • net is the function that implments ResNet32’s forward pass.
  • modeler_config contains the argments for building a ResNet32 model. Importantly, it sets up the number of classes.
  • modeler is the model pipeline. It has an important model_fn member function that outputs a dictionary of operators to be run by a Tensorflow session.

The model_fn for an image classification modeler looks like this:

def model_fn(self, x):

  # Input batch of images and labels
  images = x[0]
  labels = x[1]

  # Create graph for forward pass
  logits, predictions = self.create_graph_fn(images)

  # Return modeler operators
  if self.config.mode == "train":

    # Training mode returns operators for loss, gradient and accuracy
    loss = self.create_loss_fn(logits, labels)
    grads = self.create_grad_fn(loss)
    accuracy = self.create_eval_metrics_fn(
      predictions, labels)
    return {"loss": loss,
            "grads": grads,
            "accuracy": accuracy,
            "learning_rate": self.learning_rate}
  elif self.config.mode == "eval":

    # Evalution mode returns operators for loss and accuracy
    loss = self.create_loss_fn(logits, labels)
    accuracy = self.create_eval_metrics_fn(
      predictions, labels)
    return {"loss": loss,
            "accuracy": accuracy}
  elif self.config.mode == "infer":

    # Inference mode returns the predicted classes and probabilities for the predictions
    return {"classes": predictions["classes"],
            "probabilities": predictions["probabilities"]}