Kd-Tree Nearest Neighbor Search – WebGL Shader

This post will be second part of my kd-tree research. I’m going to discuss the the kd-tree nearest neighbor search on the GPU, using OpenGL-shaders (GLSL), to be more precise, using WebGL Shaders (OpenGL ES Shading Language 1.0). It is important to extra mention the difference here, because the shaders done in GLSL x.xx really differ from the one done in GL ES SL x.xx. at least when your trying to do some sort of GPGPU. Usually OpenCL and perhaps WebCL (didn’t try that yet) are really convienent for traversing acceleration structures like BVH/Octree or Kd-Tree, but since most browsers are more or less WebgGL-able out of the box, this was my first choice.


WebGL Demo – PixelVoronoi KdTree

Source – https://github.com/diwi/PixelVoronoi_KdTree

Note: The demo runs a lot faster in chrome, than firefox.


featured at chrome-experiments


Recursion vs. Stack vs. Stackless

First (and obvious) things first. Pointer-based tree structures are not allowed in shaders, same goes  for recursion.  So i first had to replace every recursive method by an iterative one using a stack. The maximum stack size is always known,

max stack size: Math.ceil( Math.log(number of points) / Math.log(2) )

or just:  log2(number of points) rounded up.

… so the stack is really just an array of fixed size using an index to push/pop elements.

Later in this post i will explain the stackless version, required for WebGL shaders.


Kd-Tree Representation

Next, the Pointer-based KdTree gets replaced by a flat one (nodes are saved in an array). There are several ways to store a binary tree in an array representation. A simple one is, to just save the nodes, as they appear in the stack during traversal. This requires each node to save an index to the left child (right childs is just next to the left). My goal was to keep the node a slim as possible, so i chose to align the nodes in a “power of two”-manner (dont know how science people call this) where depth, split-dimension and indices (Parent, Left, Right) can be computed based on the current node index. It’s easier to show this in an image:

(Edit: i made a mistake in the graphic above. Of course it must be: L = i<<1 and P = i>>1)

Obviously there might be some (on unbalanced trees it’s a lot more) memory wasted. For example, when node nr 2 would be a leaf, the whole branch is wasted memory, since the whole array has to be at least of size pow(2, levels). In my case, at most half the array – 1 is wasted space, but therefore the array doesnt have to be allocated that often, when the number of points change. Another thing to note is, that the root starts at index 1, depth may also start at 1 and the splitdimension should be chosen on the largest extent.


Kd-Tree Node

Next thing to care about is, what information to actually store in each node. The minimum of information a node needs is:

  • if the node is a leaf or not (1 bit)
  • splitting position (float), or when the node is a leaf, the point( float2, or float3) or the point-index (integer)

The dimension can be computed on the fly. I found it more convenient to save the dimension in 1 bit, rather then computing it based on the current depth. If the kd-tree is build in 3d-space, either 2 bits are needed to save the dimension, or it is computed on demand during traversal. Based on the application, a leaf-node may just save an index to a point instead of the point itself. In my case, having 2d-points (all positive within the viewport) i need 4 bytes per node (RGBA texture2D, UNSIGNED_BYTE, 1 byte per channel):


// .js snippet

// Kd-Tree Node Encoding: 32 bit Integer
//                      24       16        8        0
//  L ... leaf (0, 1)
//  D ... dimension (0, 1) ... could also be computet in the shader (on the fly).
//  X ... Point.x ( scaled, and converted to int)
//  Y ... Point.y ( scaled, and converted to int)
// nodes are saved in an integer array, having the root node at index 1.
// childs are at the current-index * 2 (...+0=left, ...+1= right).
// the parent node is found at current-index/2.
// due to this tree-storage, some indices contain "no nodes".

// node encoding masks
KdTree.BIT_LEAF = 0x80000000; //    31
KdTree.BIT_DIM  = 0x00008000; //    15
KdTree.BIT_P_Y  = 0x00007FFF; //  0-14
KdTree.BIT_P_X  = 0x7FFF0000; // 16-30

// point-scaling (to keep some precision). !same scaling in shader!
KdTree.PNT_SCALE = 10.0;

// get node indices for: Parent, Left-child, Right-child
KdTree.GET_P = function(node_idx){ return (node_idx>>1)  ; } // parent node_idx
KdTree.GET_L = function(node_idx){ return (node_idx<<1)  ; } // left   node_idx
KdTree.GET_R = function(node_idx){ return (node_idx<<1)+1; } // right  node_idx

// get node data: leaf, dimension, point.x, point.y
KdTree.IS_LEAF = function(node_val){ return ( node_val&KdTree.BIT_LEAF )>>31; } // last bit!
KdTree.GET_DIM = function(node_val){ return ( node_val&KdTree.BIT_DIM  )>>15; }
KdTree.GET_PX  = function(node_val){ return ((node_val&KdTree.BIT_P_X  )>>16)/KdTree.PNT_SCALE; }
KdTree.GET_PY  = function(node_val){ return ((node_val&KdTree.BIT_P_Y  )>> 0)/KdTree.PNT_SCALE; }


Kd-Tree as Texture2D

The whole tree is saved to a 2d-texture where 1 pixel contains 1 node (the less texture reads, the better). Another way to upload the tree to the shader, would be in an uniform vector, which is way faster, than doing texture-reads, but the number of uniforms is very limited, so this is no real option.


