July 12, 2024

Everything you need to assemble the DQN Megazord in JAX.

“The Rainbow Megazord”, Dall-E 3

In 2013, the introduction of Deep Q-Networks (DQN) by Mnih et al.[1] marked the first breakthrough in Deep Reinforcement Learning, surpassing expert human players in three Atari games. Over the years, several variants of DQN were published, each improving on specific weaknesses of the original algorithm.

In 2017, Hessel et al.[2] made the best out of the DQN palette by combining 6 of its powerful variants, crafting what could be called the DQN Megazord: Rainbow.

In this article, we’ll break down the individual components that make up Rainbow, while reviewing their JAX implementations in the Stoix library.

DQN

The fundamental building block of Rainbow is DQN, an extension of Q-learning using a neural network with parameters θ to approximate the Q-function (i.e. action-value function). In particular, DQN uses convolutional layers to extract features from images and a linear layer to produce a scalar estimate of the Q-value.

During training, the network parameterized by θ, referred to as the “online network” is used to select actions while the “target network” parameterized by θ- is a delayed copy of the online network used to provide stable targets. This way, the targets are not dependent on the parameters being updated.
Additionally, DQN uses a replay buffer D to sample past transitions (observations, reward, and done flag tuples) to train on at fixed intervals.

At each iteration i, DQN samples a transition j and takes a gradient step on the following loss:

DQN loss function, all images are made by the author, unless specified otherwise

This loss aims at minimizing the expectation of the squared temporal-difference (TD) error.

Note that DQN is an off-policy algorithm because it learns the optimal policy defined by the maximum Q-value term while following a different behavior policy, such as an epsilon-greedy policy.

Here’s the DQN algorithm in detail:

DQN algorithm

DQN in practice

As mentioned above, we’ll reference code snippets from the Stoix library to illustrate the core parts of DQN and Rainbow (some of the code was slightly edited or commented for pedagogical purposes).

Let’s start with the neural network: Stoix lets us break down our model architecture into a pre-processor and a post-processor, referred to as torso and head respectively. In the case of DQN, the torso would be a multi-layer perceptron (MLP) or convolutional neural network (CNN) and the head an epsilon greedy policy, both implemented as Flax modules:

https://medium.com/media/7b6c514c0bc0bfd45d845f5527f78421/href

Additionally, DQN uses the following loss (note that Stoix follows the Rlax naming conventions, therefore tm1 is equivalent to timestep t in the above equations, while t refers to timestep t+1):

https://medium.com/media/1be20d13144ef4888865ff6d20fa80e8/href

The Rainbow blueprint

Now that we have laid the foundations for DQN, we’ll review each part of the algorithm in more detail, while identifying potential weaknesses and how they are addressed by Rainbow.
In particular, we’ll cover:

Double DQN and the overestimation biasDueling DQN and the state-value / advantage predictionDistributional DQN and the return distributionMulti-step learningNoisy DQN and flexible exploration strategiesPrioritized Experience Replay and learning potentialThe Rainbow Blueprint, Dall-E 3

Double DQN

Source: Deep Reinforcement Learning with Double Q-learning [3]Improvement: Reduced overestimation bias

The overestimation bias

One issue with the loss function used in vanilla DQN arises from the Q-target. Remember that we define the target as:

Objective in the DQN loss

This objective may lead to an overestimation bias. Indeed, as DQN uses bootstrapping (learning estimates from estimates), the max term may select overestimated values to update the Q-function, leading to overestimated Q-values.

As an example, consider the following figure:

The Q-values predicted by the network are represented in blue.The true Q-values are represented in purple.The gap between the predictions and true values is represented by red arrows.

In this case, action 0 has the highest predicted Q-value because of a large prediction error. This value will therefore be used to construct the target.
However, the action with the highest true value is action 2. This illustration shows how the max term in the target favors large positive estimation errors, inducing an overestimation bias.

Illustration of the overestimation bias.

Decoupling action selection and evaluation

To solve this problem, Hasselt et al. (2015)[3] propose a new target where the action is selected by the online network, while its value is estimated by the target network:

The Double DQN target

By decoupling action selection and evaluation, the estimation bias is significantly reduced, leading to better value estimates and improved performance.

Double DQN provides stable and accurate value estimates, leading to improved performance. Source: Hasselt et al. (2015), Figure 3

Double DQN in practice

As expected, implementing Double DQN only requires us to modify the loss function:

https://medium.com/media/f214383ec1cd89468af2ec6fc067ca39/href

Dueling DQN

Source: Dueling Network Architectures for Deep Reinforcement LearningImprovement: Separation of the value and advantage computation

State value, Q-value, and advantage

In RL, we use several functions to estimate the value of a given state, action, or sequence of actions from a given state:

State-value V(s): The state value corresponds to the expected return when starting in a given state s and following a policy π thereafter.Q-value Q(s, a): Similarly, the Q-value corresponds to the expected return when starting in a given state s, taking action a, and following a policy π thereafter.Advantage A(s, a): The advantage is defined as the difference between the Q-value and the state-value in a given state s for an action a. It represents the inherent value of action a in the current state.

The following figure attempts to represent the differences between these value functions on a backup diagram (note that the state value is weighted by the probability of taking each action under policy π).

Visualization of the state value (in purple), state-action value (Q-function, in blue), and the advantage (in pink) on a backup diagram.

Usually, DQN estimates the Q-value directly, using a feed-forward neural network. This implies that DQN has to learn the Q-values for each action in each state independently.

The dueling architecture

