9 minute read

I’m frequently asked by clients what my thoughts are on LLMs and coding. Personal experience has informed me that LLMs cannot solve problems of a certain complexity for a number of reasons. One of the more common ones is codebase size and complexity: a large codebase with many more tokens than even the best SOTA models with the longest context windows can support (see RULER paper, too). This typically requires some type of semantic search or RAG to search the codebase and usually results in the model having subpar performance while addressing a user’s request for a bugfix or enhancement.

Secondly, there’s the point of problem complexity, which is the subject of this post. I recently fixed a bug in the mlx-lm GPT2 model implementation that I submitted to the MLX project last year. If you’re not familiar with MLX, it is an ML and deep‑learning framework similar to PyTorch, but specific to Apple Silicon. The pull request with information on the bug can be viewed here. In summary, when generating tokens from the GPT2 model using the inference code that includes a KV Cache in mlx-lm, the output was garbled with repetitive text for the smaller GPT2 models like the 128M “small” version. At the time I submitted the original PR, I was mostly testing with the gpt2-xl 1.6B parameter version, which generated output more or less identical to the HuggingFace transformers GPT2 implementation.

Upon analyzing the code, I noticed that positional embeddings were not being added after the initial forward pass of the model, which meant any subsequent tokens effectively had a position of 0 - not ideal! The smaller GPT2 models seem to be more sensitive to the positional embedding layer, while the larger GPT2 models like gpt2-xl still generate acceptable output without it. I would assume that during training the gpt2-xl model with its additional parameters learned positional information in the transformer block layers.

In any case, the fix was relatively simple to implement, but harder to spot initially, so I decided to test various LLMs to see if they would have identified the cause or not. The results were interesting. The following prompt was used for testing:

the token generation w/ KV cache for gpt2 implementation generates garbage output after the first token. W/o KV cache it works fine. Here's the implementation code:

# Copyright © 2023-2024 Apple Inc.

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    n_ctx: int
    n_embd: int
    n_head: int
    n_layer: int
    n_positions: int
    layer_norm_epsilon: float
    vocab_size: int
    num_key_value_heads: int = None

    def __post_init__(self):
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.n_head


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

        self.n_embd = args.n_embd
        self.n_head = args.n_head
        self.head_dim = self.n_embd // self.n_head

        self.scale = self.head_dim**-0.5

        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        B, L, D = x.shape

        qkv = self.c_attn(x)
        queries, keys, values = mx.split(qkv, 3, axis=-1)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

        if cache is not None:
            keys, values = cache.update_and_fetch(keys, values)

        output = scaled_dot_product_attention(
            queries, keys, values, cache=cache, scale=self.scale, mask=mask
        )

        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.c_proj(output)


class MLP(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_embd = args.n_embd
        self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
        self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)

    def __call__(self, x) -> mx.array:
        return self.c_proj(nn.gelu_approx(self.c_fc(x)))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_head = args.n_head
        self.n_embd = args.n_embd
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.ln_1 = nn.LayerNorm(
            self.n_embd,
            eps=self.layer_norm_epsilon,
        )
        self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        r = self.attn(self.ln_1(x), mask, cache)
        h = x + r
        r = self.mlp(self.ln_2(h))
        out = h + r
        return out


