<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.3.4">Jekyll</generator><link href="https://israel-adewuyi.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="https://israel-adewuyi.github.io/" rel="alternate" type="text/html" hreflang="en"/><updated>2026-05-05T01:52:53+00:00</updated><id>https://israel-adewuyi.github.io/feed.xml</id><title type="html">blank</title><subtitle>A collection of research notes. </subtitle><entry><title type="html">Beyond the Lottery Ticket: Multiple Winning Subnetworks in Pretrained LLMs</title><link href="https://israel-adewuyi.github.io/blog/2025/slim-peft/" rel="alternate" type="text/html" title="Beyond the Lottery Ticket: Multiple Winning Subnetworks in Pretrained LLMs"/><published>2025-12-13T00:00:00+00:00</published><updated>2025-12-13T00:00:00+00:00</updated><id>https://israel-adewuyi.github.io/blog/2025/slim-peft</id><content type="html" xml:base="https://israel-adewuyi.github.io/blog/2025/slim-peft/"><![CDATA[<d-contents> <nav class="l-text figcaption"> <h3>Contents</h3> <div><a href="#intro">Introduction</a></div> <nav class="sub-nav"> <div><a href="#tldr">tl,dr</a></div> </nav> <div><a href="#background">Background </a></div> <nav class="sub-nav"> <div><a href="#notation">Notation</a></div> <div><a href="#setup">General Experimental Setup</a></div> <div> <a href="#sparse_subnet">Extracting Sparse Subnetworks</a> </div> </nav> <div><a href="#fisher_mask">Fisher Mask Works</a></div> <nav class="sub-nav"> <div><a href="#mask_training">Training with mask</a></div> <div><a href="#fisher_result">Results</a></div> </nav> <div> <a href="#rand_mask">The Surprise: Random Masks Also Work</a> </div> <nav class="sub-nav"> <div> <a href="#generate_random_masks">Generating Random Masks</a> </div> <div> <a href="#surprising_results">Surprising Results</a> </div> <div> <a href="#lr_puzzle">The learning rate puzzle</a> </div> <div> <a href="#why_diff_lr">Why Different Learning Rates?</a> </div> <div> <a href="#diff_masks_same_params">Do Different Masks Select the Same Parameters?</a> </div> <div><a href="#implications">Implications:</a></div> </nav> <div><a href="#caveat_and_qs">Caveats and Questions</a></div> <div><a href="#acknowledgements">Acknowledgements</a></div> <div><a href="#citation">Citation</a></div> </nav> </d-contents> <h2 id="intro">Introduction</h2> <p> Reinforcement learning fine-tuning is a new axis of scale for increased performance of Large Language Models (LLMs), with labs scaling compute for RL to levels on par with pretraining. Recent works have also attempted to shed light on the how and why RL really works <d-cite key="shao2025spuriousrewardsrethinkingtraining, wen2025reinforcementlearningverifiablerewards, yue2025doesreinforcementlearningreally, mukherjee2025reinforcementlearningfinetunessmall, zhu2025pathtakenrlvrprovably"></d-cite>. </p> <p> Important to this report, Mukherjee et al. (2025) <d-cite key="mukherjee2025reinforcementlearningfinetunessmall"></d-cite> showed that RLVR finetunes a sparse subnetwork in LLMs, as little as 5-30% of parameters. With the goal of efficiency in mind, we ask the question, <strong>if most parameters don't change during training, can we identify which ones matter <em>before</em> training begins, and train only those?</strong> In our attempts to answer this question, we expected slightly complicated methods like Fisher Information matrix <d-cite key="Kirkpatrick_2017"></d-cite> would be necessary to identify the "special" parameters that matter for learning. <strong>We were wrong.</strong> </p> <p> In this report, we present preliminary findings showing that <strong>random parameter selection</strong> can match full fine-tuning performance when training only ~1% of parameters. This suggests pretrained models may contain not just one winning ticket but potentially many and we are calling this the <strong>Multiple Ticket Hypothesis</strong>. </p> <p> This report details on-going work on a small scale and the main reason for sharing is that we think the temporary findings warrants discussion and are interesting enough to be shared with the wider community. <d-footnote>An auxilliary reason is to solicit for compute resources to scale the experiments up.</d-footnote> </p> <h3 id="tldr">tl,dr of results</h3> <ul> <li> Random parameter selection at 99% sparsity can match full parameter fine-tuning performance. This suggests pretrained models contain multiple viable subnetworks (the "Multiple Ticket Hypothesis"). </li> <li> Fisher Information masks also work, validating parameter importance identification methods, but surprisingly offer no clear advantage over random selection. </li> <li> Different mask types require different optimal learning rates. </li> </ul> <h2 id="background">Background</h2> <h3 id="notation">Notation</h3> <p style="color: var(--text-color)"> Let $\theta$ denote the parameters of an LLM. We use $\theta^{(t)}$ to represent the model parameters at training step $t$, with $\theta^{(0)}$ denoting the initial pretrained model weights and $\theta_i$ to denote the i-th parameter. <br/> During an RLVR run, gradients at step t, $g^{(t)}$, are computed via backpropagation, $$g^{(t)} = \nabla_\theta J_{\text{GRPO}}(\theta^{(t)})$$ </p> <h3 id="setup">General Experimental Setup</h3> <p> In this report, all experiments are carried out on <a href="https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct">Qwen2.5-0.5B-Instruct</a>. We trained via GRPO on <a href="https://x.com/kalomaze">Kalomaze</a>'s <a href="https://app.primeintellect.ai/dashboard/environments/kalomaze/alphabet-sort">Alphabetsort environment</a>. We also use AdamW optimizer for all RLVR runs. This work was also built on <a href="https://www.primeintellect.ai/">Prime-Intellect</a>'s <a href="https://github.com/PrimeIntellect-ai/prime-rl">RL training library</a>. </p> <p> <strong>Evaluation:</strong> For evaluation, we also use the same Alphabetsort env, selecting 512 samples, seeded to 2001. </p> <h3 id="sparse_subnet">Extracting sparse subnetworks</h3> <p> Our initial intuition:<br/> <strong>Imagine a pretrained LLM with only 2 parameters, p1 and p2. If only one parameter is changed at the end of a training phase with some optimization function $\phi$, say p1, it must mean that p1 is more important than p2 at satisfying $\phi$ is on the training set. </strong> <br/><br/> The question now is, how do we identify which parameters are most important for some training data D? </p> <h2 id="fisher_mask">Fisher Mask Works</h2> <p> To identify which parameters are most important for a given task, we follow the approach laid out by Kirkpatrick et. al., <d-cite key="Kirkpatrick_2017"></d-cite>. The authors estimated the importance of some weights to a task by approximating the Fisher information matrix of the model parameters. </p> <p> We approximate the Fisher matrix, $F$ <d-footnote>A justification for this is provided in the paper, but to reiterate, the core intuition here is that the magnitude of $F_i$ is correlated to how important parameter $\theta_i$ is to task represented by D.</d-footnote> on a large batch of dataset for all the parameters of the model. $$F_i \approx \frac{1}{N} \sum_{n=1}^{N} \left( \frac{\partial \log p(x_n|\theta^{(0)})}{\partial \theta_i} \right)^2$$ where $x_n \sim \text{dataset D}$ <br/> In practise, we sample a large batch of data, run a forward pass, a backward pass and $$F_i = \theta_i.\text{grad}^2$$ </p> <p> We can then take the top <code>x%</code> of parameters in $F$, set these to <code>True</code> and all else to <code>False</code> and thus creating a binary mask $\text{MASK}_t \in \{0, 1\}^N$ over all parameters. </p> <h3 id="mask_training">Training with mask</h3> <p> During training with a mask, we modify the gradient update step to only affect the masked parameters: $$\tilde{g}^{(t)} = g^{(t)} \odot \text{MASK}^{(t)}$$ $$\theta^{(t+1)} = \theta^{(t)} - \eta_t \cdot \mathcal{U}(\tilde{g}^{(t)}, \theta^{(t)})$$ where $\odot$ denotes element-wise multiplication, $\eta_t$ is the learning rate at step $t$, and $\mathcal{U}$ represents the optimizer's update rule (e.g., AdamW). This ensures that only the selected subnetwork is updated while the full model is still used for forward passes. </p> <p> In practise and for efficiency gains, we simply store the optimizer states for the subnetwork only. </p> <h3 id="fisher_result">Results</h3> <p> We approximate $F$ using a batch of 1024 samples. We then created two masks, one at 99% sparsity i.e 4,940,328 / 494M parameters and another at 99.9% sparsity i.e 494,032 / 494M parameters. We compare the eval results, as well as the training dynamics in <a href="#fft_vs_fisher">Figure 1</a> </p> <p> We use a learning rate of $10^{-6}$ for the full finetuning run, $5 \cdot 10^{-6}$ for the 99% fisher mask and $10^{-5}$ for the 99.9% fisher mask run. </p> <figure id="fft_vs_fisher"> <img src="/assets/img/rl_subnet_1/fft_fisher.png" alt="train and eval dynamics"/> <figcaption> Figure 1: Metrics comparison between a full finetune run and sparse training runs </figcaption> </figure> <p> This confirms our initial hypothesis that indeed, parameter-importance identification (with Fisher info matrix) might be a way to pickout which subnetworks allow us get comparable levels of performance with the full finetuning. </p> <h2 id="rand_mask">The Surprise: Random Masks Also Work</h2> <p> Having validated the initial intuition, we wanted to establish a baseline for comparison and investigated random parameter selection.We generated random masks at 99% sparsity by uniform sampling parameters to update. </p> <h3 id="generate_random_masks">Generating Random Masks</h3> <p> The implementation is pretty straightforward. We seed a random number generator and select <code>(100 - x)%</code> of parameters uniformly at random, to achieve <code>x%</code> sparsity. </p> <p> We used three different seeds, <code>0</code>, <code>2001</code> and <code>42</code> to get different masks and ran an RL run with these random masks. The results in <a href="#fft_vs_random">Figure 2</a> are at $10^{-4}$, $5 \cdot 10^{-5}$ and $5 \cdot 10^{-5}$ respectively. </p> <pre>
rng = np.random.default_rng(seed=42)
for name, param in model.state_dict().items():
    if param is None:
        mask_dict[name] = None
        continue

    temp_tensor = torch.zeros_like(param, dtype=torch.bool)
    num_to_generate = int(param.numel() * keep_ratio)
    indices = rng.choice(param.numel(), size=num_to_generate, replace=False)
    temp_tensor.view(-1)[indices] = True
    active += num_to_generate
    mask_dict[name] = temp_tensor
        </pre> <h3 id="surprising_results">Surprising Results</h3> <p> <a href="#fft_vs_random">Figure 2</a> surprisingly shows that random parameter selection can match full fine-tuning performance. This finding challenges our initial assumption that some sophisticated parameter identification method would be necessary. </p> <figure id="fft_vs_random"> <img src="/assets/img/rl_subnet_1/fft_vs_random.png" alt="comparison of fft and random mask runs"/> <figcaption> Figure 2: Comparison of full fine-tuning (FFT) and random mask training. With appropriate learning rate tuning, random masks match or exceed full fine-tuning performance. </figcaption> </figure> <h3 id="lr_puzzle">The learning rate puzzle</h3> <p> The key to making random masks and even the fisher mask work is finding the right learning rate. We swept over multiple learning rates for the random masks at 99% sparsity <d-footnote>Some of the runs were cancelled and therefore aren't present because right from the start, the reward and eval curve do not improve and it felt wasteful to continue to burn through compute for results we already could intuit.</d-footnote> to better understand this relationship and presents our findings in Figures <a href="#hyperparam1">3</a> and <a href="#hyperparam2">4.</a> </p> <p> Random masks perform best at higher lr, compared to full finetuning (and fisher masks). This isn't dissimilar to <a href="https://thinkingmachines.ai/blog/lora/#optimal-learning-rates-for-lora-vs-fullft">Thinkymachine's work on lora</a>. </p> <figure id="hyperparam1"> <img src="/assets/img/rl_subnet_1/eval_scores_step_150.png" alt="Hyperparameter sweep on learning rate for the random mask at 150"/> <figcaption> Figure 3: Hyperparameter sweep on learning rate for the random mask at Step 150 </figcaption> </figure> <figure id="hyperparam2"> <img src="/assets/img/rl_subnet_1/eval_scores_step_300.png" alt="Hyperparameter sweep on learning rate for the random mask at 300"/> <figcaption> Figure 4: Hyperparameter sweep on learning rate for the random mask at Step 300 </figcaption> </figure> <h3 id="why_diff_lr">Why Different Learning Rates?</h3> <p> We hypothesize that this learning rate difference paints interesting pictures about the objective we are optimizing for and the training dynamics, with respect to the parameters of the model. Some of our hypotheses are: </p> <ul> <li> <strong>Fisher masks identify parameters already near optima:</strong> The Fisher Information Matrix identifies parameters with high curvature which could be interpreted to be that those parameters are sensitive to changes. These parameters may already be close to their optimal values for the task, requiring only small adjustments (hence lower learning rates). </li> <li> <strong>Random masks require more exploration or wiggling around:</strong> Random parameters are likely further from their optimal values on average, requiring larger updates to find good solutions (hence higher learning rates). </li> <li> <strong>Different regions of the loss landscape:</strong> Fisher masks may operate in high-curvature regions where large steps cause instability, while random masks may, on average, be in a region that appears flat and large steps are relatively safer. </li> </ul> <h3 id="diff_masks_same_params"> Do Different Masks Select the Same Parameters? </h3> <p> A natural question: are the random masks accidentally selecting the same parameters that Fisher masks identify? To answer this, we compute the Jaccard overlap between different masks, defined as $$ J(A, B) = \frac{|A \cap B|}{|A \cup B|} $$ </p> <figure id="jaccard"> <img src="/assets/img/rl_subnet_1/jaccard.png" alt="jaccard sim" style="transform: scale(0.7)"/> <figcaption> Figure 5: Jaccard overlap between the Random masks at 99% sparsity and the Fisher mask. </figcaption> </figure> <p> The Jaccard overlap between the random masks and the Fisher mask, as shown in <a href="#jaccard">Fig. 5</a> is low, about 0.5% on average. This means that the random masks and Fisher mask select almost completely different parameters, yet achieve comparable performance to full fine-tuning. </p> <h3 id="implications">Implications: The Multiple Ticket Hypothesis</h3> <p> These results suggest that LLMs appear to contain multiple viable sparse subnetworks that could be optimized on some task, not just one, for the Alphabet-sort task. </p> <p> The Lottery Ticket Hypothesis (Frankle & Carbin, 2019) <d-cite key="frankle2019lotterytickethypothesisfinding"></d-cite> proposed that dense networks contain sparse subnetworks that can be trained to match the full network's performance. Frankle and Carbin used iterative magnitude pruning to identify a single winning ticket. </p> <p> Our findings extend the original LTH to the MTH: <strong>For sufficiently over-parameterized pretrained models, there may not be just <em>one</em> winning ticket, but potentially <em>many</em> winning tickets — so many that even random selection is likely to find one </strong> i.e <em>You can just <s>do things</s> select random parameters and train.</em> </p> <p> This explains why Fisher Information masks offer no clear advantage over random selection: Both methods (random masks and fisher masks) simply need to select <em>some</em> viable subnetwork and with appropriate hyperparameter tuning, they would succeed. </p> <h2 id="caveat_and_qs">Caveats and Questions</h2> <p> These are preliminary results on a small model (Qwen2.5-0.5B) and simple task (alphabet-sort). More questions and ideas to investigate reveal themselves: </p> <ul> <li> Does this phenomenon hold for larger models and different (harder) tasks like math, code gen, logical reasoning (or any task that we might want to make the model good at with RLVR)? <d-footnote>We suspect it would. It should also be straightforward to investigate this and it would be pleasantly surprising if it doesnt hold.</d-footnote> </li> <li> It would seem logical that the parameters being repurposed for task A under a random mask training might be close to optimal for another task B. How does this random mask training affects the RLVR's ability to reduce catastrophic forgetting (compared to SFT) on some other task it has been trained on? <d-footnote>We're inclined to think that it would in some non-trivial way lead to poorer performance on some other previously trained-on task B, but what sort of task B? </d-footnote> </li> <li> Some experiments (not recorded here) on even more extreme sparsity level like 99.9% and 99.95% do not match the full performance. At 99.9% the max across steps was about 46% and even less for 99.95%. What's the threshold for the number of parameters here? <d-footnote>It's not obvious that more params will equal more performance for the simple reason that using all the params is capped at some level. But at what threshold do we start getting comparable performances?</d-footnote> </li> <li> On training dynamics, we also observe different convergence rates. From preliminary experiments, higher learning rates converge faster, at smaller sparsities, but they also become unstable more frequently than other runs as well. <d-footnote>The question to investigate here isn't clearly formulated here, but it was still an interesting observation nonetheless</d-footnote> </li> <li> Do random masks transfer across tasks, or are they task-specific? <d-footnote>We don't think they are task specific, but we don't also think that the same mask would behave the same across the different training tasks.</d-footnote> </li> <li> Why exactly do different masks require different optimal learning rates? How do we reason across this in relation to optimization theory specifically? </li> </ul> <p> Answering these questions requires significantly more compute than we currently have access to. If you're interested in collaborating, mentoring or sponsoring compute, please reach out!! </p> <h2 id="acknowledgements">Acknowledgements</h2> <ol> <li> I am super grateful to <a href="https://x.com/dayveed_d">Daniel</a> and <a href="https://x.com/andreascoclet1">Andreas</a> for sponsoring compute for the initial experiments as well as asking really insightful questions. </li> <li> PrimeIntellect also cooked with the prime-rl library. It was pleasant to hack around. </li> </ol> <h2 id="citation">Citation</h2> <p>If you find this work useful, please cite:</p> <pre><code>@misc{adewuyi2025lottery,
  author = {Adewuyi, Israel},
  title = {Beyond the Lottery Ticket: Multiple Winning Subnetworks in Pretrained LLMs},
  year = {2025},
  month = {December},
  url = {https://israel-adewuyi.github.io/blog/2025/slim-peft/},
  note = {Blog post}
}</code></pre>]]></content><author><name></name></author><summary type="html"><![CDATA[Preliminary evidence that random parameter selection can match full parameter RL finetuning.]]></summary></entry><entry><title type="html">Attention sink</title><link href="https://israel-adewuyi.github.io/blog/2025/attn_sink_evidence/" rel="alternate" type="text/html" title="Attention sink"/><published>2025-07-23T00:00:00+00:00</published><updated>2025-07-23T00:00:00+00:00</updated><id>https://israel-adewuyi.github.io/blog/2025/attn_sink_evidence</id><content type="html" xml:base="https://israel-adewuyi.github.io/blog/2025/attn_sink_evidence/"><![CDATA[<d-contents> <nav class="l-text figcaption"> <h3>Contents</h3> <div><a href="#background">Intro</a></div> <div> <a href="#attn_doing">What is an attention head doing?</a> </div> <nav class="sub-nav"> <div><a href="#attn_intuition">Intuitive explanation</a></div> <div><a href="#attn_concrete">More concretely</a></div> <div><a href="#attn_further_q">Further questions</a></div> </nav> <div><a href="#induction_head">Induction Heads </a></div> <div> <a href="#induction_head_gpt2">Identifying induction heads in GPT2</a> </div> <nav class="sub-nav"> <div><a href="#induction_input">Inputs</a></div> <div><a href="#induction_metric">Metric</a></div> <div> <a href="#induction_identify">Identifying induction heads</a> </div> <div><a href="#induction_result">Results</a></div> </nav> <div> <a href="#what_if">If the induction input isn't present?</a> </div> <nav class="sub-nav"> <div><a href="#normal_input">Inputs</a></div> <div> <a href="#normal_res">Results</a> </div> </nav> <div><a href="#closing">Closing thoughts</a></div> </nav> </d-contents> <h2 id="background">Intro</h2> <p> I read the induction heads paper <d-cite key="olsson2022context"></d-cite> a while back, while taking the ARENA course. The paper lays out a super interesting mechanistic study for in-context learning and specifically examines induction head in transformer language models. </p> <p> While playing around with induction heads in GPT2, I thought to myself that "What if the input to induction heads isn't present, what do the induction heads pay attention to?" I thought this might be a good question to investigate and after a quick literature search, I stumbled on the attention sink paper and a bunch of other works that made fantastic attempts at answering the question. </p> <p> Guo, et al., in "Active-Dormant Attention Heads" <d-cite key="guo2024activedormantattentionheadsmechanistically"></d-cite> investigated the same question but from a different angle. They trained a 3L GPT2-style transformer on bigram backcopy task and then investigated which heads were heavily involved in the backcopy task. Then they showed this heads were dormant when the bigram backcopy input isn't present. </p> <p> While the question was sort of answered already, I thought it would still be a good exercise present the thought process I went through while attempting to answer the question. </p> <p> In this remainder of this post, I briefly motivate what an attention head is doing, explain induction heads and how to look for them (with visualizations) and show what happens when the induction heads input isn't present. </p> <p>Feel free to skip parts you're familiar with.</p> <h2 id="attn_doing">What is an attention head doing ?</h2> <h3 id="attn_intuition">Intuitive overview</h3> <p>In summary, attention heads move information between tokens!</p> <figure id="resid_stream_view"> <img src="/assets/img/resid_stream_with_attn_mlp.png" alt="A simplified view of the transformer" width="600"/> <figcaption> Fig. 1: A simplified view of the transformer. Source: <a href="https://transformer-circuits.pub/2021/framework/index.html">A mathematical framework for transformer circuits.</a> </figcaption> </figure> <p> The residual stream is the main object in the transformer. A way I think of it is that it represents what the model currently thinks about all the tokens in it's context, up to a particular layer. To enrich and further refine the representation of the tokens in the context, attention heads move information from earlier tokens in the context to later tokens in the context <d-cite key="elhage2021mathematical"></d-cite> and MLP blocks compose information and perform retrieval tasks <d-cite key="Geva2020TransformerFL"></d-cite>. </p> <h3 id="attn_concrete">More concretely</h3> <p> The input to the attention layer is the residual stream <d-footnote>In the first layer, this is the sum of token embeddings and positional embeddings.</d-footnote> with shape <code class="language-plaintext">[batch_size, seq_len, d_model]</code>. This input is linearly projected using three weight matrices: <code class="language-plaintext">W_Q</code>, <code class="language-plaintext">W_K</code>, and <code class="language-plaintext">W_V</code>, each of shape <code class="language-plaintext">[d_model, d_model]</code>, to produce the query (<code class="language-plaintext">Q</code>), key (<code class="language-plaintext">K</code>), and value (<code class="language-plaintext">V</code>) matrices. </p> <p> In multi-head attention, <code class="language-plaintext">Q</code>, <code class="language-plaintext">K</code>, and <code class="language-plaintext">V</code> are split into <code class="language-plaintext">num_heads</code> parts. Each head processes a subspace of the input, with <code class="language-plaintext">Q</code> and <code class="language-plaintext">K</code> shaped as <code class="language-plaintext">[batch_size, num_heads, seq_len, d_k]</code> and <code class="language-plaintext">V</code> as <code class="language-plaintext">[batch_size, num_heads, seq_len, d_v]</code>, where <code class="language-plaintext">d_k = d_v = d_model / num_heads</code>. </p> <p> For each head, <strong>attention scores</strong> are computed as the dot product of query and key vectors, scaled by <code class="language-plaintext">1/√d_k</code>. The scores are passed through a softmax to obtain the <strong>attention pattern</strong>, which represents the importance of each token relative to others. This pattern is then multiplied by the value vectors to produce the head's output. The outputs of all heads are concatenated and projected using a weight matrix <code class="language-plaintext">W_O</code> of shape <code class="language-plaintext">[d_model, d_model]</code> to yield the final attention output, shaped <code class="language-plaintext">[batch_size, seq_len, d_model]</code>. </p> <p> From Vaswani et al. <d-cite key="attentionneed"></d-cite>, the attention mechanism is defined as: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$ Multi-head attention is expressed as: $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O $$ $$ \text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$ Here, <code class="language-plaintext">W_i^Q</code>, <code class="language-plaintext">W_i^K</code>, and <code class="language-plaintext">W_i^V</code> are head-specific projection matrices. </p> <h3 id="attn_further_q">Further questions</h3> <p> The next logical question is, how does each attention head across all the layers know what sort of information to pay attention to? During pre-training, the goal is optimizing the next-token objective w.r.t the parameters of the model, over the language domain. It stands to reason that over the course of multiple steps of gradient descent, each attention head learns to pay attention to some pattern (semantic or syntactic) in the language data and this pattern, when learned, contributes to lower loss. </p> <p>And indeed, numerous papers have explored this assumption.</p> <p> In both decoder-only and encoder-decocder transformers, attention heads have been discovered that specialize in attending to different parts of speech, as well as other lingustic propertites such direct objects of verbs, noun determiners, e.t.c. <d-cite key="vig-belinkov-2019-analyzing, clark-etal-2019-bert"></d-cite> </p> <p> Interesting mechanisms that further enable LLMs to act autoregressively have been discovered, such as Copy Supression heads <d-cite key="mcdougall2023copysuppressioncomprehensivelyunderstanding"></d-cite> and Induction heads <d-cite key="elhage2021mathematical, olsson2022context"></d-cite> </p> <p> A logical conclusion of the above paragraphs is that what attention heads pay attention to is input-specific. This begs the question : What does an attention head pay attention to, when it's input isn't present? </p> <h2 id="induction_head">Induction heads</h2> <p> I'll present a super simplified explanation of Induction heads here, but to better understand Induction heads mechanistically, Callum McDougall wrote <a href="https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated">a quite interesting explainer blog</a> which I invite readers to check out. The paper <d-cite key="olsson2022context"></d-cite> also goes into a lot more details that I only mention slightly such as the presence of previous-token heads and the role of the QK/OV circuit. </p> <p> Assume arbitrary tokens <code class="language-plaintext">A, B</code>. Then assume a sequence of tokens with <code class="language-plaintext">A</code> followed by <code class="language-plaintext">B</code> and then some other arbitrary tokens. The next time the model sees <code class="language-plaintext">A</code>, i.e <code class="language-plaintext">[A B ... A]</code>, B turns out to be one of the highly likely next tokens. </p> <p> Anthropic researchers found these phenomenon in as little as 2L transformer. One of the conclusions is that the model has learnt to increase the logits on <code class="language-plaintext">B</code> if the last token in the sequence is <code class="language-plaintext">A</code> and indeed, it's theorized that Induction heads is one of the mechanisms behind In-context learning. </p> <p> For this to be true, there has to be a previous-token head. This ensures that the first occurrence of <code class="language-plaintext">[B]</code> pays attention to the first occurrence of <code class="language-plaintext">[A]</code> and the \(W_V\) matrix copies <code class="language-plaintext">A</code> to the subspace of <code class="language-plaintext">B</code>. Then when <code class="language-plaintext">A</code> occurs in the context again, for some head \(\hat{h}\), the second occurrence of <code class="language-plaintext">A</code> pays attention to the first occurrence of <code class="language-plaintext">B</code>, sees that <code class="language-plaintext">A</code> is in the residual stream of <code class="language-plaintext">B</code> and then copies <code class="language-plaintext">B</code> to the residual stream of the second occurrence of <code class="language-plaintext">A</code> and increases it's logits. This new head \(\hat{h}\) is an induction head. </p> <h2 id="induction_head_gpt2">Identifying Induction heads in GPT2</h2> <h3 id="induction_input">Inputs</h3> <p> We sample <code class="language-plaintext">N = 25</code> random tokens from the vocabulary of a transformer language model and duplicate it along it's axis. This becomes the input to the transformer. </p> <p> Input to the transformer is a matrix of shape <code class="language-plaintext">[1, 2 * N + 1, d_model]</code> i.e batch is 1, sequence length is 2 * N + 1 <d-footnote>+ 1 because we append the bos token to the sequence</d-footnote> and d_model = embedding dimension of the transformer language model. </p> <p> Pass this sequence of randomly repeated tokens into GPT2 and cache the activations. This can be done easily by loading the model with <a href="https://transformerlensorg.github.io/TransformerLens/">transformer lens</a> and running <br/> <code class="language-plaintext">_, cache = model.run_with_cache(input_tokens)</code> </p> <h3 id="induction_metric">Metric</h3> <p> Assume we have some head <code class="language-plaintext">h</code> at some layer <code class="language-plaintext">l</code>, the attention pattern is defined as, $$ \text{A}^{l, h} = \text{softmax}\left(\frac{Q_{l, h}K^T_{l, h}}{\sqrt{d_k}}\right) $$ </p> <p> We define <strong>induction score</strong> for head <code class="language-plaintext">h</code> in layer <code class="language-plaintext">l</code> as a measurement of how much attention a token in the second repeat (at position <code class="language-plaintext">i + N</code>) pays to its corresponding token in the first repeat (at position <code class="language-plaintext">i</code>). It's represented as: </p> $$I(l, h) = \frac{1}{N} \sum_{i = 1} ^N A^{l, h} [i + N, i] $$ <h3 id="induction_identify">Identifying induction heads</h3> Retrieve the attention pattern from the cache and for each head in each layer, calculate the induction score as defined above. <pre>
def induction_head_detector( cache, cfg, ) -> list:
    induction_heads = [] 
    for layer_idx in range(cfg.n_layers): 
        for head_idx in range(cfg.n_heads): 
            # fetch the attention pattern at some layer and some head
            attn_pattern = cache["pattern",layer_idx][head_idx] 
            rand_tok_seq_len = (attn_pattern.shape[1] -1) // 2 
            # compute the induction score for the attention pattern
            score = attn_pattern.diagonal(-rand_tok_seq_len + 1).mean() 
            # filter with threshold of 0.4
            if score.item() >= 0.4:
                induction_heads.append((layer_idx, head_idx)) 
    return induction_heads</pre
        >

        <h3 id="induction_result">Results</h3>
        <p>Below is a visual map of the induction heads present in GPT2</p>
        <iframe
          src="/assets/plotly/induction_head_only.html"
          height="650px"
          frameborder="0"
        ></iframe>

        <p>
          Below is an interactive visualization of the attention patterns for
          the induction heads identified above.
        </p>
        <iframe
          src="/assets/plotly/attention_viz_induction.html"
          width="100%"
          height="650px"
          frameborder="0"
        ></iframe>

        <h2 id="what_if">What happens if the induction input isn't present?</h2>
        <h3 id="normal_input">Inputs</h3>
        <p>
          Load a tiny subset of the
          <a href="https://huggingface.co/datasets/NeelNanda/pile-10k"
            >10K pile dataset.</a
          >For the purpose of this experiment, I used
          <code class="language-plaintext">batch = 1</code> and
          <code class="language-plaintext">sequence_length = 128.</code>
        </p>
        <p>
          Forward pass is also ran on this input and the activations are cached
          as in the case above as well.
        </p>
        <p>
          For the induction heads that were identified in
          <a href="#induction_result">the section above</a>, we simply visualize
          the attention pattern for these heads.
        </p>

        <h3 id="normal_res">Results</h3>
        <iframe
          src="/assets/plotly/attention_viz_normal.html"
          width="100%"
          height="650px"
          frameborder="0"
        ></iframe>
        As can be observed, these heads all pay an overwhelming amount of
        attention to the first token.

        <h2 id="closing">Closing thoughts</h2>
        <p>
          Guo et. al.,
          <d-cite
            key="guo2024activedormantattentionheadsmechanistically"
          ></d-cite>
          observed that not only the first token, but other special tokens, get
          an overwhelming amount of attention in dormant cases.
        </p>
        <p>
          They showed further evidence of this phenomenon by confirming that the
          value vectors of these tokens were much smaller than that of other
          tokens
          <d-footnote
            >This lends evidence to the fact that the information being written
            back to the residual stream is not of huge consequence.</d-footnote
          >
          and the residual stream norm for this tokens were relatively small as
          well.
        </p>
        <p>
          This however isn't the only explanation for the first-token/special
          token phenomenon observed in attention heads. Federico
          <d-footnote
            >Personally, I enjoy Federico's papers and especially his
            <a href="https://www.youtube.com/watch?v=FAspMnu4Rt0">interview</a>
            on MLST podcast.</d-footnote
          >
          et. al. <d-cite key="barbero2025llmsattendtoken"></d-cite>, has a
          paper where he also investigates why attention sinks exists and
          presents an alternative explanation. <strong>TLDR</strong>: They serve
          the purpose of preventing mode collapse.
        </p>]]></content><author><name></name></author><summary type="html"><![CDATA[More evidence]]></summary></entry><entry><title type="html">Reinforcement Learning Meets NER</title><link href="https://israel-adewuyi.github.io/blog/2025/ner_with_rl/" rel="alternate" type="text/html" title="Reinforcement Learning Meets NER" /><published>2025-05-01T00:00:00+00:00</published><updated>2025-05-01T00:00:00+00:00</updated><id>https://israel-adewuyi.github.io/blog/2025/ner_with_rl</id><content type="html" xml:base="https://israel-adewuyi.github.io/blog/2025/ner_with_rl/"><![CDATA[<d-contents>
          <nav class="l-text figcaption">
            <h3>Contents</h3>
            <div><a href="#karoche">TLDR</a></div>
            <div><a href="#intro">Introduction</a></div>
            <div><a href="#background">Background</a></div>
            <nav class="sub-nav">
              <div><a href="#why-llms-suck">Why LLMs suck at NER</a></div>
              <div>
                <a href="#summary">Summary of related approaches</a>
              </div>
            </nav>
            <div><a href="#method">Method</a></div>
            <nav class="sub-nav">
              <div><a href="#dataset">Dataset</a></div>
              <div><a href="#prompt">Prompt</a></div>
              <nav class="sub-sub-nav">
                <div><a href="#task-desc">Task description</a></div>
                <div><a href="#few-shot">Few shot demonstrations</a></div>
                <div><a href="#output-format">LLM Output Format</a></div>
              </nav>
              <div><a href="#reward">RL and Reward design</a></div>
            </nav>
            <div><a href="#experiment">Experiment</a></div>
            <nav class="sub-nav">
              <div><a href="#eval">Evaluation</a></div>
            </nav>
            <div><a href="#results">Results / Charts</a></div>
            <nav class="sub-nav">
              <div>
                <a href="#chart1">F1 score comparison across approaches</a>
              </div>
              <div><a href="#chart2">Model sizes comparison</a></div>
              <div>
                <a href="#chart3"
                  >F1 score comparison across epochs and base model</a
                >
              </div>
            </nav>
            <div><a href="#closing">Closing Thoughts</a></div>
            <div><a href="#acknowledgement">Acknowledgements</a></div>
            <div><a href="#citation">Citation Information</a></div>
            <!-- <div><a href="#footnotes">Code Blocks</a></div>
            <div><a href="#interactive-plots">Interactive Plots</a></div>
            <div><a href="#layouts">Layouts</a></div>
            <div><a href="#other-typography">Other Typography?</a></div> -->
          </nav>
        </d-contents>
        <h2 id="karoche">TLDR</h2>
        <p>
          <i
            >This work represents preliminary experimental reports, and should
            be treated as such. See the closing thoughts for more details.</i
          >
        </p>
        <ul>
          <li>
            Via RL training, we achieved up to 7.3% increase in F1 score on a
            1.5B model, on an NER task, compared to 175B GPT3-based baselines.
          </li>
          <li>
            The RL trained model underperforms some other approaches, mostly
            involving some form of SFT.
          </li>
          <li>
            We offer some closing thoughts on this as well as future possible
            directions of research.
          </li>
        </ul>

        <p>
          Feel free to skip to Methods section or start from the introduction
          below.
        </p>
        <p>
          Code is up at
          <a href="https://github.com/israel-adewuyi/ner_with_grpo"
            >this repo</a
          >
        </p>
        <h2 id="intro">Introduction</h2>
        <p>
          Large Language Models (LLMs) built on the Transformer architecture
          <d-cite key="attentionneed"></d-cite>
          have transformed Natural Language Processing (NLP), achieving SOTA
          results in tasks such as text generation, translation, and sentiment
          analysis
          <d-cite key="radford2019language"></d-cite>. At the same time, Named
          Entity Recognition (NER)—the process of identifying and classifying
          proper names and other key terms in text—remains a core NLP task in
          applications like information extraction, question answering, machine
          translation
          <d-cite key="keraghel2024recentadvancesnamedentity"></d-cite>.
        </p>
        <p>
          Recently, researchers have revisited Reinforcement Learning (RL) as a
          means of adapting LLMs to specific objectives without full retraining.
          By defining reward functions or training reward models that is
          specific to the domain/task in question, RL fine‑tuning can elicit
          desired behaviors from a pre‑trained model. This approach has already
          shown promise in areas such as competitive mathematics
          <d-cite
            key="chen-etal-2025-learning, shao2024deepseekmathpushinglimitsmathematical, openai2025competitiveprogramminglargereasoning"
          ></d-cite>
          and code generation
          <d-cite
            key="gehring2025rlef, wang2025enhancingcodellmsreinforcement"
          ></d-cite
          >.
        </p>
        <p>
          In this work, we bring these threads together. We fine‑tune a 1.5B
          Qwen2.5
          <d-cite key="qwen2025qwen25technicalreport"></d-cite> model on the
          CoNLL2003 NER
          <d-cite key="sang2003introductionconll2003sharedtask"></d-cite>, using
          carefully designed reward signals to guide entity recognition
          performance. Our results demonstrate that, even with modest model
          size, RL‑based adaptation can rival much larger architectures such as
          GPT‑3, highlighting the potential of reinforcement learning for
          structured NLP tasks.
        </p>

        <h2 id="background">Background</h2>
        <h3 id="why-llms-suck">Why LLMs suck at NER</h3>
        <p>
          LLMs like Qwen2.5-1.5
          <d-cite key="qwen2025qwen25technicalreport"></d-cite>, Llama
          <d-cite key="grattafiori2024llama3herdmodels"></d-cite>, Gemini
          <d-cite key="geminiteam2024geminifamilyhighlycapable"></d-cite>, e.t.c
          are pretrained on massive datasets with the objective being to predict
          the next token in a sequence. This makes them great for tasks like
          text generation but less effective for NER.
        </p>
        <p>Why?</p>
        <p>
          NER requires a different approach, as it involves identifying and
          classifying entities within text. This is a token-level task, where
          the goal is to label specific tokens accurately.
        </p>
        <p>
          Imed et. al.
          <d-cite key="keraghel2024recentadvancesnamedentity"></d-cite>
          provided a comprehensive overview of the advances in the field of NER,
          but most relevant approaches to this work, which are summarized below,
          are approaches that leverage LLMs to solve NER task.
        </p>
        <h3 id="summary">Summary of related approaches</h3>
        <!-- <p>
          Most supervised methods treat NER task as a sequence labelling task
          which involves representing each token/string as an high-dimensional
          vector and then assigning a label to each token. This is done using a
          combination of recurrent neural networks (RNNs) and conditional random
          fields (CRFs) [8]. The RNNs are used to capture the sequential nature
          of the data, while the CRFs are used to model the dependencies between
          labels. This approach has been successful in many cases, but it
          requires a large amount of annotated data and can be computationally
          expensive.
        </p>
        <h3 id="unsupervised">Summary of unsupervised approaches</h3> -->
        <p>
          The core idea in most of the works listed below involves primarily
          prompt engineering
          <d-footnote>
            Which is just a fancy way of saying almost all of them reformulated
            the task in some way that the LLM might be able to solve
            easier.</d-footnote
          >
          and some combination of In Context Learning (ICL) and Supervised
          Finetuning (SFT). In the ICL paradigm, LLMs learn new tasks by being
          shown a few examples (few-shot) in the prompt which makes them
          flexible for new tasks without extra training while in the SFT
          paradigm, an LLM is further trained on specific, labeled data to make
          it better at the said task.
          <d-footnote
            >It's akin to fine-tuning a general tool for a specific
            job.</d-footnote
          >
        </p>
        <p>
          GPT-NER introduced by Wang et al
          <d-cite key="wang2023gptnernamedentityrecognition"></d-cite> and LTNER
          introduced by Yan et al
          <d-cite key="Yan2024LTNERLL"></d-cite>
          reformulated the sequence labelling task as a text generation task by
          prompting the LLM to generate the input text with the identified
          entities marked by special tokens. Both methods rely heavily on ICL.
          For retrieving these few-shot examples, GPT-NER investigated various
          strategies including embeddings derived from a fine-tuned NER model
          while LTNER utilizes vector-based retrieval from a knowledge base to
          find the most relevant examples for contextual learning. Additionally,
          GPT-NER introduced a self-verification strategy to combat
          hallucination.
        </p>
        <p>
          PromptNER introduced by Ashok and Lipton<d-cite
            key="Ashok2023PromptNERPF"
          ></d-cite>
          kept the task as a sequence labelling task
          <d-footnote
            >By prompting the LLM to list the entities in the text, given a
            predefined list of entities</d-footnote
          >, but they introduced Chain-of-Thought (CoT) prompting, as well as
          giving an explanation of all the predefined entity types.
        </p>
        <p>
          GoLLIE introduced in <d-cite key="Sainz2023GoLLIEAG"></d-cite> and
          InstructUIE introduced in
          <d-cite key="Wang2023InstructUIEMI"></d-cite> both proposed
          instruction tuning frameworks for information retrieval using LLMs.
          GoLLIE fine-tunes an LLM to follow annotation guidelines, with tasks
          and guidelines represented in a code-based format. In contrast,
          InstructUIE employs natural language instructions within a unified
          text-to-text framework to model various IE tasks.
        </p>
        <p>
          CodeIE <d-cite key="Li2023CodeIELC"></d-cite> and Code4UIE
          <d-cite key="Guo2023RetrievalAugmentedCG"></d-cite> also transform the
          sequence labelling task into a code generation task to leverage the
          code generation capabilities of LLMs.
        </p>
        <table style="border: 1px solid black">
          <thead>
            <tr>
              <th>Method</th>
              <th>Approach</th>
              <th>Model size</th>
            </tr>
          </thead>
          <tbody>
            <tr>
              <td>GPT-NER</td>
              <td>ICL</td>
              <td>Text-davinci-003</td>
            </tr>
            <tr>
              <td>LTNER</td>
              <td>ICL</td>
              <td>GPT-3.5-turbo</td>
            </tr>
            <tr>
              <td>PROMPT-NER</td>
              <td>ICL</td>
              <td>GPT4</td>
            </tr>
            <tr>
              <td>CodeIE</td>
              <td>ICL</td>
              <td>Code-davinci-002</td>
            </tr>
            <tr>
              <td>Code4UIE</td>
              <td>ICL</td>
              <td>Text-davinci-003</td>
            </tr>
            <tr>
              <td>GPT-NER</td>
              <td>ICL + SFT</td>
              <td>Text-davinci-003</td>
            </tr>

            <tr>
              <td>GoLLIE</td>
              <td>SFT</td>
              <td>Code-LLaMA 34B</td>
            </tr>

            <tr>
              <td>InstructUIE</td>
              <td>SFT</td>
              <td>Flan-T5-11B</td>
            </tr>
          </tbody>
        </table>

        <!-- { name: "OURS\n(Qwen2.5-1.5B)", size: "1.5B", highlight: true }, { name:
        "InstructUIE\n(Flan-T5-11B)", size: "11B", highlight: false }, { name:
        "LTNER\n(GPT3.5-Turbo)", size: "175B", highlight: false }, { name:
        "PromptNER\n(GPT4)", size: "175B", highlight: false }, -->

        <h2 id="method">Method</h2>
        <h3 id="dataset">Dataset</h3>
        <p>
          The CoNLL2003 introduced by Eric et. al.
          <d-cite key="sang2003introductionconll2003sharedtask"></d-cite> has
          four types of named entities: Location (LOC), Organization (ORG),
          Person (PER), and Miscellaneous (MISC). We leveraged the preprocessed
          NER dataset by Li et al. (2019a)
          <d-cite key="li2022unifiedmrcframeworknamed"></d-cite>. A sample from
          the dev set which is downloadable from their
          <a href="https://github.com/ShannonAI/mrc-for-flat-nested-ner"
            >github repo</a
          >
          looks like
        </p>

        <pre>
  {
    "context": "4 - Goran Ivanisevic ( Croatia ) beat Scott Draper ( Australia ) 6-7 ( 1-7 ) 6-3 6-4 6-4",
    "end_position": [
      3,
      9
    ],
    "entity_label": "PER",
    "impossible": false,
    "qas_id": "174.2",
    "query": "person entities are named persons or family.",
    "span_position": [
      "2;3",
      "8;9"
    ],
    "start_position": [
      2,
      8
    ]
  }