Introduced by Wang et al.[4] in 2016, Dueling DQN uses a neural network with two separate streams of computation:

The state value stream predicts the scalar value of a given state.The advantage stream predicts to predict the advantage of each action for a given state.

This decoupling enables the independent estimation of the state value and advantages, which has several benefits. For instance, the network can learn state values without having to update the action values regularly. Additionally, it can better generalize to unseen actions in familiar states.
These improvements lead to stabler and faster convergence, especially in environments with many similar-valued actions.

In practice, a dueling network uses a common representation (i.e. a shared linear or convolutional layer) parameterized by parameters θ before splitting into two streams, consisting of linear layers with parameters α and β respectively. The state value stream outputs a scalar value while the advantage stream returns a scalar value for each available action.
Adding the outputs of the two streams allows us to reconstruct the Q-value for each action as Q(s, a) = V(s) + A(s, a).

An important detail is that the mean is usually subtracted from the advantages. Indeed, the advantages need to have zero mean, otherwise, it would be impossible to decompose Q into V and A, making the problem ill-defined. With this constraint, V represents the value of the state while A represents how much better or worse each action is compared to the average action in that state.

Illustration of a dueling network

Dueling Network in practice

Here’s the Stoix implementation of a Q-network:

https://medium.com/media/91e7ff3ec86d5f10983adbe1061653ef/href

Distributional DQN

Source: A distributional perspective on Reinforcement Learning[5]Improvement: Richer value estimates

The return distribution

Most RL systems model the expectation of the return, however, a promising body of literature approaches RL from a distributional perspective. In this setting, the goal becomes to model the return distribution, which allows us to consider other statistics than the mean.
In 2017, Bellemare et al.[5] published a distributional version of DQN called C51 predicting the return distribution for each action, reaching new state-of-the-art performances on the Atari benchmark.

Illustrated comparison between DQN and C51. Source [5′]

Let’s take a step back and review the theory behind C51.
In traditional RL, we evaluate a policy using the Bellman Equation, which allows us to define the Q-function in a recursive form. Alternatively, we can use a distributional version of the Bellman equation, which accounts for randomness in the returns:

Standard and Distributional versions of the Bellman Equation

Here, ρ is the transition function.
The main difference between those functions is that Q is a numerical value, summing expectations over random variables. In contrast, Z is a random variable, summing the reward distribution and the discounted distribution of future returns.

The following illustration helps visualize how to derive Z from the distributional Bellman equation:

Consider the distribution of returns Z at a given timestep and the transition operator Pπ. PπZ is the distribution of future returns Z(s’, a’).Multiplying this by the discount factor γ contracts the distribution towards 0 (as γ is less than 1).Adding the reward distribution shifts the previous distribution by a set amount (Note that the figure assumes a constant reward for simplicity. In practice, adding the reward distribution would shift but also modify the discounted return).Finally, the distribution is projected on a discrete support using an L2 projection operator Φ.Illustration of the distributional Bellman equation. Source: [5]

This fixed support is a vector of N atoms separated by a constant gap within a set interval:

Definition of the discrete support z

At inference time, the Q-network returns an approximating distribution dt defined on this support with the probability mass pθ(st, at) on each atom i such that:

Predicted return distribution

The goal is to update θ such that the distribution closely matches the true distribution of returns. To learn the probability masses, the target distribution is built using a distributional variant of Bellman’s optimality equation:

Target return distribution

To be able to compare the distribution predicted by our neural network and the target distribution, we need to discretize the target distribution and project it on the same support z.

To this end, we use an L2 projection (a projection onto z such that the difference between the original and projected distribution is minimized in terms of the L2 norm):

L2 projection of the target distribution

Finally, we need to define a loss function that minimizes the difference between the two distributions. As we’re dealing with distributions, we can’t simply subtract the prediction from the target, as we did previously.

Instead, we minimize the Kullback-Leibler divergence between dt and d’t (in practice, this is implemented as a cross-entropy loss):

KL divergence between the projected target and the predicted return distribution

For a more exhaustive description of Distributional DQN, you can refer to Massimiliano Tomassoli’s article[8] as well as Pascal Poupart’s video on the topic[11].

C51 in practice

The key components of C51 in Stoix are the Distributional head and the categorical loss, which uses double Q-learning by default as introduced previously. The choice of defining the C51 network as a head lets us use an MLP or a CNN torso interchangeably depending on the use case.

https://medium.com/media/743444ea903453114bdf41f2b6a6dfa3/href

Noisy DQN

Source: Noisy Networks for Exploration[6]Improvement: Learnable and state-dependent exploration mechanism

Noisy parameterization of Neural Networks

As many off-policy algorithms, DQN relies on an epsilon-greedy policy as its main exploration mechanism. Therefore, the algorithm will behave greedily with respect to the Q-values most of the time and select random actions with a predefined probability.

Fortunato et al.[6] introduce NoisyNets as a more flexible alternative. NoisyNets are neural networks whose weights and biases are perturbed by a parametric function of Gaussian noise. Similarly to an epsilon-greedy policy, such noise injects randomness in the agent’s action selection, thus encouraging exploration.

However, this noise is scaled and offset by learned parameters, allowing the level of noise to be adapted state-by-state. This way, the balance between exploration and exploitation is optimized dynamically during training. Eventually, the network may learn to ignore the noise, but will do so at different rates in different parts of the state space, leading to more flexible exploration.

A network parameterized by a vector of noisy parameters is defined as follows:

Neural Network parameterized by Noisy parameters