class GPT2Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_embd = args.n_embd
        self.n_positions = args.n_positions
        self.vocab_size = args.vocab_size
        self.n_layer = args.n_layer
        self.layer_norm_epsilon = args.layer_norm_epsilon
        assert self.vocab_size > 0
        self.wte = nn.Embedding(self.vocab_size, self.n_embd)
        self.wpe = nn.Embedding(self.n_positions, self.n_embd)
        self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
        self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        inputs: mx.array,
        mask: mx.array = None,
        cache=None,
    ):
        _, L = inputs.shape

        hidden_states = self.wte(inputs)

        mask = None
        if hidden_states.shape[1] > 1:

            position_ids = mx.array(np.arange(L))
            hidden_states += self.wpe(position_ids)

            if mask is None:
                mask = create_attention_mask(hidden_states, cache)

        if cache is None:
            cache = [None] * len(self.h)

        for layer, c in zip(self.h, cache):
            hidden_states = layer(hidden_states, mask, cache=c)

        return self.ln_f(hidden_states)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.model_type = args.model_type
        self.model = GPT2Model(args)

    def __call__(
        self,
        inputs: mx.array,
        mask: mx.array = None,
        cache=None,
    ):
        out = self.model(inputs, mask, cache)
        out = self.model.wte.as_linear(out)
        return out

    def sanitize(self, weights):
        new_weights = {}
        for i in range(self.args.n_layer):
            if f"h.{i}.attn.bias" in weights:
                del weights[f"h.{i}.attn.bias"]
            if f"h.{i}.attn.c_attn.weight" in weights:
                weights[f"h.{i}.attn.c_attn.weight"] = weights[
                    f"h.{i}.attn.c_attn.weight"
                ].transpose(1, 0)
            if f"h.{i}.attn.c_proj.weight" in weights:
                weights[f"h.{i}.attn.c_proj.weight"] = weights[
                    f"h.{i}.attn.c_proj.weight"
                ].transpose(1, 0)
            if f"h.{i}.mlp.c_fc.weight" in weights:
                weights[f"h.{i}.mlp.c_fc.weight"] = weights[
                    f"h.{i}.mlp.c_fc.weight"
                ].transpose(1, 0)
            if f"h.{i}.mlp.c_proj.weight" in weights:
                weights[f"h.{i}.mlp.c_proj.weight"] = weights[
                    f"h.{i}.mlp.c_proj.weight"
                ].transpose(1, 0)
        for weight in weights:
            if not weight.startswith("model."):
                new_weights[f"model.{weight}"] = weights[weight]
            else:
                new_weights[weight] = weights[weight]
        return new_weights

    @property
    def layers(self):
        return self.model.h

I wanted to provide as little information as possible in the prompt to not steer the model, but to also relay some of the problematic behaviors. My testing is by no means intended to be comprehensive and the models may fare better or worse for other problems. Two patterns emerge from the testing results:

  1. The model correctly identifies the problem as being related to the positional embeddings not being added after the prefill/decode.
  2. The model is confused and assumes the KV Cache implementation has a bug or that the KV Cache is being used incorrectly in the attention layers.

Correct solution from o1-2024-12-17:

A key issue is that in your GPT‐2 forward pass, you only add position embeddings and create a causal mask when “hidden_states.shape[1] > 1.” In a typical “one‐token‐at‐a‐time” generation loop (when using a KV cache), each newly generated token has shape [B, 1]. Because of the if hidden_states.shape[1] > 1 check, the code entirely skips adding the new position embedding and never updates the attention mask for that new single token. That mismatch quickly leads to “garbage” outputs after the first step.
In other words, inside GPT2Model.call:

if hidden_states.shape[1] > 1:
	position_ids = mx.array(np.arange(L))
	hidden_states += self.wpe(position_ids)

if mask is None:
    mask = create_attention_mask(hidden_states, cache)

is never invoked on subsequent tokens (where L = 1). To fix this, you must always add the correct position embedding to every token in the sequence—whether you are calling the model on N tokens at once or just on a single newly generated token. You also need to ensure that you create/apply a proper causal mask for incremental decoding.
A standard fix is something along these lines:
• Keep track of how many tokens have been generated so far (e.g., total_seq_len), then compute the new position IDs as range(total_seq_len, total_seq_len + L) rather than just range(L). • Always add position embeddings for the newly generated token (i.e., remove or adjust the if hidden_states.shape[1] > 1 condition). • Always create/apply the causal mask for the new token (so the new position can only attend to previous tokens).
Once that is done, single‐token inference (with KV cache) will produce correct next‐token logits instead of garbled output.

Here is an example of an incorrect solution from llama-3.3-70b-instruct:

