Relational Recurrent Neural Networks
I am reading this paper because it was recommended as part of Ilya Sutskever's approx. 30 papers that he recommended to John Carmack to learn what really matters for machine learning / AI today. This paper confirms suspicions that standard memory architectures may stuggle at tasks that involve understanding the ways in which entities are connected and then improves upon the problem by using a new memory module - a Relational Memory Core.
Reference Link to PDF of Paper
Memory-based neural networks model temporal data by leveraging an ability to remember information for long periods. It is unclear, however, whether they also have an ability to perform complex relational reasoning with the information they remember. This paper first confirms the intuition that standard memory architectures may struggle at tasks that heavily involve an understanding of the ways in which entities are connected - i.e., tasks involving relational reasoning. This paper then improves upon these deficits by using a new memory module - a Relational Memory Core (RMC) - which employs multi-head dot product attention to allow memories to interact. This paper then shows improvements in RL tasks, program evaluation, and language modeling.
RNNs like LSTMs, bolstered by augmented memory capabilities, bounded computational costs over time, and an ability to deal with vanishing gradients, learn to correlate events across time to be proficient at storing and retrieving information. This paper proposes that it is fruitful to consider memory interactions along with storage and retrieval. Although current models can learn to compartmentalize and relate distributed, vectorized memories, they are not biased towards doing so explicitly. This paper hypothesizes that such a bias may allow a model to better understand how memories are related, and hence may give it a better capacity for relational reasoning over time. A Relational Memory Core (RMC) uses multi-head dot product attention to allow memories to interact with each other.
Relational Reasoning is the process of understanding the ways in which entities are connected and using this understanding to accomplish some higher order goal. Consider sorting the distances of various trees to a park bench: the relations (distances) between the entities (trees and bench) are compared and contrasted to produce the solution, which could not be reached if one reasoned about the properties (positions) of each individual entity in isolation.
Multi head dot product attention (MHDPA), also known as self-attention, allows memories to interact. Using MHDPA, each memory will attend over all of the other memories, and will update its content based on the attended information.
- A simple linear projection is used to construct queries (), keys () and values () for each memory (row ) in matrix .
- Use the queries, to perform a scaled dot-product attention over the keys .
- The returned scalars are put through a softmax-function to produce a set of weights, which can then be used to return a weighted average of values as , where is the dimensionality of the key vectors used as a scaling factor. Equivalently:
The output of (M), which we will
denote as , is a matrix with the same
dimensionality as , and it can be
interpreted as a proposed update to ,
which each comprising
information from memories .
Thus, in once step of attention memory is updated with information originating from other memories, and it is up to the model to
learn (via parameters ,
and
) how to
shuttle information from memory to memory. As implied by the name, MHDPA uses multiple heads. We implement this
producing
sets of queries, keys, and values, using unique parameters to compute a linear projection from the original memory for each
head .
We then independently apply attention operators for each head. For example, if
is an
dimensional matrix and we employ two attention heads, then we compute
and
, where
and
are
matrices,
and and
denote unique parameters for the linear projects to produce the queries, keys, and values, and
, where :
denotes column-wise concatenation. Intuitively, heads could be useful for letting a memory share different information,
to different targets, using each head.
Comments
You have to be logged in to add a comment
User Comments
There are currently no comments for this article.