</pre
        >
        <p>
          <code>context</code> - the input text from which entities are to be
          extracted. <br />
          <code>entity_label</code> - the entity to be extracted <br />
          <code>query</code> - an explanation of the entity to be extracted.
        </p>
        <!-- <hr style="width: 50%; margin: 0 auto" />
          -->
        <br />
        <h3 id="prompt">Prompt</h3>
        <pre>
  """
  A conversation between User and Assistant. The User provides a string of words. 
  The task of the Assistant is to identify all the {entity_label} entities 
  in the given string and return the entities surrounded by an entity tag.
  DESCRIPTION: {query}
  
  The reasoning process should be enclosed within &lt;think&gt; &lt;/think&gt; tags, 
  and the relevant words should be enclosed within &lt;entity&gt; &lt;/entity&gt; tags.
  i.e &lt;think&gt; reasoning process here &lt;/think&gt; &lt;entity&gt; comma separated 
  list of words that are locations&lt;/entity&gt;
  
  {example}
  
  User: {context}
  Assistant: 
  """
              </pre
        >
        <p>
          <code>entity_label</code> - the entity to be extracted <br />
          <code>query</code> - an explanation of the entity to be extracted -
          same as in the dataset. <br />
          <code>example</code> - few shot examples for the current
          <code>entity_label</code> <br />
          <code>context</code> - the input text from which entities are to be
          extracted. <br />
        </p>
        <p>
          The prompt construction method used in this work relies heavily and
          closely mirrors the fantastic work done by in the research literature.
          An overview of the relevant parts are provided below:
        </p>

        <p style="margin-bottom: 24px"></p>
        <h4 id="task-desc">Task description</h4>
        <p>
          Following the preprocessing step done by Eric et. al.
          <d-cite key="sang2003introductionconll2003sharedtask"></d-cite> and
          Wang et al.
          <d-cite key="wang2023gptnernamedentityrecognition"></d-cite>, for each
          input sentence, N prompts are constructed where N is the number of
          entities in the dataset <d-footnote>N = 4 for CoNLL2003</d-footnote>.
          Following the PromptNER
          <d-cite key="Ashok2023PromptNERPF"></d-cite> paper, we asked the model
          to generate a CoT inside the think tag and we also added an
          explanation for each entity type to the prompt.
        </p>
        <br />
        <br />
        <h4 id="few-shot">Few shot demonstrations</h4>
        <p>
          To generate in-context examples, we randomly sample 2 sentences from
          the dev set of the CoNLL dataset for each entity type Deepseek-V3-base
          is prompted to solve the task and the result is manually inspected and
          included in the prompt, without any modifications.
        </p>

        <br />
        <h4 id="output-format">LLM Output Format.</h4>
        <p>
          As visible from the prompt template, the LLM is prompted to output the
          CoT inside the think tag and then the list of entities in the entity
          tag.
        </p>
        <p>
          As an example, consider the following. The LLM is prompted to identify
          the location entity, <code>LOC</code> in the following text:
          <code
            >Japan began the defence of their Asian Cup title with a lucky 2-1
            win against Syria in a Group C championship match on Friday .</code
          >
          Results from Deepseek-V3-base :
        </p>
        <pre>
  <code
          >&lt;think&gt; </code> 1. **Japan**: This is a country in East Asia, which is a politically and geographically defined location. It fits the criteria for a LOCATION entity. 
  2. **Asian Cup**: This refers to a football tournament, not a specific geographic or political location, so it does not qualify as a LOCATION entity. 
  3. **Syria**: This is a country in the Middle East, which is a politically and geographically defined location. It fits the criteria for a LOCATION entity. 
  4. **Group C**: This refers to a group in the tournament, not a geographic or political location, so it does not qualify as a LOCATION entity. 
  5. The rest of the words ("began", "the", "defence", "of", "their", "title", "with", "a", "lucky", "2-1", "win", "against", "in", "a", "championship", "match", "on", "Friday") are not location names and do not fit the criteria. 
  <code
          >&lt;/think&gt; </code> 
  <code
          >&lt;entity&gt; </code>Japan, Syria<code
          >&lt;/entity&gt; </code>
          </pre> <p style="margin-bottom: 24px"></p> <h3 id="reward">RL and Reward Design</h3> <p> To guide the model to output things in the desired format, we utilize GRPO <d-cite key="shao2024deepseekmathpushinglimitsmathematical"></d-cite> which discards the critic model which is meant to provide some baseline for the advantage estimates. Instead, for each prompt, GRPO samples <code>m</code> outputs, which are referred to as <strong>group</strong> and the average outcome reward from the group serves as an estimate for the baseline. </p> <p> For each output in the group, four reward functions were designed to provide signals to the policy model. The design of the reward function was inspired by the amazing work of <a href="https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb">Willccb's grpo demo</a>. </p> <ul> <li> <strong>Soft Format Reward:</strong> Awards 0.5 if the output follows the required format (<code>&lt;think&gt;...&lt;/think&gt; &lt;entity&gt;...&lt;/entity&gt;</code>), ensuring structural consistency. </li> <li> <strong>Correctness Reward:</strong> Gives 2.0 if the extracted entities exactly match the ground truth, emphasizing accuracy. </li> <li> <strong>Positive Entity Correctness Reward:</strong> Adds 0.5 for each correctly identified entity, rewarding partial correctness. </li> <li> <strong>Negative Entity Correctness Reward:</strong> Subtracts 0.5 for each incorrectly included entity, penalizing overprediction. </li> </ul> <p> These rewards work together to encourage the model to identify entities accurately while adhering to the expected format. For example, if the model correctly identifies “Japan, Syria” as Locations but includes an extra incorrect entity, it receives a positive reward for the correct entities but a penalty for the mistake. </p> <div class="highlight"> <p><strong>Example Reward Calculation:</strong></p> <p> Input: “Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria.” </p> <p> Ground Truth: <code>&lt;entity&gt;Japan, Syria&lt;/entity&gt;</code> </p> <p> Model Output: <code>&lt;entity&gt;Japan, Syria, Asian Cup&lt;/entity&gt;</code> </p> <ul> <li>Soft Format: 0.5 (correct format)</li> <li>Correctness: 0.0 (not an exact match)</li> <li>Positive Entity: 1.0 (0.5 for Japan + 0.5 for Syria)</li> <li>Negative Entity: -0.5 (penalty for Asian Cup)</li> </ul> </div> <h2 id="experiment">Experiment</h2> <p> As stated earlier, <a href="https://huggingface.co/Qwen/Qwen2-1.5B-Instruct">Qwen2-1.5B-Instruct</a> was utilized in this experiment. <a href="https://huggingface.co/docs/trl/main/en/index">TRL library</a> provided by HuggingFace, alongsides modifications to <a href="https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb">Willccb's grpo demo</a> repository, were used to construct a training and eval pipeline. The model was also trained on 4 epochs on <strong>2 Nvidia A100</strong> GPUs. Relevant hyperparameters are: </p> <pre>
  training_args = GRPOConfig(
    ...
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    bf16=True,
    per_device_train_batch_size=4,
    num_generations=8,
    max_prompt_length=2048,
    max_completion_length=2048,
    num_train_epochs=4,
    save_strategy="epoch",
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
  )
        </pre> <p>More details can be found in the repository.</p> <h3 id="eval">Evaluation</h3> <p>F1 score is reported for the RL-trained model.</p> <ul> <li> To compare with other methods, the best F1 score evaluation across 4 epochs is reported. </li> <li>VLLM was used for sampling from the transformers.</li> <li> To ensure robustness, F1 score across 16 generations are averaged and reported, both in comparison to other methods and in comparison across epochs. </li> <li> To evaluate the RL-trained model, <a href="#few-shot">few shot examples</a> were included in the prompt. </li> </ul> <h2 id="results">Results / Charts</h2> <h3 id="chart1">F1 score comparison across approaches</h3> <p> The chart below shows the F1 score of the fine-tuned model in comparison with other methods described in the <a href="#summary">relevant works section</a>. </p> <div style="max-width: 800px; margin: auto"> <canvas id="resultsChart" width="800" height="400"></canvas> </div> <h3 id="chart2">Model sizes comparison</h3> <p> The chart below compares the model sizes across all the methods listed above </p> <div style="max-width: 900px; margin: auto"> <canvas id="modelSizeChart" width="800" height="400"></canvas> </div> <p>Notes:</p> <ul> <li> Parameter count of code-davinci-003 isn't publicly known, but inferring from the naming convention, it's based on the GPT3 model, which is reported to be 175B. <d-cite key="Brown2020LanguageMA"></d-cite> </li> <li> GPT3.5-Turbo also doesn't have it's parameter count publicly known. But based on the the GPT3 model count, 175B is conservative. </li> <li> The same logic applies to GPT4. Though, it's rumoured to be around 1.7T parameters. So 175 is also conservative. </li> </ul> <h3 id="chart3">F1 score comparison across epochs</h3> <p> The chart below shows the result of evaluating a saved checkpoint from each epoch as well as the base Qwen1.5B model with few-shot examples or without. </p> <div style="max-width: 900px; margin: auto"> <canvas id="baseModelChart" width="800" height="400"></canvas> </div> <h2 id="closing">Closing thoughts</h2> <p> This work explored, albeit, in a limited scope, how much RL training can improve the performance of LLMs on NER task and the results show impressive performance for small-sized LLMs. </p> <p> It's important to emphasize (and this are my current thoughts) that RL training benefits from a good base pretrained model, as the current RL training paradigm encourage exploitation more than they do exploration. It's useful to think of the current RL training paradigm as stabilizing the distribution over the domain of interest, so if the model cannot, under any inference sampling, over number of generations that tend to infinity, sample the answer, most likely it cannot be learned during an RL training paradigm. This intuition is informed by the performance of LLMs on math and code tasks under base + SFT + RL vs only base and RL. </p> <p> Future work would investigate other NER datasets (in other domains), how much performance is lost or retained in other tasks/benchmarks of interest, different structures to reward model as well as how small can the models be to get competitive performance, especially in real world applications, where the tradeoff of efficiency and correctness is often the focal point. </p> <p> It would also be interesting to do Interpretability on these models. What about the model changes when it's RL trained </p> <h2 id="acknowledgement">Acknowledgements</h2> <p> Gratitude goes to the Institute of Software Engineering at Innopolis university, led by <a href="https://scholar.google.com/citations?hl=en&user=16AyxX0AAAAJ&view_op=list_works&sortby=pubdate">Professor Vladmir Ivanov</a> for providing compute. I am also grateful to the following persons: </p> <ul> <li> Ilnur Khadiev for helping out with fixing GPU-related issues every so often. </li> <li> Nursultan Abdullaev for providing feedback on an initial draft. </li> </ul> <h2 id="citation">Citation information</h2> <p>Cite as:</p> <div class="highlight"> Israel, Adewuyi. (May, 2025). Reinforcement Learning meets NER https://israel-adewuyi.github.io/blog/2025/ner_with_rl/. </div> or <pre>

  @article{israel2025ner_rl,
    title   = "Reinforcement Learning Meets NER",
    author  = "Israel, Adewuyi",
    journal = "israel-adewuyi.github.io",
    year    = "2025",
    month   = "May",
    url     = "https://israel-adewuyi.github.io/blog/2025/ner_with_rl/"
  }
        </pre> ]]></content><author><name></name></author><summary type="html"><![CDATA[an attempt at solving Named Entity Recognition with RL training.]]></summary></entry><entry><title type="html">Replicating GraphRAG paper</title><link href="https://israel-adewuyi.github.io/blog/2024/replicating-graphrag/" rel="alternate" type="text/html" title="Replicating GraphRAG paper"/><published>2024-11-08T00:00:00+00:00</published><updated>2024-11-08T00:00:00+00:00</updated><id>https://israel-adewuyi.github.io/blog/2024/replicating-graphrag</id><content type="html" xml:base="https://israel-adewuyi.github.io/blog/2024/replicating-graphrag/"><![CDATA[<d-contents> <nav class="l-text figcaption"> <h3>Contents</h3> <div><a href="#background">Background</a></div> <div><a href="#notes">Organizational notes</a></div> <div><a href="#chunks">Text Chunking</a></div> <div><a href="#entities">Entity extraction</a></div> <div><a href="#clustering">Graph clustering</a></div> <div><a href="#community">Community summary</a></div> <div><a href="#query">How are queries answered?</a></div> <div><a href="#discussion">Discussion</a></div> <div><a href="#Acknowledgements">Acknowledgements</a></div> </nav> </d-contents> <h2 id="background">Introduction</h2> <p> Microsoft research recently put out the GraphRAG paper <d-cite key="edge2024localglobalgraphrag"></d-cite>. In this post, I share my attempt at replicating the paper and some thoughts about tradeoffs to be made when working with retrieval systems in general. </p> <p> The summary of the paper is, we can structure the information in a body of documents as a graph by thinking of every object in the document as an entity and drawing it's relationship with other entities. Once we have this graph, we can then reason over it to draw insights and conclusions that otherwise we might not be able to draw. <figure> <img src="/assets/img/graphrag/pipeline.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Snippet from entity extraction prompt"/> <figcaption>GraphRAG pipeline.</figcaption> </figure> </p> <p> In the paper, Microsoft research used podcast transcripts and news article as the knowledge source over which retrieval is done. I decided to use a podcast episode - specifically, <a href="https://www.youtube.com/watch?v=UTuuTTnjxMQ">Dwarkesh Patel's interview with Mechanistic Interpretability researchers Trenton Bricken and Sholto Douglas</a>. </p> <p> I would expect readers to be fairly familiar with the GraphRAG paper <d-cite key="edge2024localglobalgraphrag"></d-cite>. </p> <h2 id="notes">Organizatinal notes</h2> I encourage readers to probably read through the whole post and refer to this section from time to time. <ul> <li> For evaluation, I used one of the transcripts of <a href="https://www.youtube.com/watch?v=UTuuTTnjxMQ">Dwarkesh Patel's podcast</a> as the knowledge source. </li> <li> During development, I used a bunch of models to test out the different components, but for the final graph index generation and inference, I used <a href="https://console.groq.com/settings/limits">LLAMA 3.2 90B text-preview</a><d-cite key="llama3_2"></d-cite>, provided by <a href="https://console.groq.com/playground">Groq</a>. That being said, Gemma 9B seems to perform the best on entity-relationship extraction <d-footnote>I judged this because I have listened to the podcast episode and I was able to roughly access the quality of the generation for different models.</d-footnote> . </li> <li> I implemented a single hierarchy of clustering graph nodes and edges. </li> <li> For the sake of optimizing for API the number of calls / requests to <a href="https://console.groq.com/playground">Groq</a>, I implemented global search with a vectorDB, as opposed to LLM summarization in the graphrag paper. <d-cite key="edge2024localglobalgraphrag"></d-cite> </li> <li> I did not implement <a href="https://microsoft.github.io/graphrag/query/local_search/">local search</a>, covariates and a couple of other details that were token/api calls-expensive. </li> <li> LLM-derived Knowledge graph can be viewed <a href="/assets/html/network.html">here</a> </li> <li> <a href="https://github.com/israel-adewuyi/graphrag">Link to repository</a>, <a href="https://graphrag-impl.streamlit.app/">Link to streamlit chat interface</a><d-footnote>I did not exactly optimise the chat interface to be a conversational agent like SOTA chat agents. Ask questions, get a response. </d-footnote> </li> </ul> <h2 id="chunks">Text Chunking</h2> <p> Dwarkesh provides <a href="https://www.dwarkeshpatel.com/p/sholto-douglas-trenton-bricken?open=false#%C2%A7transcript">links</a> to the transcripts of his podcast episodes, so it was easy to get the transcripts of the episode. To preserve the notion of turn-based conversation, I chunked the transcript text based on each speaker's speech. <details> <summary>More info</summary> <p> Usually, the podcast transcripts are roughly structured as <figure class="highlight"> <pre><code class="language-python" data-lang="python"><span class="kn">Speaker A (timestamp)</span> 
