Flash-SD-KDE: Accelerating SD-KDE with Tensor Cores

TL;DR

Problem setting

SD-KDE adds a score-based shift step to KDE, which introduces extra work. The paper's goal is to keep the estimator exact while dramatically reducing runtime on modern GPUs.

Key idea

Rewrite the dominant computations in SD-KDE so they are expressed as matrix multiplications. This makes the work compatible with Tensor Core optimized GEMMs (implemented in Triton).

Method (high level)

  1. Compute an empirical score for each training point.
  2. Shift training points using the score.
  3. Evaluate KDE on the shifted samples.

The Gaussian KDE can be written as:

p^(x)=1nhdi=1nexp(xxi22h2)\hat{p}(x) = \frac{1}{n h^d} \sum_{i=1}^n \exp\left(-\frac{\|x - x_i\|^2}{2h^2}\right)

The key algebraic trick is the dot-product identity:

xy2=x2+y22xy\|x - y\|^2 = \|x\|^2 + \|y\|^2 - 2x^\top y

The implementation avoids storing full pairwise matrices by streaming tiles and accumulating results.

Oracle error vs n

Runtime comparison (16D)

Runtime comparison (1D)

Evidence

The paper reports:

Limitations and caveats

Takeaway

Flash-SD-KDE is a hardware-aware rewrite of SD-KDE that makes a statistically appealing estimator practical at larger scales.