Build state-of-the-art faster NLP models in Tensorflow 2


Inside the world of NLP, we have seen so much progress in the last few years. And we can easily divide it as Pre-BERT era and Post-BERT era. Transformer based models have been dominating this industry from last 3 years . There are quite a bunch of variations of transformer based models to solve some drawbacks, still the core idea remains the same. From 124M parameter models to 13B parameter models, the progress is rapid and fast. These progress pose us with the challenge of how these huge models can be used in production especially for startups and medium-tier companies. This is where optimisation and clever engineering comes into picture.

With the release of huggingface, things are become quite accessible and easy for the normal users. With the aim to democratize NLP, huggingface is one of the greatest thing that happens to NLP users. I being an avid NLP user, find huggingface great, but the support for Tensorflow is minimal and that insists me to dig deeper and find where things are not going well with Tensorflow. I have been working on this from last 1 year, on my personal time.

The idea of modifying huggingface source code was not possible, because making TF models serializable and at the same time, making it generalizable was a hard task. But, with few tricks and some compromises, tf-transformers can be used to solve almost al NLP problems.

tf-transformers is Tensorflow with Nitro-boost.


We have benchmarked tf-transformers for text-generation using T5 and GPT2 models. Text-generation using transformer models requires efficient caching and serializng the whole model requires much of an effort. The comparison with huggingface Tensorflow implementation shows that, tf-transformers are 80% faster (relative difference) on text generation tasks and even faster than official Google implementations.

columns GREEDY — (batch_size, sequence_length)

Greedy Decode (GPT2 — few samples)
Greedy (Full comparison)

columns BEAM — (batch_size, sequence_length, beam_size)

Beam Decode (GPT2)

columns — (batch_size, sequence_length, num_return_sequences)

Beam ( Full Comparison — Refer Github for HQ image )
Top-K-Top-p (GPT2)
Top-K-Top-P (Full Comparison — Refer Github for HQ image)

For full benchmark results and code, please refer github. tf-transformers surpasses huggingface transformers in all experiments. When comparing to PyTorch, tf-transformers is faster ( 179 / 220 ) experiments, but not by a huge margin though. Similar results holds for T5 models also. All the experiments are run on V100 GPU.


tf-transformers support tflite conversions. All the operations inside tf-transformers are designed to be compatible with tflite ops. As of now TFlite supports all BERT based models , T5, mt5 for tflite conversion. The support is available for all tasks including QA, NER, Classification except text-generation ( Auto-Regressive ) tasks. Refer tutorials.

Variable Batch Decoding

This is another unique feature tf-transformers have, which at the time of writing ( 6 months back ) was not a part of any libraries. Batch decoding ( text generation ) is very fast comparing to decode each examples individually. But, doing it for Encoder only models like GPT2, BERT, ROBERTA, ELECTRA is bit tricky. Because of different length of each examples in a batch and account these lengths while caching is not straightforward to implement in practice.

tf-transformers support variable batch_size decoding even in serialized models. This makes the decoding even faster.

Multiple Mask Mode

Transformer models born with masking and its very significant to the model performance. Models like GPT2 are pure LM, where future word prediction ( at the time of training ) depends only on the past. But BERT based models are MLM ( Masked Language Models ), where bi-directional context is inevitable. By changing the mask, we can change any models behaviour, for eg: we can change GPT2 to have bi-directional context, with the mask value.

There are 3 types of mask_mode values. causal, user_defined, prefix . By default GPT2 has causal masking. Just by changing it to prefix we can use it for text generation tasks like summarisation, where its always better to have bi-directional context. For, MLM user_defined masking should be used. This can be done by changing one argument while initializing the model.

Fast Sentence Piece Alignment

tf-transformers have fast setence piece alignment, to solves tasks like QA, NER, where keep tracking of sub word positions are critical and important. tf-transformers fast-alignment is approximate, but its much faster compared to LCS method (which is also approximate).

LCS method for Squad v1 training examples take 2300 seconds, where fast-alignment takes only ~300 seconds.

Encoder Decoder Models

Any models can be converted as Seq2Seq style models. In transformer, they are encoder-decoder models. For eg: we can use BERT as Encoder and GPT2 as Decoder. So, any encoder model can also behave as decoder, with minute changes. LegacyAI has complete support of serialization of the whole model. Both encoder and decoder will be converted into a single saved_model.

Any Encoder Model ( BERT, GPT2 etc ) can be converted into decoder mode to have extra cross attention layers with just few keyword arguments. If the encoder hidden state is different from the decoder hidden states, it will automatically projected with a random layer, which can be fine-tuned together.

Keras + model.compile2

All the models are trainable using keras . compile2 is an extra feature tf-transformers have to avoid unnecessary flattening of data internally. What it means is that, you can pass outputs in raw format to keras, so that you have complete control of loss functions. This makes it trainable using GPU, multiGPU, TPU pods.

There is a custom trainer, if we are using only single GPU machine.

Note: compile2 doesnt supports metrics yet

Super Fast Decoders via Serialization

tf-transformers support complete serialization of models. Auto-regressive tasks are so seamless in tf-transformers. All the necessary caching of models, whether it is Encoder or Encoder-Decoder models (T5, BART etc) can be done in 2 lines. Refer tutorials.

Write/Read/Process TFRecords in ~5 lines

One of the amazing feature of Tensorflow is TFRecords. And its very important for training models. The support for writing TFRecords is very easy and straighforward in tf-transformers. All you have to do is, pass your data function to TFWriter. Similarily TFReader reads the data.

TFProcessor is useful when you want to evaluate models performance on dev set, where shuffling is not required. This has complete support for tf.ragged tensors. All these has out of the box auto-batching support which works in ~90% cases.

HuggingFace Converters

As HF transformers is the most widely used NLP library, tf-transformers allows you to convert HF transformer (tensorflow) models to tf-transformer models in just one line.


tf-transformer has pipelines, which support Span Extraction (QA) , Classification, Token Classification (NER) and text-generation tasks. All pipelines can easily be customized with few lines and we allow users to pass custom tokenization function, to preprocess inputs in the way user wants.

tf-transformers supports BERT, Albert, RoBERTA, T5, mt5, GPT2. There is still a wide gap between normal tutorials and industry based application sin NLP. Primary focus of tf-transformers tis to bridge that gap without compromising speed and ease of use.


In production, we are using only serialized models, which is not only faster but also avoid the overhead of having model architecture code while deploying.

Tutorials will be following up on Github( release .