<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="https://dlewis.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://dlewis.io/" rel="alternate" type="text/html" /><updated>2026-04-07T19:16:20+00:00</updated><id>https://dlewis.io/feed.xml</id><title type="html">Derek’s Blog</title><subtitle>My blog about GenAI, programming, and other tech interests</subtitle><author><name>Derek Lewis</name></author><entry><title type="html">Evaluating Llama‑3.3‑70B Inference on NVIDIA H100 and A100 GPUs</title><link href="https://dlewis.io/evaluating-llama-33-70b-inference-h100-a100/" rel="alternate" type="text/html" title="Evaluating Llama‑3.3‑70B Inference on NVIDIA H100 and A100 GPUs" /><published>2025-04-17T00:00:00+00:00</published><updated>2025-04-17T00:00:00+00:00</updated><id>https://dlewis.io/evaluating-llama-33-70b-inference-h100-a100</id><content type="html" xml:base="https://dlewis.io/evaluating-llama-33-70b-inference-h100-a100/"><![CDATA[<p>Large‑scale language models quickly expose the limits of yesterday’s hardware. To understand how much <em>practical</em> head‑room Hopper offers over Ampere in a production‑style setting, I profiled <code class="language-plaintext highlighter-rouge">llama-3.3-70b-instruct</code> on two 4‑GPU hosts—one populated with A100 80GB (PCIe), the other with H100 80GB (SXM5). Inference was served via NVIDIA NIM using the default TensorRT‑LLM profiles (<code class="language-plaintext highlighter-rouge">TP = 4</code>, <code class="language-plaintext highlighter-rouge">PP = 1</code>, <code class="language-plaintext highlighter-rouge">bfloat16</code>).</p>

<p>Workloads and measurement were driven by one of NVIDIA’s benchmarking tools, <code class="language-plaintext highlighter-rouge">genai‑perf</code>; the accompanying charts were produced directly from <code class="language-plaintext highlighter-rouge">benchmark.py</code> with no manual post‑processing. The GitHub repository for the benchmark code
and raw data can be found in <a href="https://github.com/derekelewis/ml-notes/tree/main/benchmarks/genai_perf_benchmarks">this GitHub repo</a>. <code class="language-plaintext highlighter-rouge">genai-perf</code> measures several useful metrics, such as Time-To-First-Token (TTFT), Inter-Token Latency, and Tokens/Second - all driven by synthetically generated prompts for various patterned workloads.</p>

<hr />

<h2 id="test-methodology">Test Methodology</h2>

<table>
  <thead>
    <tr>
      <th>Dimension</th>
      <th>Setting</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><strong>Model</strong></td>
      <td><code class="language-plaintext highlighter-rouge">llama‑3.3‑70b‑instruct</code></td>
    </tr>
    <tr>
      <td><strong>Container</strong></td>
      <td><code class="language-plaintext highlighter-rouge">nvcr.io/nim/meta/llama-3.3-70b-instruct:1.8.2</code></td>
    </tr>
    <tr>
      <td><strong>Precision</strong></td>
      <td>bf16</td>
    </tr>
    <tr>
      <td><strong>Parallelism</strong></td>
      <td>Tensor parallelism = 4, Pipeline parallelism = 1</td>
    </tr>
    <tr>
      <td><strong>Traffic model</strong></td>
      <td>Synthetic prompts via <code class="language-plaintext highlighter-rouge">genai-perf</code><br />• 200 → 200 tokens (translation / Q&amp;A)<br />• 1,000 → 200 tokens (summarization)</td>
    </tr>
    <tr>
      <td><strong>Concurrency sweep</strong></td>
      <td>1, 2, 5, 10, 50, 100, 250, 500 users</td>
    </tr>
    <tr>
      <td><strong>Metrics captured</strong></td>
      <td>• Total Tokens / Second (TPS)<br />• Median Time‑To‑First‑Token (TTFT)</td>
    </tr>
  </tbody>
</table>

<hr />

<h2 id="throughput-results">Throughput Results</h2>

<h3 id="200200tokens">200 → 200 tokens</h3>
<p><img src="/assets/images/a100h100chart1.png" alt="Performance (200/200) – TPS vs TTFT" /></p>

<ul>
  <li>H100 scaled almost linearly up to 500 users, peaking at ≈ 7,000 TPS.</li>
  <li>A100 saturated near ≈ 570 TPS and 50 users; additional users primarily increased queueing delay.</li>
</ul>

<p>This corresponds to a ≈ 12–14× throughput advantage for the H100 configuration across the sweep.</p>

<h3 id="1000200tokens">1,000 → 200 tokens</h3>
<p><img src="/assets/images/a100h100chart2.png" alt="Performance (1,000/200) – TPS vs TTFT" /></p>

<p>Longer inputs magnify memory pressure during decoding, yet the relative gap remains wide:</p>

<ul>
  <li>H100 delivered ≈ 2,600 TPS at 250 concurrent users.</li>
  <li>A100 remained under ≈ 230 TPS at the same load.</li>
</ul>

<h2 id="latency-under-load">Latency Under Load</h2>

<h3 id="200200tokens-1">200 → 200 tokens</h3>
<p><img src="/assets/images/a100h100chart3.png" alt="TTFT vs Concurrency (200/200)" /></p>

<ul>
  <li>H100 preserved a less than 5s TTFT up to 500 simultaneous sessions.</li>
  <li>A100 quickly surpasses &gt;10s TTFT at 100 users and climbs quickly under high load.</li>
</ul>

<h3 id="1000200tokens-1">1,000 → 200 tokens</h3>
<p><img src="/assets/images/a100h100chart4.png" alt="TTFT vs Concurrency (1,000/200)" /></p>

<ul>
  <li>H100 maintains &lt;5s TTFT latency out to the full user sweep - indiciating, additional capacity beyond 500 users.</li>
  <li>A100 TTFT quickly spikes 10-20 concurrent users above 5s TTFT latency.</li>
</ul>

<h2 id="discussion">Discussion</h2>

<ol>
  <li>Hopper has numerous enhancements over Ampere aside from just faster clock-speeds, additional SMs, and memory bandwidth with HBM3 vs HBM2e. Those enhancements include Tensor Memory Accelerators (TMAs), FP8, and Transformer Engine. All of these lead to significantly increased throughput and lower latencies at higher user counts.</li>
  <li>In most cases, H100s will be acheive target latency and concurrent user count requirements at a lower price than A100s. Taking the current cost of an on-demand AWS <code class="language-plaintext highlighter-rouge">p4de.24xlarge</code> instance with 8xA100s compared to the <code class="language-plaintext highlighter-rouge">p5.48xlarge</code>, the cost increase the two is 2.4x, but the throughput increase can be as high as 14x, which justifies the premium. Likewise, it would take 13 instances <code class="language-plaintext highlighter-rouge">p4de.24xlarge</code> instances to match the throughput of the <code class="language-plaintext highlighter-rouge">p5.48xlarge</code> instances.</li>
  <li>Longer input &amp; output sequence lengths will increase latency and decrease concurrency/throughput.</li>
  <li>Blackwell will offer a similar and more pronounced multiple increase in throughput compared to Hopper. Hopefully, I will be in a position to do Blackwell testing soon as hardware becomes more available.</li>
  <li>Although the <code class="language-plaintext highlighter-rouge">llama-3.3-70b</code> model will fit on 2x80GB A100s or H100s at bfloat16 precision, it leaves very little room for a KV Cache, so the supported minimum per the NVIDIA NIM documentation is 4x80GB A100s or H100s.</li>
</ol>

<h2 id="key-takeaways">Key Takeaways</h2>

<ol>
  <li><strong>Hopper’s advantage is multiplicative, not incremental.</strong> In interactive‑LLM scenarios the H100 system provided an order‑of‑magnitude gain in both throughput and p50 latency.</li>
  <li><strong>Latency SLOs drive hardware choice.</strong> If your service promises sub‑second first‑token latency under real‑world concurrency, A100 will need disproportionate over‑provisioning or sharding; H100 meets the target comfortably.</li>
  <li><strong>Identify Input/Outut Sequence Length (ISL/OSL) based on workload.</strong> Input/output sequence lengths vary widely based on task, such as summarization. Reasoning models also change the dynamic with possibly far more output tokens for what might be a relatively short input sequence.</li>
  <li><strong>bf16 is the sensible default.</strong> It balances memory footprint with performance on both architectures and is the path of least resistance inside NIM. At this time <code class="language-plaintext highlighter-rouge">fp8</code> or <code class="language-plaintext highlighter-rouge">int4</code> profiles are not available from the <code class="language-plaintext highlighter-rouge">llama-3.3-70b-instruct</code> model with NVIDIA NIM.</li>
  <li><strong>Measure under expected concurrency.</strong> Single‑request benchmarks tell an incomplete story; backlog effects dominate once users count climbs. 5 seconds is my personal rule of thumb for TTFT.</li>
</ol>

<h2 id="reproducing-the-experiment">Reproducing the Experiment</h2>

<p>The benchmarks can be reproduced by running using the <code class="language-plaintext highlighter-rouge">start_nim.sh</code> &amp; <code class="language-plaintext highlighter-rouge">benchmark.sh</code> scripts provided in the above GitHub repository. The <a href="https://docs.nvidia.com/nim/benchmarking/llm/latest/index.html">“NVIDIA NIM LLMs Benchmarking” documentation</a> on <code class="language-plaintext highlighter-rouge">genai-perf</code> is also an excellent resource with some great insights into the various benchmarking metrics, as well as some of NVIDIA’s own benchmarking data with various ISL/OSLs and NIM profiles.</p>]]></content><author><name>Derek Lewis</name></author><category term="nvidia" /><category term="nim" /><category term="tensorrt‑llm" /><category term="genai‑perf" /><category term="llama" /><category term="llm" /><category term="genai" /><category term="benchmark" /><summary type="html"><![CDATA[Large‑scale language models quickly expose the limits of yesterday’s hardware. To understand how much practical head‑room Hopper offers over Ampere in a production‑style setting, I profiled llama-3.3-70b-instruct on two 4‑GPU hosts—one populated with A100 80GB (PCIe), the other with H100 80GB (SXM5). Inference was served via NVIDIA NIM using the default TensorRT‑LLM profiles (TP = 4, PP = 1, bfloat16).]]></summary></entry><entry><title type="html">Recent Experiences Debugging with LLMs</title><link href="https://dlewis.io/recent-experiences-debugging-with-llms/" rel="alternate" type="text/html" title="Recent Experiences Debugging with LLMs" /><published>2025-04-16T00:00:00+00:00</published><updated>2025-04-16T00:00:00+00:00</updated><id>https://dlewis.io/recent-experiences-debugging-with-llms</id><content type="html" xml:base="https://dlewis.io/recent-experiences-debugging-with-llms/"><![CDATA[<p>I’m frequently asked by clients what my thoughts are on LLMs and coding. Personal experience has informed me that LLMs cannot solve problems of a certain complexity for a number of reasons. One of the more common ones is codebase size and complexity: a large codebase with many more tokens than even the best SOTA models with the longest context windows can support (see <a href="https://arxiv.org/abs/2404.06654">RULER paper</a>, too). This typically requires some type of semantic search or RAG to search the codebase and usually results in the model having subpar performance while addressing a user’s request for a bugfix or enhancement.</p>

<p>Secondly, there’s the point of problem complexity, which is the subject of this post. I recently fixed a bug in the <code class="language-plaintext highlighter-rouge">mlx-lm</code> GPT2 model implementation that I submitted to the MLX project last year. If you’re not familiar with MLX, it is an ML and deep‑learning framework similar to PyTorch, but specific to Apple Silicon. The pull request with information on the bug can be viewed <a href="https://github.com/ml-explore/mlx-lm/pull/93">here</a>. In summary, when generating tokens from the GPT2 model using the inference code that includes a KV Cache in <code class="language-plaintext highlighter-rouge">mlx-lm</code>, the output was garbled with repetitive text for the smaller GPT2 models like the 128M “small” version. At the time I submitted the original PR, I was mostly testing with the <code class="language-plaintext highlighter-rouge">gpt2-xl</code> 1.6B parameter version, which generated output more or less identical to the HuggingFace transformers GPT2 implementation.</p>

<p>Upon analyzing the code, I noticed that positional embeddings were not being added after the initial forward pass of the model, which meant any subsequent tokens effectively had a position of 0 - not ideal! The smaller GPT2 models seem to be more sensitive to the positional embedding layer, while the larger GPT2 models like <code class="language-plaintext highlighter-rouge">gpt2-xl</code> still generate acceptable output without it. I would assume that during training the <code class="language-plaintext highlighter-rouge">gpt2-xl</code> model with its additional parameters learned positional information in the transformer block layers.</p>

<p>In any case, the fix was relatively simple to implement, but harder to spot initially, so I decided to test various LLMs to see if they would have identified the cause or not. The results were interesting. The following prompt was used for testing:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>the token generation w/ KV cache for gpt2 implementation generates garbage output after the first token. W/o KV cache it works fine. Here's the implementation code:

# Copyright © 2023-2024 Apple Inc.

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention


@dataclass
class ModelArgs(BaseModelArgs):
    model_type: str
    n_ctx: int
    n_embd: int
    n_head: int
    n_layer: int
    n_positions: int
    layer_norm_epsilon: float
    vocab_size: int
    num_key_value_heads: int = None

    def __post_init__(self):
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.n_head


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

        self.n_embd = args.n_embd
        self.n_head = args.n_head
        self.head_dim = self.n_embd // self.n_head

        self.scale = self.head_dim**-0.5

        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -&gt; mx.array:
        B, L, D = x.shape

        qkv = self.c_attn(x)
        queries, keys, values = mx.split(qkv, 3, axis=-1)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

        if cache is not None:
            keys, values = cache.update_and_fetch(keys, values)

        output = scaled_dot_product_attention(
            queries, keys, values, cache=cache, scale=self.scale, mask=mask
        )

        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.c_proj(output)


class MLP(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_embd = args.n_embd
        self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
        self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)

    def __call__(self, x) -&gt; mx.array:
        return self.c_proj(nn.gelu_approx(self.c_fc(x)))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_head = args.n_head
        self.n_embd = args.n_embd
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.ln_1 = nn.LayerNorm(
            self.n_embd,
            eps=self.layer_norm_epsilon,
        )
        self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -&gt; mx.array:
        r = self.attn(self.ln_1(x), mask, cache)
        h = x + r
        r = self.mlp(self.ln_2(h))
        out = h + r
        return out


class GPT2Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_embd = args.n_embd
        self.n_positions = args.n_positions
        self.vocab_size = args.vocab_size
        self.n_layer = args.n_layer
        self.layer_norm_epsilon = args.layer_norm_epsilon
        assert self.vocab_size &gt; 0
        self.wte = nn.Embedding(self.vocab_size, self.n_embd)
        self.wpe = nn.Embedding(self.n_positions, self.n_embd)
        self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
        self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        inputs: mx.array,
        mask: mx.array = None,
        cache=None,
    ):
        _, L = inputs.shape

        hidden_states = self.wte(inputs)

        mask = None
        if hidden_states.shape[1] &gt; 1:

            position_ids = mx.array(np.arange(L))
            hidden_states += self.wpe(position_ids)

            if mask is None:
                mask = create_attention_mask(hidden_states, cache)

        if cache is None:
            cache = [None] * len(self.h)

        for layer, c in zip(self.h, cache):
            hidden_states = layer(hidden_states, mask, cache=c)

        return self.ln_f(hidden_states)


class Model(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.model_type = args.model_type
        self.model = GPT2Model(args)

    def __call__(
        self,
        inputs: mx.array,
        mask: mx.array = None,
        cache=None,
    ):
        out = self.model(inputs, mask, cache)
        out = self.model.wte.as_linear(out)
        return out

    def sanitize(self, weights):
        new_weights = {}
        for i in range(self.args.n_layer):
            if f"h.{i}.attn.bias" in weights:
                del weights[f"h.{i}.attn.bias"]
            if f"h.{i}.attn.c_attn.weight" in weights:
                weights[f"h.{i}.attn.c_attn.weight"] = weights[
                    f"h.{i}.attn.c_attn.weight"
                ].transpose(1, 0)
            if f"h.{i}.attn.c_proj.weight" in weights:
                weights[f"h.{i}.attn.c_proj.weight"] = weights[
                    f"h.{i}.attn.c_proj.weight"
                ].transpose(1, 0)
            if f"h.{i}.mlp.c_fc.weight" in weights:
                weights[f"h.{i}.mlp.c_fc.weight"] = weights[
                    f"h.{i}.mlp.c_fc.weight"
                ].transpose(1, 0)
            if f"h.{i}.mlp.c_proj.weight" in weights:
                weights[f"h.{i}.mlp.c_proj.weight"] = weights[
                    f"h.{i}.mlp.c_proj.weight"
                ].transpose(1, 0)
        for weight in weights:
            if not weight.startswith("model."):
                new_weights[f"model.{weight}"] = weights[weight]
            else:
                new_weights[weight] = weights[weight]
        return new_weights

    @property
    def layers(self):
        return self.model.h
</code></pre></div></div>

<p>I wanted to provide as little information as possible in the prompt to not steer the model, but to also relay some of the problematic behaviors. My testing is by no means intended to be comprehensive and the models may fare better or worse for other problems. Two patterns emerge from the testing results:</p>

<ol>
  <li>The model correctly identifies the problem as being related to the positional embeddings not being added after the prefill/decode.</li>
  <li>The model is confused and assumes the KV Cache implementation has a bug or that the KV Cache is being used incorrectly in the attention layers.</li>
</ol>

<p>Correct solution from <code class="language-plaintext highlighter-rouge">o1-2024-12-17</code>:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>A key issue is that in your GPT‐2 forward pass, you only add position embeddings and create a causal mask when “hidden_states.shape[1] &gt; 1.” In a typical “one‐token‐at‐a‐time” generation loop (when using a KV cache), each newly generated token has shape [B, 1]. Because of the if hidden_states.shape[1] &gt; 1 check, the code entirely skips adding the new position embedding and never updates the attention mask for that new single token. That mismatch quickly leads to “garbage” outputs after the first step.
In other words, inside GPT2Model.call:

if hidden_states.shape[1] &gt; 1:
	position_ids = mx.array(np.arange(L))
	hidden_states += self.wpe(position_ids)

if mask is None:
    mask = create_attention_mask(hidden_states, cache)

is never invoked on subsequent tokens (where L = 1). To fix this, you must always add the correct position embedding to every token in the sequence—whether you are calling the model on N tokens at once or just on a single newly generated token. You also need to ensure that you create/apply a proper causal mask for incremental decoding.
A standard fix is something along these lines:
• Keep track of how many tokens have been generated so far (e.g., total_seq_len), then compute the new position IDs as range(total_seq_len, total_seq_len + L) rather than just range(L). • Always add position embeddings for the newly generated token (i.e., remove or adjust the if hidden_states.shape[1] &gt; 1 condition). • Always create/apply the causal mask for the new token (so the new position can only attend to previous tokens).
Once that is done, single‐token inference (with KV cache) will produce correct next‐token logits instead of garbled output.
</code></pre></div></div>

<p>Here is an example of an incorrect solution from <code class="language-plaintext highlighter-rouge">llama-3.3-70b-instruct</code>:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code># Step-by-step analysis of the problem:
1. **Token Generation Issue**: The issue here seems to be related to the token generation process, specifically when using a KV cache for the GPT2 implementation.
2. **KV Cache Implementation**: The KV cache is used in the attention mechanism to store and update the keys and values for each layer. However, it appears that there might be an issue with how the cache is being updated or accessed.
3. **Garbage Output After First Token**: This suggests that the problem lies in the way the model handles the input sequence beyond the first token.

# Fixed solution:
To fix this issue, we need to review and potentially modify the `cache.update_and_fetch(keys, values)` line in the Attention class's `__call__` method. 

Here's an updated version of the `Attention` class:

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

        self.n_embd = args.n_embd
        self.n_head = args.n_head
        self.head_dim = self.n_embd // self.n_head

        self.scale = self.head_dim**-0.5

        self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -&gt; mx.array:
        B, L, D = x.shape

        qkv = self.c_attn(x)
        queries, keys, values = mx.split(qkv, 3, axis=-1)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

        if cache is not None:
            # Check if the cache has been initialized
            if isinstance(cache, dict):
                keys_cache = cache.get('keys', mx.zeros_like(keys))
                values_cache = cache.get('values', mx.zeros_like(values))
                
                # Update the cache with new keys and values
                keys = mx.cat([keys_cache[:, :, -L:, :], keys], dim=-2)
                values = mx.cat([values_cache[:, :, -L:, :], values], dim=-2)
                
                # Update the cache for future use
                cache['keys'] = keys
                cache['values'] = values
                
            else:
                raise ValueError("Cache should be a dictionary")
        else:
            # If no cache is provided, do not update or fetch any cached values
            pass

        output = scaled_dot_product_attention(
            queries, keys, values, scale=self.scale, mask=mask
        )

        output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.c_proj(output)

And here is an updated version of the `TransformerBlock` class:

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_head = args.n_head
        self.n_embd = args.n_embd
        self.layer_norm_epsilon = args.layer_norm_epsilon
        self.attn = Attention(args)
        self.mlp = MLP(args)
        self.ln_1 = nn.LayerNorm(
            self.n_embd,
            eps=self.layer_norm_epsilon,
        )
        self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Any] = None,
    ) -&gt; mx.array:
        r = self.attn(self.ln_1(x), mask, cache)
        h = x + r
        r = self.mlp(self.ln_2(h))
        out = h + r
        
        # Initialize or update the cache for future use
        if cache is None:
            cache = {'keys': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim)), 
                     'values': mx.zeros((x.shape[0], self.n_head, 0, self.head_dim))}
        else:
            pass
            
        return out

# Explanation of changes:
*   **Cache Management**: The primary change made was in how the cache is managed within the `Attention` class. Now, when a cache is provided, it checks if the cache has been initialized and updates it accordingly.
*   **Key and Value Concatenation**: In the updated code, we concatenate the cached keys and values with the new ones before using them for attention computation.

# Tests and example uses:
You can test this implementation by providing a sample input sequence to your `Model` class. Make sure that you have properly initialized all necessary variables and modules.

Here is an example of how you might use it:

args = ModelArgs(model_type='gpt2', n_ctx=1024, n_embd=512, n_head=8, n_layer=12, 
                 n_positions=1024, layer_norm_epsilon=1e-5, vocab_size=50257)

model = Model(args)
input_ids = mx.array(np.random.randint(0, 100, size=(32, 128))) # replace with your own input

output = model(input_ids)
print(output.shape) # Should print (32, 128, 50257)
</code></pre></div></div>

<p>Here’s a table of my testing results:</p>

<table>
  <thead>
    <tr>
      <th>Model</th>
      <th>Pass/Fail</th>
      <th>Notes</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">o1-2024-12-17</code> (medium reasoning)</td>
      <td>Pass ✅</td>
      <td> </td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">gpt-4.5</code></td>
      <td>Fail ❌</td>
      <td>Identified cause correctly, but wrong fix</td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">o4-mini-2025-04-16</code> (high reasoning)</td>
      <td>Pass ✅</td>
      <td> </td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">gpt-4.1</code></td>
      <td>Pass ✅</td>
      <td> </td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">gpt-4o-2024-08-06</code></td>
      <td>Fail ❌</td>
      <td> </td>
    </tr>
    <tr>
      <td><code class="language-plaintext highlighter-rouge">llama-3.3-70b-instruct</code></td>
      <td>Fail ❌</td>
      <td> </td>
    </tr>
  </tbody>
</table>

<p>In short, reasoning models fared far better than non-reasoning models. Also, I was surprised to see <code class="language-plaintext highlighter-rouge">gpt-4.5</code> fail, while <code class="language-plaintext highlighter-rouge">gpt-4.1</code> succeeded - further proof of why OpenAI may be sunsetting <code class="language-plaintext highlighter-rouge">gpt-4.5</code> so quickly.</p>]]></content><author><name>Derek Lewis</name></author><category term="openai" /><category term="llm" /><category term="genai" /><category term="mlx" /><category term="llama" /><summary type="html"><![CDATA[I’m frequently asked by clients what my thoughts are on LLMs and coding. Personal experience has informed me that LLMs cannot solve problems of a certain complexity for a number of reasons. One of the more common ones is codebase size and complexity: a large codebase with many more tokens than even the best SOTA models with the longest context windows can support (see RULER paper, too). This typically requires some type of semantic search or RAG to search the codebase and usually results in the model having subpar performance while addressing a user’s request for a bugfix or enhancement.]]></summary></entry><entry><title type="html">Building the IBKR C++ API Client Library</title><link href="https://dlewis.io/ibkr-cpp-api/" rel="alternate" type="text/html" title="Building the IBKR C++ API Client Library" /><published>2025-02-11T00:00:00+00:00</published><updated>2025-02-11T00:00:00+00:00</updated><id>https://dlewis.io/ibkr-cpp-api</id><content type="html" xml:base="https://dlewis.io/ibkr-cpp-api/"><![CDATA[<p>Recently, I wanted to use the C++ API client library that Interactive Brokers provides and experiment with some algorithmitic trading and monitoring of my positions. I had hoped there would be some precompiled binaries already for MacOS &amp; Linux/aarch64; however, I did not have luck finding any. That is not a problem since Interactive Brokers provides the source to the client libraries under a license that allows for non-commercial use. The source can be downloaded <a href="https://interactivebrokers.github.io">here</a> after agreeing to the license.</p>

<p>I downloaded the <code class="language-plaintext highlighter-rouge">Latest</code> release of the library for Mac/Unix, which at the time of this writing is version <code class="language-plaintext highlighter-rouge">API 10.34</code>.  Once downloaded, I extracted the .zip file:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads $</span><span class="w"> </span>unzip <span class="nt">-d</span> twsapi_macunix.1034.02 twsapi_macunix.1034.02.zip
</code></pre></div></div>

<p>The build documentation for Mac/Unix is available <a href="https://www.interactivebrokers.com/campus/ibkr-api-page/twsapi-doc/#unix-install">here</a>. The documentation is mostly focused on the building the Python library and the C++ library build documentation is lacking. The API’s <code class="language-plaintext highlighter-rouge">Decimal</code> implementation also has an odd dependency on the Intel Decimal Library, which was concerning since I wanted to build this for MacOS on Apple Silicon and Linux/aarch64. I assumed the Intel library is Intel-specific, but much to my pleasant surprise it is not.</p>

<p>After downloading the Intel Decimal Floating-Point Math Library from: <a href="https://www.intel.com/content/www/us/en/developer/articles/tool/intel-decimal-floating-point-math-library.html">https://www.intel.com/content/www/us/en/developer/articles/tool/intel-decimal-floating-point-math-library.html</a>, I extracted that tarball:</p>

<h1 id="building-the-intel-decimal-floating-point-math-library">Building the Intel Decimal Floating-Point Math Library</h1>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads $</span><span class="w"> </span><span class="nb">tar </span>xvzf IntelRDFPMathLib20U2.tar.gz
</code></pre></div></div>

<p>Interactive Brokers provides steps to build the library contained in <code class="language-plaintext highlighter-rouge">IBJts/source/cppclient/Intel_lib_build.txt</code>, which I followed with some necessary modifications below:</p>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>IntelRDFPMathLib20U2/LIBRARY modify "makefile":
a)
Line 370:
change
BID_LIB = $(LIB_DIR)/libbid.$A
to
BID_LIB = $(LIB_DIR)/libbid.dylib # or .so when building on Linux

b)
Line 377:
change
	$(AR_CMD) $(AR_OUT)$@ $^
to
	gcc -o $@ $^ -shared

c) 
Line 112:
change
_CFLAGS_OPT := 
to
_CFLAGS_OPT := -fPIC -Wno-implicit-function-declaration # added to avoid issues with newer versions of clang and gcc (&gt;=14), which now treat implicit function declarations as an error
</code></pre></div></div>

<p>If building on Linux/aarch64, the following error will be encountered about an unknown architecture:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">[dlewis@defiant LIBRARY]$</span><span class="w"> </span>make
<span class="go">makefile.iml_head:356: *** Unknown host architecture aarch64.  Stop.
</span></code></pre></div></div>

<p>This can be easily fixed by making some small changes to the <code class="language-plaintext highlighter-rouge">makefile.iml_head</code> file around line 347:</p>

<div class="language-makefile highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c"># add aarch64 with EFI2 specified for both ARCH_LIST and ARCH_TYPE
</span><span class="nv">ARCH_ALIAS</span> <span class="o">:=</span> x86  ia64 EM64T x86_64 i686 amd64 Intel64 sun4u aarch64
<span class="nv">ARCH_LIST</span>  <span class="o">:=</span> IA32 IA64 EFI2  EFI2   IA32 EFI2  EFI2    EFI2    EFI2
<span class="nv">ARCH_TYPE</span>  <span class="o">:=</span> IA32 IA64 EFI2  EFI2   IA32 EFI2  EFI2    EFI2    EFI2  
<span class="nv">ARCH_TYPES</span> <span class="o">:=</span> IA32 IA64 EFI2
</code></pre></div></div>

<p>Now, the Intel Decimal Library can be built without issues on MacOS or Linux/aarch64:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/IntelRDFPMathLib20U2/LIBRARY $</span><span class="w"> </span>make <span class="nv">CC</span><span class="o">=</span>gcc <span class="nv">CALL_BY_REF</span><span class="o">=</span>0 <span class="nv">GLOBAL_RND</span><span class="o">=</span>0 <span class="nv">GLOBAL_FLAGS</span><span class="o">=</span>0 <span class="nv">UNCHANGED_BINARY_FLAGS</span><span class="o">=</span>0
</code></pre></div></div>

<p>Once the build finishes, there will be a .dylib/.so:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/IntelRDFPMathLib20U2/LIBRARY $</span><span class="w"> </span><span class="nb">ls</span> <span class="nt">-l</span> libbid.dylib 
<span class="go">-rwxr-xr-x  1 dlewis  staff  5019408 Feb 11 14:39 libbid.dylib
</span><span class="gp">~/Downloads/IntelRDFPMathLib20U2/LIBRARY $</span><span class="w"> </span>file libbid.dylib 
<span class="go">libbid.dylib: Mach-O 64-bit dynamically linked shared library arm64
</span></code></pre></div></div>

<h1 id="building-the-ibkr-c-api-client-library">Building the IBKR C++ API Client Library</h1>

<p>Back to building the IBKR C++ API client library. The libbid.dylib/.so needs to be copied to <code class="language-plaintext highlighter-rouge">IBJts/source/cppclient/client/lib</code>:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span><span class="nb">mkdir </span>lib
<span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span><span class="nb">cp</span> ~/Downloads/IntelRDFPMathLib20U2/LIBRARY/libbid.dylib lib/
<span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span><span class="nb">ls</span> <span class="nt">-l</span> lib/libbid.dylib
<span class="go">-rwxr-xr-x  1 dlewis  staff  5019408 Feb 11 14:44 lib/libbid.dylib
</span></code></pre></div></div>

<p>I had many issues building with CMake in <code class="language-plaintext highlighter-rouge">IBJts/source/cppclient/client</code>, and the sources include a generic <code class="language-plaintext highlighter-rouge">makefile</code>, which seemed much easier to modify, so I went that route instead. Here’s my slightly modified <code class="language-plaintext highlighter-rouge">makefile</code>:</p>

<div class="language-makefile highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">CXX</span><span class="o">=</span>g++
<span class="nv">CXXFLAGS</span><span class="o">=</span><span class="nt">-pthread</span> <span class="nt">-Wall</span> <span class="nt">-Wno-switch</span> <span class="nt">-Wno-unused-function</span> <span class="nt">-std</span><span class="o">=</span>c++11 <span class="nt">-shared</span> <span class="nt">-fPIC</span>
<span class="nv">ROOT_DIR</span><span class="o">=</span>.
<span class="nv">BASE_SRC_DIR</span><span class="o">=</span><span class="p">${</span>ROOT_DIR<span class="p">}</span>
<span class="nv">INCLUDES</span><span class="o">=</span><span class="nt">-I</span><span class="p">${</span>ROOT_DIR<span class="p">}</span>
<span class="nv">LIB_DIR</span><span class="o">=</span>lib
<span class="nv">LIB_NAME</span><span class="o">=</span>bid
<span class="nv">TARGET</span><span class="o">=</span>libTwsSocketClient.dylib <span class="c"># or .so if building on Linux</span>

