.. |H2| raw:: html

*2023-10-22: Updated the code and experimental results based on a reader's*
`comment *`__
for more details.
When using this type of model with an ANS entropy encoder, we have to be a bit
careful because it decompresses symbols in the last-in-first-out (stack) order.
Let's take a closer look in Figure 1.
.. figure:: /images/bbans_autoregressive.png
:width: 700px
:alt: Entropy Encoding and Decoding with a Generative Autoregressive Model
:align: center
Figure 1: Entropy Encoding and Decoding with a Generative Autoregressive Model
From the left side of Figure 1, notice that we have to
reverse the order of the input data to the ANS encoder (the autoregressive
model receives the input in the usual parallel order though). This is needed
because we need to decode the data in the ascending order for the
autoregressive model to work (see decoding below). Next, notice that our ANS
encoder requires both the (reversed) input data and the appropriate
distributions for each symbol (i.e. each :math:`x_j` component). Finally, the
compressed data is output, which (hopefully) is shorter than the original
input.
Decoding is shown on the right hand side of Figure 1. It's a bit more
complicated because we must *iteratively* generate the distributions for each
symbol. Starting by reversing the compressed data, we decode :math:`x_1` since
our model can unconditionally generate its distribution.
This is the reason why we needed to reverse our input during encoding. Then,
we generate :math:`x_2|x_1` and so on for each :math:`x_i|x_{1,\ldots,i-1}`
until we've recovered the original data. Notice that this is quite inefficient
since we have to call the model :math:`n` times for each component of
:math:`\bf x`.
I haven't tried this but it seems like something pretty reasonable to do
(assuming you have a good model *and* I haven't made a serious logical error).
The only problem with generative autoregressive models is that they are slow
because you have to call them :math:`n` times. Perhaps that's why no one is
interested in this? In any case, the next method overcomes this problem.
Latent Variable Models
----------------------
Latent variable models have a set of unobserved variables :math:`\bf z` in
addition to the observed ones :math:`\bf x`, giving us a likelihood function
of :math:`P(\bf x|\bf z)`. We'll usually have a prior distribution
:math:`P({\bf z})` for :math:`\bf z` (implicitly or explicitly), and depending
on the model, we may or may not have access to a posterior distribution (more
likely an estimate of it) as well: :math:`q(\bf z| \bf x)`.
The key idea with these models is that we need to encode the latent variables
as well (or else we won't be able to generate the required distributions for
:math:`\bf x`). Let's take a look at Figure 2 to see how the encoding works.
.. figure:: /images/bbans_latent_encode.png
:width: 600px
:alt: Entropy Encoding with a Latent Variable Model
:align: center
Figure 2: Entropy Encoding with a Latent Variable Model
Starting from the input data, we need to first generate some value for our
latent variable :math:`\bf z` so that we can use it with our model :math:`P(\bf x|\bf z)`.
This can be obtained either by sampling the prior (or posterior if available),
or really any other method that would generate an accurate distribution for :math:`\bf x`.
Once we have :math:`\bf z`, we can run it through our model, get distributions
for each :math:`x_i` and encode the input data as usual. The one big
difference is that we also have to encode our latent variables. The latent
variables *should* be distributed according to our prior distribution (for most
sensible models), so we can use it with the ANS coder to compress :math:`\bf
z`. Notice that we cannot use the posterior here because we won't have access
to :math:`\bf x` at decompression time, therefore, would not be able to
decompress :math:`\bf z`.
.. figure:: /images/bbans_latent_decode.png
:width: 600px
:alt: Entropy Decoding with a Latent Variable Model
:align: center
Figure 3: Entropy Decoding with a Latent Variable Model
Decoding is shown in Figure 3 and works basically as the reverse of encoding.
The major thing to notice is that we have to do operations in a
last-in-first-out order. That is, first decode :math:`\bf z`, use it to
generate distributional outputs for the components of :math:`\bf x`, then
use those outputs to decode the compressed data to recover our original message.
This is all relatively straight forward if you took time to think about it.
There are some other issues as well around discretizing :math:`\bf z` if it's
continuous but we'll cover that below. The more interesting question is can we
do better? The answer is a resounding "Yes!", and that's what this post is all
about. By using a very clever trick you can get some "bits back" to improve
your compression performance. Read on to find out more!
Bits-Back Coding
================
From the previous section, we know that we can encode and decode data using a
latent variable model with relative ease. The big downside is that we're
"wasting" space by encoding the latent variables. They're necessary to
generate the distributions for our data, but otherwise are not directly
encoding any signal. It turns out we can use a clever trick to recover
some of this "waste".
Notice in Figure 2, we randomly sample from (an estimate of) the posterior distribution.
In some sense, we're introducing new information from the random sample here
that we must encode. Instead, why don't we utilize some of the existing bits
we've encoded to get a "pseudo-random" sample [1]_ ? Figure 4 shows the encoding
process in more detail.
.. figure:: /images/bbans_bb_encode.png
:width: 600px
:alt: Bits-Back Encoding with a Latent Variable Model
:align: center
Figure 4: Bits-Back Encoding with a Latent Variable Model
The key difference here is that we're decoding the existing bitstream (from
previous data that we've compressed) to generate a (pseudo-) random sample :math:`\bf z`
using the posterior distribution. This replaces the random sampling we
did in Figure 2. Since the existing bitstream was encoded using a different
distribution, the sample we decode should *sort of* random. The nice part
about this trick is that we're still going to encode :math:`\bf z` as usual so
any bits we've popped off the bitstream to generate our pseudo-random sample,
we get "back" (that is, aren't require to be on the bitstream anymore). This
*reduces* the effective average size of encoding each datum + latent variables.
.. figure:: /images/bbans_bb_decode.png
:width: 600px
:alt: Bits-Back Decoding with a Latent Variable Model
:align: center
Figure 5: Bits-Back Decoding with a Latent Variable Model
Figure 5 shows decoding with Bits-Back. It is the same as latent variable
decoding with the exception that we have to "put back" the bits we took off
originally. Since our ANS encoding and decoding are lossless, the bits we
put back should be exactly the bits we took off. The number of bits we
removed/put back will be dependent on the posterior distribution and the bits that
were originally there.
.. figure:: /images/bbans_bitstream_view.png
:width: 600px
:alt: Visualization of Bitstream for Bits-Back Coding
:align: center
Figure 6: Visualization of Bitstream for Bits-Back Coding
To get a better sense of how it works, Figure 6 shows a visualization of
encoding and decoding two data points. Colors represent the different
data: green for existing bitstream, blue for :math:`\bf x^1`, and orange for :math:`\bf x^2`
(superscript represents data point index). The different shades represent either
observed data :math:`\bf x` or latent variable :math:`\bf z`.
From Figure 6, the first step in the process is to *reduce* the bitstream length
by (pseudo-)randomly sampling :math:`\bf z`. This is followed by encoding
:math:`\bf x` and :math:`\bf z` as usual. This process repeats for each
additional datum. Even though we have to encode :math:`\bf z`, the effective
size of the encoding is shorter because of the initial "bits back" we get each
time. Decoding is the reverse operation of encoding: decode :math:`\bf z` and
:math:`\bf x`, put back the removed bits by utilizing the posterior
distribution (which is conditional on the :math:`\bf x` we just decoded).
And this repeats until all data has been decoded.
Theoretical Limit of Bits-Back Coding
-------------------------------------
Turning back to some more detailed mathematical analysis, let's see how good
Bits-Back is theoretically. We'll start off with a few assumptions:
1. Our data :math:`\bf x` and latent variables :math:`\bf z` are sampled from
the true joint distribution :math:`P({\bf x, z})=P({\bf x|z})P({\bf z})`,
which we have access to. Of course in the real world, we don't have the
true distribution, just an approximation. But if our model is very good, it
will hopefully be very close to the true distribution.
2. We have access to an approximate posterior :math:`q({\bf z|x})`.
3. Assume we have an entropy coder so that we can optimally code any data point.
4. The pseudo-random sample we get from Bits-Back coding is drawn from the approximate posterior :math:`q({\bf z|x})`.
As noted above, if we naively use the latent variable encoding from Figure 2,
given a sample :math:`(x, z)`, our expected message length should be
:math:`-(\log P({\bf z}) + \log P({\bf x|z}))` bits long. This uses the fact
(roughly speaking) that the theoretical limit of the average number of bits
needed to represent a symbol (in the context of its probability distribution)
is its `information `__.
However using Bits-Back with an approximate posterior :math:`q({\bf z|x})`
for a given *fixed* data point :math:`\bf x`, we can calculate the expected
message length over all possible :math:`\bf z` drawn from :math:`q({\bf z|x})`.
The idea is that we're (pseudo-)randomly drawing :math:`\bf z` values, which
affect each part of the process (Bits-Back, encoding :math:`x`, and encoding
:math:`z`) so we must average (i.e. take the expectation) over it:
.. math::
L(q) &= E_{q({\bf z|x})}[-\log P({\bf z}) - \log P({\bf x|z}) + \log q({\bf z|x})] \\
&= \sum_z q({\bf z|x})[-\log P({\bf z}) - \log P({\bf x|z}) + \log q({\bf z|x})] \\
&= -\sum_z q({\bf z|x})\log \frac{P({\bf x, z})}{q({\bf z|x})} \\
&= -E_{q({\bf z|x})}\big[\log \frac{P({\bf x, z})}{q({\bf z|x})}\big] \\
\tag{2}
Equation 2 is also known as the evidence lower bound (ELBO) (see my previous
`post on VAE *