§ Comparison of forward and reverse mode AD

Quite a lot of ink has been spilt on this topic. My favourite reference is the one by Rufflewind. However, none of these examples have a good stock of examples for the diference. So here, I catalogue the explicit computations between computing forward mode AD and reverse mode AD. In general, in forward mode AD, we fix how much the inputs wiggle with respect to a parameter tt. We figure out how much the output wiggles with respect to tt. If output=f(input1,input2,inputn)output = f(input_1, input_2, \dots input_n), then outputt=ifinputiinputidt\frac{\partial output}{\partial t} = \sum_i \frac{\partial f}{\partial input_i} \frac{\partial input_i}{\partial dt}. In reverse mode AD, we fix how much the parameter tt wiggles with respect to the output. We figure out how much the parameter tt wiggles with respect to the inputs. If outputi=fi(input,)output_i = f_i(input, \dots), then tinput=itoutputifiinput\frac{\partial t}{\partial input} = \sum_i \frac{\partial t}{\partial output_i} \frac{\partial f_i}{input}. This is a much messier expression, since we need to accumulate the data over all outputs. Essentially, deriving output from input is easy, since how to compute an output from an input is documented in one place. deriving input from output is annoying, since many outputs can depent on a single output. The upshot is that if we have few "root outputs" (like a loss function), we need to run AD once with respect to this, and we will get the wiggles of all inputs at the same time with respect to this output, since we compute the wiggles output to input. The first example of z = max(x, y) captures the essential difference between the two approached succinctly. Study this, and everything else will make sense.

§ Maximum: z = max(x, y)

  • Forward mode equations:
z=max(x,y)xt=?yt=?zt={xtif x>yytotherwise \begin{aligned} z &= max(x, y) \\ \frac{\partial x}{\partial t} &= ? \\ \frac{\partial y}{\partial t} &= ? \\ \frac{\partial z}{\partial t} &= \begin{cases} \frac{\partial x}{\partial t} & \text{if $x > y$} \\ \frac{\partial y}{\partial t} & \text{otherwise} \\ \end{cases} \end{aligned}
We can compute zx\frac{\partial z}{\partial x} by setting t=xt = x. That is, xt=1,yt=0\frac{\partial x}{\partial t} = 1, \frac{\partial y}{\partial t} = 0. Similarly, can compute zy\frac{\partial z}{\partial y} by setting t=yt = y. That is, xt=1,yt=0\frac{\partial x}{\partial t} = 1, \frac{\partial y}{\partial t} = 0. If we want both gradients zx,zy\frac{\partial z}{\partial x}, \frac{\partial z}{\partial y}, we will have to rerun the above equations twice with the two initializations. In our equations, we are saying that we know how sensitive the inputs x,yx, y are to a given parameter tt. We are deriving how sensitive the output zz is to the parameter tt as a composition of x,yx, y. If x>yx > y, then we know that zz is as sensitive to tt as xx is.
  • Reverse mode equations:
z=max(x,y)tz=?tx={tzifx>y0otherwisety={tzify>x0otherwise \begin{aligned} z &= max(x, y) \\ \frac{\partial t}{\partial z} &= ? \\ \frac{\partial t}{\partial x} &= \begin{cases} \frac{\partial t}{\partial z} & \text{$if x > y$} \\ 0 & \text{otherwise} \end{cases} \\ \frac{\partial t}{\partial y} &= \begin{cases} \frac{\partial t}{\partial z} & \text{$if y > x$} \\ 0 & \text{otherwise} \end{cases} \end{aligned}
We can compute zx,zy\frac{\partial z}{\partial x}, \frac{\partial z}{\partial y} in one shot by setting t=zt = z. That is, zt=1\frac{\partial z}{\partial t} = 1. In our equations, we are saying that we know how sensitive the parameter tt is to a given output zz. We are trying to see how sensitive tt is to the inputs x,yx, y. If xx is active (ie, x>yx > y), then tt is indeed sensitive to xx and tx=1\frac{\partial t}{\partial x} = 1. Otherwise, it is not sensitive, and tx=0\frac{\partial t}{\partial x} = 0.

§ sin: z = sin(x)

  • Forward mode equations:
z=sin(x)xt=?zt=zxxt=cos(x)xt \begin{aligned} z &= sin(x) \\ \frac{\partial x}{\partial t} &= ? \\ \frac{\partial z}{\partial t} &= \frac{\partial z}{\partial x} \frac{\partial x}{\partial t} \\ &= cos(x) \frac{\partial x}{\partial t} \end{aligned}
We can compute zx\frac{\partial z}{\partial x} by setting t=xt = x. That is, setting xt=1\frac{\partial x}{\partial t} = 1.
  • Reverse mode equations:
z=sin(x)tz=?tx=tzzx=tzcos(x) \begin{aligned} z &= sin(x) \\ \frac{\partial t}{\partial z} &= ? \\ \frac{\partial t}{\partial x} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial x} \\ &= \frac{\partial t}{\partial z} cos(x) \end{aligned}
We can compute zx\frac{\partial z}{\partial x} by setting t=zt = z. That is, setting zt=1\frac{\partial z}{\partial t} = 1.

§ addition: z = x + y:

  • Forward mode equations:
z=x+yxt=?yt=?zt=zxxt+zyyt=1xt+1yt=xt+yt \begin{aligned} z &= x + y \\ \frac{\partial x}{\partial t} &= ? \\ \frac{\partial y}{\partial t} &= ? \\ \frac{\partial z}{\partial t} &= \frac{\partial z}{\partial x} \frac{\partial x}{\partial t} + \frac{\partial z}{\partial y} \frac{\partial y}{\partial t} \\ &= 1 \cdot \frac{\partial x}{\partial t} + 1 \cdot \frac{\partial y}{\partial t} = \frac{\partial x}{\partial t} + \frac{\partial y}{\partial t} \end{aligned}
  • Reverse mode equations:
z=x+ytz=?tx=tzzx=tz1=tzty=tzzy=tz1=tz \begin{aligned} z &= x + y \\ \frac{\partial t}{\partial z} &= ? \\ \frac{\partial t}{\partial x} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial x} \\ &= \frac{\partial t}{\partial z} \cdot 1 = \frac{\partial t}{\partial z} \\ \frac{\partial t}{\partial y} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial y} \\ &= \frac{\partial t}{\partial z} \cdot 1 = \frac{\partial t}{\partial z} \end{aligned}