# Step-by-step analysis of the problem:
1. **Token Generation Issue**: The issue here seems to be related to the token generation process, specifically when using a KV cache for the GPT2 implementation.
2. **KV Cache Implementation**: The KV cache is used in the attention mechanism to store and update the keys and values for each layer. However, it appears that there might be an issue with how the cache is being updated or accessed.
3. **Garbage Output After First Token**: This suggests that the problem lies in the way the model handles the input sequence beyond the first token.

# Fixed solution:
To fix this issue, we need to review and potentially modify the `cache.update_and_fetch(keys, values)` line in the Attention class's `__call__` method. 

Here's an updated version of the `Attention` class:

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

        self.n_embd = args.n_embd
        self.n_head = args.n_head
        self.head_dim = self.n_embd // self.n_head

        self.scale = self.head_dim**-0.5

        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        B, L, D = x.shape

        qkv = self.c_attn(x)
        queries, keys, values = mx.split(qkv, 3, axis=-1)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

        if cache is not None:
            # Check if the cache has been initialized
            if isinstance(cache, dict):
                keys_cache = cache.get('keys', mx.zeros_like(keys))
                values_cache = cache.get('values', mx.zeros_like(values))
                
                # Update the cache with new keys and values
                keys = mx.cat([keys_cache[:, :, -L:, :], keys], dim=-2)
                values = mx.cat([values_cache[:, :, -L:, :], values], dim=-2)
                
                # Update the cache for future use
                cache['keys'] = keys
                cache['values'] = values
                
            else:
                raise ValueError("Cache should be a dictionary")
        else:
            # If no cache is provided, do not update or fetch any cached values
            pass

        output = scaled_dot_product_attention(
            queries, keys, values, scale=self.scale, mask=mask
        )

        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.c_proj(output)

And here is an updated version of the `TransformerBlock` class:

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_head = args.n_head
        self.n_embd = args.n_embd
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.ln_1 = nn.LayerNorm(
            self.n_embd,
            eps=self.layer_norm_epsilon,
        )
        self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -> mx.array:
        r = self.attn(self.ln_1(x), mask, cache)
        h = x + r
        r = self.mlp(self.ln_2(h))
        out = h + r
        
        # Initialize or update the cache for future use
        if cache is None:
            cache = {'keys': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim)), 
                     'values': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim))}
        else:
            pass
            
        return out

# Explanation of changes:
*   **Cache Management**: The primary change made was in how the cache is managed within the `Attention` class. Now, when a cache is provided, it checks if the cache has been initialized and updates it accordingly.
*   **Key and Value Concatenation**: In the updated code, we concatenate the cached keys and values with the new ones before using them for attention computation.

# Tests and example uses:
You can test this implementation by providing a sample input sequence to your `Model` class. Make sure that you have properly initialized all necessary variables and modules.

Here is an example of how you might use it:

args = ModelArgs(model_type='gpt2', n_ctx=1024, n_embd=512, n_head=8, n_layer=12, 
                 n_positions=1024, layer_norm_epsilon=1e-5, vocab_size=50257)

model = Model(args)
input_ids = mx.array(np.random.randint(0, 100, size=(32, 128))) # replace with your own input

output = model(input_ids)
print(output.shape) # Should print (32, 128, 50257)

Here’s a table of my testing results:

Model Pass/Fail Notes
o1-2024-12-17 (medium reasoning) Pass ✅  
gpt-4.5 Fail ❌ Identified cause correctly, but wrong fix
o4-mini-2025-04-16 (high reasoning) Pass ✅  
gpt-4.1 Pass ✅  
gpt-4o-2024-08-06 Fail ❌  
llama-3.3-70b-instruct Fail ❌  

In short, reasoning models fared far better than non-reasoning models. Also, I was surprised to see gpt-4.5 fail, while gpt-4.1 succeeded - further proof of why OpenAI may be sunsetting gpt-4.5 so quickly.