Therefore, a linear layer y = wx + b becomes:

Noisy linear layer

For performance, the noise is generated at inference time using Factorized Gaussian Noise. For a linear layer with M inputs and N outputs, a noise matrix of shape (M x N) is generated as a combination of two noise vectors with size M and N. This methods reduces the number of required random variables from M x N to M + N.
The noise matrix is defined as the outer product of the noise vectors, each scaled by a function f:

Noise generation using Factorised Gaussian Noise

Improved exploration

The improved exploration induced by noisy networks allow a wide range of algorithms, such as DQN, Dueling DQN and A3C to benefit from improved performances with a reasonably low amount of extra parameters.

NoisyNets improve the performance of several algorithms on the Atari benchmark. Source: [6]

Noisy DQN in practice

In Stoix, we implement a noisy layer as follows:

https://medium.com/media/abf0db362f8e8592d6bc71ab78ef1a61/href

Note: All the linear layers in Rainbow are replaced with their noisy equivalent (see the “Assembling Rainbow” section for more details).

Prioritized Experience Replay

Source: Prioritized Experience Replay[7]
Improvement: Prioritization of experiences with higher learning potential

Estimating the Learning Potential

After taking an environment step, vanilla DQN uniformly samples a batch of experiences (also called transitions) from a replay buffer and performs a gradient descent step on this batch. Although this approach produces satisfying results, some specific experiences might be more valuable from a learning perspective than others. Therefore, we could potentially speed up the training process by sampling such experiences more often.

This is precisely the idea explored in the Prioritized Experience Replay (PER) paper published by Schaul et al.[7] in 2016. However, the main question remains: how to approximate the expected learning potential of a transition?

One idealized criterion would be the amount the RL agent can learn from a transition in its current state (expected learning progress). While this measure is not directly accessible, a reasonable proxy is the magnitude of a transition’s TD error δ, which indicates how ‘surprising’ or unexpected the transition is: specifically, how far the value is from its next-step bootstrap estimate (Andre et al., 1998).
Prioritized Experience Replay, Schaul et al. (2016)

As a reminder, the TD error is defined as follows:

The temporal-difference error

This metric is a decent estimate of the learning potential of a specific transition, as a high TD error indicates a large difference between the predicted and actual outcomes, meaning that the agent would benefit from updating its beliefs.

However, it is worth noting that alternative prioritization metrics are still being studied. For instance, Lahire et al.[9] (2022) argue that the optimal sampling scheme is distributed according to the per-sample gradient norms:

Per-sample gradient norms

However, let’s continue with the TD error, as Rainbow uses this metric.

Deriving Sampling Probabilities

Once we have selected the prioritization criterion, we can derive the probabilities of sampling each transition from it. In Prioritized Experience Replay, two alternatives are showcased:

Proportional: Here the probability of replaying a transition is equal to the absolute value of the associated TD error. A small positive constant is added to prevent transitions not being revisited once their error is zero.Rank-based: In this mode, transitions are ranked in descending order according to their absolute TD error, and their probability is defined based on their rank. This option is supposed to be more robust as it is insensible to outliers.

The sampling probabilities are then normalized and raised to the power α, a hyperparameter determining the degree of prioritization (α=0 is the uniform case).

Prioritization modes and probability normalization

Importance sampling and bias annealing

In RL, the estimation of the expected value of the return relies on the assumption that the updates correspond to the same distribution as the expectation (i.e., the uniform distribution). However, PER introduces bias as we now sample experiences according to their TD error.

To rectify this bias, we use importance sampling, a statistical method used to estimate the properties of a distribution while sampling from a different distribution. Importance sampling re-weights samples so that the estimates remain unbiased and accurate.
Typically, the correcting weights are defined as the ratio of the two probabilities:

Importance sampling ratio

In this case, the target distribution is the uniform distribution, where every transition has a probability of being sampled equal to 1/N, with N being the size of the replay buffer.
Therefore, the importance sampling coefficient in the context of PER is defined by:

Importance sampling weight used in PER

With β a coefficient adjusting the amount of bias correction (the bias is fully corrected for β=1). Finally, the weights are normalized for stability:

Normalization of the importance sampling weights

To summarize, here’s the full algorithm for Prioritized Experience Replay (the update and training steps are identical to DQN):

The Prioritized Experience Replay algorithm

Increased convergence speed with PER

The following plots highlight the performance benefits of PER. Indeed, the proportional and rank-based prioritization mechanisms enable DQN to reach the same baseline performances roughly twice as fast on the Atari benchmark.

Normalized maximum and average scores (in terms of Double DQN performance) on 57 Atari games. Source:[7]

Prioritized Experience Replay in practice

Stoix seamlessly integrates the Flashbax library which provides a variety of replay buffers. Here are the relevant code snippets used to instantiate the replay buffer, compute the sampling probabilities from the TD error, and update the buffer’s priorities:

https://medium.com/media/06179b5ff16f3056c2599ecc457e1d8d/href

Multi-step Learning

Source: Reinforcement Learning: an Introduction, chapter 7Improvement: Enhanced reward signal and sample efficiency, reduced variance

Multi-step learning is an improvement on traditional one-step temporal difference learning which allows us to consider the return over n steps when building our targets. For instance, instead of considering the reward at the next timestep, we’ll consider the n-step truncated rewards (see the below equation). This process has several advantages, among which:

