Understanding RAG (with math symbols)
January 10, 2025•1,075 words
Better rendered version hosted on Github: link
TLDR: it is a finite mixture model with a bit of handwaving.
Retrieval-Augmented Generation (RAG) is a popular technique to improve the quality of generated texts from language models. According to OpenAI (link), RAG can be quite promising if you have additional documents to provide as contexts relevant to your question asked to a language model.
There are many online tutorials on RAG, e.g. by OpenAI, LangChain and NVDIA, most of which use flowcharts/diagrams and sample code snippets. These tutorials are useful for practical purposes, but they don't reveal what is actually going on behind the scenes.
I read the original paper by Lewis et al. (2020) who first experimented with RAG. Below are my takeaways and thoughts. This note mainly covers Section 2 (Methods) of the paper.
Suppose we have a generative language model with parameters $$\theta$$. A question encoded as $$\bm{x}$$ is asked, and the model generates an answer encoded as $$\bm{y}$$ (with length $$N$$) based on the following distribution.
$$
p\theta(\bm{y} | \bm{x}) = \prodiN p\theta(yi| \bm{x}, \bm{y}_{1:i-1})
$$
In other words, the next token generated depends on the model, the question, and all preceding tokens.
The generated answer $$\bm{y}$$ might not be satisfactory if the model parameters $$\theta$$ are trained on some general knowledge but the question $$\bm{x}$$ is domain-specific (e.g. medicine, law, finance).
Now also suppose we have a collection of domain-specific passages (which could be documents or chunks of them) encoded as $${ \bm{z}j }{1 \leq j \leq M}$$. We could consider finetuning the model with these domain-specific passages so that the updated parameters, say $$\theta'$$, can incorporate domain-specific knowledge, hence the generated answer $$p_{\theta'}(\bm{y} | \bm{x})$$ will hopefully be of higher quality.
RAG is an alternative to the potentially costly process of model fine-tuning in practice. Instead of updating the parameters $$\theta$$, we search for a subset of $$K$$ domain-specific passages $${ \bm{z}j }{1 \leq j \leq K} $$ that are most relevant to the question $$\bm{x}$$. Then, we augment a potentially better answer by incorporating such relevant information.
Effectively, the passages $${ \bm{z}j }{1 \leq j \leq M}$$ can be viewed as latent variables. By applying properties of conditional probabilities,
$$
p{\textrm{RAG}}(\bm{y} | \bm{x}) = \sum{j=1}{M} p{\phi}(\bm{y} | \bm{x}, \bm{z}j) p{\eta}(\bm{z}j | \bm{x}) \approx \sum{j=1}{K} p{\phi}(\bm{y} | \bm{x}, \bm{z}j) p{\eta}(\bm{z}_j | \bm{x})
$$
where $$ p{\eta}(\bm{z}j | \bm{x}) $$ is called the retriever and $$p{\phi}(\bm{y} | \bm{x}, \bm{z}j)$$ the generator.
Before diving into the individual terms, notice that the number of retrieved passages $$K$$ can be equal to the total number of domain-specific passages $$M$$, in which case the second step of approximation is not needed. In other words, we are leveraging all available additional information for generating the answer. However, this is typically not the case in practice because $$M$$ can be huge (e.g. the authors used a 2018 snapshot of Wikipedia and divided the texts into disjoint chunks of 100 words).
The retriever $$ p{\eta}(\bm{z}j | \bm{x}) $$ serves to measure the relevance (i.e. some kind of distance) between the question $$\bm{x}$$ and each passage $$\bm{z}_j$$. The authors based their retriever on the Dense Passage Retrieval in Karphkhin et al. (2020) such that
$$
p{\eta}(\bm{z}j | \bm{x}) \propto \exp(\bm{d}(\bm{z}_j)T \bm{q}(\bm{x}))
$$
where $$\bm{d}(\bm{z}j)$$ and $$\bm{q}(\bm{x})$$ are representations of $$\bm{z}j$$ and $$\bm{x}$$, respectively, based on the same encoder model with parameters $$\eta$$. The authors note that searching for the top $$K$$ most relevant passages can be done in approximately sub-linear time (see also MIPS), which is feasible for practical applications.
The generator $$p{\phi}(\bm{y} | \bm{x}, \bm{z}j)$$ is effectively a language model with parameters $$\phi$$ that generates the answer $$\bm{y}$$ based on both $$\bm{z}j$$ and $$\bm{x}$$. Technically, it can be any language model and, ideally, one that is (somehow) fine-tuned specifically on the collection of domain-specific documents $${ \bm{z}j }_{1 \leq j \leq M}$$.
For the generator, I find two hand-wavy things (one in the paper, and one in practice): sequence concatenation and model training/finetuning.
Sequence Concatenation
Most generative language models are sequence-to-sequence, that is, the model input is one single sequence and so is the model output. The authors use a simple concatenation of $$\bm{x}$$ and $$\bm{z}j$$ to for this single sequence input, so $$p{\phi}(\bm{y} | \bm{x}, \bm{z}j)$$ is basically just $$p{\phi}(\bm{y} | {\bm{x}, \bm{z}_j})$$ if we use $${$$ and $$}$$ to denote concatenation.
I understand why this choice is made from a technical perspective, and maybe there will be better ways to combine $$\bm{x}$$ and $$\bm{z}j$$ in the future. However, simple concatenation might not make sense from a layperson's perspective, because the retrieved passage $$\bm{z}j$$ may not necessarily make sense to a real human (e.g. it might be a broken-up paragraph, chunked partial tables from a document, or highly irrelevant contents).
One might wonder why we need to append useless information when asking questions to the model. With a bit of understanding of finite mixture model, I would expect (or hope?) that irrelevant/useless retrievals get assigned rather low weights by the retriever $$ p{\eta}(\bm{z}j | \bm{x}) $$, thus contributing very little to the combined result generated from $$p_{\textrm{RAG}}(\bm{y} | \bm{x})$$.
Model Training/Finetuning
In the paper, the authors actually train both the retriever and the generator at the same time. They provide pairs of $$(\bm{x}, \bm{y})$$ without explicitly specifying which passages $$\bm{z}_j$$ are the most relevant. Consequently, the parameters $$\eta$$ and $$\phi$$ are both updated based on the question $$\bm{x}$$ and $$\bm{y}$$. This is very similar to training a finite mixture model because we typically don't know which latent component actually generated the observed data.
However, this goes back to the problem of model finetuning which is what we wanted to circumvent in practice in the first place. I highly doubt that commercial solutions of RAG (such as API endpoints provided by tech companies) rely on any kind of fine-tuned model as the generator, while they probably have some solutions for the retriever.
I suspect the generator is simply $$p\theta(\bm{y} | {\bm{x}, \bm{z}j})$$ in practice. In other words, we concatenate the original question with whatever information is retrieved from domain-specific passages, and then feed this long message into the general language model that we started off with.
So what actually is it that makes RAG generate better answers? I think there are two factors: the additional information $$\bm{z}_j$$ that is searched and fed into the generator as part of the prompt, and the nature of mixture models with allows the retriever to put more weights on more relevant passages.