This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.21-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
from jax import numpy as np
from jax import scipy
from jax import random
import numpy as onp
import scipy as oscipy
onp.random.seed(10009)
num_features = 10
num_points = 100
true_beta = onp.random.randn(num_features).astype(np.float32)
all_x = onp.random.randn(num_points, num_features).astype(np.float32)
y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)
y
We'll write a non-batched version, a manually batched version, and an autobatched version.
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
return result
log_joint(onp.random.randn(num_features))
# This doesn't work, because we didn't write `log_prob()` to handle batching.
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
log_joint(onp.random.randn(batch_size, num_features))
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=-1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),
axis=-1)
return result
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
It just works.
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
A little code is copied from above.
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon
return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))
elbo = jax.jit(elbo, static_argnums=(2, 3))
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
def normal_sample(key, shape):
"""Convenience function for quasi-stateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = np.zeros(num_features, np.float32)
beta_log_scale = np.zeros(num_features, np.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact.
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')