How Transformer works in Machine Translation (with code)
Transformers are models that process sequence data and output sequence data. They are used in many NLP tasks, such as machine translation, text summarization, and question answering. In this post, we will focus on how Transformer works in machine translation. We will also break down the code of a Transformer model in PyTorch: nn.Transformer
.
Structure of Transformer
Transformer models are based on the paper Attention Is All You Need by Vaswani et al. (2017). The authors proposed a new architecture called Transformer, which is based solely on attention mechanisms. The Transformer model is composed of an encoder and a decoder. In Machine Translation, the encoder takes the source language sentence as input and outputs a vector representation of the sentence.
The image above shows the structure of the Transformer model.
Encoder
After the input words, let’s say 10 words are tokenized into 10 tokens, for example, ['I', 'am', 'a', 'student', 'studying', 'Electrical', 'and', 'Computer', 'Engineering', '.']
, then, they are converted into 10 integers by a tokenizer, for example,
$[21, 42, 33, 14, 65, 26, 57, 28, 19, 40]$ Then, they are passed through the embedding layer, which is a simple linear layer to form a
$[10, 512]$ matrix.
Then, the output of the embedding layer is passed to the positional encoding layer. The positional encoding layer adds positional information to the input sequence. The output of the positional encoding layer is then passed to the encoder layers.
For the Encoder nn.TransformerEncoder
, it is composed of N nn.TransformerEncoderLayer
, typically 6. Each nn.TransformerEncoderLayer
attends to the input from the previous layer with self-attention layer, followed by layer norm, a feed forward layer, and another layer norm. In the code, they correspond to self-attention layer self.self_attn
and the feed-forward layer which is a simple combination of two linear layers self.linear1
and self.linear2
.
class TransformerEncoderLayer(Module):
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.activation = activation
The decoder consists of N nn.TransformerDecoderLayer
, typically 6. Each nn.TransformerDecoderLayer
has three sub-layers. The first is a multi-head self-attention layer, the second is a multi-head attention layer, and the third is a simple, position-wise fully connected feed-forward network.
In nn.TransformerEncoderLayer
, the input src
passes through the self-attention layer _sa_block
and the feed-forward layer _ff_block
.
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
'''
Only keeping the essential code here
'''
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
Decoder
Compared to the Encoder, the Decoder has one extra attention head, which is called cross-attention. The cross-attention head attends to the output of the encoder. Which in essence, means that the output of the decoder is dependent on the output of the encoder. You can think of it as the decoder is trying to learn the translation of the input sentence by attending to the output of the encoder.
Similar to the encoder, the decoder nn.TransformerDecoder
is also a stack of N nn.TransformerDecoderLayer
. The difference is that except for the self-attention layer self.self_attn
, the decoder also has a cross-attention layer self.multihead_attn
.
class TransformerDecoderLayer(Module):
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
device=None, dtype=None) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
In the decoder, the output of the encoder is passed to the cross-attention layer _mha_block
and the feed-forward layer _ff_block
.
# multihead attention block
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout2(x)
References: