6 minute read

Optimizing Performance with PyTorch CUDA/C++ Extensions: A Deep Dive

Understanding how CUDA and efficient kernels work under the hood enables us to make informed decisions about model architecture, optimize critical operations, and squeeze maximum performance from our GPU hardware. One powerful technique is using PyTorch CUDA extensions to leverage custom CUDA kernels. In this series of posts, we’ll explore how to create, profile, and optimize these extensions, using a simple, straightforward matrix multiplication example as a guide.

Matrix Multiplication

The Power of Custom CUDA Kernels

PyTorch allows us to write custom CUDA kernels and integrate them seamlessly into our Python code. This capability is particularly useful when we need to optimize specific operations that are critical to an application’s performance.

Our Example: Matrix Multiplication

We’ll use matrix multiplication as our example. While PyTorch already has highly optimized matrix multiplication routines, implementing our own allows us to understand the process and potentially optimize for specific use cases. For the first post in this series, we’ll start with 3 basic matrix multiplication kernels that can be optimized and vary based on quality of implementation: the Element-wise Kernel, the Row-wise Kernel, and the Column-wise Kernel. Each of these approaches distributes the computation across GPU threads in a different manner, offering varying levels of parallelism and memory access patterns.

Creating the CUDA Extension

Let’s break down the key components of our CUDA extension:

The CUDA Kernels (matrixMultiply.cu)

// element-wise matrix multiplication
__global__ void matrixMulKernel(float *m, float *n, float *p, int size)
{
    int i = blockDim.y * blockIdx.y + threadIdx.y;
    int j = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < size && j < size)
    {
        float pValue = 0;
        for (int k = 0; k < size; ++k)
        {
            pValue += m[i * size + k] * n[k * size + j];
        }
        p[i * size + j] = pValue;
    }
}

// row-wise matrix multiplication
// thread 0:
// p_0,0 p[0]
// p_0,1 p[1]
// ..
// thread 1:
// p_1,0 p[4]
// p_1,1 p[5]
// ..
__global__ void matrixMulKernelRow(float *m, float *n, float *p, int size)
{
    int row = blockDim.x * blockIdx.x + threadIdx.x;

    if (row < size)
    {
        for (int col = 0; col < size; ++col)
        {
            float pValue = 0;
            for (int i = 0; i < size; ++i)
            {
                pValue += m[row * size + i] * n[i * size + col];
            }
            p[row * size + col] = pValue;
        }
    }
}

// column-wise matrix multiplication
// thread 0:
// p_0,0 p[0]
// p_1,0 p[4]
// ..
// thread 1:
// p_0,1 p[1]
// p_1,1 p[5]
// ..
__global__ void matrixMulKernelCol(float *m, float *n, float *p, int size)
{
    int col = blockDim.x * blockIdx.x + threadIdx.x;

    if (col < size)
    {
        for (int row = 0; row < size; ++row)
        {
            float pValue = 0;
            for (int i = 0; i < size; ++i)
            {
                pValue += m[row * size + i] * n[i * size + col];
            }
            p[row * size + col] = pValue;
        }
    }
}

The Element-wise kernel, while simple, offers true parallelism by utilizing the GPU’s massive thread capacity. However, the Row-wise and Column-wise kernels, despite appearing to offer a different parallelization strategy, actually introduce serialization and will perform poorly. They fail to fully utilize the GPU’s parallel architecture, with each thread sequentially computing an entire row or column. This approach negates much of the performance benefit of using a GPU. The Element-wise kernel, while parallel, is far from optimal. It lacks shared memory usage, efficient memory access patterns, and coalescing, all critical for peak GPU performance. In a future post we’ll profile these kernels to quantify their performance differences.

The C++ Wrapper (torchMatrixMultiply.cu)

__global__ void matrixMulKernel(float *m, float *n, float *p, int size);
__global__ void matrixMulKernelRow(float *m, float *n, float *p, int size);
__global__ void matrixMulKernelCol(float *m, float *n, float *p, int size);

using KernelFunc = void (*)(float *, float *, float *, int);

torch::Tensor cuda_matrixMultiply(const torch::Tensor &a, const torch::Tensor &b, KernelFunc kernel)
{
    TORCH_CHECK(a.sizes() == b.sizes());
    TORCH_CHECK(a.dtype() == torch::kFloat);
    TORCH_CHECK(b.dtype() == torch::kFloat);
    TORCH_INTERNAL_ASSERT(a.device().type() == torch::DeviceType::CUDA);
    TORCH_INTERNAL_ASSERT(b.device().type() == torch::DeviceType::CUDA);

    torch::Tensor a_contiguous{a.contiguous()};
    torch::Tensor b_contiguous{b.contiguous()};
    torch::Tensor result{torch::empty(a_contiguous.sizes(), a_contiguous.options())};

    float *a_ptr = a_contiguous.data_ptr<float>();
    float *b_ptr = b_contiguous.data_ptr<float>();
    float *result_ptr = result.data_ptr<float>();

    // Assumes square matrices and we cast to int for simplicity
    // and compatibility with our existing kernel code. In practice,
    // we would need to handle non-square matrices and use an unsigned long
    // to match PyTorch's tensor sizes.
    int dim{static_cast<int>(a.sizes()[0])};

    dim3 blockSize(16, 16);
    dim3 gridSize((dim + blockSize.x - 1) / blockSize.x, (dim + blockSize.y - 1) / blockSize.y);

    kernel<<<gridSize, blockSize>>>(a_ptr, b_ptr, result_ptr, dim);
    checkCudaError(cudaGetLastError(), "Kernel launch failed");

    return result;
}

