Differentiable programming is useful in machine learning research, because it allows for efficient optimization of any desired parameters. Forward- and reverse-mode automatic differentiation are the two major paradigms for finding these derivatives, and differentiable programming is often framed in terms of functional programming. In this post, I will focus on reverse-mode automatic differentiation, because forward-mode is trivial (dual numbers) and inefficient for functions from many variables to few variables, e.g. .
I was curious how derivatives for the more obscure/less functional-appearing operations are determined; one such operation is sort
. Assume we have an array with elements . We perform the operation . The question is to (efficiently) find the Jacobian
$$J_{ij}=\frac{\partial B_i}{\partial A_j}=\bar J_{ij}^{-1}=\frac{1}{\frac{\partial A_i}{\partial B_j}}$$
Without trying to dig through the source code for JAX or another automatic differentiation framework, here are the observations I made. First, sorting an array is equivalent to identifying a particular permutation of the elements. In addition to the traditional jax.numpy.sort
operation, there is jax.numpy.argsort
which precisely finds the list of indices specifying the permutation. When we perform reverse mode automatic differentiation, we are given a computational graph, for which each node represents , and we are given and , and we must use the chain rule to compute (technically, we keep track of the adjoint, denoted by the above). In multiple dimensions, we must also use the multi-dimensional chain rule, so that there is an implicit sum/matrix-multiplication.
In the case of , the derivatives are 1 or 0, with it being 1 when is the component of the sorted array and 0 otherwise.
Since the derivative is so simple, we just need to propagate the gradient from the sorted elements back to the corresponding elements of , taking care of the chain rule. In other words, if is sorted to , then the gradient from is copied back to . We thus need to invert the permutation provided by jax.numpy.argsort
. How do we invert a sort? The trick is to argsort
the argsort
(https://stackoverflow.com/questions/9185768/inverting-permutations-in-python)!
Assume we have already argsorted
, which returns indices . If we argsort
, then sort
is in the original order of the indices of , , therefore argsort
must be the inverse permutation to the permutation represent by .
I implemented the forward and backward of jax.numpy.sort
manually.
After JITting the appropriate functions, the timing is precisely the same. The gradients are also correct (up to some transpositional/dimensional trickery). This strongly suggests that we achieved the reference implementation for differentiating sort
.
I thought this was a neat trick of coercing what is typically a highly branched and seemingly imperative code, (sorting being a prototypical example of this) into a functional programming form. Essentially, argsort
is a weird kind of idempotent operation, though it preserves only the ordering of a collection of items unless you explicitly keep track of the elements as well.