Recent Experiences Debugging with LLMs
I’m frequently asked by clients what my thoughts are on LLMs and coding. Personal experience has informed me that LLMs cannot solve problems of a certain complexity for a number of reasons. One of the more common ones is codebase size and complexity: a large codebase with many more tokens than even the best SOTA models with the longest context windows can support (see RULER paper, too). This typically requires some type of semantic search or RAG to search the codebase and usually results in the model having subpar performance while addressing a user’s request for a bugfix or enhancement.
Secondly, there’s the point of problem complexity, which is the subject of this post. I recently fixed a bug in the mlx-lm
GPT2 model implementation that I submitted to the MLX project last year. If you’re not familiar with MLX, it is an ML and deep‑learning framework similar to PyTorch, but specific to Apple Silicon. The pull request with information on the bug can be viewed here. In summary, when generating tokens from the GPT2 model using the inference code that includes a KV Cache in mlx-lm
, the output was garbled with repetitive text for the smaller GPT2 models like the 128M “small” version. At the time I submitted the original PR, I was mostly testing with the gpt2-xl
1.6B parameter version, which generated output more or less identical to the HuggingFace transformers GPT2 implementation.
Upon analyzing the code, I noticed that positional embeddings were not being added after the initial forward pass of the model, which meant any subsequent tokens effectively had a position of 0 - not ideal! The smaller GPT2 models seem to be more sensitive to the positional embedding layer, while the larger GPT2 models like gpt2-xl
still generate acceptable output without it. I would assume that during training the gpt2-xl
model with its additional parameters learned positional information in the transformer block layers.
In any case, the fix was relatively simple to implement, but harder to spot initially, so I decided to test various LLMs to see if they would have identified the cause or not. The results were interesting. The following prompt was used for testing:
the token generation w/ KV cache for gpt2 implementation generates garbage output after the first token. W/o KV cache it works fine. Here's the implementation code:
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_ctx: int
n_embd: int
n_head: int
n_layer: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim**-0.5
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPT2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.n_positions = args.n_positions
self.vocab_size = args.vocab_size
self.n_layer = args.n_layer
self.layer_norm_epsilon = args.layer_norm_epsilon
assert self.vocab_size > 0
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
_, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPT2Model(args)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
out = self.model.wte.as_linear(out)
return out
def sanitize(self, weights):
new_weights = {}
for i in range(self.args.n_layer):
if f"h.{i}.attn.bias" in weights:
del weights[f"h.{i}.attn.bias"]
if f"h.{i}.attn.c_attn.weight" in weights:
weights[f"h.{i}.attn.c_attn.weight"] = weights[
f"h.{i}.attn.c_attn.weight"
].transpose(1, 0)
if f"h.{i}.attn.c_proj.weight" in weights:
weights[f"h.{i}.attn.c_proj.weight"] = weights[
f"h.{i}.attn.c_proj.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_fc.weight" in weights:
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
f"h.{i}.mlp.c_fc.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_proj.weight" in weights:
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
f"h.{i}.mlp.c_proj.weight"
].transpose(1, 0)
for weight in weights:
if not weight.startswith("model."):
new_weights[f"model.{weight}"] = weights[weight]
else:
new_weights[weight] = weights[weight]
return new_weights
@property
def layers(self):
return self.model.h
I wanted to provide as little information as possible in the prompt to not steer the model, but to also relay some of the problematic behaviors. My testing is by no means intended to be comprehensive and the models may fare better or worse for other problems. Two patterns emerge from the testing results:
- The model correctly identifies the problem as being related to the positional embeddings not being added after the prefill/decode.
- The model is confused and assumes the KV Cache implementation has a bug or that the KV Cache is being used incorrectly in the attention layers.
Correct solution from o1-2024-12-17
:
A key issue is that in your GPT‐2 forward pass, you only add position embeddings and create a causal mask when “hidden_states.shape[1] > 1.” In a typical “one‐token‐at‐a‐time” generation loop (when using a KV cache), each newly generated token has shape [B, 1]. Because of the if hidden_states.shape[1] > 1 check, the code entirely skips adding the new position embedding and never updates the attention mask for that new single token. That mismatch quickly leads to “garbage” outputs after the first step.
In other words, inside GPT2Model.call:
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
is never invoked on subsequent tokens (where L = 1). To fix this, you must always add the correct position embedding to every token in the sequence—whether you are calling the model on N tokens at once or just on a single newly generated token. You also need to ensure that you create/apply a proper causal mask for incremental decoding.
A standard fix is something along these lines:
• Keep track of how many tokens have been generated so far (e.g., total_seq_len), then compute the new position IDs as range(total_seq_len, total_seq_len + L) rather than just range(L). • Always add position embeddings for the newly generated token (i.e., remove or adjust the if hidden_states.shape[1] > 1 condition). • Always create/apply the causal mask for the new token (so the new position can only attend to previous tokens).
Once that is done, single‐token inference (with KV cache) will produce correct next‐token logits instead of garbled output.
Here is an example of an incorrect solution from llama-3.3-70b-instruct
:
# Step-by-step analysis of the problem:
1. **Token Generation Issue**: The issue here seems to be related to the token generation process, specifically when using a KV cache for the GPT2 implementation.
2. **KV Cache Implementation**: The KV cache is used in the attention mechanism to store and update the keys and values for each layer. However, it appears that there might be an issue with how the cache is being updated or accessed.
3. **Garbage Output After First Token**: This suggests that the problem lies in the way the model handles the input sequence beyond the first token.
# Fixed solution:
To fix this issue, we need to review and potentially modify the `cache.update_and_fetch(keys, values)` line in the Attention class's `__call__` method.
Here's an updated version of the `Attention` class:
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim**-0.5
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
if cache is not None:
# Check if the cache has been initialized
if isinstance(cache, dict):
keys_cache = cache.get('keys', mx.zeros_like(keys))
values_cache = cache.get('values', mx.zeros_like(values))
# Update the cache with new keys and values
keys = mx.cat([keys_cache[:, :, -L:, :], keys], dim=-2)
values = mx.cat([values_cache[:, :, -L:, :], values], dim=-2)
# Update the cache for future use
cache['keys'] = keys
cache['values'] = values
else:
raise ValueError("Cache should be a dictionary")
else:
# If no cache is provided, do not update or fetch any cached values
pass
output = scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
And here is an updated version of the `TransformerBlock` class:
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
# Initialize or update the cache for future use
if cache is None:
cache = {'keys': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim)),
'values': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim))}
else:
pass
return out
# Explanation of changes:
* **Cache Management**: The primary change made was in how the cache is managed within the `Attention` class. Now, when a cache is provided, it checks if the cache has been initialized and updates it accordingly.
* **Key and Value Concatenation**: In the updated code, we concatenate the cached keys and values with the new ones before using them for attention computation.
# Tests and example uses:
You can test this implementation by providing a sample input sequence to your `Model` class. Make sure that you have properly initialized all necessary variables and modules.
Here is an example of how you might use it:
args = ModelArgs(model_type='gpt2', n_ctx=1024, n_embd=512, n_head=8, n_layer=12,
n_positions=1024, layer_norm_epsilon=1e-5, vocab_size=50257)
model = Model(args)
input_ids = mx.array(np.random.randint(0, 100, size=(32, 128))) # replace with your own input
output = model(input_ids)
print(output.shape) # Should print (32, 128, 50257)
Here’s a table of my testing results:
Model | Pass/Fail | Notes |
---|---|---|
o1-2024-12-17 (medium reasoning) |
Pass ✅ | |
gpt-4.5 |
Fail ❌ | Identified cause correctly, but wrong fix |
o4-mini-2025-04-16 (high reasoning) |
Pass ✅ | |
gpt-4.1 |
Pass ✅ | |
gpt-4o-2024-08-06 |
Fail ❌ | |
llama-3.3-70b-instruct |
Fail ❌ |
In short, reasoning models fared far better than non-reasoning models. Also, I was surprised to see gpt-4.5
fail, while gpt-4.1
succeeded - further proof of why OpenAI may be sunsetting gpt-4.5
so quickly.