Sampling vs Prediction

Some things have recently been bugging me when applying deep learning models to natural language generation. This post contains my random thoughts on two of these:  sampling and prediction. By writing this post, I hope to try to tease these apart in my head to help improve my natural language models.



Sampling is the act of selecting a value for a variable having an underlying probability distribution.

For example, when we toss a coin, we have a 50% chance of “heads” and a 50% chance of “tails”. In this case, our variable may be “toss outcome” or “t”, our values are “head” and “tail”, and our probability distribution may be modeled as [0.5, 0.5], representing the chances of “heads” and “tails”.

To make a choice according to the probability distribution we sample.

One way of sampling is to generate a (pseudo-) random number between 0 and 1 and to compare this random number with the probability distribution. For example, we can generate a cumulative probability distribution, where we create bands of values that sum to one from the probability distribution. In the present case, we can convert our probability distribution [0.5, 0.5] into a cumulative distribution [0.5, 1], where we say 0 to 0.5 is in the band “heads” and 0.5 to 1 is in the band “tails”. We then compare our random number with these bands, and the band the number falls into is the band we select. E.g. if our random variable is 0.2, we select the first band – 0 to 0.5 – and our variable value is “heads”; if our random variable is 0.7, we select the second band – 0.5 to 1 – and our variable value is tails.

Let’s compare this example to a case of a weighted coin. Our weighted coin has a probability distribution of [0.8, 0.2]. This means that it is weighted so that 80% of tosses come out as “heads”, and 20% of tosses come out as “tails”. If we use the same sampling technique, we generate a cumulative probability distribution of [0.8, 1], with a band of 0 to 0.8 being associated with “heads” and a band of 0.8 to 1 being associated with tails. Given the same random variable values of 0.2 and 0.7, we now get sample values of “heads” and “heads”, as both values fall within the band 0 to 0.8

Now that all seems a bit long winded. Why have I repeated the obvious?

Because most model architectures that act to predict or generate text skip over how sampling is performed. Looking at the code associated with the models, I would say 90% of cases generate an array of probabilities for each time step, and then take the maximum of this array (e.g. using numpy.argmax(array)). This array is typically the output of a softmax layer, and so sums to 1. It may thus be taken as a probability distribution. So our sampling often takes the form of a greedy decoder that just selects the highest probability output at each time step.

Now, if our model consistently outputs probabilities of [0.55, 0.45] to predict a coin toss, then based on the normal greedy decoder, our output will always be “heads”. Repeated 20 times, our output array would be 20 values of “heads”. This seems wrong – our model is saying that over 20 repetitions, we should have 11 “heads” and 9 “tails”.

Let’s park this for a moment and look at prediction.


With machine learning frameworks such as Tensorflow and Keras you create “models” that take an input “tensor” (a multidimensional array). Each “model” applies a number of transformations, typically via one or more neural networks, to output another “tensor”. The generation of an output is typically referred to as a “prediction”.

A “prediction” can thus be deemed an output of a machine learning model, given some input.

To make an accurate prediction, a model needs to be trained to determine a set of parameter values for the model. Often there are millions of parameters. During training, predictions are made based on input data. These predictions are compared with “ground truth” values, i.e. the actual observed/recorded/measured output values. Slightly confusingly, each pair of “input-ground truth values” is often called a sample. However, these samples have nothing really to do with our “sampling” as discussed above.

For training, loss functions for recurrent neural networks are typically based on cross entropy loss. This compares the “true” probability distribution, in the form of a one-hot array for the “ground truth” output token (e.g. for “heads” – [1, 0]), with the “probability distribution” generated by the model (which may be [0.823, 0.177]), where cross entropy is used to compute the difference. This post by Rob DiPietro has a nice explanation of cross entropy. The difference filters back through the model to modify the parameter values via back-propagation.

Sampling vs Prediction

Now we have looked at what we mean by “sampling” and “prediction” we come to the nub of our problem: how do these two processes interact in machine learning models?

One problem is that in normal English usage, a “prediction” is seen as the choice of a discrete output. For example, if you were asked to “predict” a coin toss, you would say one of “heads” and “tails”, not provide me with the array “[0.5, 0.5]”.

Hence, “prediction” in normal English actually refers to the output of the “sampling” process.

