dup a = (a, a)
add = uncurry (+)
scale = (*)
mul = uncurry scale
sqr = mul . dup
cross f g (a, b) = (f a, g b)
f /-\ g = cross f g . dup
f \-/ g = add . cross f g
lin f a = (f a, f)
dId = lin id
dDup = lin dup
[dFst, dSnd, dAdd] = map lin [fst, snd, add]
dScale = lin . scale
dConst n _ = (n, const 0)
Automatic Differentiation
Nothing up my sleeve. No imports. No language extensions.
Don’t the following lines look like examples from a strange Haskell tutorial?
By the way, the cross
function is a special case of (***)
from
Control.Arrow
, and similarly (/-\)
is a special case of (&&&)
. We
define them because we’re avoiding imports.
The next few lines look a little less innocent:
dCross f g ab = ((c, d), cross f' g') where
((c, f'), (d, g')) = cross f g ab
infixr 9 <.
(g <. f) a = (c, g0 . f0) where
(b, f0) = f a
(c, g0) = g b
dMul (a, b) = (a * b, scale b \-/ scale a)
f >-< f' = f /-\ (scale . f')
dExp = exp >-< exp
dLog = log >-< recip
It may seem we are merely doodling, but in fact, we can now compute the
derivative of any
elementary function
via automatic differentiation (AD).
First, write a given function in point-free form.
For example, instead of \(f(x) = x^2\) we write sqr = mul . dup
. We can convert
small instances by hand, and larger instances with a bracket abstraction
algorithm.
Then to derive, replace (.)
with (<.)
and prepend 'd'
to every
identifier after capitalizing its first letter. For example, the derivative
of (/-\)
is:
f </-\ g = dCross f g <. dDup
and the derivative of sqr = mul . dup
is:
dSqr = dMul <. dDup
The function dSqr
takes a number \(a\) and returns a pair where:
-
the first element is \(f(a) = a^2\), and:
-
the second element is the linear map that best approximates \(f(a + \epsilon) - f(a)\) for small \(\epsilon\), also known as the first derivative.
We have \(f(5) = 25\), and since the derivative of \(x^2\) is \(2x\), we have \(f'(5) = \epsilon \mapsto 10 \epsilon\). Customarily we write \(f'(5) = 10\), but we wish to emphasize that the derivative is really a linear map. At any rate, we can easily convert to standard notation by setting \(\epsilon = 1\).
dSqr_example = (f, f' 1) == (25, 10) where (f, f') = dSqr 5
Calculus Homework
We can produce symbolic derivatives by using symbolic numbers if we’re willing to prepend:
{-# LANGUAGE NoMonomorphismRestriction #-} import Data.Number.Symbolic
(I’ve omitted type annotations in pursuit of minimalism, thus to avoid
specialization to Double
we disable the dreaded monomorphism restriction.)
If done, then in GHCi, simply evaluate the derivative at the symbolic
number x
, and evaluate its linear map at 1:
dxex = dMul <. (dId </-\ dExp) snd (dxex $ var "x") 1
We find the derivative of \(x e^x\) is:
exp x+x*exp x
Some elementary functions require Data.Complex
and identities such as
\(\arccos(z) = -i \ln (\sqrt{z^2 - 1} + z)\). It’s friendlier to define
shortcuts:
dSin = sin >-< cos
dCos = cos >-< (negate . sin)
dAsin = asin >-< (\x -> recip (sqrt (1 - sqr x)))
dAcos = acos >-< (\x -> - recip (sqrt (1 - sqr x)))
dAtan = atan >-< (\x -> recip (sqr x + 1))
dSinh = sinh >-< cosh
dCosh = cosh >-< sinh
dAsinh = asinh >-< (\x -> recip (sqrt (sqr x + 1)))
dAcosh = acosh >-< (\x -> - recip (sqrt (sqr x - 1)))
dAtanh = atanh >-< (\x -> recip (1 - sqr x))
dPow (a, b) = (a ** b, scale (b * (a**(b - 1))) \-/ scale (log a * (a**b)))
Let’s do more examples. Suppose we wish to derive:
-
\(x \mapsto x^5\)
-
\((x, y) \mapsto x^2 + y^2\)
-
\((x, y) \mapsto (\cos x y, \sin x y)\)
We convert to point-free form:
dFifthPower = dPow <. (dId </-\ dConst 5)
dMagSqr = dAdd <. dCross (dMul <. dDup) (dMul <. dDup)
dCosSinProd = (dCos </-\ dSin) <. dMul
Then find their symbolic derivatives with:
snd (dFifthPower (var "x")) 1 snd (dMagSqr (var "x", var "y")) (1, 0) snd (dMagSqr (var "x", var "y")) (0, 1) snd (dCosSinProd (var "x", var "y")) (1, 0) snd (dCosSinProd (var "x", var "y")) (0, 1)
Answers:
5.0*x**4.0 x+x -- With respect to x. y+y -- With respect to y. ((-sin (x*y))*y,cos (x*y)*y) -- With respect to x. ((-sin (x*y))*x,cos (x*y)*x) -- With respect to y.
One more. Let’s confirm an example from Wikipedia:
wiki = (dMul <. dCross dSqr dId) /-\ (dAdd <. dCross (dScale 5) dSin)
We type:
snd (fst (wiki (var "x", var "y"))) (1, 0) snd (fst (wiki (var "x", var "y"))) (0, 1) snd (snd (wiki (var "x", var "y"))) (1, 0) snd (snd (wiki (var "x", var "y"))) (0, 1)
and find the entries of the Jacobian matrix are indeed:
y*(x+x) x*x 5.0 cos y
Reducing (finger) typing with (static) typing
Twiddling each identifier is irksome. A more ergonomic solution that suits Haskell is to exploit typeclasses so something like:
mul . dup :: Function
represents the function \(x \mapsto x^2\), while:
mul . dup :: Derivative
represents its derivative, namely the function computing \(x \mapsto (\epsilon \mapsto 2 x \epsilon)\).
To achieve this, we’d overload mul
so it’s our original mul
as a
Function
but dMul
as a Derivative
. Similarly for the other building
blocks such as (.)
and id
.
These last two deserve special mention. Principled overloading of (.)
and
id
is known as category theory. This is a great boon, for we can stand
on the shoulders of mathematical giants to view differentiation in a whole
new light.
Instead of typeclasses, we persevere with renaming identifiers by hand. This gives us a cut-away view of category theory magic and also makes us appreciate typeclasses; absence makes the heart grow fonder.
CPS is RAD
For some applications, computing derivatives is most efficient using a method called reverse-mode automatic differentiation (RAD), which means we write our function so that all sequences of compositions are left-associative. For example:
dSinCosSqr = (((dSin <.) dCos <.) dMul <.) dDup
This is fine for small examples, but grows tiresome for larger functions. Furthermore, we only care about left-folding compositions for the linear maps (derivatives), not for the function itself (the "primal").
Can we force left-associativity for the linear maps without explicit parentheses everywhere? Yes! Continuation-passing style (CPS) comes to our rescue.
Once again, with the right set of GHC extensions:
mul . dup :: RAD
represents the derivative of \(x \mapsto x^2\) computed via RAD. And once again, we’re forgoing typeclasses, so we’re forced to employ tedious text substitution instead. Define:
rad d a = let (f, f') = d a in (f, (. f'))
infixr 9 <<.
g <<. f = \a -> let
(b, f0) = f a
(c, g0) = g b
in (c, f0 . g0)
To force our derivatives to be computed in reverse-mode, we apply rad
to
each building block and replace (<.)
with (<<.)
. This leads to a
problem with dCross
, which does not expect arguments in CPS form; we fix
this below.
The linear map is now represented as a continuation, so should be
applied to id
before use. For example, the derivative of
\(f(x) = \sin (\cos x^2)\) via RAD is:
radSqr = rad dMul <<. rad dDup
radSinCosSqr = rad dSin <<. rad dCos <<. radSqr
and we can compute \(f'(2)\) with:
radSinCosSqr'2 = snd (radSinCosSqr 2) id 1
Time to fix dCross
. We need more machinery for its RAD variant:
inl a = (a, 0)
inr b = (0, b)
jam = uncurry (+)
join (f, g) = jam . cross f g
unjoin h = (h . inl, h . inr)
radCross f g ab = let
((c, f'), (d, g')) = cross f g ab
in ((c, d), join . cross f' g' . unjoin)
For example:
radCosSinProd = (radCross (rad dCos) (rad dSin) <<. rad dDup) <<. rad dMul
radMagSqr = rad dAdd <<. radCross (rad dMul <<. rad dDup) (rad dMul <<. rad dDup)
Gradients
In practice, the co-domain of each continuation is often the scalar field, that is, we wish to differentiate a function that utlimately spits out a single number.
For example, in neural networks, we define a cost function that has many parameters (the weights and biases) and outputs a single number. We differentiate to find the gradient, which tells us how best to tweak the parameters to reduce the cost. Namely, we perform gradient-based optimization. When computed via RAD, this is known as backpropagation.
It turns out we can simplify this special case because the linear maps from a
vector space to a scalar field are isomorphic to that vector space; they are
dual. We call this isomorphism and its inverse dot
and undot
for the
1-dimensional case. We also define 2-dimensional variants. With clever
typeclasses, we could avoid inventing similar names and carefully matching the
isomorphisms with various function types; the compiler would do all this for
us.
Sandwiching a function of type ((b -> s) -> (a -> s))
between undot
and dot
yields a function of type b -> a
.
dot = scale
undot = ($ 1)
dot2 (u, v) = dot u \-/ dot v
undot2 f = (f (1, 0), f (0, 1))
dua d a = let (f, f') = d a in (f, undot . (. f') . dot)
duaBin d a = let (f, f') = d a in (f, undot2 . (. f') . dot)
duaDup a = (dup a, jam)
duaJam a = (jam a, dup)
duaScale s a = (scale s a, scale s)
duaCross = dCross
Composition is the same as for the CPS form, so we can reuse (<<.)
in our
examples.
duaSqr = duaBin dMul <<. duaDup
duaMagSqr = duaBin dAdd <<. dCross (duaBin dMul <<. duaDup) (duaBin dMul <<. duaDup)
f <</-\ g = duaCross f g <<. duaDup
duaxex = duaBin dMul <<. (dua dId <</-\ dua dExp)
Ideally, compositions of inverses such has dot . undot
would be optimized
away. GHC can do this with rewrite rules.
One Neuron
In machine learning, a single neuron of a neural net is a bunch of weights and a bias that are typically represented with floats.
type Neuron = ([Double], Double)
To compute the output of a neuron, we multiply each input with its
corresponding weight, add the bias, sum it all up, then feed the result
through some activation function f
:
fire inputs neuron = f $ sum (zipWith (*) inputs weights) + bias
where
weights = fst neuron
bias = snd neuron
f x = recip $ 1 + exp (-x) -- Sigmoid.
fire_example = fire [0.2,0.1,0.7] ([3,1,4], -2)
Training requires the derivative of this function with respect to each weight and bias. (We’ve simplified a little: we actually want the derivative of something like the square of the difference between the output of this function and the desired output.)
We could organize the weights into nested pairs to fit the function into our
framework, but we’re better off supporting a vector of weights directly,
because in practice there are many of them. For our toy demo we use lists.
Since we’re shunning typeclasses, we invent yet more names for the list
versions of jam
and cross
:
jamList = sum
dJamList = lin jamList
crossList = zipWith id
dCrossList f's = cross id crossList . unzip . crossList f's
We can express firing a neuron as follows:
sigmoid = recip . add . (cross (const 1) (exp . negate)) . dup
fire' inputs = sigmoid . add . cross (jamList . crossList scaleInputs) id
where scaleInputs = scale <$> inputs
as well as its first derivative:
dSigmoid = sigmoid >-< (\x -> sigmoid x * (1 - sigmoid x)) -- Exercise.
dFire inputs = dSigmoid <. dAdd <. dCross (dJamList <. dCrossList dScaleInputs) dId
where dScaleInputs = dScale <$> inputs
We can compute in reverse-mode with a few more mechanical code transformations:
replList = replicate 3
duaJamList a = (jamList a, replList)
duaCrossList = dCrossList
duaFire inputs = dua dSigmoid <<. duaBin dAdd <<. duaCross (duaJamList <<. duaCrossList duaScaleInputs) dId
where duaScaleInputs = dScale <$> inputs
For example:
snd ( duaFire [0.2,0.1,0.7] ([3,1,4],-2) ) 1
computes the gradient at the input [0.2,0.7,0.1] for the weights [3,1,4] and bias -2:
([2.9829290414066574e-2,1.4914645207033287e-2,0.104402516449233],0.14914645207033286)
This technique scales up for layers of neurons, but without typeclasses, we’d have to juggle closely related building blocks, painstakingly ensure the right type of block is in the right place.
References
How does this all work?
Our first lines of code are based on Conal Elliott’s blog post, who cites a paper by Jerzy Karczmarczuk. For the rest of our code, see Conal Elliott’s talk on automatic differentiation and category theory and his corresponding paper.
You may remember the CPS trick from computer science classics such as Reynolds' paper on definitional interpreters, but in fact the idea can be traced to an 1854 paper by Cayley.
In Haskell, sequences of (++)
are fastest when evaluated in
"forward-mode".
Joachim
Breitner explains how to transform to CPS to achieve this with an approach
that resembles ours.
Our choice of (<.)
was inspired by
the flow package.
Naming is hard
Conal Elliott laments that automatic differentiation is "typically presented in opposition to symbolic differentiation", which is clearly at odds with our examples above!
To avoid confusion, I propose renaming "symbolic differentiation" to "schoolbook differentiation" or perhaps "shambolic differentiation", because often it seems to refer to the centuries-old methods of differentiation taught in school, where expressions are small enough that sharing saves little, and has nothing to do with whether the answer is ultimately symbolic or numeric.
Less grating are the terms "forward-mode" and "reverse-mode" for what
functional programmers might call right and left folds.
The reverse
function of Haskell98 is a left fold, a lovely coincidence stemming from
the right-folded nature of Haskell lists.