WebGL Limitations

Until here everything went fine when doing this in GLSL. I could just copy the stackbased nearest neighbor search routine to the fragment shader and it worked out of the box. When switching to GL ES a lot of problems came up. Most important, you can’t access an array by a dynamic index (stackpointer), loops must be constant too. One (slow) way around this is, to loop over all possible indices compare them to the stack-pointer, when found, use the loop index for array access. This works, but it looks really ugly in the code and also is quite a disappointment when comparing the performance to the GLSL-version. (also the shader-validation takes long, and sometimes the webgl context gets lost). So to sum up, stackbased tree traversal should be avoided in WebGL shaders. Next limitation is: vector components cant be accessed by dynamic indices which would have been nice, to select the split-position using the dimension as index. Last limitation is: no bit-wise operators are available, more generally, GL ES seems to be a lot more focused on floating-point ops. This last one is not that bad actually, because all bitwise operations can be replaced by the basic arithemtic expressions:


// fragment shader - snippet

precision mediump float;
precision highp int;

// Bit-Shifting
#define SHIFT_15 0x8000
#define SHIFT_08 0x0100
#define SHIFT_01 0x0002

// RS_XX= Right Shift by XX, LS_XX = Left shift by XX
#define RS_15(i) ( (i) / SHIFT_15 )
#define LS_15(i) ( (i) * SHIFT_15 )
#define RS_08(i) ( (i) / SHIFT_08 )
#define LS_08(i) ( (i) * SHIFT_08 )
#define RS_01(i) ( (i) / SHIFT_01 )
#define LS_01(i) ( (i) * SHIFT_01 )

// node index pointer
#define P(i) ( RS_01(i)   ) // parent node
#define L(i) ( LS_01(i)   ) // left child
#define R(i) ( LS_01(i)+1 ) // right child

#define POINT_PREC 0.1 // for scale point coords


Shader – Kd-Tree traversal – Nearest Neighbor Search

All those limitations force you to rethink the kd-tree traversal for doing a (still efficient) NNS. Since a stack-based solution is not sufficient enough, i had to came up with another solution. To not get caught in an endless recursion, which happens when unwinding from a leaf, and checking the other halfspaces for beeing closer (normal distance in split-dimension) than the current min distance, somewhere a flag has to be set, to avoid stepping down a branch multiple times. Since nodes are read from the texture on demand, no state can be saved there (which again would result in a stackbased version). In the end i used a bit-mask (integer) to save the state. Each bit represents a depth-level (= 2^depth) that is set to 1 or 0 based on the traversing direction (up or down) and current depth. This can be seen as an array of bits, that is somehow a stack in the end, but allows dynamic editing. This traversing technique seems to be quite suitable for traversing tree-structures in GL ES shaders.


// fragment shader - snippet
// general algorithm:
// 1) while traversing down, always choose half-space [HS] the point is in
// 2) while traversing back, check if current min distance is greater
//    than normal distance to split plane. if so, check the other HS too.
// instead of using a stack, i use an index-pointer-based iterating process
// by checking/modifying a bit-mask (... 2^depth) to avoid checking the same
// HS again and again (endless recursion). This is BY FAR! the best solution
// when using GLSL-ES 1.0, in GLSL a stackbased solution is probably better.
void getNearestNeighbor(inout NN nn){

  Node node;
  bool down  = true;
  int  n_idx = 1; // 1 = root
  int  depth = 1; // current depth (power of 2 --> bit indicates depth)
  int  dcode = 0; // depth-bits inidicate checked HalfSpaces

  for(int i = 0; i < 500; i++){                    // constant loop (GL ES)

    getNode(n_idx, node);                          // get node from texture
    float pd = planeDistance(nn, node);            // normal dist to split plane

    if(down){                                      // if traversing down
      if( down = (node.leaf == 0) ){               //   if not leaf
        depth = LS_01(depth);                      //     incr depth (go down)
        n_idx = (pd < 0.0) ? L(n_idx) : R(n_idx);  //     get child
      } else {                                     //   else (=leaf)
        updateMinDis(nn, node);                    //     update min distance
        depth = RS_01(depth);                      //     decr depth (go up now)
        n_idx = P(n_idx);                          //     get parent
    } else {                                       // else (=undwinding)
      if(down = ((dcode < depth) &&                //   if not checked yet
                 (abs(pd) < nn.dis)))              //   AND overlapping
      {                                            //     --> check (other) HS
        dcode += depth;                            //     set depth-bit
        depth  = LS_01(depth);                     //     incr depth (go down)
        n_idx  = (pd < 0.0) ? R(n_idx) : L(n_idx); //     get (other) child
      } else {                                     //   else (=undwinding)
        dcode -= (dcode < depth) ? 0 : depth;      //     clear depth-bit
        depth  = RS_01(depth);                     //     decr depth (go up)
        n_idx  = P(n_idx);                         //     get parent

    if(depth == 0) break; // THIS is the end of the nearest neighbor search.


The performance is quite amazing too. While the stackbased solution was the main fps-killer, its now the construction of the kd-tree which is done in javascript.

Having a kd-tree of millions of nodes (huge memory usage) i get about 10-15 fps in google chrome when doing the nearest neighbor search for each pixel at a resolution of 1900 x 1200 (GTX 550 Ti).