Autobatching log-densities example

This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.

Inspired by a notebook by @davmre.

In [0]:
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.21-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import itertools
import re
import sys
import time

from matplotlib.pyplot import *

import jax

from jax import lax
from jax import numpy as np
from jax import scipy
from jax import random

import numpy as onp
import scipy as oscipy

Generate a fake binary classification dataset

In [2]:
onp.random.seed(10009)

num_features = 10
num_points = 100

true_beta = onp.random.randn(num_features).astype(np.float32)
all_x = onp.random.randn(num_points, num_features).astype(np.float32)
y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)
In [3]:
y
Out[3]:
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)

Write the log-joint function for the model

We'll write a non-batched version, a manually batched version, and an autobatched version.

Non-batched

In [4]:
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `np.sum`.
    result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))
    result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
    return result
In [5]:
log_joint(onp.random.randn(num_features))
/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU.
  warnings.warn('No GPU found, falling back to CPU.')
Out[5]:
array(-213.23558, dtype=float32)
In [6]:
# This doesn't work, because we didn't write `log_prob()` to handle batching.
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)

log_joint(onp.random.randn(batch_size, num_features))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-c7ddbb18b4cb> in <module>()
      3 batched_test_beta = onp.random.randn(batch_size, num_features)
      4 
----> 5 log_joint(onp.random.randn(batch_size, num_features))

<ipython-input-4-fff01ffe382a> in log_joint(beta)
      3     # Note that no `axis` parameter is provided to `np.sum`.
      4     result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))
----> 5     result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
      6     return result

/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in <lambda>(x, y)
    240     fn = lambda x, y: lax_fn(*_promote_args_like(numpy_fn, x, y))
    241   else:
--> 242     fn = lambda x, y: lax_fn(*_promote_args(numpy_fn.__name__, x, y))
    243   return _wraps(numpy_fn)(fn)
    244 

/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _promote_args(fun_name, *args)
    177   """Convenience function to apply Numpy argument shape and dtype promotion."""
    178   _check_arraylike(fun_name, *args)
--> 179   return _promote_shapes(*_promote_dtypes(*args))
    180 
    181 

/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _promote_shapes(*args)
    137   else:
    138     shapes = [shape(arg) for arg in args]
--> 139     nd = len(lax.broadcast_shapes(*shapes))
    140     return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
    141             if len(shp) != nd else arg for arg, shp in zip(args, shapes)]

/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/util.pyc in memoized_fun(*args, **kwargs)
    159       cache.popitem(last=False)
    160 
--> 161     ans = cache[key] = fun(*args, **kwargs)
    162     return ans
    163   return memoized_fun

/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lax.pyc in broadcast_shapes(*shapes)
     67   if not onp.all((shapes == result_shape) | (shapes == 1)):
     68     raise ValueError("Incompatible shapes for broadcasting: {}"
---> 69                      .format(tuple(map(tuple, shapes))))
     70   return tuple(result_shape)
     71 

ValueError: Incompatible shapes for broadcasting: ((100, 10), (1, 100))

Manually batched

In [7]:
def batched_log_joint(beta):
    result = 0.
    # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
    # or setting it incorrectly yields an error; at worst, it silently changes the
    # semantics of the model.
    result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),
                           axis=-1)
    # Note the multiple transposes. Getting this right is not rocket science,
    # but it's also not totally mindless. (I didn't get it right on the first
    # try.)
    result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),
                           axis=-1)
    return result
In [8]:
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)

batched_log_joint(batched_test_beta)
Out[8]:
array([-147.84033 , -207.02205 , -109.26075 , -243.8083  , -163.02911 ,
       -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],
      dtype=float32)

Autobatched with vmap

It just works.

In [9]:
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
Out[9]:
array([-147.84033 , -207.02205 , -109.26075 , -243.8083  , -163.02911 ,
       -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ],
      dtype=float32)

Self-contained variational inference example

A little code is copied from above.

Set up the (batched) log-joint function

In [10]:
@jax.jit
def log_joint(beta):
    result = 0.
    # Note that no `axis` parameter is provided to `np.sum`.
    result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))
    result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
    return result

batched_log_joint = jax.jit(jax.vmap(log_joint))

Define the ELBO and its gradient

In [11]:
def elbo(beta_loc, beta_log_scale, epsilon):
    beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon
    return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))
elbo = jax.jit(elbo, static_argnums=(2, 3))
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))

Optimize the ELBO using SGD

In [12]:
def normal_sample(key, shape):
    """Convenience function for quasi-stateful RNG."""
    new_key, sub_key = random.split(key)
    return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))

key = random.PRNGKey(10003)

beta_loc = np.zeros(num_features, np.float32)
beta_log_scale = np.zeros(num_features, np.float32)

step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
    key, epsilon = normal_sample(key, epsilon_shape)
    elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
        beta_loc, beta_log_scale, epsilon)
    beta_loc += step_size * beta_loc_grad
    beta_log_scale += step_size * beta_log_scale_grad
    if i % 10 == 0:
        print('{}\t{}'.format(i, elbo_val))
0	-180.853881836
10	-113.060455322
20	-102.737258911
30	-99.7873535156
40	-98.9089889526
50	-98.297454834
60	-98.1863174438
70	-97.5797195435
80	-97.2860031128
90	-97.4699630737
100	-97.4771728516
110	-97.5806732178
120	-97.494354248
130	-97.5027313232
140	-96.8639526367
150	-97.4419784546
160	-97.0694046021
170	-96.8402862549
180	-97.2133789062
190	-97.5650253296
200	-97.2639770508
210	-97.1197967529
220	-97.395942688
230	-97.1683197021
240	-97.1184082031
250	-97.2434539795
260	-97.2978668213
270	-96.692855835
280	-96.9643859863
290	-97.3005523682
300	-96.6359176636
310	-97.0351867676
320	-97.529083252
330	-97.2881164551
340	-97.0732192993
350	-97.1561889648
360	-97.2588195801
370	-97.1951446533
380	-97.1309204102
390	-97.1172637939
400	-96.9387359619
410	-97.2667694092
420	-97.353225708
430	-97.2100753784
440	-97.2843475342
450	-97.1630859375
460	-97.2612457275
470	-97.2134399414
480	-97.2399749756
490	-97.1491317749
500	-97.2352828979
510	-96.9342041016
520	-97.212097168
530	-96.8257751465
540	-97.0128479004
550	-96.9417648315
560	-97.1652069092
570	-97.2916564941
580	-97.429397583
590	-97.2437133789
600	-97.1521911621
610	-97.4984436035
620	-96.9906997681
630	-96.8895645142
640	-96.8996887207
650	-97.1379394531
660	-97.4370574951
670	-96.9923629761
680	-97.1562423706
690	-97.1869049072
700	-97.1116027832
710	-97.7810516357
720	-97.2322616577
730	-97.1620635986
740	-96.9958190918
750	-96.6672210693
760	-97.1679534912
770	-97.5143508911
780	-97.2890090942
790	-96.9122619629
800	-97.1709976196
810	-97.290473938
820	-97.1624298096
830	-97.1910629272
840	-97.5638198853
850	-97.0019378662
860	-96.8655548096
870	-96.7633743286
880	-96.8366088867
890	-97.1217956543
900	-97.0955505371
910	-97.0682373047
920	-97.1194763184
930	-96.8792953491
940	-97.4562530518
950	-96.6928024292
960	-97.293762207
970	-97.3353042603
980	-97.349609375
990	-97.0967559814

Display the results

Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact.

In [13]:
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
Out[13]:
<matplotlib.legend.Legend at 0x125ac6490>