<span class="nl">$(TARGET)</span><span class="o">:</span>
	<span class="p">$(</span>CXX<span class="p">)</span> <span class="p">$(</span>CXXFLAGS<span class="p">)</span> <span class="p">$(</span>INCLUDES<span class="p">)</span> <span class="p">$(</span>BASE_SRC_DIR<span class="p">)</span>/<span class="k">*</span>.cpp <span class="nt">-L</span><span class="p">$(</span>LIB_DIR<span class="p">)</span> <span class="nt">-l</span><span class="p">$(</span>LIB_NAME<span class="p">)</span> <span class="nt">-o</span><span class="p">$(</span>TARGET<span class="p">)</span>

<span class="nl">clean</span><span class="o">:</span>
	<span class="nb">rm</span> <span class="nt">-f</span> <span class="p">$(</span>TARGET<span class="p">)</span> <span class="k">*</span>.o
</code></pre></div></div>

<p>Once that is modified, <code class="language-plaintext highlighter-rouge">make</code> can simply be run without issue:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span>make
<span class="go">g++ -pthread -Wall -Wno-switch -Wno-unused-function -std=c++11 -shared -fPIC -I. ./*.cpp -Llib -lbid -olibTwsSocketClient.dylib

</span></code></pre></div></div>

<p>In preparation for building the sample client IBKR provides, I copied <code class="language-plaintext highlighter-rouge">libTwsSocketClient.dylib</code> to <code class="language-plaintext highlighter-rouge">lib</code>, where previously <code class="language-plaintext highlighter-rouge">libbid.dylib</code> was copied to:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span><span class="nb">cp </span>libTwsSocketClient.dylib lib/
<span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/source/cppclient/client $</span><span class="w"> </span><span class="nb">ls</span> <span class="nt">-l</span> lib/
<span class="go">total 12200
-rwxr-xr-x  1 dlewis  staff  1221496 Feb 11 14:57 libTwsSocketClient.dylib
-rwxr-xr-x  1 dlewis  staff  5019408 Feb 11 14:44 libbid.dylib
</span></code></pre></div></div>

<h1 id="building-the-c-sample-application">Building the C++ Sample Application</h1>

<p>Now, the client can be built in <code class="language-plaintext highlighter-rouge">IBJts/samples/Cpp/TestCppClient</code>. First, some similar <code class="language-plaintext highlighter-rouge">makefile</code> modifications need to be made, and again I avoided using CMake because of problems encountered with the build process. It can be made to work, but it is much simpler to just use the provided <code class="language-plaintext highlighter-rouge">makefile</code>. Here is my diff of the makefile after modifications:</p>

<div class="language-diff highlighter-rouge"><div class="highlight"><pre class="highlight"><code>~/Downloads/twsapi_macunix.1034.02/IBJts/samples/Cpp/TestCppClient $ diff -u makefile makefile.orig
<span class="gd">--- makefile	2025-02-11 15:03:37
</span><span class="gi">+++ makefile.orig	2025-02-11 15:02:31
</span><span class="p">@@ -4,10 +4,10 @@</span>
 BASE_SRC_DIR=${ROOT_DIR}/client
 INCLUDES=-I${BASE_SRC_DIR} -I${ROOT_DIR}
 SOURCE_DIR=${BASE_SRC_DIR}
<span class="gd">-SOURCE_LIB=TwsSocketClient
</span><span class="gi">+SOURCE_LIB=libTwsSocketClient.so
</span> LIB_DIR=$(SOURCE_DIR)/lib
 LIB_NAME_A=libbid.a
<span class="gd">-LIB_NAME=bid
</span><span class="gi">+LIB_NAME_SO=libbid.so
</span> TARGET=TestCppClient
 
 $(TARGET)Static:
<span class="p">@@ -17,7 +17,7 @@</span>
 	$(CXX) $(CXXFLAGS) $(INCLUDES) $(BASE_SRC_DIR)/*.cpp ./*.cpp -L$(LIB_DIR) -l:$(LIB_NAME_SO) -o$(TARGET)
 
 $(TARGET)Dynamic:
<span class="gd">-	$(CXX) $(CXXFLAGS) $(INCLUDES)  ./*.cpp -L$(LIB_DIR) -l$(LIB_NAME) -L$(SOURCE_DIR) -l$(SOURCE_LIB) -o$(TARGET)Dynamic
</span><span class="gi">+	$(CXX) $(CXXFLAGS) $(INCLUDES)  ./*.cpp -L$(LIB_DIR) -l:$(LIB_NAME_SO) -L$(SOURCE_DIR) -l:$(SOURCE_LIB) -o$(TARGET)Dynamic
</span> 
 run$(TARGET)Static:
 	./$(TARGET)Static
</code></pre></div></div>

<p>The <code class="language-plaintext highlighter-rouge">makefile</code> can be easily modified or used to generate static application binaries, but in this case I opted for dynamic linking. If you go the statically linked route, then you can also easily make the necessary modifications to the C++ API library <code class="language-plaintext highlighter-rouge">makefile</code> in</p>

<p>Finally, we can build and run the <code class="language-plaintext highlighter-rouge">TestCppClientDynamic</code>:</p>

<div class="language-console highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/samples/Cpp/TestCppClient $</span><span class="w"> </span>make TestCppClientDynamic
<span class="go">g++ -pthread -Wall -Wno-switch -Wpedantic -Wno-unused-function -std=c++11 -I../../../source/cppclient/client -I../../../source/cppclient  ./*.cpp -L../../../source/cppclient/client/lib -lbid -L../../../source/cppclient/client -lTwsSocketClient -oTestCppClientDynamic
</span><span class="gp">~/Downloads/twsapi_macunix.1034.02/IBJts/samples/Cpp/TestCppClient $</span><span class="w"> </span><span class="nv">DYLD_LIBRARY_PATH</span><span class="o">=</span>../../../source/cppclient/client/lib ./TestCppClientDynamic 
<span class="go">Start of C++ Socket Client Test 0
Attempt 1 of 50
Connecting to 127.0.0.1:7496 clientId:0
</span><span class="gp">Error. Id: -1, Time: Tue Feb 11 15:08:17 2025, Code: 502, Msg: Couldn't connect to TWS. Confirm that "Enable ActiveX and Socket Clients" is enabled and connection port is the same as "Socket Port" on the TWS "Edit-&gt;</span>Global Configuration...-&gt;API-&gt;Settings<span class="s2">" menu. Live Trading ports: TWS: 7496; IB Gateway: 4001. Simulated Trading ports for new installations of version 954.1 or newer:  TWS: 7497; IB Gateway: 4002
</span><span class="go">Cannot connect to 127.0.0.1:7496 clientId:0
Sleeping 10 seconds before next attempt
</span></code></pre></div></div>

<p>Success! Hopefully, this will be helpful to someone looking to build the IBKR C++ API library on MacOS or Linux/aarch64.</p>]]></content><author><name>Derek Lewis</name></author><category term="c++" /><category term="unix" /><category term="linux" /><category term="macos" /><summary type="html"><![CDATA[Recently, I wanted to use the C++ API client library that Interactive Brokers provides and experiment with some algorithmitic trading and monitoring of my positions. I had hoped there would be some precompiled binaries already for MacOS &amp; Linux/aarch64; however, I did not have luck finding any. That is not a problem since Interactive Brokers provides the source to the client libraries under a license that allows for non-commercial use. The source can be downloaded here after agreeing to the license.]]></summary></entry><entry><title type="html">PyTorch CUDA Extensions</title><link href="https://dlewis.io/pytorch-cuda-extensions/" rel="alternate" type="text/html" title="PyTorch CUDA Extensions" /><published>2024-08-08T00:00:00+00:00</published><updated>2024-08-08T00:00:00+00:00</updated><id>https://dlewis.io/pytorch-cuda-extensions</id><content type="html" xml:base="https://dlewis.io/pytorch-cuda-extensions/"><![CDATA[<h1 id="optimizing-performance-with-pytorch-cudac-extensions-a-deep-dive">Optimizing Performance with PyTorch CUDA/C++ Extensions: A Deep Dive</h1>

<p>Understanding how CUDA and efficient kernels work under the hood enables us to make informed decisions about model architecture, optimize critical operations, and squeeze maximum performance from our GPU hardware. One powerful technique is using PyTorch CUDA extensions to leverage custom CUDA kernels. In this series of posts, we’ll explore how to create, profile, and optimize these extensions, using a simple, straightforward matrix multiplication example as a guide.</p>

<p><img src="/assets/images/matrixmultiplication.webp" alt="Matrix Multiplication" /></p>

<h2 id="the-power-of-custom-cuda-kernels">The Power of Custom CUDA Kernels</h2>

<p>PyTorch allows us to write custom CUDA kernels and integrate them seamlessly into our Python code. This capability is particularly useful when we need to optimize specific operations that are critical to an application’s performance.</p>

<h3 id="our-example-matrix-multiplication">Our Example: Matrix Multiplication</h3>

<p>We’ll use matrix multiplication as our example. While PyTorch already has highly optimized matrix multiplication routines, implementing our own allows us to understand the process and potentially optimize for specific use cases. For the first post in this series, we’ll start with 3 basic matrix multiplication kernels that can be optimized and vary based on quality of implementation: the Element-wise Kernel, the Row-wise Kernel, and the Column-wise Kernel. Each of these approaches distributes the computation across GPU threads in a different manner, offering varying levels of parallelism and memory access patterns.</p>

<h2 id="creating-the-cuda-extension">Creating the CUDA Extension</h2>

<p>Let’s break down the key components of our CUDA extension:</p>

<h3 id="the-cuda-kernels-matrixmultiplycu">The CUDA Kernels (matrixMultiply.cu)</h3>

<div class="language-cuda highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// element-wise matrix multiplication</span>
<span class="k">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">)</span>
<span class="p">{</span>
    <span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">y</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">y</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">y</span><span class="p">;</span>
    <span class="kt">int</span> <span class="n">j</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">&lt;</span> <span class="n">size</span> <span class="o">&amp;&amp;</span> <span class="n">j</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span>
    <span class="p">{</span>
        <span class="kt">float</span> <span class="n">pValue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">k</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">k</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="o">++</span><span class="n">k</span><span class="p">)</span>
        <span class="p">{</span>
            <span class="n">pValue</span> <span class="o">+=</span> <span class="n">m</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">k</span><span class="p">]</span> <span class="o">*</span> <span class="n">n</span><span class="p">[</span><span class="n">k</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">j</span><span class="p">];</span>
        <span class="p">}</span>
        <span class="n">p</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">pValue</span><span class="p">;</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="c1">// row-wise matrix multiplication</span>
<span class="c1">// thread 0:</span>
<span class="c1">// p_0,0 p[0]</span>
<span class="c1">// p_0,1 p[1]</span>
<span class="c1">// ..</span>
<span class="c1">// thread 1:</span>
<span class="c1">// p_1,0 p[4]</span>
<span class="c1">// p_1,1 p[5]</span>
<span class="c1">// ..</span>
<span class="k">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernelRow</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">)</span>
<span class="p">{</span>
    <span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>

    <span class="k">if</span> <span class="p">(</span><span class="n">row</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span>
    <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">col</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="o">++</span><span class="n">col</span><span class="p">)</span>
        <span class="p">{</span>
            <span class="kt">float</span> <span class="n">pValue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span>
            <span class="p">{</span>
                <span class="n">pValue</span> <span class="o">+=</span> <span class="n">m</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">n</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
            <span class="p">}</span>
            <span class="n">p</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">pValue</span><span class="p">;</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="c1">// column-wise matrix multiplication</span>
<span class="c1">// thread 0:</span>
<span class="c1">// p_0,0 p[0]</span>
<span class="c1">// p_1,0 p[4]</span>
<span class="c1">// ..</span>
<span class="c1">// thread 1:</span>
<span class="c1">// p_0,1 p[1]</span>
<span class="c1">// p_1,1 p[5]</span>
<span class="c1">// ..</span>
<span class="k">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernelCol</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">)</span>
<span class="p">{</span>
    <span class="kt">int</span> <span class="n">col</span> <span class="o">=</span> <span class="n">blockDim</span><span class="p">.</span><span class="n">x</span> <span class="o">*</span> <span class="n">blockIdx</span><span class="p">.</span><span class="n">x</span> <span class="o">+</span> <span class="n">threadIdx</span><span class="p">.</span><span class="n">x</span><span class="p">;</span>

    <span class="k">if</span> <span class="p">(</span><span class="n">col</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">)</span>
    <span class="p">{</span>
        <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">row</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">row</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="o">++</span><span class="n">row</span><span class="p">)</span>
        <span class="p">{</span>
            <span class="kt">float</span> <span class="n">pValue</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span>
            <span class="k">for</span> <span class="p">(</span><span class="kt">int</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span><span class="p">;</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="n">size</span><span class="p">;</span> <span class="o">++</span><span class="n">i</span><span class="p">)</span>
            <span class="p">{</span>
                <span class="n">pValue</span> <span class="o">+=</span> <span class="n">m</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">n</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">col</span><span class="p">];</span>
            <span class="p">}</span>
            <span class="n">p</span><span class="p">[</span><span class="n">row</span> <span class="o">*</span> <span class="n">size</span> <span class="o">+</span> <span class="n">col</span><span class="p">]</span> <span class="o">=</span> <span class="n">pValue</span><span class="p">;</span>
        <span class="p">}</span>
    <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div>

<p>The Element-wise kernel, while simple, offers true parallelism by utilizing the GPU’s massive thread capacity. However, the Row-wise and Column-wise kernels, despite appearing to offer a different parallelization strategy, actually introduce serialization and will perform poorly. They fail to fully utilize the GPU’s parallel architecture, with each thread sequentially computing an entire row or column. This approach negates much of the performance benefit of using a GPU. The Element-wise kernel, while parallel, is far from optimal. It lacks shared memory usage, efficient memory access patterns, and coalescing, all critical for peak GPU performance. In a future post we’ll profile these kernels to quantify their performance differences.</p>

<h3 id="the-c-wrapper-torchmatrixmultiplycu">The C++ Wrapper (torchMatrixMultiply.cu)</h3>

<div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernel</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">);</span>
<span class="n">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernelRow</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">);</span>
<span class="n">__global__</span> <span class="kt">void</span> <span class="nf">matrixMulKernelCol</span><span class="p">(</span><span class="kt">float</span> <span class="o">*</span><span class="n">m</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">n</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="n">p</span><span class="p">,</span> <span class="kt">int</span> <span class="n">size</span><span class="p">);</span>

<span class="k">using</span> <span class="n">KernelFunc</span> <span class="o">=</span> <span class="kt">void</span> <span class="p">(</span><span class="o">*</span><span class="p">)(</span><span class="kt">float</span> <span class="o">*</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="p">,</span> <span class="kt">float</span> <span class="o">*</span><span class="p">,</span> <span class="kt">int</span><span class="p">);</span>

<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="nf">cuda_matrixMultiply</span><span class="p">(</span><span class="k">const</span> <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">b</span><span class="p">,</span> <span class="n">KernelFunc</span> <span class="n">kernel</span><span class="p">)</span>
<span class="p">{</span>
    <span class="n">TORCH_CHECK</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">sizes</span><span class="p">()</span> <span class="o">==</span> <span class="n">b</span><span class="p">.</span><span class="n">sizes</span><span class="p">());</span>
    <span class="n">TORCH_CHECK</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span> <span class="o">==</span> <span class="n">torch</span><span class="o">::</span><span class="n">kFloat</span><span class="p">);</span>
    <span class="n">TORCH_CHECK</span><span class="p">(</span><span class="n">b</span><span class="p">.</span><span class="n">dtype</span><span class="p">()</span> <span class="o">==</span> <span class="n">torch</span><span class="o">::</span><span class="n">kFloat</span><span class="p">);</span>
    <span class="n">TORCH_INTERNAL_ASSERT</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">device</span><span class="p">().</span><span class="n">type</span><span class="p">()</span> <span class="o">==</span> <span class="n">torch</span><span class="o">::</span><span class="n">DeviceType</span><span class="o">::</span><span class="n">CUDA</span><span class="p">);</span>
    <span class="n">TORCH_INTERNAL_ASSERT</span><span class="p">(</span><span class="n">b</span><span class="p">.</span><span class="n">device</span><span class="p">().</span><span class="n">type</span><span class="p">()</span> <span class="o">==</span> <span class="n">torch</span><span class="o">::</span><span class="n">DeviceType</span><span class="o">::</span><span class="n">CUDA</span><span class="p">);</span>

    <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">a_contiguous</span><span class="p">{</span><span class="n">a</span><span class="p">.</span><span class="n">contiguous</span><span class="p">()};</span>
    <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">b_contiguous</span><span class="p">{</span><span class="n">b</span><span class="p">.</span><span class="n">contiguous</span><span class="p">()};</span>
    <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">result</span><span class="p">{</span><span class="n">torch</span><span class="o">::</span><span class="n">empty</span><span class="p">(</span><span class="n">a_contiguous</span><span class="p">.</span><span class="n">sizes</span><span class="p">(),</span> <span class="n">a_contiguous</span><span class="p">.</span><span class="n">options</span><span class="p">())};</span>

    <span class="kt">float</span> <span class="o">*</span><span class="n">a_ptr</span> <span class="o">=</span> <span class="n">a_contiguous</span><span class="p">.</span><span class="n">data_ptr</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">();</span>
    <span class="kt">float</span> <span class="o">*</span><span class="n">b_ptr</span> <span class="o">=</span> <span class="n">b_contiguous</span><span class="p">.</span><span class="n">data_ptr</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">();</span>
    <span class="kt">float</span> <span class="o">*</span><span class="n">result_ptr</span> <span class="o">=</span> <span class="n">result</span><span class="p">.</span><span class="n">data_ptr</span><span class="o">&lt;</span><span class="kt">float</span><span class="o">&gt;</span><span class="p">();</span>

    <span class="c1">// Assumes square matrices and we cast to int for simplicity</span>
    <span class="c1">// and compatibility with our existing kernel code. In practice,</span>
    <span class="c1">// we would need to handle non-square matrices and use an unsigned long</span>
    <span class="c1">// to match PyTorch's tensor sizes.</span>
    <span class="kt">int</span> <span class="n">dim</span><span class="p">{</span><span class="k">static_cast</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">sizes</span><span class="p">()[</span><span class="mi">0</span><span class="p">])};</span>

    <span class="n">dim3</span> <span class="nf">blockSize</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">);</span>
    <span class="n">dim3</span> <span class="nf">gridSize</span><span class="p">((</span><span class="n">dim</span> <span class="o">+</span> <span class="n">blockSize</span><span class="p">.</span><span class="n">x</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">blockSize</span><span class="p">.</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="n">blockSize</span><span class="p">.</span><span class="n">y</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">blockSize</span><span class="p">.</span><span class="n">y</span><span class="p">);</span>

    <span class="n">kernel</span><span class="o">&lt;&lt;&lt;</span><span class="n">gridSize</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">&gt;&gt;&gt;</span><span class="p">(</span><span class="n">a_ptr</span><span class="p">,</span> <span class="n">b_ptr</span><span class="p">,</span> <span class="n">result_ptr</span><span class="p">,</span> <span class="n">dim</span><span class="p">);</span>
    <span class="n">checkCudaError</span><span class="p">(</span><span class="n">cudaGetLastError</span><span class="p">(),</span> <span class="s">"Kernel launch failed"</span><span class="p">);</span>

    <span class="k">return</span> <span class="n">result</span><span class="p">;</span>
<span class="p">}</span>

<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="nf">matrixMultiply</span><span class="p">(</span><span class="k">const</span> <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">a</span><span class="p">,</span> <span class="k">const</span> <span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span><span class="n">b</span><span class="p">,</span> <span class="k">const</span> <span class="n">std</span><span class="o">::</span><span class="n">optional</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&gt;</span> <span class="o">&amp;</span><span class="n">kernel_type</span><span class="p">)</span>
<span class="p">{</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">kernel_type</span><span class="p">.</span><span class="n">has_value</span><span class="p">())</span>
    <span class="p">{</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">kernel_type</span> <span class="o">==</span> <span class="s">"row"</span><span class="p">)</span>
        <span class="p">{</span>
            <span class="k">return</span> <span class="n">cuda_matrixMultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">matrixMulKernelRow</span><span class="p">);</span>
        <span class="p">}</span>
        <span class="k">else</span> <span class="nf">if</span> <span class="p">(</span><span class="n">kernel_type</span> <span class="o">==</span> <span class="s">"col"</span><span class="p">)</span>
        <span class="p">{</span>
            <span class="k">return</span> <span class="n">cuda_matrixMultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">matrixMulKernelCol</span><span class="p">);</span>
        <span class="p">}</span>
        <span class="k">else</span>
        <span class="p">{</span>
            <span class="k">throw</span> <span class="n">std</span><span class="o">::</span><span class="n">invalid_argument</span><span class="p">(</span><span class="s">"Invalid kernel type"</span><span class="p">);</span>
        <span class="p">}</span>
    <span class="p">}</span>
    <span class="k">else</span>
    <span class="p">{</span>
        <span class="k">return</span> <span class="n">cuda_matrixMultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">matrixMulKernel</span><span class="p">);</span>
    <span class="p">}</span>
<span class="p">}</span>

<span class="n">TORCH_LIBRARY</span><span class="p">(</span><span class="n">myextension</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span>
<span class="p">{</span>
    <span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="s">"mymatrixmultiply(Tensor a, Tensor b, str? kernel_type = None) -&gt; Tensor"</span><span class="p">);</span>
<span class="p">}</span>

<span class="n">TORCH_LIBRARY_IMPL</span><span class="p">(</span><span class="n">myextension</span><span class="p">,</span> <span class="n">CUDA</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span>
<span class="p">{</span>
    <span class="n">m</span><span class="p">.</span><span class="n">impl</span><span class="p">(</span><span class="s">"mymatrixmultiply"</span><span class="p">,</span> <span class="n">TORCH_FN</span><span class="p">(</span><span class="n">matrixMultiply</span><span class="p">));</span>
<span class="p">}</span>
</code></pre></div></div>

<p>This wrapper handles the conversion between PyTorch tensors and CUDA array pointers, launches each kernel, and registers the function with PyTorch’s dispatcher. We use an optional parameter and function pointer to specify the kernel at runtime, allowing us to switch between the Element-wise, Row-wise, and Column-wise kernels in our Python code.</p>

<h2 id="building-the-extension">Building the Extension</h2>

<p>Building a CUDA extension requires careful attention to linking and include paths. Our Makefile handles this by specifying the compiler, compiler flags, and paths to the necessary libraries and include directories. It also ensures the build directory exists, compiles the CUDA source files into object files, and links them into a shared library that will be loaded by PyTorch as our extension:</p>

<div class="language-makefile highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="nv">NVCC</span> <span class="o">:=</span> nvcc
<span class="nv">NVCC_FLAGS</span> <span class="o">:=</span> <span class="nt">-g</span> <span class="nt">-G</span>
<span class="nv">BUILD_DIR</span> <span class="o">:=</span> build
<span class="nv">PYTHON_VER</span> <span class="o">:=</span> 3.12
<span class="nv">MLENV_DIR</span> <span class="o">:=</span> ~/mlenv/lib/python<span class="p">$(</span>PYTHON_VER<span class="p">)</span>/site-packages/torch
<span class="nv">TORCH_INCLUDE</span> <span class="o">:=</span> <span class="p">$(</span>MLENV_DIR<span class="p">)</span>/include
<span class="nv">TORCH_LIB</span> <span class="o">:=</span> <span class="p">$(</span>MLENV_DIR<span class="p">)</span>/lib

<span class="c"># Find all .cu files in the current directory
</span><span class="nv">SRCS</span> <span class="o">:=</span> <span class="p">$(</span>filter-out torchMatrixMultiply.cu, <span class="p">$(</span>wildcard <span class="k">*</span>.cu<span class="p">))</span>
<span class="c"># Generate corresponding object file names in the build directory
</span><span class="nv">OBJS</span> <span class="o">:=</span> <span class="p">$(</span>patsubst %.cu,<span class="p">$(</span>BUILD_DIR<span class="p">)</span>/%.o,<span class="p">$(</span>SRCS<span class="p">))</span>
<span class="c"># Generate executable names from the object files
</span><span class="nv">EXECS</span> <span class="o">:=</span> <span class="p">$(</span>patsubst <span class="p">$(</span>BUILD_DIR<span class="p">)</span>/%.o,<span class="p">$(</span>BUILD_DIR<span class="p">)</span>/%,<span class="p">$(</span>OBJS<span class="p">))</span>

<span class="c"># Ensure the build directory exists
</span><span class="nf">$(</span><span class="nb">shell</span> <span class="nb">mkdir</span> <span class="nt">-p</span> <span class="p">$(</span>BUILD_DIR<span class="p">))</span>

<span class="nl">.PHONY</span><span class="o">:</span> <span class="nf">all clean</span>

<span class="nl">all</span><span class="o">:</span> <span class="nf">$(EXECS) $(BUILD_DIR)/torchMatrixMultiply.so</span>

<span class="c"># Rule to build executables from object files
</span><span class="nl">$(BUILD_DIR)/%</span><span class="o">:</span> <span class="nf">$(BUILD_DIR)/%.o</span>
	<span class="p">$(</span>NVCC<span class="p">)</span> <span class="p">$(</span>NVCC_FLAGS<span class="p">)</span> <span class="nv">$&lt;</span> <span class="nt">-o</span> <span class="nv">$@</span>

<span class="c"># Rule to compile .cu files into object files
</span><span class="nl">$(BUILD_DIR)/%.o</span><span class="o">:</span> <span class="nf">%.cu</span>
	<span class="p">$(</span>NVCC<span class="p">)</span> <span class="p">$(</span>NVCC_FLAGS<span class="p">)</span> <span class="nt">--compiler-options</span> <span class="s1">'-fPIC'</span> <span class="nt">-c</span> <span class="nv">$&lt;</span> <span class="nt">-o</span> <span class="nv">$@</span>

<span class="nl">.SECONDARY</span><span class="o">:</span> <span class="nf">$(OBJS)</span>

<span class="nl">$(BUILD_DIR)/torchMatrixMultiply.so</span><span class="o">:</span> <span class="nf">torchMatrixMultiply.cu</span>
	<span class="p">$(</span>NVCC<span class="p">)</span> <span class="p">$(</span>NVCC_FLAGS<span class="p">)</span> <span class="nt">-shared</span> <span class="nt">--compiler-options</span> <span class="s1">'-fPIC'</span> <span class="nt">-L</span> ~/mlenv/lib/python3.12/site-packages/torch/lib <span class="se">\</span>
	<span class="nt">-lc10</span> <span class="nt">-ltorch_cpu</span> <span class="nt">-ltorch</span> <span class="nt">-ltorch_python</span> <span class="nt">-lc10_cuda</span> <span class="nt">-ltorch_cuda</span> <span class="se">\</span>
	<span class="nt">-isystem</span> <span class="p">$(</span>TORCH_INCLUDE<span class="p">)</span>/torch/csrc/api/include <span class="se">\</span>
	<span class="nt">-isystem</span> <span class="p">$(</span>TORCH_INCLUDE<span class="p">)</span> <span class="se">\</span>
	torchMatrixMultiply.cu <span class="p">$(</span>BUILD_DIR<span class="p">)</span>/matrixMultiply.o <span class="nt">-o</span> <span class="p">$(</span>BUILD_DIR<span class="p">)</span>/torchMatrixMultiply.so <span class="se">\</span>
	<span class="nt">-DTORCH_API_INCLUDE_EXTENSION_H</span> <span class="nt">-DPYBIND11_COMPILER_TYPE</span><span class="o">=</span><span class="se">\"</span>_gcc<span class="se">\"</span> <span class="nt">-DPYBIND11_STDLIB</span><span class="o">=</span><span class="se">\"</span>_libstdcpp<span class="se">\"</span> <span class="se">\</span>
	<span class="nt">-DPYBIND11_BUILD_ABI</span><span class="o">=</span><span class="se">\"</span>_cxxabi1011<span class="se">\"</span> <span class="nt">-isystem</span> /usr/include/python<span class="p">$(</span>PYTHON_VER<span class="p">)</span> <span class="nt">-D_GLIBCXX_USE_CXX11_ABI</span><span class="o">=</span>0 <span class="nt">-std</span><span class="o">=</span>c++17

<span class="nl">clean</span><span class="o">:</span>
	<span class="nb">rm</span> <span class="nt">-rf</span> <span class="p">$(</span>BUILD_DIR<span class="p">)</span>
</code></pre></div></div>

<p>Key points:</p>
<ul>
  <li>We link against PyTorch libraries</li>
  <li>We include PyTorch and Python headers</li>
  <li>We set necessary compilation flags for PyTorch and pybind11</li>
</ul>

<h2 id="using-the-extension-in-pytorch">Using the Extension in PyTorch</h2>

<p>Once built, we can use our extension in PyTorch:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">torch</span>

<span class="n">torch</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="nf">load_library</span><span class="p">(</span><span class="sh">"</span><span class="s">build/torchMatrixMultiply.so</span><span class="sh">"</span><span class="p">)</span>

<span class="n">torch</span><span class="p">.</span><span class="nf">manual_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1000</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="sh">"</span><span class="s">cuda</span><span class="sh">"</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1000</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="sh">"</span><span class="s">cuda</span><span class="sh">"</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">b</span><span class="p">)</span>

<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">row kernel:</span><span class="sh">"</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="n">myextension</span><span class="p">.</span><span class="nf">mymatrixmultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="sh">"</span><span class="s">row</span><span class="sh">"</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">col kernel:</span><span class="sh">"</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="n">myextension</span><span class="p">.</span><span class="nf">mymatrixmultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="sh">"</span><span class="s">col</span><span class="sh">"</span><span class="p">))</span>
<span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">default kernel:</span><span class="sh">"</span><span class="p">)</span>
<span class="nf">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="n">myextension</span><span class="p">.</span><span class="nf">mymatrixmultiply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span>
</code></pre></div></div>

<h2 id="profiling-with-nvidia-compute-command-line-profiler-ncu">Profiling with NVIDIA Compute Command Line Profiler (ncu)</h2>

<p>To optimize our kernel, we need to understand its performance characteristics. NVIDIA’s <code class="language-plaintext highlighter-rouge">ncu</code> tool is useful for this:</p>

<div class="language-bash highlighter-rouge"><div class="highlight"><pre class="highlight"><code>ncu <span class="nt">--kernel-name</span> regex:<span class="s2">"matrixMulKernel.*"</span> python torchMatrixMultiply.py
</code></pre></div></div>

<p>This command profiles our kernel, providing insights into metrics like:</p>
<ul>
  <li>SM occupancy</li>
  <li>Memory throughput</li>
  <li>Instruction throughput</li>
</ul>

<p>In the next post, we’ll review the profiling results and optimization strategies, such as those below to improve our kernel’s performance.</p>

<h2 id="optimization-strategies">Optimization Strategies</h2>

<p>Based on the profiling results, we can apply various optimization strategies:</p>

<ol>
  <li>
    <p><strong>Shared Memory</strong>: If memory bandwidth is a bottleneck, we can use shared memory to reduce global memory accesses.</p>
  </li>
  <li>
    <p><strong>Loop Unrolling</strong>: This can increase instruction-level parallelism.</p>
  </li>
  <li>
    <p><strong>Tiling</strong>: Dividing the matrices into smaller tiles can improve cache utilization.</p>
  </li>
  <li>
    <p><strong>Vectorization</strong>: Using vector loads and stores can increase memory throughput.</p>
  </li>
  <li>
    <p><strong>Warp-level Primitives</strong>: For certain operations, warp-level primitives can be faster than block-level synchronization.</p>
  </li>
</ol>

<h2 id="example-github-repository">Example GitHub Repository</h2>

<p>Complete example code is available on <a href="https://github.com/derekelewis/ml-notes/tree/main/cuda">GitHub</a>.</p>

<h2 id="conclusion">Conclusion</h2>

<p>Custom CUDA extensions offer a powerful way to optimize applications utilizing PyTorch. By understanding the nuances of CUDA programming, profiling, and optimization techniques, we can significantly improve the performance of our machine learning workloads.</p>

<p>In future posts, we’ll dive deeper into the profiling process for each of our kernels and explore optimization strategies to enhance their performance.</p>]]></content><author><name>Derek Lewis</name></author><category term="pytorch" /><category term="c++" /><category term="cuda" /><summary type="html"><![CDATA[Optimizing Performance with PyTorch CUDA/C++ Extensions: A Deep Dive]]></summary></entry></feed>