Download RandAugment for Image Classification with Keras/TensorFlow - Python
Categories:Viewed: 30 - Published at: 2 months ago

Data augmentation has, for a long while, been serving as a means of replacing a "static" dataset with transformed variants, bolstering the invariance of Convolutional Neural Networks (CNNs), and usually leading to robustness to input.

Note: Invariance boils down to making models blind to certain pertubations, when making decisions. An image of a cat is still an image of a cat if you mirror it or rotate it.
        While data augmentation in the form that we've been using it does encode a <em>lack</em> of knowledge about translational variance, which is important for object detection, semantic and instance segmentation, etc. - the <em>invariance</em> it provides is oftentimes favorable for classification models, and thus, augmentation is more commonly and more aggressively applied to classification models.

Types of Augmentation

Augmentation started being very simple - small rotations, horizontal and vertical flips, contrast or brightness fluctuations, etc. In recent years, more elaborate methods have been devised, including CutOut (spatial dropout introducing black squares randomly in the input images) and MixUp (mixing up parts of images and updating label proportions), and their combination - CutMix. The newer augmentation methods actually account for labels, and methods like CutMix change the label proportions to be equal to the proportions of the image taken up by parts of each class being mixed up. With a growing list of possible augmentations, some have started to apply them randomly (or at least some subset of them), with the idea that a random set of augmentations will bolster the robustness of models, and replace the original set with a much larger space of input images. This is where RandAugment kicks in!

KerasCV and RandAugment

KerasCV is a separate package, but still an official addition to Keras, developed by the Keras team. This means that it gets the same amount of polish and intuitiveness of the main package, but it also integrates seemelessly with regular Keras models, and their layers. The only difference you'll ever notice is calling keras_cv.layers... instead of keras.layers.... KerasCV is still in development as of writing, and already includes 27 new preprocessing layers, RandAugment, CutMix, and MixUp being some of them. Let's take a look at what it looks like to apply RandAugment to images, and how we can train a classifier with and without random augmentation. First, install keras_cv:

$ pip install keras_cv
Note: KerasCV requires TensorFlow 2.9 to work. If you don't already have it, run $ pip install -U tensorflow first.
        Now, let's import TensorFlow, Keras and KerasCV, alongisde TensorFlow datasets for easy access to Imagenette:
import tensorflow as tf
from tensorflow import keras
import keras_cv
import tensorflow_datasets as tfds

Let's load in an image and display it in its original form:

import matplotlib.pyplot as plt
import cv2

cat_img = cv2.cvtColor(cv2.imread('cat.jpg'), cv2.COLOR_BGR2RGB)
cat_img = cv2.resize(cat_img, (224, 224))

Now, let's apply RandAugment to it, several times and take a look at the results:

fig = plt.figure(figsize=(10,10))
for i in range(16):
    ax = fig.add_subplot(4,4,i+1)
    aug_img = keras_cv.layers.RandAugment(value_range=(0, 255))(cat_img)
    # aug_img is a float-based tensor so we convert it back

The layer has a magnitude argument, which defaults to 0.5 and can be changed to increase or decrease the effect of augmentation:

fig = plt.figure(figsize=(10,10))
for i in range(16):
    ax = fig.add_subplot(4,4,i+1)
    aug_img = keras_cv.layers.RandAugment(value_range=(0, 255), magnitude=0.1)(cat_img)

When set to a low value such as 0.1 - you'll see much less aggressive augmentation: Being a layer - it can be used within models or in pipelines while creating datasets. This makes RandAugment pretty flexible! Additional arguments are the augmentations_per_image and rate arguments, which work together. For 0...augmentations_per_image, the layer adds a random preprocessing layer to the pipeline to be applied to an image. In the case of the default 3 - three different operations are added to the pipeline. Then, a random number is sampled for each augmentation in the pipeline - and if it's lower than rate (defaults to around 0.9) - the augmentation is applied.

In essence - there's a 90% probability of each (random) augmentation in the pipeline being applied to the image.

This naturally means that not all augmentations have to be applied, especially if you lower the rate. You can also customize which operations are allowed through a RandomAugmentationPipeline layer, which RandAugment is the special case of. A separate guide on RandomAugmentationPipeline will be published soon.

Training a Classifier with and without RandAugment

To simplify the data preparation/loading aspect and focus on RandAugment, let's use tfds to load in a portion of Imagenette:

(train, valid_set, test_set), info = tfds.load("imagenette", 
                                           split=["train[:70%]", "validation", "train[70%:]"],
                                           as_supervised=True, with_info=True)

class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
print(f'Class names: {class_names}') # Class names: ['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']
print('Num of classes:', n_classes) # Num of classes: 10

print("Train set size:", len(train)) # Train set size: 6628
print("Test set size:", len(test_set)) # Test set size: 2841
print("Valid set size:", len(valid_set)) # Valid set size: 3925
        <div class="alert alert-reference">
            <div class="flex">

        <strong>Advice:</strong> For more on loading datasets and working with tfds, as well as their splits - read our "Split Train, Test and Validation Sets with Tensorflow Datasets - tfds"

        We've only loaded a portion of the training data in, to make it easier to overfit the dataset in fewer epochs (making our experiment run faster, in effect). Since the images in Imagenette are of different sizes, let's create a preprocess() function that resizes them to map the dataset with, as well as an augment() function that augments images in a
def preprocess(images, labels):
  return tf.image.resize(images, (224, 224)), tf.one_hot(labels, 10)

def augment(images, labels):
  inputs = {"images": images, "labels": labels}
  outputs = keras_cv.layers.RandAugment(value_range=(0, 255))(inputs)
  return outputs['images'], outputs['labels']