torch::Tensor matrixMultiply(const torch::Tensor &a, const torch::Tensor &b, const std::optional<std::string> &kernel_type)
{
    if (kernel_type.has_value())
    {
        if (kernel_type == "row")
        {
            return cuda_matrixMultiply(a, b, matrixMulKernelRow);
        }
        else if (kernel_type == "col")
        {
            return cuda_matrixMultiply(a, b, matrixMulKernelCol);
        }
        else
        {
            throw std::invalid_argument("Invalid kernel type");
        }
    }
    else
    {
        return cuda_matrixMultiply(a, b, matrixMulKernel);
    }
}

TORCH_LIBRARY(myextension, m)
{
    m.def("mymatrixmultiply(Tensor a, Tensor b, str? kernel_type = None) -> Tensor");
}

TORCH_LIBRARY_IMPL(myextension, CUDA, m)
{
    m.impl("mymatrixmultiply", TORCH_FN(matrixMultiply));
}

This wrapper handles the conversion between PyTorch tensors and CUDA array pointers, launches each kernel, and registers the function with PyTorch’s dispatcher. We use an optional parameter and function pointer to specify the kernel at runtime, allowing us to switch between the Element-wise, Row-wise, and Column-wise kernels in our Python code.

Building the Extension

Building a CUDA extension requires careful attention to linking and include paths. Our Makefile handles this by specifying the compiler, compiler flags, and paths to the necessary libraries and include directories. It also ensures the build directory exists, compiles the CUDA source files into object files, and links them into a shared library that will be loaded by PyTorch as our extension:

NVCC := nvcc
NVCC_FLAGS := -g -G
BUILD_DIR := build
PYTHON_VER := 3.12
MLENV_DIR := ~/mlenv/lib/python$(PYTHON_VER)/site-packages/torch
TORCH_INCLUDE := $(MLENV_DIR)/include
TORCH_LIB := $(MLENV_DIR)/lib

# Find all .cu files in the current directory
SRCS := $(filter-out torchMatrixMultiply.cu, $(wildcard *.cu))
# Generate corresponding object file names in the build directory
OBJS := $(patsubst %.cu,$(BUILD_DIR)/%.o,$(SRCS))
# Generate executable names from the object files
EXECS := $(patsubst $(BUILD_DIR)/%.o,$(BUILD_DIR)/%,$(OBJS))

# Ensure the build directory exists
$(shell mkdir -p $(BUILD_DIR))

.PHONY: all clean

all: $(EXECS) $(BUILD_DIR)/torchMatrixMultiply.so

# Rule to build executables from object files
$(BUILD_DIR)/%: $(BUILD_DIR)/%.o
	$(NVCC) $(NVCC_FLAGS) $< -o $@

# Rule to compile .cu files into object files
$(BUILD_DIR)/%.o: %.cu
	$(NVCC) $(NVCC_FLAGS) --compiler-options '-fPIC' -c $< -o $@

.SECONDARY: $(OBJS)

$(BUILD_DIR)/torchMatrixMultiply.so: torchMatrixMultiply.cu
	$(NVCC) $(NVCC_FLAGS) -shared --compiler-options '-fPIC' -L ~/mlenv/lib/python3.12/site-packages/torch/lib \
	-lc10 -ltorch_cpu -ltorch -ltorch_python -lc10_cuda -ltorch_cuda \
	-isystem $(TORCH_INCLUDE)/torch/csrc/api/include \
	-isystem $(TORCH_INCLUDE) \
	torchMatrixMultiply.cu $(BUILD_DIR)/matrixMultiply.o -o $(BUILD_DIR)/torchMatrixMultiply.so \
	-DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" \
	-DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/include/python$(PYTHON_VER) -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17

clean:
	rm -rf $(BUILD_DIR)

Key points:

  • We link against PyTorch libraries
  • We include PyTorch and Python headers
  • We set necessary compilation flags for PyTorch and pybind11

Using the Extension in PyTorch

Once built, we can use our extension in PyTorch:

import torch

torch.ops.load_library("build/torchMatrixMultiply.so")

torch.manual_seed(42)
a = torch.rand(1000, 1000, device="cuda")
print(a)
b = torch.rand(1000, 1000, device="cuda")
print(b)

print("row kernel:")
print(torch.ops.myextension.mymatrixmultiply(a, b, "row"))
print("col kernel:")
print(torch.ops.myextension.mymatrixmultiply(a, b, "col"))
print("default kernel:")
print(torch.ops.myextension.mymatrixmultiply(a, b))

Profiling with NVIDIA Compute Command Line Profiler (ncu)

To optimize our kernel, we need to understand its performance characteristics. NVIDIA’s ncu tool is useful for this:

ncu --kernel-name regex:"matrixMulKernel.*" python torchMatrixMultiply.py

This command profiles our kernel, providing insights into metrics like:

  • SM occupancy
  • Memory throughput
  • Instruction throughput

In the next post, we’ll review the profiling results and optimization strategies, such as those below to improve our kernel’s performance.

Optimization Strategies

Based on the profiling results, we can apply various optimization strategies:

  1. Shared Memory: If memory bandwidth is a bottleneck, we can use shared memory to reduce global memory accesses.

  2. Loop Unrolling: This can increase instruction-level parallelism.

  3. Tiling: Dividing the matrices into smaller tiles can improve cache utilization.

  4. Vectorization: Using vector loads and stores can increase memory throughput.

  5. Warp-level Primitives: For certain operations, warp-level primitives can be faster than block-level synchronization.

Example GitHub Repository

Complete example code is available on GitHub.

Conclusion

Custom CUDA extensions offer a powerful way to optimize applications utilizing PyTorch. By understanding the nuances of CUDA programming, profiling, and optimization techniques, we can significantly improve the performance of our machine learning workloads.

In future posts, we’ll dive deeper into the profiling process for each of our kernels and explore optimization strategies to enhance their performance.

Updated: