Introduction
In recent years, hardware vendors have invented various datatypes and specialized instructions/cores in order to speed up scientific computation tasks such as deep learning. Many approaches use reduced precision datatypes in the computations to increase throughput, while at the same time trying to prevent affecting the convergence behaviour of the final algorithm.
In this tutorial, we discuss the particular datatypes and tensor operations supported on A100 GPUs, how they can be leveraged to increase throughput on deep learning tasks, as well as their potential impact on convergence.
TL;DR: how to get the best performance out of the A100 GPU's
- These tips are largely based on the nvidia cuDNN developer guide.
- These tips are also provided as-is, i.e. they are intended for advanced users.
If you're unsure how these changes may affect your model convergence, we suggest some caution before applying them.
- Use the lowest precision data format possible, this will usually be Half-Precision (FP16)
- Convert your model and batched data to NHWC (2D) or NDHWC (3D) data format.
This is the default in TensorFlow, but requires a few extra lines of code in PyTorch.
TL;DR:model.to(memory_format=torch.channels_last); ...; batch = batch.to(memory_format=torch.channels_last)
;
(model.to(...) is an in-place operation, batch.to(...) is not an in-place operation, thus requiring assignment) - Use one of the following activation functions: {relu, tanh, sigmoid, elu, gelu, softplus, swish}
- When using convolutions:
- Make sure the size of your channel dim equals
0 mod 8
whenever possible - Make sure the size of your channel dim is larger or equal to
32
whenever possible - When using grouped convolutions:
- Make sure the size of your input channel dim is equal to the size of your output channel dim
- Make sure the size of your input channel dim equals of one of the following:
{4,8,16,32}
for cuDNN v8.2 or{1,4,8,16,32,64,128,256}
for cuDNN v8.3
- Make sure the size of your channel dim equals
Datatypes
Various different datatypes will be discussed in this tutorial, so we will introduce them here.
- Double-Precision (FP64): a floating point number represented by 64 bits. According to the IEEE 754 standard, 1 bit is used for the sign, 11 are used for the exponent, and 52 bits are used for the fraction (also known as the mantissa).
- Single-Precision (FP32): a floating point number represented by 32 bits. According to the IEEE 754 standard, 1 bit is used for the sign, 8 are used for the exponent, and 23 are used for the fraction.
- Half-Precision (FP16): a floating point number represented by 16 bits. According to the IEEE 754 standard, 1 bit is used for the sign, 5 are used for the exponent, and 10 are used for the fraction.
- Bfloat16: a floating point number represented by 16 bits. It uses 1 bit for the sign, 8 for the exponent, and 7 for the fraction. Thus, it is a truncated version of FP32 that allows a similar range, but smaller precision.
- TensorFloat32 (TF32): a floating point number represented by 19 bits. It uses 1 bit for the sign, 8 for the exponent, and 10 for the fraction. Thus, it has similar range to FP32 and similar precision to FP16.
- n-bit integer (often denoted as INTn): an integer represented by n bits. Typically, 4, 8, 16, 32, 64 or 128 bits are used, depending on how large the range is that needs to be represented.
- Binary: 1 bit, a 0 or 1.
What are Tensor Cores?
Tensor Cores are cores that specialize in General Matrix-Matrix Multiplication (GEMM) operations, i.e.
D = A × B + C
(where A, B, C and D are matrices), which are at the core of neural network training and inference.
NVIDIA Volta Tensor Cores
Tensor Cores were first introduced in the NVIDIA Volta GPUs, where each tensor core could execute 64 FP16 fused multiply-add operations (FMA) with accumulation in FP32 in a single clock cycle. Thus, Tensor Cores in Volta GPUs were able to perform multiplication of two 4×4 matrices (A and B, both in FP16), add a 4×4 matrix (C, in FP32) and store the result in a 4×4 matrix (D, in FP32) in one clock cycle. The combination of using different numercial precisions became known as mixed precision, and typically refers to this mix of using FP16 for multiplication and FP32 for accumluation.
NVIDIA Ampere Tensor Cores
Ampere Tensor Cores differ from Volta Tensor Cores in two fundamental ways:
- They can operate on larger matrices. E.g. they can execute 256 FP16 FMA operations in a single clock cycle, and thus perform a GEMM operation where A is 8×8 and B, C and D are 8×4.
- They can operate on more datatypes: FP64, TF32, FP16, BF16, INT8, INT4, Binary are all supported as input types (note that FP32 is not supported).
Sparse Matrix Multiply-Accumulate (MMA) operations
The A100 GPUs provide hardware support for MMA operations on matrices that satisfy a very specific sparsity: if out of every 4 (row-wise) elements at most 2 are non-zero, the specific sparse MMA operation can be used to increase the maximum throughput of the operations by a factor of two. At the time of writing, support on the software is limited to the low level cuSPARSELt library, which would allow you to exploit these instructions. Higher level frameworks like PyTorch and TensorFlow do not (yet) appear to support this at the time of writing (December 2021). More on the Sparse MMA can be found in the NVIDIA Ampere whitepaper.
Theoretical performance of A100 GPUs
If you ever had a look at the theoretical performance of A100 GPUs, you might have been confused by how many items the peak performance table lists, and in which cases you might expect which performance. The following table (based on Table 3 of the NVIDIA AMPERE whitepaper) tries to clarify what your expected performance is based on the input accuracy (i.e. datatype of A and B), accumulator accuracy (typically determined by the datatype of C and D) for MMA operations:
Input | Accumulator | Performance | SPARSE MMA performance |
---|---|---|---|
FP64 | FP64 | 19.5 TFLOPS | - |
TF32 | FP32 | 156 TFLOPS | 312 TFLOPS |
FP16 | FP32 | 312 TFLOPS | 624 TFLOPS |
FP16 | FP16 | 312 TFLOPS | 624 TFLOPS |
Bfloat16 | FP32 | 312 TFLOPS | 624 TFLOPS |
INT8 | INT32 | 624 TOPS | 1248 TOPS |
INT4 | INT32 | 1248 TOPS | 2496 TOPS |
Binary | INT32 | 4992 TOPS | - |
Table: Theoretical performance of MMA operations on a single A100 GPU (source: NVIDIA AMPERE whitepaper). TFLOPS: Tera (10^12) floating point operations per second. TOPS: Tera (non-floating point) operations per second.
Of course, not all operations you want to do are MMA operations. For 'normal' floating point (ant integer) arthmatic, the regular gpu cores are used.
Precision | Performance |
---|---|
FP64 | 9.7 TFLOPS |
FP32 | 19.5 TFLOPS |
BF16 | 39 TFLOPS |
FP16 | 78 TFLOPS |
INT32 | 19.5 TOPS |
Table: Theoretical performance of non-MMA operations on a single A100 GPU (source: NVIDIA AMPERE whitepaper).
Real-world performance of A100 GPUs
We use the following benchmark script to illustrate the performance difference between pure FP32, using TensorFloat32 and using mixed precision (i.e. FP16 inputs and FP32 accumulators) to train a network from tf.keras.applications on synthetic data:
Then, we allocate a single A100:
salloc -p gpu -n 1 --ntasks-per-node 1 --gpus 1 --cpus-per-task 18 -t 8:00:00
use ssh to connect to the allocated node, and run the benchmark script with the following environment:
module load 2021 module load TensorFlow/2.6.0-foss-2021a-CUDA-11.3.1 module list export OMP_NUM_THREADS=18 python benchmark.py
Using various values for the model, mixed-prec and disable-tf32 arguments, we run with various precisions and models to construct the following table:
Model | Precision | Throughput (img/s) | Speedup (compared to FP32) | Loss (1st iteration) |
---|---|---|---|---|
ResNet50 | FP32 | 455.9 | 1 | 7.457645893096924 |
ResNet50 | TF32 | 750.6 | 1.65 | 7.456507205963135 |
ResNet50 | FP16 (input) + FP32 (accumulator) | 1087.6 | 2.38 | 7.45703125 |
VGG19 | FP32 | 212.7 | 1 | 6.9077839851379395 |
VGG19 | TF32 | 550.3 | 2.59 | 6.907783508300781 |
VGG19 | FP16 (input) + FP32 (accumulator) | 1099.8 | 5.17 | 6.90625 |
DenseNet121 | FP32 | 391.9 | 1 | 6.96142053604126 |
DenseNet121 | TF32 | 591.5 | 1.51 | 6.961450576782227 |
DenseNet121 | FP16 (input) + FP32 (accumulator) | 876.2 | 2.24 | 6.9609375 |
A few key results to note:
- Speedup of mixed precision or TF32 over tradition FP32 varies per model
- Speedup is much smaller than the theoretical difference in throughput from the tables in the previous section (but still very substantial!)
- Loss is affected by the reduced precision. Note that this is not a problem in itself: as long as the convergence behavior of the training is not affected, the reduced precision is fine. In their published benchmark results, NVIDIA has demonstrated that a large amount of well-known models indeed converge properly using mixed precision.
Training with TensorFloat32
Because TensorFloat32 covers the same range as traditional FP32, training in TF32 can easily be done as a drop-in replacement. In fact, NVIDIA has made the use of TF32 the default for any cuDNN call. Also, both TensorFlow and PyTorch use TF32 by default. TensorFlow will (depending on the verbosity level for the logging you set) also inform you explicitely that it will use FP32, e.g.:
2021-11-30 12:17:30.251449: I tensorflow/stream_executor/cuda/cuda_blas.cc:1760] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
Disabling the use of TensorFloat32 (debugging only)
In some cases, you might want to disable the use of TensorFloat32. For example, if you are debugging convergence issues, and you want to make sure the datatype is not the problem. Or, if you want to compare your convergence behaviour between two machines in order to validate a run, but only one of these machines supports TensorFloat32.
General environment variable
Setting the environment variable NVIDIA_TF32_OVERRIDE=0 before running your code should in principle disable the use of the TensorFloat datatype. All low level CUDA libraries will respect this variable. For higher level framworks that use CUDA libraries as a backend, it may depend on the specific framework.
TensorFlow
TensorFlow does not seem to respect the general NVIDIA_TF32_OVERRIDE variable. To turn off the use of TensorFloat32 by TensorFlow, you'll explicitely have to disable it by calling
tf.config.experimental.enable_tensor_float_32_execution(False)
in your code. See here.
PyTorch
PyTorch does respect the NVIDIA_TF32_OVERRIDE environment variable. However, you can also turn it off explicitely in your code using
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True. torch.backends.cuda.matmul.allow_tf32 = False # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = False
Training with Mixed Precision
Training with mixed precision is less trivial than using TensorFloat32. The main reason is that the range of values that FP16 can represent is smaller. This can be a problem particularly with small gradients, which may fall below the FP16 representable range (for more information, see the histograms in NVIDIA's documentation on mixed precision training). This can be solved by so-called loss scaling. Essentially, in loss scaling, losses are multiplied by a factor 'S' in between the forward and backward propagation steps. Then, after the backward propagation, the weight gradient is multiplied by 1/S before doing the weight update.
Of course, this procedure can be done manually, but many frameworks support some form of automatic loss scaling. Below, we summarize the key parts of the TensorFlow and PyTorch documentation on mixed precision training, but we encourage you to read their respective use manual sections to get a full picture.
TensorFlow
The official documentation contains an extensive section on using mixed precision in TensorFlow.
To enable mixed precision, you have to set the global policy:
from tensorflow.keras import mixed_precision mixed_precision.set_global_policy('mixed_float16')
If you train with the using tf.keras.Model.fit API, that's all you need to do: this API automatically performs loss scaling if te 'mixed_float16' policy is set. If however you implement a custom training loop (like in our benchmark example above), you have to wrap the optimizer in the tf.keras.mixed_precision.LossScaleOptimizer class like so:
# Any keras optimizer, use RMSprop as example: optimizer = keras.optimizers.RMSprop() # Wrap in LossScaleOptimizer to perform loss scaling optimizer = mixed_precision.LossScaleOptimizer(optimizer)
If you want, you can specify an explicit loss scale, but it is recommended to keep the default loss scaling behavior of this optimizer. Finally, you have to insert the scaling step after calculating the loss, compute the gradients on the scaled loss, and then get the unscaled gradients:
@tf.function def train_step(x, y): with tf.GradientTape() as tape: predictions = model(x) loss = loss_object(y, predictions) # Scale loss: scaled_loss = optimizer.get_scaled_loss(loss) # Compute gradients on scaled loss: scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables) # Invert the scaling before applying the gradients gradients = optimizer.get_unscaled_gradients(scaled_gradients) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss
PyTorch
PyTorch has an Automatic Mixed Precision package (AMP). The official documentation of the API can be found here, but using mixed precision in PyTorch is explained more extensively in their PyTorch recipe section.
Typically, automatic mixed precision training uses torch.cuda.amp.autocast together with torch.cuda.amp.GradScaler. The first, torch.cuda.amp.autocast, ensures that operations run in an op-specific dtype which is determined by autocast. It aims to select the dtype such that FP16 is used for inputs if the use of Tensor Cores is expected to be faster for that op. The second, torch.cuda.amp.GradScaler aims to automatically scale the gradients to prevent underflow. Alltogether, your code would typically look like this:
use_amp = True net = make_model(in_size, out_size, num_layers) opt = torch.optim.SGD(net.parameters(), lr=0.001) scaler = torch.cuda.amp.GradScaler(enabled=use_amp) for epoch in range(epochs): for input, target in zip(data, targets): # This context manager makes sure the dtypes in 'net' are set to support mixed precision Tensor Core operations as much as possible with torch.cuda.amp.autocast(enabled=use_amp): output = net(input) loss = loss_fn(output, target) # scaler.scale(loss) returns scaled losses, before the backward() is called scaler.scale(loss).backward() scaler.step(opt) scaler.update() opt.zero_grad() # set_to_none=True here can modestly improve performance
If you want to inspect or modify gradients (e.g. clipping), this requires you to unscale the gradients in between the backward() and the step(...) calls. See the official documentation for details.
Which datatype should I use?
There is little reason not to use TensorFloat32: it is (much) faster than using FP32, and since the range is similar, it does not require things like loss scaling. Therefore, no code changes are needed. The only reason not to use it would be the reduced precision of the fraction, which in theory could affect convergence behavior. Practical experience so far has shown that convergence behaviour with TensorFloat32 for deep learning is generally not altered (have you ever wondered if FP32 was precise enough to make your training converge?). Only in cases where you experience issues with convergence could you try to see if disabling it helps - but even if it does, there were probably other steps you could take to make your training more stable that would have a smaller effect on training speed.
Training in mixed precision is more involved. It requires some code changes (though frameworks have automated a lot for you) and you'll have to think carefully about where you inspect/modify gradients. It does provide a substantially larger speedup than using TF32. This can be particularly important for compute intensive tasks, such as hyperparameter tuning. It is therefore worth to try using mixed precision. Mixed precision has succesfully been used to train a large number of well known networks to proper convergence. If convergence does prove to be an issue for your particular task, switch back to TF32 and see if that helps. If it did, check your loss scaling code, inspect scaled/unscaled losses, and verify that nothing gets clipped due to underflow.
Sources:
nvidia deep learning performance tutorial
nvidia tensor core performance tutorial