In [23]:
from pymc.gp import *

Mean function

The mean function of a GP can be interpreted as a "prior guess" at the form of the true function.

In [24]:
# Generate mean
def quadfun(x, a, b, c):
    return (a*x**2 + b*x + c)

M = Mean(quadfun, a=1., b=0.5, c=2.)
In [25]:
from pylab import *
x = arange(-1,1,0.1)
plot(x, M(x), 'k-')
Out[25]:
[<matplotlib.lines.Line2D at 0x111c94a90>]

Covariance function

The behavior of individual realizations from the GP is governed by the covariance function. The Matèrn class of functions is a flexible choice.

In [34]:
from pymc.gp.cov_funs import matern
import numpy as np
C = Covariance(eval_fun=matern.euclidean, diff_degree=1.4, amp=0.4, scale=1, rank_limit=1000)

subplot(1,2,2)
contourf(x, x, C(x,x).view(ndarray), origin='lower', extent=(-1,1,-1,1), cmap=cm.bone)
colorbar()

subplot(1,2,1)
plot(x, C(x,0).view(ndarray), 'k-')
ylabel('C(x,0)')
Out[34]:
<matplotlib.text.Text at 0x112713290>
In [38]:
# Returns the diagnonal
C(x)
Out[38]:
array([ 0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,
        0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,
        0.16,  0.16])
In [39]:
diag(C(x,x))
Out[39]:
array([ 0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,
        0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,  0.16,
        0.16,  0.16])

Drawing realizations from a GP

In [44]:
# Generate realizations
f_list = [Realization(M, C) for i in range(3)]

# Plot mean and covariance
x = arange(-1,1,0.01)
plot_envelope(M, C, x)

# Add realizations
for f in f_list:
    plot(x, f(x))

Non-parametric Regression

Under a GP prior for an unknown function f, when the observation error is normally distributed, the posterior us another GP with new mean and covariance functions.

In [68]:
M = Mean(quadfun, a=1., b=0.5, c=2.)
C = Covariance(eval_fun=matern.euclidean, diff_degree=1.4, amp=0.4, scale=1, rank_limit=1000)

obs_x = np.array([-.5, .5])
V = np.array([0.002, 0.002])
data = np.array([3.1, 2.9])

observe(M=M, C=C, obs_mesh=obs_x, obs_V=V, obs_vals=data)

# Generate realizations from posterior
f_list = [Realization(M,C) for i in range(3)]

The function observe informs the mean and covariance functions that values on obs_mesh with observation variance V. Making observations with no error is called conditioning. This is useful when, for example, forcing a rate function to be zero when a population's size is zero.

In [69]:
plot_envelope(M, C, mesh=x)
for f in f_list:
    plot(x, f(x))

Salmon Example

In [78]:
class salmon:
    """
    Reads and organizes data from csv files,
    acts as a container for mean and covariance objects,
    makes plots.
    """
    def __init__(self, name, data):

        # Read in data

        self.name = name

        self.abundance = data[:,0].ravel()
        self.frye = data[:,1].ravel()

        # Specify priors

        # Function for prior mean
        def line(x, slope):
            return slope * x

        self.M = Mean(line, slope = mean(self.frye / self.abundance))

        self.C = Covariance( matern.euclidean,
                                diff_degree = 1.4,
                                scale = 100. * self.abundance.max(),
                                amp = 200. * self.frye.max())

        observe(self.M,self.C,obs_mesh = 0, obs_vals = 0, obs_V = 0)

        self.xplot = linspace(0,1.25 * self.abundance.max(),100)



    def plot(self):
        """
        Plot posterior from simple nonstochetric regression.
        """
        figure()
        plot_envelope(self.M, self.C, self.xplot)
        for i in range(3):
            f = Realization(self.M, self.C)
            plot(self.xplot,f(self.xplot))

        plot(self.abundance, self.frye, 'k.', markersize=4)
        xlabel('Female abundance')
        ylabel('Frye density')
        title(self.name)
        axis('tight')
In [79]:
sockeye_data = np.reshape([2986,9,
3424,12.39,
1631,4.5,
784,2.56,
9671,32.62,
2519,8.19,
1520,4.51,
6418,15.21,
10857,35.05,
15044,36.85,
10287,25.68,
16525,52.75,
19172,19.52,
17527,40.98,
11424,26.67,
24043,52.6,
10244,21.62,
30983,56.05,
12037,29.31,
25098,45.4,
11362,18.88,
24375,19.14,
18281,33.77,
14192,20.44,
7527,21.66,
6061,18.22,
15536,42.9,
18080,46.09,
17354,38.82,
17301,42.22,
11486,21.96,
20120,45.05,
10700,13.7,
12867,27.71,], (34,2))
In [81]:
# Instantiate salmon object with sockeye abundance
sockeye = salmon('sockeye', sockeye_data)

# Observe some data
observe(sockeye.M, sockeye.C, obs_mesh = sockeye.abundance, obs_vals = sockeye.frye, obs_V = .25*sockeye.frye)
In [82]:
sockeye.plot()