πŸ‘ Attention

πŸ‘ Attention

Core Idea

The main assumption in sequence modelling networks such as RNNs, LSTMs and GRUs is that the current state holds information for the whole of input seen so far. Hence the final state of a RNN after reading the whole input sequence should contain complete information about that sequence. But this seems to be too strong a condition and too much to ask.

Image for post

Attention mechanism relax this assumption and proposes that we should look at the hidden states corresponding to the whole input sequence in order to make any prediction.

Details

The architecture of attention mechanism:

Image for post

The network is shown in a state:

  • the encoder (lower part of the figure) has computed the hidden states h_jh\_j corresponding to each input X_jX\_j
  • the decoder (top part of the figure) has run for tβˆ’1t-1 steps and is now going to produce output for time step tt.

The whole process can be divided into four steps:

  1. Encoding
  2. Computing Attention Weights/Alignment
  3. Creating context vector
  4. Decoding/Translation

Encoding

Image for post
  • (X_1,X_2,…,X_T)(X\_1, X\_2, \dots, X\_T): Input sequence

    • TT: Length of sequence
  • (hβ†’_1,hβ†’_2,…,hβ†’_T)(\overrightarrow{h}\_{1}, \overrightarrow{h}\_{2}, \dots, \overrightarrow{h}\_{T}): Hidden state of the forward RNN

  • (h←_1,h←_2,…h←_T)(\overleftarrow{h}\_{1}, \overleftarrow{h}\_{2}, \ldots \overleftarrow{h}\_{T}): Hidden state of the backward RNN

  • The hidden state for the jj-th input h_jh\_j is the concatenation of jj-th hidden states of forward and backward RNNs.

    h_j=[hβ†’_j;h←_j],βˆ€j∈[1,T] h\_{j}=\left[\overrightarrow{h}\_{j} ; \overleftarrow{h}\_{j}\right], \quad \forall j \in[1, T]

Computing Attention Weights/Alignment

Image for post

At each time step tt of the decoder, the amount of attention to be paid to the hidden encoder unit h_jh\_j is denoted by Ξ±tj\alpha_{tj} and calculated as a function of both h_jh\_j and previous hidden state of decoder s_tβˆ’1s\_{t-1}:

e_tj=a(h_j,s_tβˆ’1),βˆ€j∈[1,T]Ξ±tj=exp⁑(e_tj)βˆ‘k=1Texp⁑(e_tk) \begin{array}{l} e\_{t j}=\boldsymbol{a}\left(h\_{j}, s\_{t-1}\right), \forall j \in[1, T] \\\\ \\\\ \alpha_{t j}=\frac{\displaystyle \exp \left(e\_{t j}\right)}{\displaystyle \sum_{k=1}^{T} \exp \left(e\_{t k}\right)} \end{array}
  • a(β‹…)\boldsymbol{a}(\cdot): parametrized as a feedforward neural network that runs for all jj at the decoding time step tt
  • Ξ±_tj∈[0,1]\alpha\_{tj} \in [0, 1]
  • βˆ‘_jΞ±_tj=1\displaystyle \sum\_j \alpha\_{tj} = 1
  • Ξ±_tj\alpha\_{tj} can be visualized as the attention paid by decoder at time step tt to the hidden ecncoder unit h_jh\_j

Computing Context Vector

Image for post

Now we compute the context vector. The context vector is simply a linear combination of the hidden weights h_jh\_j weighted by the attention values Ξ±tj\alpha_{tj} that we’ve computed in the precdeing step:

c_t=βˆ‘_j=1TΞ±_tjh_j c\_t = \sum\_{j=1}^T \alpha\_{tj}h\_j

From the equation we can see that Ξ±tj\alpha_{tj} determines how much h_jh\_j affects the context c_tc\_t. The higher the value, the higher the impact of h_jh\_j on the context for time tt.

Decoding/Translation

Image for post

Compute the new hidden state s_ts\_t using

  • the context vector c_tc\_t
  • the previous hidden state of the decoder s_tβˆ’1s\_{t-1}
  • the previous output y_tβˆ’1y\_{t-1}
s_t=f(s_tβˆ’1,y_tβˆ’1,c_t) s\_{t}=f\left(s\_{t-1}, y\_{t-1}, c\_{t}\right)

The output at time step tt is

p(y_t∣y_1,y_2,…y_tβˆ’1,x)=g(y_tβˆ’1,s_t,c_i) p\left(y\_{t} \mid y\_{1}, y\_{2}, \ldots y\_{t-1}, x\right)=g\left(y\_{t-1}, s\_{t}, c\_{i}\right)
In the paper, authors have used a GRU cell for ff and a similar function for gg.

Reference