rakshit

Translating RL Math into JAX: What I Learned Working Through 10 Problems

I've been trying to get more serious about reinforcement learning, not just the conceptual side but actually being able to implement things from papers. One thing I kept running into is the gap between reading an equation in a paper and knowing what to do with it in code. You see a summation with some Greek letters and it looks fine, but then you sit down to implement it and it's not immediately obvious how the shapes work out, or which axis you're summing over, or why your broadcast is failing.

So I put together a set of exercises where I take standard RL math and translate it directly into JAX. No for loops, just vectorized operations. This post walks through what I did, what tripped me up, and a cheat sheet at the end that I'll probably keep referencing.

why jax

JAX is basically NumPy but with automatic differentiation and JIT compilation. It's what a lot of the serious RL research code uses (DeepMind, a lot of the academic repos). The functional style also forces you to think clearly about shapes and data flow, which is actually useful when you're translating math.

The one thing that takes getting used to is that JAX arrays are immutable. You can't do Q[i, j] = value. You have to use .at[].set() instead. It feels weird at first but makes sense once you understand why JAX needs to trace computation graphs.

the problems

problem 1: state-value from action-value

V(s)=aAπ(as)Q(s,a)V(s) = \sum_{a \in A} \pi(a|s) Q(s,a)

This is the basic identity that connects V and Q. You have a policy pi of shape (S, A) and Q-values of shape (S, A), and you want to get V of shape (S,).

The translation is a weighted sum along the action axis. einsum handles it cleanly:

V = jnp.einsum('sa,sa->s', pi, Q)

The sa,sa->s notation says: multiply element-wise across both s and a, then sum over a (since a is missing from the output). The s dimension survives.

problem 2: the advantage function

A(s,a)=Q(s,a)V(s)A(s,a) = Q(s,a) - V(s)

Straightforward math, but the shapes are the trap. Q is (S, A) and V is (S,). Subtracting them directly fails because JAX aligns shapes from the right, so it tries to match A with S and they don't align.

The fix is to add a dummy dimension to V:

A = Q - V[:, None]

V[:, None] turns (S,) into (S, 1), and JAX broadcasts the 1 across the A dimension. This pattern comes up a lot.

problem 3: expected sarsa target (batched)

