UP | HOME

Hierarchical Softmax

Often, we want to define a distribution over words. This can be done by giving each word a vector representation \(w\) and using softmax. Then, given an input vector \(v_{w_i}\), the probability \(p(w_o \mid w_i) = \frac{\exp(v_{w_i}^\intercal v_{w_o})}{\sum_w v_{w_i}^\intercal v_{w}}\). Note that the denominator will take \(|W|\) time to compute.

Hierarchical softmax is a way of approximating softmax, where the probabilities can all be calculated in \(\log |W|\) time. We use a tree to calculate the probabilities:

So to calculate the probability \(p(w\mid w_i)\):

Note that this yields a valid probability distribution over all of \(W\), because the probabilities at each internal node form a valid distribution.

At each internal node \(n\), the probability of the outgoing branches is modeled as: \[\sigma(\pm v_{w_I}^{\intercal} v_n)\] where \(v_n\) is an embedding for the internal node, and part of the parameters of the model and \(\sigma\) is the sigmoid.

One of the things I was confused about when learning about the hierarchical softmax was how the vector embeddings for the internal node would be encouraged to correctly reflect the probabilities at the leaves during training. I see now that hierarchical softmax is a re-parameterization of the softmax function. The probability at the leaves is by-definition whatever is given by taking the product along the paths.

Note that in a balanced tree, there are \(W-1\) internal nodes. So this parameterization involves the same amount of weights, but makes computing the probabilities faster (\(O(\log |W|)\) vs \(O(|W|)\)).

1. Helpful links

Created: 2024-07-15 Mon 01:28