Immediate feedback: considering a larger time horizon allows the agent to learn the value of state-action pairs much faster, especially in environments where rewards are delayed and specific actions might not pay out immediately.Sample efficiency: Each update in multi-step learning incorporates information from multiple time steps, making each sample more informative. This improves sample efficiency, meaning the agent can learn more from fewer experiences.Balancing Bias and Variance: Multi-step methods offer a trade-off between bias and variance. One-step methods have low bias but high variance, while multi-step methods have higher bias but lower variance. By tuning the number of steps, one can find a balance that works best for the given environment.

The multi-step distributional loss used in Rainbow is defined as:

Multi-step target return distribution

In practice, using n-step returns implies a few adjustments to our code:

We now sample trajectories of n experiences, instead of individual experiencesThe reward is replaced with the n-step discounted returnsThe done flag is set to True if any of the n done flag is TrueThe next state s(t+1) is replaced by the last observation of the trajectory s(t+n)

Multi-Step learning in practice

Finally, we can reuse the categorical loss function used in C51 with these updated inputs:

https://medium.com/media/f667c9dc3fbefd3824f58ccb43677608/href

Assembling Rainbow

Congratulations on making it this far! We now have a better understanding of all the moving pieces that constitute Rainbow. Here’s a summary of the Rainbow agent:

Neural Network Architecture:
 — Torso: A convolutional neural network (CNN) or multi-layer perceptron (MLP) base that creates embeddings for the head network.
 — Head: Combines Dueling DQN and C51. The value stream outputs the state value distribution over atoms, while the advantage stream outputs the advantage distribution over actions and atoms. These streams are aggregated, and Q-values are computed as the weighted sum of atom values and their respective probabilities. An action is selected using an epsilon-greedy policy.
 — Noisy Layers: All linear layers are replaced with their noisy equivalents to aid in exploration.Loss Function: Uses a distributional loss modeling the n-step returns, where targets are computed using Double Q-learning.Replay Buffer: Employs a prioritization mechanism based on the TD error to improve learning efficiency.

Here’s the network used for the Rainbow head:

https://medium.com/media/548655defb33bdb46b76dccff1ec4c25/href

Performances and ablations

To conclude this article, let’s take a closer look at Rainbow’s performances on the Atari benchmark, as well as the ablation study.
The following figure compares Rainbow with the other DQN baselines we studied. The measured metric is the median human-normalized score. In other words, the median human performance on Atari games is set to 100%, which enables us to quickly spot algorithms achieving a human level.

Three of the DQN baselines reach this level after 200 million frames:

Distributional DQNDueling DQNPrioritized Double DQN

Interestingly, Rainbow reaches the same level after only 44 million frames, making it roughly 5 times more sample efficient than the best baselines. At the end of training, it exceeds 200% of the median human-normalized score.

Median human-normalized performance across 57 Atari games. Each line represents a DQN baseline. Source: [2]

This second figure represents the ablation study, which represents the performances of Rainbow without one of its components. These results allow us to make several observations:

The three most crucial components of Rainbow are the distributional head, the use of multi-step learning, and the prioritization of the replay buffer.Noisy layers contribute significantly to the overall performance. Using standard layers with an epsilon-greedy policy doesn’t allow the agent to reach the 200% score in 200 million frames.Despite achieving strong performances on their own, the dueling structure and double Q-learning only provide marginal improvements in the context of Rainbow.Median human-normalized performance across 57 Atari games. Each line represents an ablation of Rainbow. Source: [2]

Thank you very much for reading this article, I hope it provided you with a comprehensive introduction to Rainbow and its components. I highly advise reading through the Stoix implementation of Rainbow for a more detailed description of the training process and the Rainbow architecture.

Until next time 👋

Bibliography

[1] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., & Riedmiller, M. (2013). Playing Atari with Deep Reinforcement Learning, arXiv
[2] Hessel, M., Modayil, J., van Hasselt, H., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M., & Silver, D. (2017). Rainbow: Combining Improvements in Deep Reinforcement Learning, arXiv.
[3] van Hasselt, H., Guez, A., & Silver, D. (2015). Deep Reinforcement Learning with Double Q-learning, arXiv.
[4] Wang, Z., Schaul, T., Hessel, M., van Hasselt, H., Lanctot, M., & de Freitas, N. (2016). Dueling Network Architectures for Deep Reinforcement Learning (No. arXiv:1511.06581), arXiv
[5] Bellemare, M. G., Dabney, W., & Munos, R. (2017). A Distributional Perspective on Reinforcement Learning, arXiv
[5′] Dabney, W., Ostrovski, G., Silver, D., & Munos, R. (2018). Implicit Quantile Networks for Distributional Reinforcement Learning, arXiv
[6] Fortunato, M., Azar, M. G., Piot, B., Menick, J., Osband, I., Graves, A., Mnih, V., Munos, R., Hassabis, D., Pietquin, O., Blundell, C., & Legg, S. (2019). Noisy Networks for Exploration, arXiv.
[7] Schaul, T., Quan, J., Antonoglou, I., & Silver, D. (2016). Prioritized Experience Replay, arXiv

Additional resources

[8] Massimiliano Tomassoli, Distributional RL: An intuitive explanation of Distributional RL
[9] Lahire, T., Geist, M., & Rachelson, E. (2022). Large Batch Experience Replay, arXiv.
[10] Sutton, R. S., & Barto, A. G. (1998). Reinforcement Learning: An Introduction.
[11] Pascal Poupart, CS885 Module 5: Distributional RLYouTube

Rainbow: The Colorful Evolution of Deep Q-Networks was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.

