๐Ÿ‘ Attention

Core Idea

The main assumption in sequence modelling networks such as RNNs, LSTMs and GRUs is that the current state holds information for the whole of input seen so far. Hence the final state of a RNN after reading the whole input sequence should contain complete information about that sequence. But this seems to be too strong a condition and too much to ask.

Image for post

Attention mechanism relax this assumption and proposes that we should look at the hidden states corresponding to the whole input sequence in order to make any prediction.

Details

The architecture of attention mechanism:

Image for post

The network is shown in a state:

  • the encoder (lower part of the figure) has computed the hidden states $h_j$ corresponding to each input $X_j$
  • the decoder (top part of the figure) has run for $t-1$ steps and is now going to produce output for time step $t$.

The whole process can be divided into four steps:

  1. Encoding
  2. Computing Attention Weights/Alignment
  3. Creating context vector
  4. Decoding/Translation

Encoding

Image for post
  • $(X_1, X_2, \dots, X_T)$: Input sequence

    • $T$: Length of sequence
  • $(\overrightarrow{h}_{1}, \overrightarrow{h}_{2}, \dots, \overrightarrow{h}_{T})$: Hidden state of the forward RNN

  • $(\overleftarrow{h}_{1}, \overleftarrow{h}_{2}, \ldots \overleftarrow{h}_{T})$: Hidden state of the backward RNN

  • The hidden state for the $j$-th input $h_j$ is the concatenation of $j$-th hidden states of forward and backward RNNs.

    $$ h_{j}=\left[\overrightarrow{h}_{j} ; \overleftarrow{h}_{j}\right], \quad \forall j \in[1, T] $$

Computing Attention Weights/Alignment

Image for post

At each time step $t$ of the decoder, the amount of attention to be paid to the hidden encoder unit $h_j$ is denoted by $\alpha_{tj}$ and calculated as a function of both $h_j$ and previous hidden state of decoder $s_{t-1}$: $$ \begin{array}{l} e_{t j}=\boldsymbol{a}\left(h_{j}, s_{t-1}\right), \forall j \in[1, T] \\ \\ \alpha_{t j}=\frac{\displaystyle \exp \left(e_{t j}\right)}{\displaystyle \sum_{k=1}^{T} \exp \left(e_{t k}\right)} \end{array} $$

  • $\boldsymbol{a}(\cdot)$: parametrized as a feedforward neural network that runs for all $j$ at the decoding time step $t$
  • $\alpha_{tj} \in [0, 1]$
  • $\displaystyle \sum_j \alpha_{tj} = 1$
  • $\alpha_{tj}$ can be visualized as the attention paid by decoder at time step $t$ to the hidden ecncoder unit $h_j$

Computing Context Vector

Image for post

Now we compute the context vector. The context vector is simply a linear combination of the hidden weights $h_j$ weighted by the attention values $\alpha_{tj}$ that we’ve computed in the precdeing step: $$ c_t = \sum_{j=1}^T \alpha_{tj}h_j $$ From the equation we can see that $\alpha_{tj}$ determines how much $h_j$ affects the context $c_t$. The higher the value, the higher the impact of $h_j$ on the context for time $t$.

Decoding/Translation

Image for post

Compute the new hidden state $s_t$ using

  • the context vector $c_t$
  • the previous hidden state of the decoder $s_{t-1}$
  • the previous output $y_{t-1}$

$$ s_{t}=f\left(s_{t-1}, y_{t-1}, c_{t}\right) $$

The output at time step $t$ is $$ p\left(y_{t} \mid y_{1}, y_{2}, \ldots y_{t-1}, x\right)=g\left(y_{t-1}, s_{t}, c_{i}\right) $$

In the paper, authors have used a GRU cell for $f$ and a similar function for $g$.

Reference

Previous
Next