§ multiplication: z = xy

  • Forward mode equations:
z=xyxt=?yt=?zt=zxxt+zyyt=yxt+xyt \begin{aligned} z &= x y \\ \frac{\partial x}{\partial t} &= ? \\ \frac{\partial y}{\partial t} &= ? \\ \frac{\partial z}{\partial t} &= \frac{\partial z}{\partial x} \frac{\partial x}{\partial t} + \frac{\partial z}{\partial y} \frac{\partial y}{\partial t} \\ &= y \frac{\partial x}{\partial t} + x \frac{\partial y}{\partial t} \end{aligned}
  • Reverse mode equations:
z=xytz=?tx=tzzx=tzyty=tzzy=tzx \begin{aligned} z &= x y \\ \frac{\partial t}{\partial z} &= ? \\ \frac{\partial t}{\partial x} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial x} = \frac{\partial t}{\partial z} \cdot y \\ \frac{\partial t}{\partial y} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial y} = \frac{\partial t}{\partial z} \cdot x \end{aligned}

§ subtraction: z = x - y:

  • Forward mode equations:
z=x+yxt=?yt=?zt=zxxtzyyt=1xt1yt=xtyt \begin{aligned} z &= x + y \\ \frac{\partial x}{\partial t} &= ? \\ \frac{\partial y}{\partial t} &= ? \\ \frac{\partial z}{\partial t} &= \frac{\partial z}{\partial x} \frac{\partial x}{\partial t} - \frac{\partial z}{\partial y} \frac{\partial y}{\partial t} \\ &= 1 \cdot \frac{\partial x}{\partial t} - 1 \cdot \frac{\partial y}{\partial t} = \frac{\partial x}{\partial t} - \frac{\partial y}{\partial t} \end{aligned}
  • Reverse mode equations:
z=xytz=?tx=tzzx=tz1=tzty=tzzy=tz1=tz \begin{aligned} z &= x - y \\ \frac{\partial t}{\partial z} &= ? \\ \frac{\partial t}{\partial x} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial x} \\ &= \frac{\partial t}{\partial z} \cdot 1 = \frac{\partial t}{\partial z} \\ \frac{\partial t}{\partial y} &= \frac{\partial t}{\partial z} \frac{\partial z}{\partial y} \\ &= \frac{\partial t}{\partial z} \cdot -1 = -\frac{\partial t}{\partial z} \end{aligned}