Improving the performance of a training loop can save hours of computing time when training machine learning models. One of the ways of improving the performance of TensorFlow code is using the
tf.function() decorator - a simple, one-line change that can make your functions run significantly faster.
In this short guide, we will explain how
tf.function()improves performance and take a look at some best practices.
Python Decorators and tf.function()
In Python, a decorator is a function that modifies the behavior of other functions. For instance, suppose you call the following function in a notebook cell:
import tensorflow as tf x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32) def some_costly_computation(x): aux = tf.eye(100, dtype=tf.dtypes.float32) result = tf.zeros(100, dtype = tf.dtypes.float32) for i in range(1,100): aux = tf.matmul(x,aux)/i result = result + aux return result %timeit some_costly_computation(x)
16.2 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
However, if we pass the costly function into a
quicker_computation = tf.function(some_costly_computation) %timeit quicker_computation(x)
quicker_computation() - a new function that performs much faster than the previous one:
4.99 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
some_costly_computation() and outputs the
quicker_computation() function. Decorators also modify functions, so it was natural to make
tf.function() a decorator as well.
Using the decorator notation is the same as calling
@tf.function def quick_computation(x): aux = tf.eye(100, dtype=tf.dtypes.float32) result = tf.zeros(100, dtype = tf.dtypes.float32) for i in range(1,100): aux = tf.matmul(x,aux)/i result = result + aux return result %timeit quick_computation(x)
5.09 ms ± 283 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
How come we can make certain functions run 2-3x faster?
TensorFlow code can be run in two modes: eager mode and graph mode. Eager mode is the standard, interactive way to run code: every time you call a function, it is executed.
Graph mode, however, is a little bit different. In graph mode, before executing the function, TensorFlow creates a computation graph, which is a data structure containing the operations required for executing the function. The computation graph allows TensorFlow to simplify the computations and find opportunities for parallelization. The graph also isolates the function from the overlying Python code, allowing it to be run efficiently on many different devices.
A function decorated with
@tf.function is executed in two steps:
- In the first step, TensorFlow executes the Python code for the function and compiles a computation graph, delaying the execution of any TensorFlow operation.
- Afterwards, the computation graph is run.
Note: The first step is known as "tracing".
The first step will be skipped if there is no need to create a new computation graph. This improves the performance of the function but also means that the function will not execute like regular Python code (in which each executable line is executed). For example, let's modify our previous function:
@tf.function def quick_computation(x): print('Only prints the first time!') aux = tf.eye(100, dtype=tf.dtypes.float32) result = tf.zeros(100, dtype = tf.dtypes.float32) for i in range(1,100): aux = tf.matmul(x,aux)/i result = result + aux return result quick_computation(x) quick_computation(x)
This results in:
Only prints the first time!
print() is only executed once during the tracing step, which is when regular Python code is run. The next calls to the function only execute TenforFlow operations from the computation graph (TensorFlow operations).
However, if we use
@tf.function def quick_computation_with_print(x): tf.print("Prints every time!") aux = tf.eye(100, dtype=tf.dtypes.float32) result = tf.zeros(100, dtype = tf.dtypes.float32) for i in range(1,100): aux = tf.matmul(x,aux)/i result = result + aux return result quick_computation_with_print(x) quick_computation_with_print(x)
Prints every time! Prints every time!
tf.print() in its computation graph as it's a TensorFlow operation - not a regular Python function.
<div class="alert alert-warn"> <div class="flex"> <strong>Warning:</strong> Not all Python code is executed in every call to a function decorated with
@tf.function. After tracing, only the operations from the computational graph are run, which means some care must be taken in our code. </div> </div> <h3 id="bestpracticeswithtffunction">Best Practices with
Writing Code with TensorFlow Operations
As we've just shown, some parts of the code are ignored by the computation graph. This makes it hard to predict the behavior of the function when coding with "normal" Python code, as we've just seen with
print(). It is better to code your function with TensorFlow operations when applicable to avoid unexpected behavior.
while loops may or may not be converted into the equivalent TensorFlow loop. Therefore, it is better to write your "for" loop as a vectorized operation, if possible. This will improve the performance of your code and ensure that your function traces correctly.
As an example, consider the following:
x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32) @tf.function def function_with_for(x): summ = float(0) for row in x: summ = summ + tf.reduce_mean(row) return summ @tf.function def vectorized_function(x): result = tf.reduce_mean(x, axis=0) return tf.reduce_sum(result) print(function_with_for(x)) print(vectorized_function(x)) %timeit function_with_for(x) %timeit vectorized_function(x)
tf.Tensor(0.672811, shape=(), dtype=float32) tf.Tensor(0.67281103, shape=(), dtype=float32) 1.58 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 440 µs ± 8.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
The code with the TensorFlow operations is considerably faster.
Avoid References to Global Variables
Consider the following code:
x = tf.Variable(2, dtype=tf.dtypes.float32) y = 2 @tf.function def power(x): return tf.pow(x,y) print(power(x)) y = 3 print(power(x))
tf.Tensor(4.0, shape=(), dtype=float32) tf.Tensor(4.0, shape=(), dtype=float32)
The first time the decorated function
power() was called, the output value was the expected 4. However, the second time, the function ignored that the value of
y was changed. This happens because the value of Python global variables is frozen for the function after tracing.
A better way would be to use
tf.Variable() for all your variables and pass both as arguments to your function.
x = tf.Variable(2, dtype=tf.dtypes.float32) y = tf.Variable(2, dtype = tf.dtypes.float32) @tf.function def power(x,y): return tf.pow(x,y) print(power(x,y)) y.assign(3) print(power(x,y))
tf.Tensor(4.0, shape=(), dtype=float32) tf.Tensor(8.0, shape=(), dtype=float32)
Debugging [email protected]_s
In general, you want to debug your function in eager mode, and then decorate them with
@tf.function after your code is running correctly because the error messages in eager mode are more informative.
Some common problems are type errors and shape errors. Type errors happen when there is a mismatch in the type of the variables involved in an operation:
x = tf.Variable(1, dtype = tf.dtypes.float32) y = tf.Variable(1, dtype = tf.dtypes.int32) z = tf.add(x,y)
InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a float tensor but is a int32 tensor [Op:AddV2]
Type errors easily creep in, and can easily be fixed by casting a variable to a different type:
y = tf.cast(y, tf.dtypes.float32) z = tf.add(x, y) tf.print(z) # 2
Shape errors happen when your tensors do not have the shape your operation require:
x = tf.random.uniform(shape=[100, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32) y = tf.random.uniform(shape=[1, 100], minval=-1, maxval=1, dtype=tf.dtypes.float32) z = tf.matmul(x,y)
InvalidArgumentError: Matrix size-incompatible: In: [100,100], In: [1,100] [Op:MatMul]
One convenient tool for fixing both kinds of errors is the interactive Python debugger, which you can call automatically in a Jupyter Notebook using
%pdb. Using that, you can code your function and run it through some common use cases. If there is an error, an interactive prompt opens. This prompt allows you to go up and down the abstraction layers in your code and check the values, types, and shapes of your TensorFlow variables.
We've seen how TensorFlow's
tf.function() makes your function more efficient, and how the
@tf.function decorator applies the function to your own.
This speed-up is useful in functions that will be called many times, such as custom training steps for machine learning models.