Differentiating through sort


Differentiable programming is useful in machine learning research, because it allows for efficient optimization of any desired parameters. Forward- and reverse-mode automatic differentiation are the two major paradigms for finding these derivatives, and differentiable programming is often framed in terms of functional programming. In this post, I will focus on reverse-mode automatic differentiation, because forward-mode is trivial (dual numbers) and inefficient for functions from many variables to few variables, e.g. f:\mathbb{R}^n\rightarrow \mathbb R.

I was curious how derivatives for the more obscure/less functional-appearing operations are determined; one such operation is sort. Assume we have an array A with elements A_i. We perform the operation B = \text{sort}(A). The question is to (efficiently) find the Jacobian

$$J_{ij}=\frac{\partial B_i}{\partial A_j}=\bar J_{ij}^{-1}=\frac{1}{\frac{\partial A_i}{\partial B_j}}$$

Without trying to dig through the source code for JAX or another automatic differentiation framework, here are the observations I made. First, sorting an array is equivalent to identifying a particular permutation of the elements. In addition to the traditional jax.numpy.sort operation, there is jax.numpy.argsort which precisely finds the list of indices specifying the permutation. When we perform reverse mode automatic differentiation, we are given a computational graph, for which each node represents g(f(x)), and we are given g' and f(x), and we must use the chain rule to compute (g\circ f)(x)=g'(f(x))f'(x) (technically, we keep track of the adjoint, denoted by the \bar J above). In multiple dimensions, we must also use the multi-dimensional chain rule, so that there is an implicit sum/matrix-multiplication.

In the case of f(A)=\text{sort}(A)=B, the derivatives \frac{\partial B_i}{\partial A_j} are 1 or 0, with it being 1 when A_j is the i^{th} component of the sorted array and 0 otherwise.

Since the derivative is so simple, we just need to propagate the gradient from the sorted elements B_j back to the corresponding elements of A_i, taking care of the chain rule. In other words, if A_i is sorted to B_j, then the gradient from B_j is copied back to A_i. We thus need to invert the permutation provided by jax.numpy.argsort. How do we invert a sort? The trick is to argsort the argsort (https://stackoverflow.com/questions/9185768/inverting-permutations-in-python)!

Assume we have already argsorted A, which returns indices I. If we argsort I, then sort I is in the original order of the indices of A, [0,1,2,\cdots], therefore argsortI must be the inverse permutation to the permutation represent by I.

I implemented the forward and backward of jax.numpy.sort manually.

No description available.

 

After JITting the appropriate functions, the timing is precisely the same. The gradients are also correct (up to some transpositional/dimensional trickery). This strongly suggests that we achieved the reference implementation for differentiating sort.

No description available.

No description available.

I thought this was a neat trick of coercing what is typically a highly branched and seemingly imperative code, (sorting being a prototypical example of this) into a functional programming form. Essentially, argsort is a weird kind of idempotent operation, though it preserves only the ordering of a collection of items unless you explicitly keep track of the elements as well.

Code: https://github.com/mikesha2/diffsort/tree/main

Leave a Reply

Your email address will not be published. Required fields are marked *