<span class="kn"># Speaker A's speech</span>
<span class="kn">Speaker B (timestamp)</span> 
<span class="kn"># Speaker B's speech</span>
              </code></pre> </figure> </p> </details> </p> <p> After chunking the <a href="https://www.dwarkeshpatel.com/p/sholto-douglas-trenton-bricken?open=false#%C2%A7transcript">transcripts</a> along each speaker's speech, there were approximately 483 chunks. I analyzed the token length of each chunk. <img src="/assets/img/graphrag/Sholto&Trenton.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Token distribution for chunks in the original chunk set"/> </p> <p> Along the lines of the original paper and for reasons I explain the next section, I decided to merge chunks, such that each chunk had an average size of 1000 tokens <d-footnote>In the original paper, the authors used an average size of 600 tokens. </d-footnote>. This reduced the number of chunks from <code class="language-plaintext">483</code> to <code class="language-plaintext">47</code>. <img src="/assets/img/graphrag/Sholto&Trenton_merged.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Token distribution for chunks in the original chunk set"/> </p> <h2 id="entities">Entity extraction</h2> <p> I hardcoded a list of entity types and manually asked ChatGPT to extraxt entity types from some random sample of chunks, which I then filtered for repeated entity types - this was because extracting the entity type for each chunk was expensive with respect to api calls. I ended up with a list of <a href="https://github.com/israel-adewuyi/graphrag/blob/master/config/entities.py">33 entity types.</a>. </p> <p>The entity-relationship extraction phase is in two stages.</p> <p> In the first stage, I prompted the LLM to extract all the entities and relationships from each chunk and output the results in json format. For each entity, I extract the entity_name, entity_type and description. For each relationship, I extract the source entity, target entity, relationship description. </p> <figure> <img src="/assets/img/graphrag/image.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Snippet from entity extraction prompt"/> <figcaption>Snippet from entity extraction prompt.</figcaption> </figure> <p> In the second stage, for each entity-relationship json retrieved for each chunk, I prompt the LLM a second time to check if it has extracted all the possible entities and relationship. The authors referred to this as gleaning <d-footnote>The prompts I used here and other parts of this implementation were largely taken from <a href="https://github.com/microsoft/graphrag/tree/main/graphrag/prompt_tune/template">the official github implementation.</a> </d-footnote>. For most chunks, this seemed to extract almost as many relationships as it did, the first time. </p> <p> A <a href="/assets/html/network.html">knowledge graph</a> is then built from the entities and their relationships that have been extracted. I used the <a href="https://networkx.org/documentation/stable/index.html">NetworkX</a> library to construct the network graph.<d-cite key="trajanoska2023enhancingknowledgegraphconstruction"></d-cite> <d-cite key="SciPyProceedings_11"></d-cite> </p> <p> A problem here which can be observed from the <a href="/assets/html/network.html">knowledge graph</a> is that there are a bunch of entities that are not captured in any relatioship with another entity. </p> <p> In practise, this means that there were two api calls per chunk. In the previous setup where I had > 450 chunks, this would have been close to 1000 api calls but now, I had less than 100 chunks for the whole document. </p> <h2 id="clustering">Graph clustering</h2> <p> Now, we have a graph that's fairly representative of the whole document. A key insight at this stage is that we can group nodes and edges into communities which in principle should be representative of some semantic relatinoship between all the nodes in the community. </p> <p> In following with the original paper, I used the Leiden algorithm <d-cite key="traag2019louvain"></d-cite> provided in <a href="https://cdlib.readthedocs.io/en/latest/">cdlib library</a> to cluster the graph into communities and I did the graph clustering for a single hierarchy level. <img src="/assets/img/graphrag/graph_community.png" class="img-fluid rounded z-depth-1" width="70%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Token distribution for chunks in the original chunk set"/> </p> <h2 id="community">Community summary</h2> <p> My implementation here highly differs from the original implementation. To understand the decision I made, let's delve into what happens when there is a global search on a query. </p> <p> In the original implementation of global search, when there is a query, an LLM is used to generate intermediate answer and a relevancy score to this query,for <strong>each</strong> community summary. Now, this is alot of API calls!!. <d-footnote>See <a href="https://console.groq.com/settings/limits">here</a> for the rate limits on requests and tokens per day</d-footnote> </p> <p> A neat thing I thought to do, was to generate a list of conclusions or mini-summaries for each community cluster. So instead of just a single community summary, there are a list of insights and conclusions for each community that are generated after the clustering process. </p> <p> These conclusions are indexed in a <a href="https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.faiss.FAISS.html">vectorDB</a>. </p> <h2 id="query">How are queries answered?</h2> <p> For each query, I run a similarity search with all vectors in the vectorDB from the previous section and I return 15 vectors,<d-footnote>The choice here is arbitrary. A more principled way would have been to use some threshold similarity score or keep adding retrieved docs to the prompt till it exceeds context length or number of tokens per query.</d-footnote> along with their similarity scores. These conclusions, their scores and the original query are then fed into the LLM for a synthesized respnse. </p> <p> To also be efficient with tokens, I also implemented caching responses. This means that if a query is sent and at some time later, a similar query is sent <d-footnote>or the system deems similar</d-footnote>, it returns the previous results. You could imagine in a large scale system where there is the potential to have a lot of duplicate queries and being able to save resources. </p> <p>I did not implement conversation history.</p> <h2 id="discussion">Discussion / Final thoughts</h2> <ol> <li> As mentioned in <a href="#entities">this discussion</a> , there were some entities without relationships with other entities in the <a href="/assets/html/network.html">knowledge graph</a>. My thoughts, supported by the paper is that it would be possible to extract more relationships and also entities by running more rounds of gleaning. </li> <li> The quality of the generations, from the knowledge graph to the query response would probably be better with more advanced models like GPT 4o. </li> <li> Gleaning was able to capture significantly more relationships, but more entities were also introduced to the knowledge graph. I expect that with more rounds of gleaning, this should be resolved neatly. However, it's unclear if more rounds of gleaning will be supported by the context length of the model. </li> <li> The approach of drawing insights from communities and using these insights to generate response would add alot of latency to the retrival process when this is over huge datasets, such as the whole set of all Dwarkesh's podcasts <d-footnote>I plan to work on this sometime in the near future</d-footnote>. The rationale is, similarity search is over pieces of text which are way smaller than chunks, but for each chunk, there are 5 - 20 such texts. </li> <li> Work is not done. I plan to extend this to, at the time of this writing, ~80 Dwarkesh episodes. </li> </ol> <h2 id="Acknowledgements">Acknowledgements</h2> <ol> <li> <a href="https://rustam-lukmanov.com/">Professor Rustam Lukmanov</a> for suggesting the project in the first place, providing feedback and the motivation to do this write-up. </li> <li> <a href="http://kimfom.space">Kim Fom</a> for useful feedback on writing good code. </li> <li> <strong>Khush Patel</strong> for some advice related to deployment of the admittedly unpolished interface and tradeoffs when choosing a vectorDB. </li> </ol>]]></content><author><name></name></author><summary type="html"><![CDATA[a replication of 'From Local to Global']]></summary></entry><entry><title type="html">Replicating ‘Refusal Mechanism’</title><link href="https://israel-adewuyi.github.io/blog/2024/replicating-refusal/" rel="alternate" type="text/html" title="Replicating ‘Refusal Mechanism’"/><published>2024-10-05T00:00:00+00:00</published><updated>2024-10-05T00:00:00+00:00</updated><id>https://israel-adewuyi.github.io/blog/2024/replicating-refusal</id><content type="html" xml:base="https://israel-adewuyi.github.io/blog/2024/replicating-refusal/"><![CDATA[<d-contents> <nav class="l-text figcaption"> <h3>Contents</h3> <div><a href="#background">Background</a></div> <div><a href="#summary">Summary</a></div> <div><a href="#setup">Setup</a></div> <div><a href="#results1">Results with Gemma 2-2B</a></div> <div><a href="#results2">Results with Gemma 2-9B</a></div> <div><a href="#discussion">Discussion</a></div> </nav> </d-contents> <h2 id="background">Background</h2> <p> This post represents a step towards my understanding of model behaviour and how to align LLMs with our interests. When I first read the blog, it seemed approachable on the surface level, I felt I could track what the author was doing as well as their motivations and it felt like a good experiment to try and replicate. </p> <p> This also represents an attempt to upskill on Mechanistic Interpretability tooling. </p> <p> This post is based on <d-cite key="andy_refusal"></d-cite>. If you need a more indepth explanation, or a refresher, I suggest the reader goes through the blog and return, because this writeup just summarises my findings and assumes the reader is familiar with mech interp-related terms. </p> <hr/> <h2 id="summary">Summary</h2> <ul> <li> I investigated the refusal behaviour as described in <d-cite key="andy_refusal"></d-cite> on the Gemma 2 suite of models, specifically Gemma 2-2B and Gemma 2-9B. </li> <li> I couldn't steer with the refusal heads contribution with Gemma 2-2B. </li> <li> I could steer with the refusal heads contribution with Gemma 2-9B albeit, with significant increase in the scaling factor, > 26x </li> <li>For both models, I could steer using the difference vector.</li> <li> For inhibiting the refusal behaviour on harmless prompts, I could not steer with both the refusal head contribution as well as the difference vector. </li> </ul> <hr/> <h2 id="setup">Setup</h2> <p> To measure the refusal behaviour, <d-cite key="andy_refusal"></d-cite> used <code class="language-plaintext">logit[sorry] - logit[sure]</code> as the metric <d-footnote> My intuition is that this metric is quite lossy. See <strong>Takeaways</strong> for disucssion on this. </d-footnote>. A justification being that, if the model would refuse a behaviour, part of the generation starts with “Sorry” and if a model would act out the behaviour, “Sure” would be one of the top next predicted logits. </p> <p> Initially, I tried using the dataset of harmful and harmless objects <d-cite key="andy_refusal"></d-cite> used, but I ran into troubles making sense of the results. Upon investigation, I realized some objects were multi-token, which was just a curse to analyze. So I decided to cherry-pick objects that were single token, instead <d-footnote>Link to dataset I used</d-footnote>. </p> <p> I followed <a href="https://ai.google.dev/gemma/docs/formatting">Gemma instruction prompt template.</a> </p> <hr/> <h2 id="results1">Results with Gemma 2-2B</h2> <h3>Residual stream patching</h3> <img src="/assets/img/refusal_replication/2-2b_resid_attrib_plot.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Logit attribution for the residual stream at each layer"/> <p> This doesn't compare cleanly with the results from <d-cite key="andy_refusal"></d-cite>. The absolute value of the refusal score for harmful logits appears to be higher here the absolute value of the refusal score for harmful logits in <d-cite key="andy_refusal"></d-cite>. For harmless logits, the opposite appears to be true. </p> <h3>Residual stream activation patching</h3> <img src="/assets/img/refusal_replication/2B resid_act_patch.png" class="img-fluid rounded z-depth-1" width="50%" height="40%" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Patching residual stream at each layer"/> <p> The results at the <code class="language-plaintext">obj</code> token position as well as the last token position is expected. At the <code class="language-plaintext">'.'</code> token position, which would be henceforth regarded as the <code class="language-plaintext">post obj</code> token position, layers 8 - 15 seems to be carrying signals related to the refusal behaviour. Going forward, these layers are layers of interest. </p> <h3>Attention layer activatin patching</h3> <p> The <code class="language-plaintext">resid_post</code> at any layer can be decomposed into <code class="language-plaintext">resid_post = resid_pre + attn_out + mlp_out </code>. So let's see what's up with <code class="language-plaintext"> attn_out </code>. </p> <img src="/assets/img/refusal_replication/2B attn_out_act_patch.png" class="img-fluid rounded z-depth-1" width="50%" height="40%" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Patching residual stream at each layer"/> <ul> <li> A surprising result is that attn_out activation patching cannot fully recover the refusal behaviour. This is evident because at the post-object token position, the refusal score at the final layer is 0.7707379. </li> <li> By layer 15, the score is close to the final layer's refusal score - 0.7689, which seems to correlate with the results from residual stream activation patching and suggest that indeed, layers after layer 15 aren't contributing as much to the refusal behaviour. </li> </ul> <h3>MLP Layer activation patching</h3> <p> I decided to run activation patching on the MLP out of each layer as well, just to see what gives. </p> <img src="/assets/img/refusal_replication/2B mlp_out_act_patch.png" class="img-fluid rounded z-depth-1" width="50%" height="40%" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Patching residual stream at each layer"/> <p> In <i>retrospect</i>, this result makes sense. One can think of it as, patching in at the <code class="language-plaintext">obj</code> token position is analogous to replacing the harmless objects with the harmful object in the prompt. The refusals score at <code class="language-plaintext">obj</code> position is <code class="language-plaintext">0.888</code>. </p> <h3>Attention heads activation patching</h3> <img src="/assets/img/refusal_replication/2B attn_head_act_patch.png" class="img-fluid rounded z-depth-1" width="50%" height="40%" loading="eager" onerror="this.onerror=null; $('.responsive-img-srcset').remove();" alt="Patching residual stream at each layer"/> <p> Setting an arbitrary threshold of <code class="language-plaintext">0.005</code>, 11 heads were contributing to the refusal behaviour and this set of heads were selected to be the sufficient for the refusal behaviour. </p> <h3>Steering</h3> <h4>With difference vector</h4> <h4>With activation vector</h4> <hr/> <h2 id="results2">Results with Gemma 2-9B</h2> <h3>Residual stream attribution</h3> <h3>Residual stream activation patching</h3> <h3>Attention Layer activation patching</h3> <h3>MLP Layer activation patching</h3> <h3>Attention heads activation patching</h3> <h3>Steering</h3>]]></content><author><name></name></author><summary type="html"><![CDATA[a replication of the initial experiments on the 'Refusal Mechanism']]></summary></entry></feed>