This means that, in most cases, when we are “predicting” with a machine learning model (e.g. using `model.predict(input)` in Keras), we are not actually “predicting”. Instead, we are generating an instance of the probability distribution given a set of current conditions. In particular, if our model has a softmax layer at its output, we are generating a conditional probability distribution, where the model is “conditioned” based on the input data. In a recurrent neural network model (such as an LSTM or GRU based model), at each time step, we are generating a conditional probability distribution, where the model is “conditioned” based on the input data and the “state” parameters of the neural network.

As many have suggested, I find it useful to think of machine learning models, especially those using multiple layers (“deep” learning), as function approximators. In many cases, a machine learning model is modeling the probability function for a particular context, where the probability function outputs probabilities across an output discrete variable space. For a coin toss, we our output discrete variable space is “heads” or “tails”; for language systems, our output discrete variable space is based on the set of possible output tokens, e.g. the set of possible characters or words. This is useful, because our problem then becomes: how do we make a model that accurate approximates the probability distribution function. Note this is different from: how do we generate a model that accurately predicts future output tokens.

These different interpretations of the term “prediction” matter less if our output array has most of its mass under one value, e.g. for each “prediction” you get an output that is similar to a one-hot vector. Indeed, this is what the model is trying to fit if we supply our “ground truth” output as one-hot vectors. In this case, taking a random sample and taking the maximum probability value will work out to be mostly the same.

I think some of the confusion occurs as different people think about the model prediction in different ways. If the prediction is seen to be a normalized score, then it makes more sense to pick the highest score. However, if the prediction is seen to be a probability distribution, then this makes less sense. Using a softmax computation on an output does result in a valid probability distribution. So I think we should be treating our model as approximating a probability distribution function and looking at sampling properly.

Another factor is the nature of the “conditioning” of the probability distribution. This can be thought of as: what constraints are there on the resultant distribution. A fully constrained probability distribution becomes deterministic: only one output is going to be correct, and this has a value of one. Hence, the argmax is the same as a random sample but random sampling will still be correct. If we can never truly know or model our constraints then we can never obtain a deterministic setup.

Deterministic vs Stochastic

The recent explosion in deep learning approaches was initially driven by image classification. In these cases, you have an image as an input, and you are trying to ascertain what is in an image (or a pixel). This selection is independent of time and reduces to a deterministic selection: there is either a dog in the image or there is not a dog in the image. The probability distribution output by our image classification model thus indicates more a confidence in our choice.

Language, though, is inherently stochastic: there are multiple interchangeable “right” answers. Also language unfolds over time. This amplifies the stochasticity.

Thinking about a single word choice, language models will generally produce flatter conditional probability distributions for certain contexts. For example, you can often use different words and phrases to mean the same thing (synonyms for example, or different length noun phrases). Also, the length of the sequence is itself a random variable.

In this case, the difference between randomly sampling the probability distribution and taking the maximum value of our probability array becomes more pronounced. Cast your mind back to our coin toss example: if we are accurately modeling the probability distribution then actually argmax doesn’t work – we have two equally likely choices and we have to hack our decoding function to make a choice when we have equal probabilities. However, sampling from the probability distribution works.

A key point people often miss is that we can very accurately model the probability distribution, while producing outputs that do not match “ground truth values”.

What does this all mean?

Firstly, it means that when selecting an output token for each time step, we need to be properly sampling rather than taking the maximum value.

Secondly, it means we have to pay more attention to the sampling of output in our models. For example, using beam search or the Viterbi algorithm to perform our sampling and generate our output sequence. It also suggests more left-field ideas, like how could we build sampling into our model, rather than have it as an external step.

Thirdly, we need to look at our loss function in our language models. At one extreme, you have discriminator-style loss functions, where we perform a binary prediction of whether the output looks right; at the other extreme, you having teaching forcing on individual tokens. The former is difficult to train (how do you know what makes a “good” output and move towards that) but the later also tends to memorize an input, or struggle to produce coherent sequences of tokens that resemble natural language. If we are trying to model a probability distribution, then should our loss be based on trying to replicate the distribution of our input? Is a one-hot encoding the right way to do this?

Of course, I could just be missing the point of all of this. However, the generally poor performance of many generative language models (look at the real output not the BLEU scores!) suggests there is some low-hanging fruit that just requires some shift in perspective.