Quickstart
Example: The Cliff Walking
Let's solve the famous Cliff Walking problem with tabular Q-Learning and SARSA [1].
Define the environment:
from gym.envs.toy_text import CliffWalkingEnv
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from pandas import DataFrame
from yarllib.helpers.history import History
from yarllib.policies import EpsGreedyPolicy
env = CliffWalkingEnv()
def print_summary(history: History):
print("Training statistics:")
print(f"Number of episodes: {history.nb_episodes}")
print(f"Average total reward: {history.total_rewards.mean()}")
print(f"Average number of steps: {history.lengths.mean()}")
print(f"Average total reward (last 50 episodes): {history.total_rewards[:-50].mean()}")
print(f"Average number of steps (last 50 episodes): {history.lengths[:-50].mean()}")
First, let's define the parameters, common both to
Q-Learning and SARSA:
nb_steps = 30000
alpha = 0.1
gamma = 0.99
seed = 42
epsilon = 0.1
policy = EpsGreedyPolicy(epsilon)
params = dict(
env=env,
nb_steps=nb_steps,
policy=policy,
seed=seed
)
Define the Q-Learning agent:
from yarllib.models.tabular import TabularQLearning
qlearning = TabularQLearning(env.observation_space, env.action_space).agent()
print(f"Table dimensions: {qlearning.model.q.shape}")
Run for 30000 steps using \varepsilon-greedy policy with \varepsilon = 0.1:
qlearning_history = qlearning.train(**params)
print_summary(qlearning_history)
Define and train a SARSA agent:
from yarllib.models.tabular import TabularSarsa
sarsa = TabularSarsa(env.observation_space, env.action_space).agent()
sarsa_history = sarsa.train(**params)
print_summary(sarsa_history)
Compare the sum of rewards:
def _plot(histories, labels):
assert len(histories) == len(labels)
for h, l in zip(histories, labels):
data = h.total_rewards
df = DataFrame(data.T)
df = pd.concat([df[col] for col in df])
df = df.rolling(200).mean()
sns.lineplot(data=df, label=l)
plt.xlim(175, 600)
_plot([qlearning_history, sarsa_history], ["q-learning", "sarsa"])
We get what we expected: Q-Learning performs
worse than SARSA, as explained in the Example 6.6
in the Sutton & Barto textbook [1].