This is my second post in the series of generative models using Deep Learning. My earlier post talked about other form of generative models i.e., variational autoencoders. I specifically explained how to build variational autoencoders from plain autoencoders. Just like previous post, this post is also not about implementation of Generative Adversarial Networks (GANs), neither it is about the mathematics behind them. I will start with something we know i.e., GANs (an introduction would be provided), and build GANs for text from there.
Generative Adversarial Networks (GANs) have been called one of the major break through in machine learning. In fact they have been called "the most interesting idea in machine learning in last 10 years". There are many good tutorials out there (links to which will be provided at the bottom of this post) which talk about general GANs, however most of them discuss GANs from computer vision perspective (for image data), and are not directly applicable for text data.
Now before we discuss GANs for text data, lets revise general GANs and how they are applicable to image data. The basic idea in GANs is pretty simple. GANs consists of mainly two components, one is generator and another is discriminator. The job of the generator is to generate some random but real looking data, while the job of the discriminator is to discriminate between real data and real-looking random data. The training happens in an adversarial combat mode. If the discriminator is able to tell the difference between real and real-looking data, it means that generator has not done its job well at generating real-looking data, and generator's weights are updated so that it can generate better real-looking data. And if the discriminator is not able to tell the difference between real data and real-looking data, it means that the discriminator has not trained well, and its weights are updated. Now the training continues until we get a good generator and good discriminator.
In their original paper, Goodfellow et al. propose to minimize the following objective function and we will discuss GANs first in the context of this objective function:
$$ \min_G \max_D \mathbb{E}_{x\sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z\sim p_{z}(z)}[\log 1-D(G(z))]$$
Here $G$ is the generative model, $D$ is the discriminative model, $x$ is real data, $z$ is the random noise (not real-looking data).
Although nasty looking, the above equation is pretty simple. It is basically trying to maximize the probability of discriminator's ability to classify the real data, and generator's ability to generate fake data. In other words, if you think of discriminator as classifier (for example multi-layer perceptron), then the objective function is nothing but cross entropy i.e. $\log p(y|x) + \log(1-p(y|z))$. Both the terms in the objective function denotes the probability of generating real data so it has to be maximized with respect to discriminator while minimize with respect to generator. Hopefully the pictures below clarifies things further.
GANs should now be clear from the above picture. As you can see we have two kind of data i.e., real data and real-looking data, both of which are fed to discriminator, which maximizes the probability of generating real data.
As we all know GANs were originally designed for image data so they are naturally applicable to images, more so because of their real-valued nature. Each image pixel is nothing but some real value denoting the color. The application of GANs to text data is a bit tricky. This is primarily because of the following reasons.
Despite these challenges, GANs have been used for text data, and in this post I will try to make it simpler. The original paper can called these networks as GAN-RNN.
Now assuming that we know GANs for image data, we will make modifications in them to make them applicable for text data. The idea in GANs is to learn a generator that is able to generate real looking data. In GAN-RNN which is GAN for text data, the very first modification we make is replace the generator with a recurrent neural network (RNN). We can also call it encoder as it is usually called in sequence-to-sequence model. The job of this RNN is to generate random text. Note that in GANs we first generate random noise and then transforms this random noise into some real-looking data through generator, here we perform both of these tasks by RNN. To begin with, the input to RNN's first cell is a start-of-sentence token, and from there on, subsequent cells generate rest of the words automatically. Note that when comparing GAN-RNN with plain GANs, we have replaced both random noise generator and generator in GANs with RNN. It performs the job of both. In the beginning when we have not really learned the weights of RNN, it will be as good as random noise i.e. meaningless text; however as the training continues, it will learn to generate meaningful text. With this simple replacement, we have solved two of the three problems mentioned above i.e. how to generate text data sequentially, and problem with discrete vs. real valued.
Now since we know how to generate random text data, we just need to make sure this random generated data is real looking. We do this by comparing random data with the real data. The comparison takes place in the discriminator. The discriminator takes the random data and real data as input and produces some score denoting how far the random data is from real data. Now the question is - how do we design a discriminator that is suitable for text data. For this we will make another modification to GANs but before that, it is important to note that the output of the generator is the text data which is not really differentiable, causing problems with the learning. So we have two problems (1) design a discriminator (2) make the whole pipeline differentiable. And to your surprise, both of these problems can be solved by a simple modification i.e. by introducing a new function called summarization function. This summarization function operates on output of the hidden states of the RNN (so the output of the RNN is not really discrete, but continuous). Now one has flexibility is designing this summarization fuction in such a way that the output of this summarization can be used into a discriminator. The summarization function can be anything e.g. last hidden state of RNN. Once we have such a summarization function, a discriminator can be designed in the same way as for the image data. It need to maximize the probability of generating real data. The whole process is explained in the image below.
Generative Adversarial Networks (GANs) have been called one of the major break through in machine learning. In fact they have been called "the most interesting idea in machine learning in last 10 years". There are many good tutorials out there (links to which will be provided at the bottom of this post) which talk about general GANs, however most of them discuss GANs from computer vision perspective (for image data), and are not directly applicable for text data.
Now before we discuss GANs for text data, lets revise general GANs and how they are applicable to image data. The basic idea in GANs is pretty simple. GANs consists of mainly two components, one is generator and another is discriminator. The job of the generator is to generate some random but real looking data, while the job of the discriminator is to discriminate between real data and real-looking random data. The training happens in an adversarial combat mode. If the discriminator is able to tell the difference between real and real-looking data, it means that generator has not done its job well at generating real-looking data, and generator's weights are updated so that it can generate better real-looking data. And if the discriminator is not able to tell the difference between real data and real-looking data, it means that the discriminator has not trained well, and its weights are updated. Now the training continues until we get a good generator and good discriminator.
In their original paper, Goodfellow et al. propose to minimize the following objective function and we will discuss GANs first in the context of this objective function:
$$ \min_G \max_D \mathbb{E}_{x\sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z\sim p_{z}(z)}[\log 1-D(G(z))]$$
Here $G$ is the generative model, $D$ is the discriminative model, $x$ is real data, $z$ is the random noise (not real-looking data).
Although nasty looking, the above equation is pretty simple. It is basically trying to maximize the probability of discriminator's ability to classify the real data, and generator's ability to generate fake data. In other words, if you think of discriminator as classifier (for example multi-layer perceptron), then the objective function is nothing but cross entropy i.e. $\log p(y|x) + \log(1-p(y|z))$. Both the terms in the objective function denotes the probability of generating real data so it has to be maximized with respect to discriminator while minimize with respect to generator. Hopefully the pictures below clarifies things further.
Generative Adversarial Network for Image Data |
GANs should now be clear from the above picture. As you can see we have two kind of data i.e., real data and real-looking data, both of which are fed to discriminator, which maximizes the probability of generating real data.
As we all know GANs were originally designed for image data so they are naturally applicable to images, more so because of their real-valued nature. Each image pixel is nothing but some real value denoting the color. The application of GANs to text data is a bit tricky. This is primarily because of the following reasons.
- Generation of text in a sequential manner: In image data, one can think of each image as a real-valued single vector, and you can generate the entire vector at a time, however the story in the text is different. Here you need to generate one word at a time.
- Real vs discrete: Text data is discrete while image data is real. How do we generate discrete noise that would ultimately be transformed in some real looking text?
- Discriminator for text data: Even if we are able to generate random text, how do we design a discriminator which is able to differentiate between real text and fake text?
Despite these challenges, GANs have been used for text data, and in this post I will try to make it simpler. The original paper can called these networks as GAN-RNN.
Now assuming that we know GANs for image data, we will make modifications in them to make them applicable for text data. The idea in GANs is to learn a generator that is able to generate real looking data. In GAN-RNN which is GAN for text data, the very first modification we make is replace the generator with a recurrent neural network (RNN). We can also call it encoder as it is usually called in sequence-to-sequence model. The job of this RNN is to generate random text. Note that in GANs we first generate random noise and then transforms this random noise into some real-looking data through generator, here we perform both of these tasks by RNN. To begin with, the input to RNN's first cell is a start-of-sentence token, and from there on, subsequent cells generate rest of the words automatically. Note that when comparing GAN-RNN with plain GANs, we have replaced both random noise generator and generator in GANs with RNN. It performs the job of both. In the beginning when we have not really learned the weights of RNN, it will be as good as random noise i.e. meaningless text; however as the training continues, it will learn to generate meaningful text. With this simple replacement, we have solved two of the three problems mentioned above i.e. how to generate text data sequentially, and problem with discrete vs. real valued.
Generative Adversarial Network for Text Data: the generator has been replaced with an RNN encoder. We still do not know what the discriminatory looks like (see the following picture). |
Now since we know how to generate random text data, we just need to make sure this random generated data is real looking. We do this by comparing random data with the real data. The comparison takes place in the discriminator. The discriminator takes the random data and real data as input and produces some score denoting how far the random data is from real data. Now the question is - how do we design a discriminator that is suitable for text data. For this we will make another modification to GANs but before that, it is important to note that the output of the generator is the text data which is not really differentiable, causing problems with the learning. So we have two problems (1) design a discriminator (2) make the whole pipeline differentiable. And to your surprise, both of these problems can be solved by a simple modification i.e. by introducing a new function called summarization function. This summarization function operates on output of the hidden states of the RNN (so the output of the RNN is not really discrete, but continuous). Now one has flexibility is designing this summarization fuction in such a way that the output of this summarization can be used into a discriminator. The summarization function can be anything e.g. last hidden state of RNN. Once we have such a summarization function, a discriminator can be designed in the same way as for the image data. It need to maximize the probability of generating real data. The whole process is explained in the image below.