Now - we one-hot encoded the labels. We didn't necessarily have to, but for augmentations like CutMix that tamper with labels and their proportions, you'll have to. Since you might want to apply those as well as RandAugment works really well with them to create robust classifiers - let's leave the one-hot encoding in. Additionally, RandAugment takes in a dictionary with images and labels exactly because of this - some augmentations that you can add will actually change the labels, so they're mandatory. You can extract the augmented images and labels from the outputs dictionary easily, so this is an extra, yet simple, step to take during augmentation. Let's map the existing datasets returned from tfds with the preprocess() function, batch them and augment only the training set:

valid_set =
train_set =
train_set_aug =, 

Let's train a network! keras_cv.models has some built-in networks, similar to keras.applications. While the list is still short - it'll expand through time and take over keras.applications. The API is very similar, so porting code will be fairly easy for most practicioners:

# rescaling to [0..1]
effnet = keras_cv.models.EfficientNetV2B0(include_rescaling=True, include_top=True, classes=10)


history =, epochs=25, validation_data = valid_set)

Alternatively, you can use the current keras.applications:

effnet = keras.applications.EfficientNetV2B0(weights=None, classes=10)


history1 =, epochs=50, validation_data=valid_set)

This results in a model that doesn't really do super well:

Epoch 1/50
208/208 [==============================] - 60s 238ms/step - loss: 2.7742 - accuracy: 0.2313 - val_loss: 3.2200 - val_accuracy: 0.3085
Epoch 50/50
208/208 [==============================] - 48s 229ms/step - loss: 0.0272 - accuracy: 0.9925 - val_loss: 2.0638 - val_accuracy: 0.6887

Now, let's train the same network setup on the augmented dataset. Each batch is individually augmented, so whenever the same batch of images (in the next epoch) comes around - they'll have different augmentations:

effnet = keras.applications.EfficientNetV2B0(weights=None, classes=10)

history2 =, epochs=50, validation_data = valid_set)
Epoch 1/50
208/208 [==============================] - 141s 630ms/step - loss: 2.9966 - accuracy: 0.1314 - val_loss: 2.7398 - val_accuracy: 0.2395
Epoch 50/50
208/208 [==============================] - 125s 603ms/step - loss: 0.7313 - accuracy: 0.7583 - val_loss: 0.6101 - val_accuracy: 0.8143

Much better! While you'd still want to apply other augmentations, such as CutMix and MixUp, alongside other training techniques to maximize the network's accuracy - just RandAugment significantly helped and can be comparable to a longer augmentation pipeline. If you compare the training curves, including the training and validation curves - it only becomes clear how much RandAugment helps: In the non-augmented pipeline, the network overfits (training accuracy hits ceiling) and the validation accuracy stays low. In the augmented pipeline, the training accuracy stays lower than the validation accuracy from start to end. With a higher training loss, the network is much more aware of the mistakes it still makes, and more updates can be made to make it invariant to the transformations. The former sees no need to update, while the latter does and raises the ceiling of potential.


KerasCV is a separate package, but still an official addition to Keras, developed by the Keras team, aimed at bringing industry-strength CV to your Keras projects. KerasCV is still in development as of writing, and already includes 27 new preprocessing layers, RandAugment, CutMix, and MixUp being some of them. In this short guide, we've taken a look at how you can use RandAugment to apply a number of random transformations from a given list of applied transformations, and how easy it is to include in any Keras training pipeline.

Going Further - Practical Deep Learning for Computer Vision

Your inquisitive nature makes you want to go further? We recommend checking out our Course: "Practical Deep Learning for Computer Vision with Python".

Another Computer Vision Course?

We won't be doing classification of MNIST digits or MNIST fashion. They served their part a long time ago. Too many learning resources are focusing on basic datasets and basic architectures before letting advanced black-box architectures shoulder the burden of performance. We want to focus on demystification, practicality, understanding, intuition and real projects. Want to learn how you can make a difference? We'll take you on a ride from the way our brains process images to writing a research-grade deep learning classifier for breast cancer to deep learning networks that "hallucinate", teaching you the principles and theory through practical work, equipping you with the know-how and tools to become an expert at applying deep learning to solve computer vision.

What's inside?

  • The first principles of vision and how computers can be taught to "see"
  • Different tasks and applications of computer vision
  • The tools of the trade that will make your work easier
  • Finding, creating and utilizing datasets for computer vision
  • The theory and application of Convolutional Neural Networks
  • Handling domain shift, co-occurrence, and other biases in datasets
  • Transfer Learning and utilizing others' training time and computational resources for your benefit
  • Building and training a state-of-the-art breast cancer classifier
  • How to apply a healthy dose of skepticism to mainstream ideas and understand the implications of widely adopted techniques
  • Visualizing a ConvNet's "concept space" using t-SNE and PCA
  • Case studies of how companies use computer vision techniques to achieve better results
  • Proper model evaluation, latent space visualization and identifying the model's attention
  • Performing domain research, processing your own datasets and establishing model tests
  • Cutting-edge architectures, the progression of ideas, what makes them unique and how to implement them
  • KerasCV - a WIP library for creating state of the art pipelines and models
  • How to parse and read papers and implement them yourself
  • Selecting models depending on your application
  • Creating an end-to-end machine learning pipeline
  • Landscape and intuition on object detection with Faster R-CNNs, RetinaNets, SSDs and YOLO
  • Instance and semantic segmentation
  • Real-Time Object Recognition with YOLOv5
  • Training YOLOv5 Object Detectors
  • Working with Transformers using KerasNLP (industry-strength WIP library)
  • Integrating Transformers with ConvNets to generate captions of images
  • DeepDream
  • Deep Learning model optimization for computer vision