/blog @bhaprayan   ·  

Computational graphs for backprop

Disclaimer: This post may appear trivial if you’re already a deep learning expert. Nonetheless, I found this technique to be novel at the time, but couldn’t find a concise, complete exposition, motivated by an example, anywhere else. Hence, I decided to write up what I would’ve liked to have read when I first encountered this material.

Earlier this year, I had to go through the motions of deriving backprop update rules for recurrent neural networks. In the past, I’d derived update rules for networks with simpler architectures but back then I’d preferred to stick to analytically deriving the update rules and then chaining them together. Out of laziness, I didn’t rely on a systematic procedure to compute derivatives and name intermediate variables, which led to an error-prone and unwieldy process when dealing with more complex architectures. This approach was especially troublesome when applied to recurrent networks, primarily due to the sheer number of variables you need to keep track of during the forward pass.

In contrast, there exists the simpler, systematic approach of drawing a computational graph, enumerating the forward update equations, and then following the gradients backward over the graph. This coupled with a function that encapsulates all the derivative operations we need, leads to a much cleaner final implementation and reduces the overall risk of error.

I’ll explain the technique by illustrating how you’d apply it to compute the forward and reverse update rules for the Gated Recurrent Unit (GRU) cell. The intermediate steps, however, transfer easily to any network architecture. This post borrows from many sources which I’ve linked to in the references section.

Forward Pass

A typical GRU cell is expressed by the following equations:

$$\textbf{r}_t = \sigma( \textbf{W}_{rh} \textbf{h}_{t-1} + \textbf{W}_{rx} \textbf{x}_t ) \\$$ $$\textbf{z}_t = \sigma( \textbf{W}_{zh} \textbf{h}_{t-1} + \textbf{W}_{zx} \textbf{x}_t ) \\$$ $$\widetilde{\textbf{h}}_t = tanh(\textbf{W}_h(\textbf{r}_t * \textbf{h}_{t-1}) + \textbf{W}_x \textbf{x}_t) \\$$ $$\textbf{h}_t = (1 - \textbf{z}_t) * \textbf{h}_{t-1} + \textbf{z}_t * \widetilde{\textbf{h}}_t$$

Though it’s possible to differentiate these equations analytically for the backward pass, this quickly turns hairy. An easier way out is to break the equations down to their constituent operations, and draw the corresponding computational graph.

The sequential operations corresponding to the GRU forward pass equations are:

$$z_1 = W_{zh} * h_{t-1} \\$$ $$z_2 = W_{zx} * x_t \\$$ $$z_3 = z_1 + z_2 \\$$ $$z_4 = \sigma(z_3) \longrightarrow z_t \\$$ $$z_5 = W_{rh} * h_{t-1} \\$$ $$z_6 = W_{rx} * x_t \\$$ $$z_7 = z_5 + z_6 \\$$ $$z_8 = \sigma(z_7) \longrightarrow r_t \\$$ $$z_9 = r_t \times h_{t-1} = z_8 \times h_{t-1} \\$$ $$z_{10} = W_h * z_9 \\$$ $$z_{11} = W_x * x_t \\$$ $$z_{12} = z_{10} + z_{11} \\$$ $$z_{13} = tanh(z_{12}) \longrightarrow \widetilde{h}_{t} \\$$ $$z_{14} = 1 - z_t = 1 - z_4 \\$$ $$z_{15} = z_{14} \times h_{t-1} \\$$ $$z_{16} = z_t \times \widetilde{h}_{t} = z_4 \times z_{13} \\$$ $$z_{17} = z_{15} + z_{16}$$

Using this representation, we’re now well positioned to easily visualize the sequence of steps that occur when a GRU cell receives an input signal. This in turn greatly eases the burden on our working memory when deriving the backward pass, since we can visualize the gradient flow.

Backward Pass

The following rules govern how gradients flow at intermediate junctions1, which we’ll use to derive the backward pass:

Split Addition Function Matrix Multiply Hadamard Product
$$ c = a; b = a $$ $$ \delta_a = \delta_b + \delta_c $$ $$ c = a + b $$ $$ \delta_a = \delta_c; \delta_b = \delta_c $$ $$ b = f(a) $$ $$ \delta_a = \delta_b * f’(a) $$ $$b = W a $$ $$ \delta_a = \delta_b * W^T$$ $$ c = a * b $$ $$ \delta_a = \delta_c * b $$ $$ \delta_b = \delta_c * a $$

We can trace the gradient flow backward by starting from the top right (i.e. where the GRU cell receives the gradient value from the next time step) and traversing the graph backward to compute the gradient at each intermediate step until we reach the start of the graph. When we reach the start, we just recurse and use this gradient value to compute the backward pass for the preceding time step.

Function Call

To make things even simpler, we can encapsulate the derivative computation inside a function, which accepts operands and returns the derivative computed according to the intermediate junction rules.

function sigdx(x):
    return x * (1 - x)

function tanhdx(x):
    return 1 - (x**2)

function deriv(dz, x, y, op):
    note: .T is short for transpose
    case op:
        'none': return dx
        #  component wise multiply (i.e. hadamard product)
        '*': return dz * y.T, dz * x.T
        # matrix multiply
        '@': return y @ dz, dz @ x
        # add / sub
        '+': return dz, dz
        '-': return dz, -dz
        'tanh': return dz * tanhdx(x).T
        'sigmoid': return dz * sigdx(x).T * (1 - sigdx).T

For example, the backprop update rules for the equation $$ z_{17} = z_{15} + z_{16} $$ can be computed using

ret = deriv(dz17, z15, z16, "+")
dz15 += ret[0]
dz16 += ret[1]

We now do this for each of the update equations in reverse, till the reach the first equation, at which point we will have computed the backward pass for the entire graph. I’ve linked to the procedure for a single GRU cell, which you can easily extend to the multi-cell case through a loop.

And that’s it! Sticking to this mechanical procedure is a foolproof 2 way to compute gradients, without having to deal with the overhead of keeping track of variable names, derivative chaining, and it’s associated mental overhead. Of course, this is approximately how backpropagation is implemented using autograd packages anyway3, but tracing out these steps is useful for insight.

Questions, feedback, corrections? Reach out!



  1. These are derived from straightforward applications of calculus identities and the chain rule, visualized in graphical form. ↩︎

  2. More so than computing update rules analytically anyway. I don’t think there’s a purely foolproof way to do anything when left to human lapses :) ↩︎

  3. Check out this paper for example, which provides a neat survey of autodiff techniques and their connection to machine learning. ↩︎

Written June 7, 2020. Send feedback to @bhaprayan.

← Q/A with Ted Chiang  Stuff Matters →