🤔 Taken for Granted? It's time we reconsider the Instruction Tuning Loss! 👉

Uncovering the benefits of Weighted Instruction Tuning (WIT)

TL;DR — Small Tweaks, Big Gains 🤷

Instruction Tuning is the cornerstone of language model (LM) post-training — it is what enables the models to “follow” user instructions instead of merely completing text. Yet surprisingly, one critical piece has largely flown under the radar: the Loss Function!

In our recent work, On the Effect of Instruction Tuning Loss on Generalization(co-authored with Anwoy Chatterjee, Sumit Bhatia and Tanmoy Chakraborty), soon to appear in the Transactions of the Association for Computational Linguistics (TACL), we revisit this overlooked choice and introduce Weighted Instruction Tuning (WIT) — a simple yet effective alternative that lets you assign different weights to prompt and response tokens during training.

We found that assigning a low-to-moderate (0-0.5) weight to the prompt tokens and a moderate-to-high (0.5-1) weight to the response tokens consistently yields best-performing models across settings — when tested extensively across five models (spanning sizes and families), three instruction tuning datasets (varying sizes), and five diverse benchmarks…

WIT-finetuned models not only demonstrate consistent improvement in generalization (average relative gain of 6.55%), but also are more robust to minor changes in prompts as well as serve as a stronger bases for subsequent preference alignment tuning!

Introduction — The Ubiquitous Loss Function We’ve Barely Scrutinized

Instruction Tuning has emerged as an important step in the post-training phase of LMs. It is what makes today’s language models capable of “following” user instructions — from summarizing a news article to giving life advice in the tone of a pirate!

Behind the scenes, instruction tuning works by finetuning a pretrained LM on a collection of (prompt, response) pairs comprising of a diverse set of tasks — where prompts encode tasks in the form of natural language instructions and responses provide ideal outputs. But lurking inside nearly every instruction tuning recipe is a crucial detail that has been overlooked:

The conventional loss is computed only on the response tokens!

But WHY?? 🙂

A couple of recent works have already started questioning it:

Meanwhile, several recent works (, , ) have raised concerns around models memorizing response patterns — suggesting that we may be over-tuning on response tokens and hurting generalization…

So, this begs the question:

What if we assign different weights to prompt and response tokens during training? To control how much to tune on the prompt and response tokens?

This is exactly what WIT does — it lets you assign different weights to prompt and response tokens during training. And it turns out that this simple tweak goes a long way in improving the generalizability of the model!

Weighted Instruction Tuning (WIT)

Let \(\mathcal{D} = \{(\boldsymbol{P}_i, \boldsymbol{R}_i)\}_{i=1}^{N_{\mathcal{T}}}\) be an instruction tuning dataset with \(N_{\mathcal{T}}\) (prompt, response) pairs. Each prompt \(\boldsymbol{P}_i\) includes an instruction (implicit or explicit) and optionally some input, while \(\boldsymbol{R}_i\) is the expected ground-truth response.

If \(\lvert\boldsymbol{S}\rvert\) denotes the number of tokens in a sequence \(\boldsymbol{S}\), then:

\[\boldsymbol{P}_i = \left\{p_i^{(1)}, p_i^{(2)}, \ldots, p_i^{(\lvert\boldsymbol{P}_i\rvert)}\right\}\] \[\boldsymbol{R}_i = \left\{r_i^{(1)}, r_i^{(2)}, \ldots, r_i^{(\lvert\boldsymbol{R}_i\rvert)}\right\}\]

The WIT loss is given by:

\[\mathcal{L}_{WIT} = -\frac{\sum\limits_{i=1}^{N_{\mathcal{T}}}\left[\lambda_p \cdot \sum\limits_{j=1}^{\lvert\boldsymbol{P}_i\rvert} \log \mathbb{P}_{\mathcal{M}}\left(p_i^{(j)} |\; p_i^{(1)},\ldots, p_i^{(j-1)} \right) + \lambda_r \cdot \sum\limits_{j=1}^{\lvert\boldsymbol{R}_i\rvert} \log \mathbb{P}_{\mathcal{M}}\left(r_i^{(j)} |\; r_i^{(1)},\ldots, r_i^{(j-1)} \right)\right]}{\sum\limits_{i=1}^{N_{\mathcal{T}}}\Big(\mathbb{I}{(\lambda_p \neq 0)}\cdot\lvert \boldsymbol{P}_i\rvert + \mathbb{I}{(\lambda_r \neq 0)}\cdot \lvert\boldsymbol{R}_i\rvert\Big)}\]

where \(\mathbb{I}(\cdot)\) is the indicator function, \(\lambda_p\) is the prompt token weight, and \(\lambda_r\) is the response token weight. \(\mathcal{L}_{WIT}\) computes the weighted sum of log-probabilities – scaling the log-probabilities of prompt tokens by \(\lambda_p\) and those of response tokens by \(\lambda_r\) – and then normalizes by the count of tokens with non-zero weight. The indicator function (\(\mathbb{I}\)) ensures that the weighted sum is divided exactly by those tokens whose weight is non-zero. Note that the conventional instruction tuning loss \(\mathcal{L}_{IT}\) is a special case of \(\mathcal{L}_{WIT}\) for \((\lambda_p, \lambda_r) = (0,1)\) and continual pre-training is a special case of \(\mathcal{L}_{WIT}\) for \((\lambda_p, \lambda_r) = (1,1)\).

So…Does it work?

YES!! And that too consistently, across settings:

We ran an extensive set of experiments to test WIT across a wide range of settings. Here’s what we varied:

Across the board, WIT outperforms the conventional instruction tuning loss

Figure 1: Heatmaps depicting average performance across five benchmarks (MMLU, BBH, AlpacaEval, IFEval and MT-Bench) for different configurations of (λp, λr) and for different models finetuned on Tülu-v2, Alpaca-Cleaned, and LIMA. Best performing configuration is highlighted with a red circle. The color map is based on relative gain with respect to conventional instruction tuning. Rows correspond to prompt token weights (λp) and columns correspond to response token weights (λr). Conventional instruction tuning is marked with IT and base model performance is marked with Base.

Key Observations:

🔮 Some Hints for Future Research

Building on the empirical results, we looked at broader patterns and preliminary insights that could inspire future studies on the interplay between task characteristics and token weighting:

Figure 1: Correlation coefficients (Spearman and Kendall’s τ) between the optimal prompt-token weight (λp) and various characteristics of the finetuning datasets, evaluation benchmarks and language models.

Conclusion and Future Directions