Comprehensive Guide on Seq2Seq
Start your free 7-days trial now!
What is Seq2Seq?
Seq2Seq is a variant of the artificial neural network that maps a sequence in one domain to a different sequence in another domain. For instance, Seq2Seq can be used to map the following:
"A cat sat on a mat" -> "猫がマットに座った"
In this way, Seq2Seq is a popular model when you need to generate some text as output.
Anatomy of Seq2Seq model
The following is a high-level diagram of the Seq2Seq model:
The roles are as follows:
Encoder: processes the input sequence and aims to pack as much information as possible into a single vector called the context vector.
Context vector: a vector that encapsulates the "meaning" of the input sequence. This is passed on to the decoder.
Decoder: Uses the context vector to make predictions token by token.
Encoder
The encoder processes the input sequence token by token, and aims to pack as much information as possible into a single vector called the context vector, denoted by $\mathbf{h}$. Schematically, the a general encoder would look like the following:
Here, the output of the LSTM at each time step is unimportant to us, that is, we just need h and c to be passed onto the next layer at each time step. Therefore, in the encoder, there is no cost function to compute - this happens in the decoder component.
Decoder
The underlying structure of the decoder is strikingly similar to that of the standard LSTM network. The main difference is that, in the first time-step, the hidden state vector $\mathbf{h}$ generated by the encoder is passed into the LSTM:
This seemingly minor difference is what makes this LSTM network capable of parsing sequences in a more robust manner. Notice how for the first time-step, we have a special input <start>
, which is used to inform the decoder to begin decoding. Similarly, for the final time-step, we have another special input <end>
, which is used to instruct the decoder to end decoding. In some literature, these special inputs are denoted as <go>
and <eos>
.
We know that the standard LSTM network passes two different vectors - hidden state vector as well as cell vector - to the next time-step. However, in the seq2seq model, the encoder only passes the hidden state vector to the decoder, and discards the cell vector.
Padding
When we consider batch processing, we often construct a matrix that holds a random sample of data items from our original dataset. For instance, consider the following dataset:
Height | Weight | Target |
---|---|---|
180 | 80 | Male |
160 | 55 | Female |
170 | 70 | Male |
Randomly sampling from this would give us a matrix, and we can then simply feed this into our neural network.
However, Seq2Seq deals with inputs that are of variable length. The consequence of this is that we are unable to construct a matrix since some inputs would have less "features" than others. As an example, consider the toy-problem of building a Seq2Seq network to perform a simple arithmetic addition. The inputs and outputs would be as follows:
Notice how both the inputs and outputs have varying length. In order to ensure that we can represent batches in a matrix, we need to perform an operation called padding. Note that the _
is used to instruct the decoder to begin decoding.
Since we are dealing with sequences, the training dataset is usually given in the following form:
5+4 _962+45_1078+27 _35
Notice how we have filled the input sequence with spaces as a means of performing padding.
Since these paddings are not inherently part of the dataset, we do not want our Seq2Seq model to be affected by them. Remember, we perform padding for the sole purpose of constructing a batch matrix. In order to disregard the padded spaces, we need to slightly adjust the Seq2Seq model like so:
When the encoder receives a padded item, then pass the output of the previous time-step.
When the decoder receives a padded item, then ensure that they are not involved when computing the loss.
Attention
The problem with the original Seq2Seq model is that the hidden state vector $\mathbf{h}$ is fixed in size. This means that no matter how long the sequence of the input is, the underlying information is forcefully packed into a fixed size. For instance, the sentence "hello there" and "you say goodbye and I say hello" will both be represented by the same vector size. It is more natural for longer sequences to be represented by a vector of larger size.
The way to handle this is by packing the output of the LSTM layer at each time-step into one hidden state matrix:
Let us call this matrix $\mathbf{hs}$ - the number of rows of the $\mathbf{hs}$ would be equal to the number of words in the sequence. In this way, longer sequences would be represented by a larger matrix.
Schematic diagram
The schematic diagram of Seq2Seq with attention looks like the following:
The main difference the standard Seq2Seq and Seq2Seq with attention is that we have a new layer called Attention in the decoder. The $\mathbf{hs}$ is passed into the attention layer. The purpose of the attention layer is to relate equal or similar tokens together - for instance, we know that "You" and "あなた" are semantically equivalent. Ideally then, we want our Seq2Seq model to automatically learn these tight relationships between the input token and the output token.
Attention is named as such because we want the Seq2Seq model to be attentive to the important information when performing translation.
Our goal is to select or pick out the vector that corresponds to the input token "You". However, the problem with this is that the task of selecting a vector out of multiple vectors is not differentiable. This means that back-propagation is no longer possible. The way to overcome this problem is by "selecting" all the vectors instead of selecting just one - we can do this via a weight vector $\mathbf{a}$.
Each element in $\mathbf{a}$ represents the importance of the corresponding input token. For instance, in the above diagram, the word You has the highest weight, which means that the input token "You" corresponds to the output token "あなた" the most. You can also think of a as a discrete probability mass function since the the elements of $\mathbf{a}$ sum up to one.
The context vector is computed as follows:
Since the weight of the input token "You"
was 0.7, you can think of the context vector as holding more information about the token "You"
. The context vector holds all the necessary (ideally) information needed to perform translation for the current time-step.
We now know that in order to compute the context vector $\mathbf{c}$, we need to have a well-defined $\mathbf{a}$, which represents the weight of each token.
We want to compare the hidden state vector $\mathbf{h}$ to each row of $\mathbf{hs}$ and compute how similar they are. One simply way of doing this is to compute the dot product (see how similar the direction)