Automatic Differentiation

Nothing up my sleeve. No imports. No language extensions.

Don’t the following lines look like examples from a strange Haskell tutorial?

dup a = (a, a)
add = uncurry (+)
scale = (*)
mul = uncurry scale
sqr = mul . dup
cross f g (a, b) = (f a, g b)
f /-\ g = cross f g . dup
f \-/ g = add . cross f g
lin f a = (f a, f)
dId = lin id
dDup = lin dup
[dFst, dSnd, dAdd] = map lin [fst, snd, add]
dScale = lin . scale
dConst n _ = (n, const 0)

By the way, the cross function is a special case of (***) from Control.Arrow, and similarly (/-\) is a special case of (&&&). We define them because we’re avoiding imports.

The next few lines look a little less innocent:

dCross f g ab = ((c, d), cross f' g') where
  ((c, f'), (d, g')) = cross f g ab
infixr 9 <.
(g <. f) a = (c, g0 . f0) where
  (b, f0) = f a
  (c, g0) = g b
dMul (a, b) = (a * b, scale b \-/ scale a)
f >-< f' = f /-\ (scale . f')
dExp = exp >-< exp
dLog = log >-< recip

It may seem we are merely doodling, but in fact, we can now compute the derivative of any elementary function via automatic differentiation (AD). First, write a given function in point-free form. For example, instead of \(f(x) = x^2\) we write sqr = mul . dup. We can convert small instances by hand, and larger instances with a bracket abstraction algorithm.

Then to derive, replace (.) with (<.) and prepend 'd' to every identifier after capitalizing its first letter. For example, the derivative of (/-\) is:

f </-\ g = dCross f g <. dDup

and the derivative of sqr = mul . dup is:

dSqr = dMul <. dDup

The function dSqr takes a number \(a\) and returns a pair where:

  • the first element is \(f(a) = a^2\), and:

  • the second element is the linear map that best approximates \(f(a + \epsilon) - f(a)\) for small \(\epsilon\), also known as the first derivative.

We have \(f(5) = 25\), and since the derivative of \(x^2\) is \(2x\), we have \(f'(5) = \epsilon \mapsto 10 \epsilon\). Customarily we write \(f'(5) = 10\), but we wish to emphasize that the derivative is really a linear map. At any rate, we can easily convert to standard notation by setting \(\epsilon = 1\).

dSqr_example = (f, f' 1) == (25, 10) where (f, f') = dSqr 5

Calculus Homework

We can produce symbolic derivatives by using symbolic numbers if we’re willing to prepend:

{-# LANGUAGE NoMonomorphismRestriction #-}
import Data.Number.Symbolic

(I’ve omitted type annotations in pursuit of minimalism, thus to avoid specialization to Double we disable the dreaded monomorphism restriction.) If done, then in GHCi, simply evaluate the derivative at the symbolic number x, and evaluate its linear map at 1:

dxex = dMul <. (dId </-\ dExp)
snd (dxex $ var "x") 1

We find the derivative of \(x e^x\) is:

exp x+x*exp x

Some elementary functions require Data.Complex and identities such as \(\arccos(z) = -i \ln (\sqrt{z^2 - 1} + z)\). It’s friendlier to define shortcuts:

dSin   = sin >-< cos
dCos   = cos >-< (negate . sin)
dAsin  = asin >-< (\x -> recip (sqrt (1 - sqr x)))
dAcos  = acos >-< (\x -> - recip (sqrt (1 - sqr x)))
dAtan  = atan >-< (\x -> recip (sqr x + 1))
dSinh  = sinh >-< cosh
dCosh  = cosh >-< sinh
dAsinh = asinh >-< (\x -> recip (sqrt (sqr x + 1)))
dAcosh = acosh >-< (\x -> - recip (sqrt (sqr x - 1)))
dAtanh = atanh >-< (\x -> recip (1 - sqr x))
dPow (a, b) = (a ** b, scale (b * (a**(b - 1))) \-/ scale (log a * (a**b)))

Let’s do more examples. Suppose we wish to derive:

  • \(x \mapsto x^5\)

  • \((x, y) \mapsto x^2 + y^2\)

  • \((x, y) \mapsto (\cos x y, \sin x y)\)

We convert to point-free form:

dFifthPower = dPow <. (dId </-\ dConst 5)
dMagSqr = dAdd <. dCross (dMul <. dDup) (dMul <. dDup)
dCosSinProd = (dCos </-\ dSin) <. dMul

Then find their symbolic derivatives with:

snd (dFifthPower (var "x")) 1
snd (dMagSqr (var "x", var "y")) (1, 0)
snd (dMagSqr (var "x", var "y")) (0, 1)
snd (dCosSinProd (var "x", var "y")) (1, 0)
snd (dCosSinProd (var "x", var "y")) (0, 1)


x+x  -- With respect to x.
y+y  -- With respect to y.
((-sin (x*y))*y,cos (x*y)*y)  -- With respect to x.
((-sin (x*y))*x,cos (x*y)*x)  -- With respect to y.

One more. Let’s confirm an example from Wikipedia:

wiki = (dMul <. dCross dSqr dId) /-\ (dAdd <. dCross (dScale 5) dSin)

We type:

snd (fst (wiki (var "x", var "y"))) (1, 0)
snd (fst (wiki (var "x", var "y"))) (0, 1)
snd (snd (wiki (var "x", var "y"))) (1, 0)
snd (snd (wiki (var "x", var "y"))) (0, 1)

and find the entries of the Jacobian matrix are indeed:

cos y

Reducing (finger) typing with (static) typing

Twiddling each identifier is irksome. A more ergonomic solution that suits Haskell is to exploit typeclasses so something like:

mul . dup :: Function

represents the function \(x \mapsto x^2\), while:

mul . dup :: Derivative

represents its derivative, namely the function computing \(x \mapsto (\epsilon \mapsto 2 x \epsilon)\).

To achieve this, we’d overload mul so it’s our original mul as a Function but dMul as a Derivative. Similarly for the other building blocks such as (.) and id.

These last two deserve special mention. Principled overloading of (.) and id is known as category theory. This is a great boon, for we can stand on the shoulders of mathematical giants to view differentiation in a whole new light.

Instead of typeclasses, we persevere with renaming identifiers by hand. This gives us a cut-away view of category theory magic and also makes us appreciate typeclasses; absence makes the heart grow fonder.


For some applications, computing derivatives is most efficient using a method called reverse-mode automatic differentiation (RAD), which means we write our function so that all sequences of compositions are left-associative. For example:

dSinCosSqr = (((dSin <.) dCos <.) dMul <.) dDup

This is fine for small examples, but grows tiresome for larger functions. Furthermore, we only care about left-folding compositions for the linear maps (derivatives), not for the function itself (the "primal").

Can we force left-associativity for the linear maps without explicit parentheses everywhere? Yes! Continuation-passing style (CPS) comes to our rescue.

Once again, with the right set of GHC extensions:

mul . dup :: RAD

represents the derivative of \(x \mapsto x^2\) computed via RAD. And once again, we’re forgoing typeclasses, so we’re forced to employ tedious text substitution instead. Define:

rad d a = let (f, f') = d a in (f, (. f'))
infixr 9 <<.
g <<. f = \a -> let
  (b, f0) = f a
  (c, g0) = g b
  in (c, f0 . g0)

To force our derivatives to be computed in reverse-mode, we apply rad to each building block and replace (<.) with (<<.). This leads to a problem with dCross, which does not expect arguments in CPS form; we fix this below.

The linear map is now represented as a continuation, so should be applied to id before use. For example, the derivative of \(f(x) = \sin (\cos x^2)\) via RAD is:

radSqr = rad dMul <<. rad dDup
radSinCosSqr = rad dSin <<. rad dCos <<. radSqr

and we can compute \(f'(2)\) with:

radSinCosSqr'2 = snd (radSinCosSqr 2) id 1

Time to fix dCross. We need more machinery for its RAD variant:

inl a = (a, 0)
inr b = (0, b)
jam = uncurry (+)

join (f, g) = jam . cross f g
unjoin h = (h . inl, h . inr)

radCross f g ab = let
  ((c, f'), (d, g')) = cross f g ab
  in ((c, d), join . cross f' g' . unjoin)

For example:

radCosSinProd = (radCross (rad dCos) (rad dSin) <<. rad dDup) <<. rad dMul
radMagSqr = rad dAdd <<. radCross (rad dMul <<. rad dDup) (rad dMul <<. rad dDup)


In practice, the co-domain of each continuation is often the scalar field, that is, we wish to differentiate a function that utlimately spits out a single number.

For example, in neural networks, we define a cost function that has many parameters (the weights and biases) and outputs a single number. We differentiate to find the gradient, which tells us how best to tweak the parameters to reduce the cost. Namely, we perform gradient-based optimization. When computed via RAD, this is known as backpropagation.

It turns out we can simplify this special case because the linear maps from a vector space to a scalar field are isomorphic to that vector space; they are dual. We call this isomorphism and its inverse dot and undot for the 1-dimensional case. We also define 2-dimensional variants. With clever typeclasses, we could avoid inventing similar names and carefully matching the isomorphisms with various function types; the compiler would do all this for us.

Sandwiching a function of type ((b -> s) -> (a -> s)) between undot and dot yields a function of type b -> a.

dot = scale
undot = ($ 1)
dot2 (u, v) = dot u \-/ dot v
undot2 f = (f (1, 0), f (0, 1))

dua d a = let (f, f') = d a in (f, undot . (. f') . dot)
duaBin d a = let (f, f') = d a in (f, undot2 . (. f') . dot)

