Memory Networks

 

A few months back, the Allen Institute of Artificial Intelligence  organized an AI Question Answering Kaggle challenge. The challenge involved answering 8th Grade multiple choice questions on Physics, Chemistry and Biology. I was part of the lab team competing in the challenge. The questions were pretty hard; for answering the questions, the system would need background knowledge in Sciences, ability to infer statements from known facts, and apply general facts to specific examples, e.g., it would need to understand that nails conduct electricity since iron conducts electricity and nails are made of iron. The winning entry could answer only about 60% of the questions correctly, and that too, I suspect, using information retrieval. Deep Learning, in my opinion, is not yet developed enough to tackle this hard problem, but it has takes a few steps in this direction with Memory Networks.

Memory Networks

The task involves answering questions from the simple synthetic  bAbI dataset. Each question comes with a set of facts, a subset of which are useful for answering the question. For example,  ‘Mary moved to the bathroom’ ‘John went to the hallway’ are Facts. ‘Where is Mary ?’ is the question and the answer is ‘bathroom’. Out of the two facts, ‘Mary moved to the bathroom’ is the supportive fact since it is required for answering the question. The neural network needs to understand the question, and use the supporting facts to answer the question. Memory Network is the name of the  machine learning architecture first proposed by Weston et. al  for addressing this task.

The memory network consists of a memory array and additionally has four parts, the input feature map, the generalization component which updates old memories, the output feature map, and the response which converts the output feature map to text.

Implementation Details

Weston et. al use a simple neural network embedding model for implementing the memory network. The input layer simply consists of plain text facts stored in an array. They do not use the generalization module, and most action happens in the output and response component. The output module computes features by searching for k supporting facts in the memory. It does this by computing the maximum score between question and memories m_{i}. The facts which are ordered through time are indexed by i.

For k = 1, \ O_1(x,\boldsymbol{m}) \ = \ \underset{i = 1 \cdots n}{arg \ max} \ s(x,m_i)

where s(x,y) = \phi_x^TU^TU\phi_y and U is a D \times n embedding matrix, while \phi is the one hot encoded bag of words vector. It is basically adding the D dimensional embeddings (vectors) for all words in x and then taking a dot product with the corresponding embedding vector obtained by summing the embeddings for all words in y. In other words, it is obtaining the memory which produces the largest dot product (similarity) with the question embedding.

For k > 1, just add the location with the highest dot product to the list of questions i.e., [x,m_{o1}] and again find the memory with max dot product with all the elements in the list.

O_2([x,m_{o1}],\boldsymbol{m}) \ = \  \underset{i = 1\cdots n}{arg \ max} \ s([x,m_{o1}],m_i)

The response module outputs a single word which maximizes the score (dot product) between the vocabulary words and the list.

r \ = \  \underset{w \in W}{arg \ max} \ s(w,[x,m_{o1},m_{02})

An RNN can be used, if instead of a single word output, a phrase or sentence is expected.

Training

Margin loss is used for learning the embedding matrix $U$, specifically the model is trained so that the score between the correct memory and question is more than the score between the false memory and the question by margin \gamma

The loss function is given by

\sum_{f_1 \neq m_{o1}} max(0, \gamma - s(x,m_{o1}) + s(x,f_1))  \\ + \sum_{f_2 \neq m_{o2}} max(0, \gamma - s([x,m_{o1}],m_{o2}) + s([x,m_{o1}],f_2))  \\ + \sum_{f_3 \neq m_{o1}} max(0, \gamma - s([x,m_{o1},m_{o2}],r) + s([x,m_{o1},m_{o2}],f_3))

If the number of facts is too large, one may need to first narrow down the set using a reverse index (information retrieval model), or by first clustering and then selecting the appropriate cluster.

As seen in the objective function, the algorithm requires knowledge of the supporting memories which may not be available in a real world setting. Learning without this knowledge is termed as weak learning. End-to-End Memory Networks (Sukhbataar et al) attempts to address this problem by replacing the strong supervision, by weighting the memories and learning those weights using end-to-end training.

A straightforward implementation of the memory network using the Keras library is available here. Note that it uses an LSTM for generating the answer instead of the dot product. In subsequent posts we will explore the use of attention models for improving memory networks.