In this post, we shed some light on the adjoint state method as used in the famous “Neural ODE” paper [1].
In Section 1, we start by introducing the adjoint state method in its raw form (ODE, loss minimization, adjoint equations), in continuous time (denoted by [C]). If this is already clear to you, then… no need to read the rest of the post! You’re already an adjoint master. Otherwise, please stick with me as we will derive an analogous procedure in discrete time (denoted by [D]) in Section 2, where things are more intuitive.
1. Continuous time: Adjoint method
Consider a system that evolves in continuous time and is in state at time
. The state evolves according to an ordinary differential equation (ODE):
Note that depends on the parameters
, that are typically the weights and biases of a neural network.
A notation remark: denotes the decision variable, while
will denote its current value. Similarly,
is the generic state variable at time
, while
shall denote the solution of the ODE (1) at time
, given a initial condition
.
1.2. Loss minimization
We want to minimize a penalty by tuning parameters
appropriately:
Observe that the loss depends on through
, which in turn depends on
in a complex manner, through the ODE (1).
1.3. Adjoint equations
We now want to compute the gradient of with respect to the parameters
, given the current value
of the parameters and the corresponding solution
of the ODE (1). This will allow us to use gradient descent to minimize (2). To this aim, we first define the adjoint vector
as
which satisfies the following ODE:
where is to be read as the “Jacobian of
with respect to state
evaluated at
”. Then, the gradient of
with respect to the parameters
,
, is given by:
Fix the current value of the parameters. To compute
, via the adjoint state method, we need to
- Solve the ODE (1) forward in time, e.g., with Euler’s method, starting from
to obtain
for all
.
- Solve the adjoint equation (4) backward in time, starting from
, to obtain
for all
.
- Compute the integral (5) to obtain the gradient.
where steps 2 and 3 can be carried out jointly, along a single backward pass.
If little of this is clear to you or you lack the main intuitions, no worries! This post is all about it: we will derive the same procedure in discrete time, where things are more intuitive.
2. Discrete time: Backpropagation
2.1. Discrete-time ODE formulation
To better understand the adjoint method, we construct the discrete-time counterpart of the ODE (1). We will heavily abuse of notation overload, so please bear with me. We start by defining a series of parametric functions such that:
where is the state at step
. Note that this is nothing more that a discretized version of the ODE (1) (the time step is embedded in
). In the ML community, equation (6) defines a residual network (ResNet) architecture. The loss
depends on the final state
. Here is a visualization of the discrete-time process with 3 layers (
):

2.2. Loss minimization
As before, our goal is to minimize the loss with respect to the parameters
via gradient descent. To this aim, we need to compute
.
2.3. Adjoint equations
We directly compute the gradient of with respect to the parameters
. We notice that
appears
times in the state evolution, hence any change in
will propagate towards the loss
along
different paths.

In each path, we can apply the chain rule, as a cascade of two effects: i) the state is affected by the parameter
via the function
and varies as
and ii) the loss
varies as
, which is our adjoint variable
.
We can then sum up the contributions of each one of the paths to obtain
where all derivatives are computed at the current value of the parameters and at the corresponding solution
of (6) for the initial condition
.
We can recognize that (7) is the discrete-time counterpart of (5), that we report here for convenience:
Next we study how the adjoint variables are connected with each other. Recall that measures how the loss changes as a function of the state
. By chain rule, a small change in the state
impacts the subsequent state
via the function
, which in turns affects the loss:

To obtain the analogous of (4), we need to express in terms of the function
, via the recursion (6):
where is the identity matrix. Plugging (9) into (8), we obtain for
:
that is the discrete-time counterpart of the adjoint ODE (4), that we report here for convenience:
2.4. Backpropagation
We now compute the gradient via backpropagation.
First, we need to solve the recursion (6) forward in time starting from the initial condition and for the current value
of the parameters, to obtain the state succession
.
Then, in a single backward pass, we solve recursively the adjoint dynamics [D] (10) to obtain from the initial seed
, while accumulating the loss gradient [D] (7). More precisely,
Backpropagation procedure:
- 0. Set
and
.
- 1. Solve the recursion (6) forward in time to obtain
.
- 2. Compute the final adjoint variable
.
- 3. Initialize the loss gradient
.
- 4. For
:
- 4.1. Compute the adjoint
by solving the adjoint dynamics:
.
- 4.2. Update the loss gradient
.
- 4.1. Compute the adjoint
- 5. Return the loss gradient
.
Once again, we can recognize that the backpropagation method is the discrete-time counterpart of the adjoint method in Section 1.
You’re now an adjoint master 😉
References
[1] Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David K. Duvenaud. “Neural ordinary differential equations.” Advances in neural information processing systems 31 (2018).
Leave a comment