duaDup a = (dup a, jam)
duaJam a = (jam a, dup)
duaScale s a = (scale s a, scale s)
duaCross = dCross

Composition is the same as for the CPS form, so we can reuse (<<.) in our examples.

duaSqr = duaBin dMul <<. duaDup
duaMagSqr = duaBin dAdd <<. dCross (duaBin dMul <<. duaDup) (duaBin dMul <<. duaDup)

f <</-\ g = duaCross f g <<. duaDup
duaxex = duaBin dMul <<. (dua dId <</-\ dua dExp)

Ideally, compositions of inverses such has dot . undot would be optimized away. GHC can do this with rewrite rules.

One Neuron

In machine learning, a single neuron of a neural net is a bunch of weights and a bias that are typically represented with floats.

type Neuron = ([Double], Double)

To compute the output of a neuron, we multiply each input with its corresponding weight, add the bias, sum it all up, then feed the result through some activation function f:

fire inputs neuron = f $ sum (zipWith (*) inputs weights) + bias
   weights = fst neuron
   bias = snd neuron
   f x = recip $ 1 + exp (-x)  -- Sigmoid.

fire_example = fire [0.2,0.1,0.7] ([3,1,4], -2)

Training requires the derivative of this function with respect to each weight and bias. (We’ve simplified a little: we actually want the derivative of something like the square of the difference between the output of this function and the desired output.)

We could organize the weights into nested pairs to fit the function into our framework, but we’re better off supporting a vector of weights directly, because in practice there are many of them. For our toy demo we use lists. Since we’re shunning typeclasses, we invent yet more names for the list versions of jam and cross:

jamList = sum
dJamList = lin jamList
crossList = zipWith id
dCrossList f's = cross id crossList . unzip . crossList f's

We can express firing a neuron as follows:

sigmoid = recip . add . (cross (const 1) (exp . negate)) . dup
fire' inputs = sigmoid . add . cross (jamList . crossList scaleInputs) id
  where scaleInputs = scale <$> inputs

as well as its first derivative:

dSigmoid = sigmoid >-< (\x -> sigmoid x * (1 - sigmoid x))  -- Exercise.
dFire inputs = dSigmoid <. dAdd <. dCross (dJamList <. dCrossList dScaleInputs) dId
  where dScaleInputs = dScale <$> inputs

We can compute in reverse-mode with a few more mechanical code transformations:

replList = replicate 3
duaJamList a = (jamList a, replList)
duaCrossList = dCrossList
duaFire inputs = dua dSigmoid <<. duaBin dAdd <<. duaCross (duaJamList <<. duaCrossList duaScaleInputs) dId
  where duaScaleInputs = dScale <$> inputs

For example:

snd ( duaFire [0.2,0.1,0.7] ([3,1,4],-2) ) 1

computes the gradient at the input [0.2,0.7,0.1] for the weights [3,1,4] and bias -2:


This technique scales up for layers of neurons, but without typeclasses, we’d have to juggle closely related building blocks, painstakingly ensure the right type of block is in the right place.


How does this all work?

You may remember the CPS trick from computer science classics such as Reynolds' paper on definitional interpreters, but in fact the idea can be traced to an 1854 paper by Cayley.

In Haskell, sequences of (++) are fastest when evaluated in "forward-mode". Joachim Breitner explains how to transform to CPS to achieve this with an approach that resembles ours.

Our choice of (<.) was inspired by the flow package.

Naming is hard

Conal Elliott laments that automatic differentiation is "typically presented in opposition to symbolic differentiation", which is clearly at odds with our examples above!

To avoid confusion, I propose renaming "symbolic differentiation" to "schoolbook differentiation" or perhaps "shambolic differentiation", because often it seems to refer to the centuries-old methods of differentiation taught in school, where expressions are small enough that sharing saves little, and has nothing to do with whether the answer is ultimately symbolic or numeric.

Less grating are the terms "forward-mode" and "reverse-mode" for what functional programmers might call right and left folds. The reverse function of Haskell98 is a left fold, a lovely coincidence stemming from the right-folded nature of Haskell lists.

Ben Lynn 💡