Targeti=Ri+γaπ(aSi)Q(Si,a)\text{Target}_i = R_i + \gamma \sum_{a'} \pi(a'|S_i') Q(S_i', a')

This is like Problem 1 but over a batch of transitions instead of all states. pi_next and Q_next are both (B, A), and you want a (B,) output.

target = R + gamma * jnp.einsum('ba,ba->b', pi_next, Q_next)

Same einsum logic, just with b instead of s as the surviving index.

problem 4: td error

δi=Ri+γV(Si)V(Si)\delta_i = R_i + \gamma V(S_i') - V(S_i)

This one is basically just arithmetic on three (B,) arrays:

td_error = R + (gamma * V_next) - V_curr

Worth noting: in my original draft I had a typo (gama instead of gamma). Easy to miss, and Python won't catch it until runtime.

problem 5: boltzmann (softmax) policy

π(as)=exp(Q(s,a)/τ)aexp(Q(s,a)/τ)\pi(a|s) = \frac{\exp(Q(s,a)/\tau)}{\sum_{a'} \exp(Q(s,a')/\tau)}

Temperature-scaled softmax. You can do it manually:

numerator = jnp.exp(Q / tau)
denominator = jnp.sum(jnp.exp(Q / tau), axis=1, keepdims=True)
pi = numerator / denominator

The keepdims=True is important. Without it, denominator collapses from (S, 1) to (S,), and you're back to the same broadcast problem as Problem 2.

You can also just use jax.nn.softmax(Q / tau, axis=-1), which is shorter and numerically more stable (it handles the log-sum-exp trick under the hood).

problem 6: reinforce loss

J(θ)=1Bi=1BGilogπθ(AiSi)J(\theta) = - \frac{1}{B} \sum_{i=1}^B G_i \log \pi_\theta(A_i|S_i)

Policy gradient. G is the return for each trajectory and log_probs is the log probability of the action the agent actually took.

loss = -jnp.mean(G * log_probs)

The negative sign is because we want to maximize expected return, but optimizers minimize by default. jnp.mean is just the cleaner way to write the 1B\frac{1}{B} \sum part.

problem 7: value iteration with 3-argument reward

Vnew(s)=aπ(as)sP(ss,a)[R(s,a,s)+γV(s)]V_{new}(s) = \sum_a \pi(a|s) \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V(s')]

This one is where einsum really earns its place. You have P and R of shape (S, A, S) and V of shape (S,).

expected_R = jnp.einsum('ijk,ijk->ij', P, R)
expected_future_V = gamma * jnp.einsum('ijk,k->ij', P, V)
Q = expected_R + expected_future_V
V_new = jnp.einsum('ij,ij->i', pi, Q)

The subtle thing in the second line: V is indexed by k, which is the s' (next state) axis, not the i (current state) axis. Getting that wrong gives you a computation that runs fine but calculates the wrong thing, which is the worst kind of bug.

problem 8: ppo clipped objective

LCLIP=1Bi=1Bmin(riAi,clip(ri,1ϵ,1+ϵ)Ai)L_{CLIP} = \frac{1}{B} \sum_{i=1}^B \min(r_i A_i, \text{clip}(r_i, 1-\epsilon, 1+\epsilon) A_i)

The core of PPO. The clip prevents the policy ratio from going too far from 1, which keeps updates stable.

loss = jnp.mean(jnp.minimum(ratio * A, jnp.clip(ratio, 1-epsilon, 1+epsilon) * A))

One thing to be careful about: jnp.min returns the single minimum value of the whole array. jnp.minimum does element-wise comparison between two arrays. You want the latter here.

problem 9: huber loss

L(δ)={12δ2if δ1δ12otherwiseL(\delta) = \begin{cases} \frac{1}{2}\delta^2 & \text{if } |\delta| \le 1 \\ |\delta| - \frac{1}{2} & \text{otherwise} \end{cases}

Piecewise functions in JAX use jnp.where:

mean_loss = jnp.mean(jnp.where(jnp.abs(delta) < 1, 0.5 * delta**2, jnp.abs(delta) - 0.5))

Huber loss is used in DQN because it's less sensitive to large outlier errors than MSE, which matters when TD errors can occasionally be large early in training.

problem 10: batched q-table update

Qnew(Si,Ai)=Q(Si,Ai)+α[Ri+γmaxaQ(Si,a)Q(Si,Ai)]Q_{new}(S_i, A_i) = Q(S_i, A_i) + \alpha [R_i + \gamma \max_{a'} Q(S_i', a') - Q(S_i, A_i)]

This one combines fancy indexing with JAX's immutable update syntax:

current_q = Q[states, actions]
updated_q = current_q + alpha * (rewards + gamma * jnp.max(Q[next_states, :], axis=1) - current_q)
Q = Q.at[states, actions].set(updated_q)

Q[states, actions] where states and actions are (B,) integer arrays does fancy indexing: it grabs B specific entries from the (S, A) table. The .at[].set() then writes back the updates to those same positions.

cheat sheet

einsum rules

Three things to remember:

  1. Repeated index = multiply along that axis
  2. Missing from output = sum over that axis
  3. Present in output = keep that dimension

A common gotcha with 3D tensors: always double-check which axis a 1D vector is supposed to align with. For a transition tensor P of shape (S, A, S'), multiplying by V(s') means aligning with the last axis (k), not the first.

# Wrong: aligns V with current state s
jnp.einsum('ijk,i->ij', P, V)

# Right: aligns V with next state s'
jnp.einsum('ijk,k->ij', P, V)

broadcasting

JAX aligns shapes from the right. When you subtract a (S,) vector from a (S, A) matrix, it tries to match S (the last dim of the matrix) with S (the only dim of the vector), which works fine. But if you want to subtract along the first axis, you need to add a trailing dimension:

# Subtract V(s) from Q(s,a) along state axis
A = Q - V[:, None]   # V becomes (S, 1), broadcasts to (S, A)

For reductions that you plan to divide by, use keepdims=True:

denominator = jnp.sum(jnp.exp(Q / tau), axis=-1, keepdims=True)  # (S, 1), not (S,)

environment shapes vs. batch shapes

STotal number of statesLarge, fixed
ATotal number of actionsFixed
BBatch size (sampled transitions)e.g., 32 or 64

Q is always (S, A). states, actions, rewards from a replay buffer are always (B,). Keep these separate in your head.

immutable array updates

# Overwrite specific indices
Q = Q.at[states, actions].set(updated_q)

# Add to specific indices
Q = Q.at[states, actions].add(alpha * td_error)

Use square brackets inside .at[], not parentheses.

quick reference

jnp.mean(x)
jnp.max(Q, axis=-1)
jnp.clip(x, a, b)
element-wisejnp.minimum(x, y)
Piecewise functionsjnp.where(condition, x, y)
Softmaxjax.nn.softmax(x, axis=-1)

what i'm building toward

The goal with these exercises is to get to a point where reading a paper and implementing it are closer together. Once you have the translation patterns down (einsum for summations, broadcasting rules, .at[] for updates), the actual code tends to be pretty compact. The hard part shifts from "how do I write this" to "am I computing the right thing."

Next up: actually implementing the policy optimization and value optimization algorithms.