​Everything you need to assemble the DQN Megazord in JAX.“The Rainbow Megazord”, Dall-E 3In 2013, the introduction of Deep Q-Networks (DQN) by Mnih et al.[1] marked the first breakthrough in Deep Reinforcement Learning, surpassing expert human players in three Atari games. Over the years, several variants of DQN were published, each improving on specific weaknesses of the original algorithm.In 2017, Hessel et al.[2] made the best out of the DQN palette by combining 6 of its powerful variants, crafting what could be called the DQN Megazord: Rainbow.In this article, we’ll break down the individual components that make up Rainbow, while reviewing their JAX implementations in the Stoix library.DQNThe fundamental building block of Rainbow is DQN, an extension of Q-learning using a neural network with parameters θ to approximate the Q-function (i.e. action-value function). In particular, DQN uses convolutional layers to extract features from images and a linear layer to produce a scalar estimate of the Q-value.During training, the network parameterized by θ, referred to as the “online network” is used to select actions while the “target network” parameterized by θ- is a delayed copy of the online network used to provide stable targets. This way, the targets are not dependent on the parameters being updated.Additionally, DQN uses a replay buffer D to sample past transitions (observations, reward, and done flag tuples) to train on at fixed intervals.At each iteration i, DQN samples a transition j and takes a gradient step on the following loss:DQN loss function, all images are made by the author, unless specified otherwiseThis loss aims at minimizing the expectation of the squared temporal-difference (TD) error.Note that DQN is an off-policy algorithm because it learns the optimal policy defined by the maximum Q-value term while following a different behavior policy, such as an epsilon-greedy policy.Here’s the DQN algorithm in detail:DQN algorithmDQN in practiceAs mentioned above, we’ll reference code snippets from the Stoix library to illustrate the core parts of DQN and Rainbow (some of the code was slightly edited or commented for pedagogical purposes).Let’s start with the neural network: Stoix lets us break down our model architecture into a pre-processor and a post-processor, referred to as torso and head respectively. In the case of DQN, the torso would be a multi-layer perceptron (MLP) or convolutional neural network (CNN) and the head an epsilon greedy policy, both implemented as Flax modules:https://medium.com/media/7b6c514c0bc0bfd45d845f5527f78421/hrefAdditionally, DQN uses the following loss (note that Stoix follows the Rlax naming conventions, therefore tm1 is equivalent to timestep t in the above equations, while t refers to timestep t+1):https://medium.com/media/1be20d13144ef4888865ff6d20fa80e8/hrefThe Rainbow blueprintNow that we have laid the foundations for DQN, we’ll review each part of the algorithm in more detail, while identifying potential weaknesses and how they are addressed by Rainbow.In particular, we’ll cover:Double DQN and the overestimation biasDueling DQN and the state-value / advantage predictionDistributional DQN and the return distributionMulti-step learningNoisy DQN and flexible exploration strategiesPrioritized Experience Replay and learning potentialThe Rainbow Blueprint, Dall-E 3Double DQNSource: Deep Reinforcement Learning with Double Q-learning [3]Improvement: Reduced overestimation biasThe overestimation biasOne issue with the loss function used in vanilla DQN arises from the Q-target. Remember that we define the target as:Objective in the DQN lossThis objective may lead to an overestimation bias. Indeed, as DQN uses bootstrapping (learning estimates from estimates), the max term may select overestimated values to update the Q-function, leading to overestimated Q-values.As an example, consider the following figure:The Q-values predicted by the network are represented in blue.The true Q-values are represented in purple.The gap between the predictions and true values is represented by red arrows.In this case, action 0 has the highest predicted Q-value because of a large prediction error. This value will therefore be used to construct the target. However, the action with the highest true value is action 2. This illustration shows how the max term in the target favors large positive estimation errors, inducing an overestimation bias.Illustration of the overestimation bias.Decoupling action selection and evaluationTo solve this problem, Hasselt et al. (2015)[3] propose a new target where the action is selected by the online network, while its value is estimated by the target network:The Double DQN targetBy decoupling action selection and evaluation, the estimation bias is significantly reduced, leading to better value estimates and improved performance.Double DQN provides stable and accurate value estimates, leading to improved performance. Source: Hasselt et al. (2015), Figure 3Double DQN in practiceAs expected, implementing Double DQN only requires us to modify the loss function:https://medium.com/media/f214383ec1cd89468af2ec6fc067ca39/hrefDueling DQNSource: Dueling Network Architectures for Deep Reinforcement LearningImprovement: Separation of the value and advantage computationState value, Q-value, and advantageIn RL, we use several functions to estimate the value of a given state, action, or sequence of actions from a given state:State-value V(s): The state value corresponds to the expected return when starting in a given state s and following a policy π thereafter.Q-value Q(s, a): Similarly, the Q-value corresponds to the expected return when starting in a given state s, taking action a, and following a policy π thereafter.Advantage A(s, a): The advantage is defined as the difference between the Q-value and the state-value in a given state s for an action a. It represents the inherent value of action a in the current state.The following figure attempts to represent the differences between these value functions on a backup diagram (note that the state value is weighted by the probability of taking each action under policy π).Visualization of the state value (in purple), state-action value (Q-function, in blue), and the advantage (in pink) on a backup diagram.Usually, DQN estimates the Q-value directly, using a feed-forward neural network. This implies that DQN has to learn the Q-values for each action in each state independently.The dueling architectureIntroduced by Wang et al.[4] in 2016, Dueling DQN uses a neural network with two separate streams of computation:The state value stream predicts the scalar value of a given state.The advantage stream predicts to predict the advantage of each action for a given state.This decoupling enables the independent estimation of the state value and advantages, which has several benefits. For instance, the network can learn state values without having to update the action values regularly. Additionally, it can better generalize to unseen actions in familiar states.These improvements lead to stabler and faster convergence, especially in environments with many similar-valued actions.In practice, a dueling network uses a common representation (i.e. a shared linear or convolutional layer) parameterized by parameters θ before splitting into two streams, consisting of linear layers with parameters α and β respectively. The state value stream outputs a scalar value while the advantage stream returns a scalar value for each available action. Adding the outputs of the two streams allows us to reconstruct the Q-value for each action as Q(s, a) = V(s) + A(s, a).An important detail is that the mean is usually subtracted from the advantages. Indeed, the advantages need to have zero mean, otherwise, it would be impossible to decompose Q into V and A, making the problem ill-defined. With this constraint, V represents the value of the state while A represents how much better or worse each action is compared to the average action in that state.Illustration of a dueling networkDueling Network in practiceHere’s the Stoix implementation of a Q-network:https://medium.com/media/91e7ff3ec86d5f10983adbe1061653ef/hrefDistributional DQNSource: A distributional perspective on Reinforcement Learning[5]Improvement: Richer value estimatesThe return distributionMost RL systems model the expectation of the return, however, a promising body of literature approaches RL from a distributional perspective. In this setting, the goal becomes to model the return distribution, which allows us to consider other statistics than the mean.In 2017, Bellemare et al.[5] published a distributional version of DQN called C51 predicting the return distribution for each action, reaching new state-of-the-art performances on the Atari benchmark.Illustrated comparison between DQN and C51. Source [5′]Let’s take a step back and review the theory behind C51.In traditional RL, we evaluate a policy using the Bellman Equation, which allows us to define the Q-function in a recursive form. Alternatively, we can use a distributional version of the Bellman equation, which accounts for randomness in the returns:Standard and Distributional versions of the Bellman EquationHere, ρ is the transition function.The main difference between those functions is that Q is a numerical value, summing expectations over random variables. In contrast, Z is a random variable, summing the reward distribution and the discounted distribution of future returns.The following illustration helps visualize how to derive Z from the distributional Bellman equation:Consider the distribution of returns Z at a given timestep and the transition operator Pπ. PπZ is the distribution of future returns Z(s’, a’).Multiplying this by the discount factor γ contracts the distribution towards 0 (as γ is less than 1).Adding the reward distribution shifts the previous distribution by a set amount (Note that the figure assumes a constant reward for simplicity. In practice, adding the reward distribution would shift but also modify the discounted return).Finally, the distribution is projected on a discrete support using an L2 projection operator Φ.Illustration of the distributional Bellman equation. Source: [5]This fixed support is a vector of N atoms separated by a constant gap within a set interval:Definition of the discrete support zAt inference time, the Q-network returns an approximating distribution dt defined on this support with the probability mass pθ(st, at) on each atom i such that:Predicted return distributionThe goal is to update θ such that the distribution closely matches the true distribution of returns. To learn the probability masses, the target distribution is built using a distributional variant of Bellman’s optimality equation:Target return distributionTo be able to compare the distribution predicted by our neural network and the target distribution, we need to discretize the target distribution and project it on the same support z.To this end, we use an L2 projection (a projection onto z such that the difference between the original and projected distribution is minimized in terms of the L2 norm):L2 projection of the target distributionFinally, we need to define a loss function that minimizes the difference between the two distributions. As we’re dealing with distributions, we can’t simply subtract the prediction from the target, as we did previously.Instead, we minimize the Kullback-Leibler divergence between dt and d’t (in practice, this is implemented as a cross-entropy loss):KL divergence between the projected target and the predicted return distributionFor a more exhaustive description of Distributional DQN, you can refer to Massimiliano Tomassoli’s article[8] as well as Pascal Poupart’s video on the topic[11].C51 in practiceThe key components of C51 in Stoix are the Distributional head and the categorical loss, which uses double Q-learning by default as introduced previously. The choice of defining the C51 network as a head lets us use an MLP or a CNN torso interchangeably depending on the use case.https://medium.com/media/743444ea903453114bdf41f2b6a6dfa3/hrefNoisy DQNSource: Noisy Networks for Exploration[6]Improvement: Learnable and state-dependent exploration mechanismNoisy parameterization of Neural NetworksAs many off-policy algorithms, DQN relies on an epsilon-greedy policy as its main exploration mechanism. Therefore, the algorithm will behave greedily with respect to the Q-values most of the time and select random actions with a predefined probability.Fortunato et al.[6] introduce NoisyNets as a more flexible alternative. NoisyNets are neural networks whose weights and biases are perturbed by a parametric function of Gaussian noise. Similarly to an epsilon-greedy policy, such noise injects randomness in the agent’s action selection, thus encouraging exploration.However, this noise is scaled and offset by learned parameters, allowing the level of noise to be adapted state-by-state. This way, the balance between exploration and exploitation is optimized dynamically during training. Eventually, the network may learn to ignore the noise, but will do so at different rates in different parts of the state space, leading to more flexible exploration.A network parameterized by a vector of noisy parameters is defined as follows:Neural Network parameterized by Noisy parametersTherefore, a linear layer y = wx + b becomes:Noisy linear layerFor performance, the noise is generated at inference time using Factorized Gaussian Noise. For a linear layer with M inputs and N outputs, a noise matrix of shape (M x N) is generated as a combination of two noise vectors with size M and N. This methods reduces the number of required random variables from M x N to M + N.The noise matrix is defined as the outer product of the noise vectors, each scaled by a function f:Noise generation using Factorised Gaussian NoiseImproved explorationThe improved exploration induced by noisy networks allow a wide range of algorithms, such as DQN, Dueling DQN and A3C to benefit from improved performances with a reasonably low amount of extra parameters.NoisyNets improve the performance of several algorithms on the Atari benchmark. Source: [6]Noisy DQN in practiceIn Stoix, we implement a noisy layer as follows:https://medium.com/media/abf0db362f8e8592d6bc71ab78ef1a61/hrefNote: All the linear layers in Rainbow are replaced with their noisy equivalent (see the “Assembling Rainbow” section for more details).Prioritized Experience ReplaySource: Prioritized Experience Replay[7]Improvement: Prioritization of experiences with higher learning potentialEstimating the Learning PotentialAfter taking an environment step, vanilla DQN uniformly samples a batch of experiences (also called transitions) from a replay buffer and performs a gradient descent step on this batch. Although this approach produces satisfying results, some specific experiences might be more valuable from a learning perspective than others. Therefore, we could potentially speed up the training process by sampling such experiences more often.This is precisely the idea explored in the Prioritized Experience Replay (PER) paper published by Schaul et al.[7] in 2016. However, the main question remains: how to approximate the expected learning potential of a transition?One idealized criterion would be the amount the RL agent can learn from a transition in its current state (expected learning progress). While this measure is not directly accessible, a reasonable proxy is the magnitude of a transition’s TD error δ, which indicates how ‘surprising’ or unexpected the transition is: specifically, how far the value is from its next-step bootstrap estimate (Andre et al., 1998).Prioritized Experience Replay, Schaul et al. (2016)As a reminder, the TD error is defined as follows:The temporal-difference errorThis metric is a decent estimate of the learning potential of a specific transition, as a high TD error indicates a large difference between the predicted and actual outcomes, meaning that the agent would benefit from updating its beliefs.However, it is worth noting that alternative prioritization metrics are still being studied. For instance, Lahire et al.[9] (2022) argue that the optimal sampling scheme is distributed according to the per-sample gradient norms:Per-sample gradient normsHowever, let’s continue with the TD error, as Rainbow uses this metric.Deriving Sampling ProbabilitiesOnce we have selected the prioritization criterion, we can derive the probabilities of sampling each transition from it. In Prioritized Experience Replay, two alternatives are showcased:Proportional: Here the probability of replaying a transition is equal to the absolute value of the associated TD error. A small positive constant is added to prevent transitions not being revisited once their error is zero.Rank-based: In this mode, transitions are ranked in descending order according to their absolute TD error, and their probability is defined based on their rank. This option is supposed to be more robust as it is insensible to outliers.The sampling probabilities are then normalized and raised to the power α, a hyperparameter determining the degree of prioritization (α=0 is the uniform case).Prioritization modes and probability normalizationImportance sampling and bias annealingIn RL, the estimation of the expected value of the return relies on the assumption that the updates correspond to the same distribution as the expectation (i.e., the uniform distribution). However, PER introduces bias as we now sample experiences according to their TD error.To rectify this bias, we use importance sampling, a statistical method used to estimate the properties of a distribution while sampling from a different distribution. Importance sampling re-weights samples so that the estimates remain unbiased and accurate.Typically, the correcting weights are defined as the ratio of the two probabilities:Importance sampling ratioIn this case, the target distribution is the uniform distribution, where every transition has a probability of being sampled equal to 1/N, with N being the size of the replay buffer. Therefore, the importance sampling coefficient in the context of PER is defined by:Importance sampling weight used in PERWith β a coefficient adjusting the amount of bias correction (the bias is fully corrected for β=1). Finally, the weights are normalized for stability:Normalization of the importance sampling weightsTo summarize, here’s the full algorithm for Prioritized Experience Replay (the update and training steps are identical to DQN):The Prioritized Experience Replay algorithmIncreased convergence speed with PERThe following plots highlight the performance benefits of PER. Indeed, the proportional and rank-based prioritization mechanisms enable DQN to reach the same baseline performances roughly twice as fast on the Atari benchmark.Normalized maximum and average scores (in terms of Double DQN performance) on 57 Atari games. Source:[7]Prioritized Experience Replay in practiceStoix seamlessly integrates the Flashbax library which provides a variety of replay buffers. Here are the relevant code snippets used to instantiate the replay buffer, compute the sampling probabilities from the TD error, and update the buffer’s priorities:https://medium.com/media/06179b5ff16f3056c2599ecc457e1d8d/hrefMulti-step LearningSource: Reinforcement Learning: an Introduction, chapter 7Improvement: Enhanced reward signal and sample efficiency, reduced varianceMulti-step learning is an improvement on traditional one-step temporal difference learning which allows us to consider the return over n steps when building our targets. For instance, instead of considering the reward at the next timestep, we’ll consider the n-step truncated rewards (see the below equation). This process has several advantages, among which:Immediate feedback: considering a larger time horizon allows the agent to learn the value of state-action pairs much faster, especially in environments where rewards are delayed and specific actions might not pay out immediately.Sample efficiency: Each update in multi-step learning incorporates information from multiple time steps, making each sample more informative. This improves sample efficiency, meaning the agent can learn more from fewer experiences.Balancing Bias and Variance: Multi-step methods offer a trade-off between bias and variance. One-step methods have low bias but high variance, while multi-step methods have higher bias but lower variance. By tuning the number of steps, one can find a balance that works best for the given environment.The multi-step distributional loss used in Rainbow is defined as:Multi-step target return distributionIn practice, using n-step returns implies a few adjustments to our code:We now sample trajectories of n experiences, instead of individual experiencesThe reward is replaced with the n-step discounted returnsThe done flag is set to True if any of the n done flag is TrueThe next state s(t+1) is replaced by the last observation of the trajectory s(t+n)Multi-Step learning in practiceFinally, we can reuse the categorical loss function used in C51 with these updated inputs:https://medium.com/media/f667c9dc3fbefd3824f58ccb43677608/hrefAssembling RainbowCongratulations on making it this far! We now have a better understanding of all the moving pieces that constitute Rainbow. Here’s a summary of the Rainbow agent:Neural Network Architecture: — Torso: A convolutional neural network (CNN) or multi-layer perceptron (MLP) base that creates embeddings for the head network. — Head: Combines Dueling DQN and C51. The value stream outputs the state value distribution over atoms, while the advantage stream outputs the advantage distribution over actions and atoms. These streams are aggregated, and Q-values are computed as the weighted sum of atom values and their respective probabilities. An action is selected using an epsilon-greedy policy. — Noisy Layers: All linear layers are replaced with their noisy equivalents to aid in exploration.Loss Function: Uses a distributional loss modeling the n-step returns, where targets are computed using Double Q-learning.Replay Buffer: Employs a prioritization mechanism based on the TD error to improve learning efficiency.Here’s the network used for the Rainbow head:https://medium.com/media/548655defb33bdb46b76dccff1ec4c25/hrefPerformances and ablationsTo conclude this article, let’s take a closer look at Rainbow’s performances on the Atari benchmark, as well as the ablation study.The following figure compares Rainbow with the other DQN baselines we studied. The measured metric is the median human-normalized score. In other words, the median human performance on Atari games is set to 100%, which enables us to quickly spot algorithms achieving a human level.Three of the DQN baselines reach this level after 200 million frames:Distributional DQNDueling DQNPrioritized Double DQNInterestingly, Rainbow reaches the same level after only 44 million frames, making it roughly 5 times more sample efficient than the best baselines. At the end of training, it exceeds 200% of the median human-normalized score.Median human-normalized performance across 57 Atari games. Each line represents a DQN baseline. Source: [2]This second figure represents the ablation study, which represents the performances of Rainbow without one of its components. These results allow us to make several observations:The three most crucial components of Rainbow are the distributional head, the use of multi-step learning, and the prioritization of the replay buffer.Noisy layers contribute significantly to the overall performance. Using standard layers with an epsilon-greedy policy doesn’t allow the agent to reach the 200% score in 200 million frames.Despite achieving strong performances on their own, the dueling structure and double Q-learning only provide marginal improvements in the context of Rainbow.Median human-normalized performance across 57 Atari games. Each line represents an ablation of Rainbow. Source: [2]Thank you very much for reading this article, I hope it provided you with a comprehensive introduction to Rainbow and its components. I highly advise reading through the Stoix implementation of Rainbow for a more detailed description of the training process and the Rainbow architecture.Until next time 👋Bibliography[1] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D., & Riedmiller, M. (2013). Playing Atari with Deep Reinforcement Learning, arXiv[2] Hessel, M., Modayil, J., van Hasselt, H., Schaul, T., Ostrovski, G., Dabney, W., Horgan, D., Piot, B., Azar, M., & Silver, D. (2017). Rainbow: Combining Improvements in Deep Reinforcement Learning, arXiv.[3] van Hasselt, H., Guez, A., & Silver, D. (2015). Deep Reinforcement Learning with Double Q-learning, arXiv. [4] Wang, Z., Schaul, T., Hessel, M., van Hasselt, H., Lanctot, M., & de Freitas, N. (2016). Dueling Network Architectures for Deep Reinforcement Learning (No. arXiv:1511.06581), arXiv[5] Bellemare, M. G., Dabney, W., & Munos, R. (2017). A Distributional Perspective on Reinforcement Learning, arXiv[5′] Dabney, W., Ostrovski, G., Silver, D., & Munos, R. (2018). Implicit Quantile Networks for Distributional Reinforcement Learning, arXiv[6] Fortunato, M., Azar, M. G., Piot, B., Menick, J., Osband, I., Graves, A., Mnih, V., Munos, R., Hassabis, D., Pietquin, O., Blundell, C., & Legg, S. (2019). Noisy Networks for Exploration, arXiv. [7] Schaul, T., Quan, J., Antonoglou, I., & Silver, D. (2016). Prioritized Experience Replay, arXivAdditional resources[8] Massimiliano Tomassoli, Distributional RL: An intuitive explanation of Distributional RL[9] Lahire, T., Geist, M., & Rachelson, E. (2022). Large Batch Experience Replay, arXiv. [10] Sutton, R. S., & Barto, A. G. (1998). Reinforcement Learning: An Introduction.[11] Pascal Poupart, CS885 Module 5: Distributional RL, YouTubeRainbow: The Colorful Evolution of Deep Q-Networks was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.  reinforcement-learning, dqn, jax, deep-dives, deep-learning Towards Data Science – MediumRead More

How useful was this post?

Click on a star to rate it!

Average rating 0 / 5. Vote count: 0

No votes so far! Be the first to rate this post.

FavoriteLoadingAdd to favorites
July 12, 2024

Recent Posts

0 Comments

Submit a Comment