In the previous post, I introduced the model inference of the softmax regression. In this revisit post, I will discuss it in more details. The softmax regression can be viewed as a generalization of the logistic regression from the binary classification problem to the multiclass classification problem.

Inference on the softmax regression

Suppose we have an \(m\) class classification task, with input feature vector \(\vec x \in \mathbb{R}^n\), we encode the real label a one-hot vector \(\vec y \in \mathbb{R}^m\) for each input, which leads to \(n \times m\) parameters \(\mathbb{\Omega}\) in the model.

where \(\vec{\hat{y}}\) is the output of the model, and each element \(\hat{y}_i\) in it indicates the probability of \(i^{th}\) class being the real class given a \(\vec x\).

Likelihood and loss

Now let’s derive the loss function of the softmax regression by first introducing the likelihood function \(\mathbb{L}(\mathbb{\omega} | \vec y, \vec x)\).

Further more,

In order to optimize the model, we want to maximize the likelihood of the model, which is equivalent to minimizing the negative log-likelihood.

We note that this is the cross entropy error function of the model, which means maximizing the likelihood of the model conforms to minimizing the loss of the model.

Derivative

Now, let’s find the derivative of the loss with respect to parameters.

First, for function \(\vec y = f(\vec z)\) where \(y_i = \frac{e^{z_i}}{\sum_{j = 1}^me^{z_j}}\), the partial derivative of \(\frac{\partial y_i}{\partial z_j}\) is:

  • For \(i = j\), we have
  • For \(i \neq j\), we have

And, we can see that \(\vec z\) is just a function of \(\vec x\) with the relationship

Then, let’s derive the derivatives of the original loss function:

Also note that, we omit the contribution of \(P(\vec x| \mathbb{\Omega})\) in the likelihood function and in the derivative as well. In a mini-batch gradient descent, it’s better to take it into consideration?

The hierarchical softmax

In the above derivation, we can see that in order to calculate the derivative of the loss function with respect to the parameters, we need to calculate the value of \(\sum_{i = 1}^m e^{\vec \omega_i \vec x}\). When the total class number \(m\) is large, that step will be computational heavy, as it requires \(|n \times m|\) multiplications involving all the parameters in the model.

The hierarchical softmax is much cheaper in model training phase with respect to the original softmax with large \(m\).

Suppose we use a balanced binary tree to construct the hierarchical softmax, where every leaf \(i\) represents the probability of \(P(y = i | \vec x)\), with each node representing the probability \(P(\text{left} | \vec x, \text{context})\). There are \(m - 1\) nodes in the tree, and the height of the tree is \(\log_2 m\).

Given a \(y = i\), the path from root to leaf \(i\) is determined, which means that in the calculation of the derivative of each \(y\), we need to do \(|n * \log_2 m|\) multiplications as the contribution of the parameters not on the path is zero.

Also note that, the hierarchical softmax only accelerate the process of model training, in prediction phase, you still need to calculate the probability of all leaves to find a optimal prediction.