Skip to content

Thompson Sampling

Thompson sampling is a way to choose action in exploration-exploitation problems which makes use of posterior distributions for the variable of interest.

Minimize Waiting Time

We will assume an exponential distribution wait time for each group with an unknown average wait time for each group.

The conjugate prior of the exponential distribution is a gamma distribution.

The goal is to find the group with the minimum wait time.

from conjugate.distributions import Gamma, Exponential
from conjugate.models import exponential_gamma

import numpy as np

import matplotlib.pyplot as plt

lam = np.array([0.5, 0.55, 0.6, 0.8, 1])
n_groups = len(lam)

true_dist = Exponential(lam=lam)

We will create some helper functions to abstract:

  • sampling from the true distribution of a group
  • create the statistics required for Bayesian update of exponential gamma model
  • single step in the Thompson sampling process
def sample_true_distribution(
    group_to_sample: int,
    rng,
    true_dist: Exponential = true_dist,
) -> float:
    return true_dist[group_to_sample].dist.rvs(random_state=rng)


def bayesian_update_stats(
    group_sampled: int,
    group_sample: float,
    n_groups: int = n_groups,
) -> tuple[np.ndarray, np.ndarray]:
    x = np.zeros(n_groups)
    n = np.zeros(n_groups)

    x[group_sampled] = group_sample
    n[group_sampled] = 1

    return x, n


def thompson_step(estimate: Gamma, rng) -> Gamma:
    sample = estimate.dist.rvs(random_state=rng)

    group_to_sample = np.argmin(sample)

    group_sample = sample_true_distribution(group_to_sample, rng=rng)
    x, n = bayesian_update_stats(group_to_sample, group_sample)

    return exponential_gamma(x_total=x, n=n, prior=estimate)

After defining a prior / initial estimate for each of the distributions, we can use a for loop in order to perform the Thompson sampling and progressively update this estimate.

alpha = beta = np.ones(n_groups)
estimate = Gamma(alpha, beta)

rng = np.random.default_rng(42)

total_samples = 250
for _ in range(total_samples):
    estimate = thompson_step(estimate=estimate, rng=rng)

We can see that the group with the lowest wait time was actually exploited the most!

fig, axes = plt.subplots(ncols=2)
fig.suptitle("Thompson Sampling using conjugate-models")

ax = axes[0]
estimate.set_max_value(2).plot_pdf(label=lam, ax=ax)
ax.legend(title="True Mean")
ax.set(
    xlabel="Mean Wait Time",
    title="Posterior Distribution by Group",
)

ax = axes[1]
n_times_sampled = estimate.beta - 1
ax.scatter(lam, n_times_sampled / total_samples)
ax.set(
    xlabel="True Mean Wait Time",
    ylabel="% of times sampled",
    ylim=(0, None),
    title="Exploitation of Best Group",
)
# Format yaxis as percentage
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x:.0%}"))
plt.show()

Thompson Sampling

Comments