tf.function is a decorator function provided by Tensorflow 2.0 that converts regular python code to a callable Tensorflow graph function, which is usually more performant and python independent. It is used to create portable Tensorflow models.
Table of Contents:
2. Why do we need graphs?
3. How can you create graphs in TF2.0
4. Example – tf.function can significantly reduce the code run time
5. Flow control – changes made by Autograph
6. Side effects of using tf.function you must be aware of
7. Polymorphic functions
8. Limitations in speed up from using tf.function
Tensorflow released the second version of the library in September 2019. This version, popularly called Tensorflow2.0 marked significant changes from the previous Tensorflow1.x version.
TF2.0 allows user to build dynamic computation graphs through a feature called eager execution. Even though eager execution is widely preferred for easier debugging and no need for
tf.session calls, there are some cases when as a user you might still want to lazy execution (static computation graphs) like when you want to improve the code run time performance.
This is why TF2.0 has the
tf.function API, to give any user the option to convert a regular (eager) python code to a lazy code which is actually speed optimized.
This tutorial will bring you up to speed with:
- All that the
tf.functionhas to offer
- Why many developers still use graphs for deployment and
- How tf.function can actually speed up your code.
2. Why do we need graphs?
As you already read above, eager execution is a better choice for easy debugging and more intuitive programming using Python. It is the same as writing regular python code, where you can run your code line by line in console, or as a script and debug your code using
However, running TensorFlow code step-by-step (as in eager execution) in Python prevents a host of accelerations otherwise available in the lazy mode. If you can extract tensor computations from Python, you can make them into a graph.
You may ask, what are graphs in the first place?
Graphs are a type of data structures that contains tensors and the computations performed. Graphs store the flow of information and operations between tensors through tf.Operation objects and tf.Tensor tensors.
But, why use Graphs?
- The primary reason is, graphs allow
your neural network model to be used in environments that dont have a Python interpreter. For example, graphs can be deployed in mobile applications or servers. This is not suitable for eagerly executed code.
- The second reason is that graphs can speed up computation time. They eliminate the need for repetitive initialisation of variables and computation on these variables.
They store this information once, when the graph is defined and then all new tensors and variables make use of this existing graph.
This might not make sense right now, but you will see this speedup in action in upcoming section below.
3. How can you create graphs in TF2.0?
tf.function API is used in TF2.0 to create graphs for eagerly executed code.
There are two ways you can use this.
1. As a decorator: Using
@tf.function decorator before your code will create a graph for that piece of code.
Let’s look at an example.
# Decorate using @tf.function import tensorflow as tf import time from datetime import datetime @tf.function def function(x): a = tf.constant([[2.0], [3.0]]) b = tf.constant(4.0) return a+b
You can see that we have used the
@tf.function decorator. This means that a graph for this function has been created. Let’s test it by calling the function with some input and then visualising it using Tensorboard.
# Plot a graph for function() using Tensorboard stamp = datetime.now().strftime("%Y%m%d-%H%M%S") logdir = 'logs/func/%s' % stamp writer = tf.summary.create_file_writer(logdir) tf.summary.trace_on(graph=True, profiler=True) # Call only one tf.function when tracing. z = function(2) with writer.as_default(): tf.summary.trace_export( name="function_trace", step=0, profiler_outdir=logdir)
%load_ext tensorboard %tensorboard --logdir logs/func
This opens the Tensorboard window:
Here you can see the graph for our created function
Let’s see what is the other method of creating graphs through
2. As a callable function : In this method you can simply
tf.function-ise an existing function to create a graph for that function.
Practically, this is the same as what applying a decorator to a function does. This example simply makes it explicit.
# Define a Python function def callable_func(a, b): return tf.matmul(a, b) # Create a `Function` object that contains a graph function_that_uses_graph = tf.function(callable_func) # Make some tensors a1 = tf.constant([[1.0, 2.0]]) b1 = tf.constant([[2.0], [3.0]]) # It works! function_that_uses_graph(a1, b1).numpy()
4. Example – tf.function can significantly reduce the code run time
The question that arises is, how much does
tf.function speed up operations in Tensorflow?
Let’s look at this speed up by observing the code-run time for a code as it is and then with
class SequentialModel(tf.keras.Model): def __init__(self, **kwargs): super(SequentialModel, self).__init__(**kwargs) self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) self.dense_1 = tf.keras.layers.Dense(128, activation="relu") self.dropout = tf.keras.layers.Dropout(0.2) self.dense_2 = tf.keras.layers.Dense(10) def call(self, x): x = self.flatten(x) x = self.dense_1(x) x = self.dropout(x) x = self.dense_2(x) return x input_data = tf.random.uniform([60, 28, 28]) eager_model = SequentialModel() graph_model = tf.function(eager_model) print("Eager time:", timeit.timeit(lambda: eager_model(input_data), number=10000)) print("Graph time:", timeit.timeit(lambda: graph_model(input_data), number=10000))
Eager time: 7.824499414999991 Graph time: 5.808633186999941
In the above code snippet, we have implemented a classification Sequential model with a lot of small layers. You will be able to see the effect of
tf.function on code-time speed up only for functions that have a lot of complex operations.
For small functions called a few times on a single machine, there wont be much effect.
This is why for the above model, you can see a considerable difference in execution time for eager mode and graph mode.
5. Flow control – changes made by Autograph
When you use
tf.function, flow control and loops written in Python are converted to TensorFlow compatibe code via
tf.autograph by default.
# Define a new function and `tf.function`-ise it def my_function(x): if tf.reduce_sum(x) <= 1: return x * x else: return x-1 a_function = tf.function(my_function) print("First branch, with graph:", a_function(tf.constant(1.0)).numpy()) print("Second branch, with graph:", a_function(tf.constant([5.0, 5.0])).numpy())
First branch, with graph: 1.0 Second branch, with graph: [4. 4.]
You can directly call the Autograph conversion to see how Python is converted into TensorFlow compatible code. This is, mostly, unreadable, but you can see the transformation.
# See how autograph converts python code to make it TF-compatible print(tf.autograph.to_code(my_function))
def tf__my_function(x): with ag__.FunctionScope('my_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (do_return, retval_) def set_state(vars_): nonlocal do_return, retval_ (do_return, retval_) = vars_ def if_body(): nonlocal do_return, retval_ try: do_return = True retval_ = (ag__.ld(x) * ag__.ld(x)) except: do_return = False raise def else_body(): nonlocal do_return, retval_ try: do_return = True retval_ = (ag__.ld(x) - 1) except: do_return = False raise ag__.if_stmt((ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) <= 1), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2) return fscope.ret(retval_, do_return)
Autograph performs this transformation for all Python control loops like for-loop, while-loop and if-else loop.
6. Side-effects of using tf.function you must be aware of
When you call a function that you’ve decorated with tf.function, functions like printing, appending to lists, and mutating globals gets implemented only for first time. From second time onwards they get ignored.
Side Effect 1: print() works only once
Afterwards, the traced
tf.Graph is reexecuted, without executing the Python code.
Let’s look at a code example to gain a better understanding.
# See in action how the print statement executes only once sue to Python side effect! def f(x): print("Traced with", x) tf.print("Executed with", x) f(1) f(1)
Traced with 1 Executed with 1 Executed with 1
print() statement was executed only for the first time f(1) was called. For the second time, as a graph had already been traced, this Python code wasnt reexecuted and thus the statement ‘Traced with 1’ is not seen.
Side Effect 2: Iterator works only once
Another example of a Python side effect is with generators and iterators within a
tf.function code block.
In case of a iterator, the iterator state advances only once, during tracing. Let’s see an example.
# Iterator side effect in action; the value of iterator advances only once during tracing var = tf.Variable(0) @tf.function def side_effect(iterator): var.assign_add(next(iterator)) tf.print("Value of variable:", var) iterator = iter([0, 1, 2, 3]) side_effect(iterator) # This reuses the first value from the iterator, rather than consuming the next value. side_effect(iterator) side_effect(iterator)
Value of variable: 0 Value of variable: 0 Value of variable: 0
The results you should get, if not using
tf.function decorator is something like this:
# Results of using iterator without tf.function var = tf.Variable(0) def side_effect(iterator): var.assign_add(next(iterator)) tf.print("Value of variable:", var) iterator = iter([0, 1, 2, 3]) side_effect(iterator) side_effect(iterator) side_effect(iterator)
Value of variable: 0 Value of variable: 1 Value of variable: 3
The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow functions like
tf.print are the best way to ensure your code will be executed by Tensorflow for every call.
7. Polymorphic Functions
tf.function is polymorphic.
That means, a function can work on different datatypes and for each datatype it will create a new graph, or in other words, retrace an existing graph.
TF will store distinct graphs for each datatype for a given function. You can access individually by using
Let’s understand this through an example.
# Define a function using tf.function decorator @tf.function def poly_func(a): print("Tracing with", a) return a # Test it's polymorphism print(poly_func(tf.constant(1))) print(poly_func(tf.constant(1.1))) print(poly_func(tf.constant("a")))
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(1, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(1.1, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'a', shape=(), dtype=string)
Here the function
poly_func has been traced thrice, once for integer datatype, once for float datatype and once for string datatype.
If we use the same function for any of these datatypes again, no retracing will occur the existing graph will be used.
# This doesn't print 'Tracing with ...' print(poly_func(tf.constant(3.6)))
tf.Tensor(3.6, shape=(), dtype=float32)
You can use
pretty_printed_concrete_signatures() to see all of the available polymorphic forms:
# Print Function signatures print(poly_func.pretty_printed_concrete_signatures())
poly_func(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() poly_func(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() poly_func(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Further, you can obtain a concrete function for the graph of a particular datatype by using
get_concrete_function. This function obtained will only work with the specified datatype.
Below we have obtained the concrete function for integer datatype.
# Get the concrete func print("Obtaining concrete trace") polyfunc_int = poly_func.get_concrete_function(tf.constant(3)) print("Executing traced function") print(polyfunc_int(tf.constant(5))) print(polyfunc_int(a=tf.constant(6)))
Obtaining concrete trace Executing traced function tf.Tensor(5, shape=(), dtype=int32) tf.Tensor(6, shape=(), dtype=int32)
If you call
polyfunc_int for any other dtype apart from integer, you will get an error. Below, let’s pass a float when an integer is expected:
# A concrete fucntion throws errors for other data types print(polyfunc_int(tf.constant(7.4)))
--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) <ipython-input-21-af362321f1a2> in <module>() ----> 1 print(polyfunc_int(tf.constant(7.4))) /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs) 1653 TypeError: If the arguments do not match the function's signature. 1654 """ ...(TRUNCATED)... InvalidArgumentError: cannot compute __inference_poly_func_166 as input #0(zero-based) was expected to be a int32 tensor but is a float tensor [Op:__inference_poly_func_166]
In practice, you will rarely need to use concrete functions.
8. Limitations in speed-up from using
Just wrapping a tensor-using function in tf.function does not automatically speed up your code.
For small functions called a few times on a single machine, the overhead of calling a graph or graph fragment may dominate runtime. Also, if most of the computation was already happening on an accelerator, such as stacks of GPU-heavy convolutions, the graph speedup won’t be large.
For complicated computations, graphs can provide a signficiant speedup. This is because graphs reduce the Python-to-device communication, and perform some speedups.