This follows on from the previous post on fitting a gaussian distribution with pyro: Fitting a Distribution with Pyro

Here we assume we are flipping a slightly biased coin. We think the probability of a heads is close to 0.5, but we are not sure. We want to fit a beta distribution to the random observed data.

References:

Part of a project on github.

Import the required libraries:

import numpy as np
import torch
from torch.distributions import constraints
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

import matplotlib.pyplot as plt

plt.style.use("seaborn-whitegrid")
torch.manual_seed(0)

Generate observed data

We use the distribution module of pytorch to generate random data from Bernoulli trials with a known probability of success \(P(p)=0.4\).

# Generate data from actual distribution
true_dist = dist.Bernoulli(0.4)
n = 100
data = true_dist.sample(sample_shape=(n, 1))

Prior distribution

We propose the probability of a heads comes from a beta distribution. We assume that the probability is close to 0.5 but with some error. This is characterised in the prior distribution as follows:

prior = dist.Beta(10, 10)

plt.figure(num=None, figsize=(10, 6), dpi=80)
x_range = np.linspace(0, 1, num=100)

y_values = torch.exp(prior.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="prior")

plt.title("PDF")
plt.legend()
plt.savefig("images/beta_prior_pdf.png")
plt.show()

Analytical Posterior

Using the random data we have generated, we can calculate the posterior distribution. In the case of a beta distribution - the posterior has an analytical solution, based on conjugacy:

# Analytical posterior
posterior = dist.Beta(
    prior.concentration1 + data.sum(),
    prior.concentration0 + len(data) - data.sum(),
)

Variational inference

We can solve the same problem with variational inference using pyro. We setup the model to sample from a Bernoulli trial, where the probability of a heads comes from a beta distribution. The model is conditioned to give the generated data when it is sampled from.

def data_model(params):
    # returns a Bernoulli trial outcome
    beta = pyro.sample("beta_dist", dist.Beta(params[0], params[1]))
    return pyro.sample("data_dist", dist.Bernoulli(beta))


conditioned_data_model = pyro.condition(data_model, data={"data_dist": data})

The guide function creates a pyro beta distribution object given a set of parameters, which we will track.

def guide(params):
    # returns the Bernoulli probablility
    alpha = pyro.param(
        "alpha", torch.tensor(params[0]), constraint=constraints.positive
    )
    beta = pyro.param(
        "beta", torch.tensor(params[1]), constraint=constraints.positive
    )
    return pyro.sample("beta_dist", dist.Beta(alpha, beta))

We iterate over the above functions, starting from our prior distribution. Each step we converge towards an ideal posterior form of the guide.

svi = pyro.infer.SVI(
    model=conditioned_data_model,
    guide=guide,
    optim=pyro.optim.SGD({"lr": 0.001, "momentum": 0.8}),
    loss=pyro.infer.Trace_ELBO(),
)

params_prior = [prior.concentration1, prior.concentration0]

# Iterate over all the data and store results
losses, alpha, beta = [], [], []
pyro.clear_param_store()

num_steps = 3000
for t in range(num_steps):
    losses.append(svi.step(params_prior))
    alpha.append(pyro.param("alpha").item())
    beta.append(pyro.param("beta").item())

posterior_vi = dist.Beta(alpha[-1], beta[-1])

We plot the trajectories of the parameters to show they have converged sufficiently:

plt.figure(num=None, figsize=(10, 6), dpi=80)
plt.plot(alpha, label='alpha')
plt.plot(beta, label='beta')
plt.title("Parameter trajectories")
plt.xlabel("Iteration")
plt.legend()
plt.savefig("images/beta_trajectories.png")
plt.show()

Comparing distributions

We can compare the variational inference distribution to the analytical posterior.

plt.figure(num=None, figsize=(10, 6), dpi=80)
x_range = np.linspace(0, 1, num=100)

y_values = torch.exp(prior.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="prior")

y_values = torch.exp(posterior.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="posterior")

y_values = torch.exp(posterior_vi.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="posterior_vi")

plt.title("PDF")
plt.legend()
plt.savefig("images/beta_pdfs.png")
plt.show()

The estimated posterior from variational inference is very similar to the analytical posterior.

Extra: More Data

If we generate much more data and repeat the process, we can get a tighter posterior:

n = 10000
data_more = true_dist.sample(sample_shape=(n, 1))
conditioned_data_model_more = pyro.condition(data_model, data={"data_dist": data_more})

svi = pyro.infer.SVI(
    model=conditioned_data_model_more,
    guide=guide,
    optim=pyro.optim.SGD({"lr": 0.0001, "momentum": 0.8}),
    loss=pyro.infer.Trace_ELBO(),
)

# Iterate over all the data and store results
losses, alpha, beta = [], [], []
pyro.clear_param_store()

num_steps = 3000
for t in range(num_steps):
    losses.append(svi.step(params_prior))
    alpha.append(pyro.param("alpha").item())
    beta.append(pyro.param("beta").item())

posterior_vim = dist.Beta(alpha[-1], beta[-1])

plt.figure(num=None, figsize=(10, 6), dpi=80)
x_range = np.linspace(0, 1, num=1000)

y_values = torch.exp(prior.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="prior")

y_values = torch.exp(posterior.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="posterior")

y_values = torch.exp(posterior_vi.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="posterior_vi")

y_values = torch.exp(posterior_vim.log_prob(torch.tensor(x_range)))
plt.plot(x_range, y_values, label="posterior_vi_more_data")

plt.title("PDF")
plt.legend()
plt.savefig("images/beta_pdfs_more.png")
plt.show()