Energy based model

Energy based probabilistic models define a probability distribution through an energy function:

where \(Z\) is the normalization factor, which is also called the partition function by analogy with physical systems:

The formulae looks pretty much like the one of softmax.

An energy based model can be learnt by performing sgd on the empirical negative log-likelihood of the training data. As for the logistic regression we will first define the log-likelihood and then the loss function as being the negative log-likelihood:

And use stochastic gradient \(-\frac{\partial \log p(\boldsymbol{x}^{(i)})}{\partial \boldsymbol\theta}\) to optimize the model, where \(\boldsymbol\theta\) are the parameters of the model.

EBM with hidden units

In some situation, we may not observe \(\boldsymbol{x}\) fully, or we want to introduce some unobserved variables to increase thee expressive power of the model. By adding the hidden variables \(\boldsymbol{h}\), we have:

Now let’s introduce the notation of free energy, term from physics, defined as

Then we have:

where \(Z = \sum_{\boldsymbol{x}} e^{-F(\boldsymbol{x})}\) is again the partition function.

For any energy-based (bolzmann) distribution, the gradient of the loss has the form:

As shown in above, eq (2) is the final form of the stochastic gradient of all energy-based distribution. In this post, we will use eq (1) for notation simplicity.

The above gradient contains two parts, which are referred to as the positive phase and the negative phase. The positive phase increases the probability of training data (by reducing the corresponding free energy), while the negative phase decreases the probability of samples generated by the model (by increasing the energy of all \(\boldsymbol{x} \sim P\)).

It’s difficult to determine the gradient analytically, as it involves the computation of \(\sum_{\boldsymbol{x}} p(\boldsymbol{x}) \frac{\partial F(\boldsymbol{x})}{\partial \boldsymbol\theta}\).

The first step in making this computation tractable is to estimate the expectation using a fixed number of model samples. Samples used to estimate the negative phase gradient are referred to as negative particles, which are denoted as \(N\). The gradient becomes:

The elements \(\tilde{\boldsymbol{x}}\) of \(N\) are sampled according to \(P\) (Monte-Carlo).

Restricted Boltzmann Machines

Boltzmann machines are a particular form of log-linear Markov Random Field, for which the energy function is linear in its free parameters. To make them powerful enough to represent complicated distributions (go from the limited parametric setting to a non-parameteric one), let’s consider that some of the variables are never observed. Restricted Boltzmann machines restrict BMs to those without visible-visible and hidden-hidden connections.

The energy funciton \(E(\boldsymbol{v}, \boldsymbol{h})\) of an RBM is defined as:

  1. for binomial energy term

    where \(\Omega\) represents the weights connecting hidden and visible units and \(\boldsymbol{b}\) and \(\boldsymbol{c}\) are the offsets of the visible and hidden variables respectively.

    Thus the energy function:

    The visible and hidden units are conditionally independent given one-another. So we have:

  2. for fixed variance Gaussian energy term

    The energy function is

  3. for softmax energy term

    The energy function is

RBM with binary units

Suppose that \(\boldsymbol{v}\) and \(\boldsymbol{h}\) are binary vectors, a probabilistic version of the usual neuron activation function turns out to be:

The free energy of an RBM with binary units further simplifies to:

And the gradients for an RBM with binary units:

Sampling

Samples of \(P(\boldsymbol{x})\) can be obtained by running a Markov chain to convergence, using Gibbs sampling as the transition operator.

Gibbs sampling of the joint of \(N\) random variables \(S=(S_1, … , S_N)\) is done through a sequence of \(N\) sampling sub-steps of the form \(S_i \sim p(S_i | S_{-i})\) where \(S_{-i}\) contains the \(N-1\) other random variables in \(S\) excluding \(S_i\).

For RBMs, \(S\) consists of the set of visible and hidden units. However, since they are conditionally independent, one can perform block Gibbs sampling. In this setting, visible units are sampled simultaneously given fixed values of the hidden units. Similarly, hidden units are sampled simultaneously given the visible units.

In theory, each parameter update in the learning process would require running one sampling chain to convergence. It is needless to say that doing so would be prohibitively expensive. As such, several algorithms have been devised for RBMs, in order to efficiently sample from \(p(v,h)\) during the learning process.

Contrastive Divergence (CD-k)

Contrastive Divergence uses two tricks to speed up the sampling process:

  1. Since we eventually want \(p(\boldsymbol{v}) \approx p_{\text{train}}(\boldsymbol{v})\) (the true, underlying distribution of the data), we initialize the Markov chain with a training example (i.e., from a distribution that is expected to be close to \(p\), so that the chain will be already close to having converged to its final distribution \(p\)).
  2. CD does not wait for the chain to converge. Samples are obtained after only k-steps of Gibbs sampling. In practice, \(k=1\) has been shown to work surprisingly well.

In a CD-1, we have

The updating rules are like:

  • output binomial unit \(i\) <-> input binomial unit \(j\)
    • weight \(w_{ij}\):

      positive phase contribution:

      negative phase contribution:

    • bias \(b_i\):

      positive phase contribution:

      negative phase contribution:

  • output binomial unit \(i\) <-> input Gaussian unit \(j\)
    • bias \(b_i\) and weight \(w_{ij}\) as above
    • parameter \(a_j\):

      positive phase contribution: \(2 a_j (x^0_j)^2\)

      negative phase contribution: \(2 a_j (x^1_j)^2\)

  • output softmax unit \(i\) <-> input binomial unit \(j\)

    same formulas as for binomial units, except that \(P(y_i=1|\boldsymbol{x})\) is computed differently (with softmax instead of sigmoid)

Reference

link1 and link2