Commit 17d0a7a1 authored by Jacky Lin's avatar Jacky Lin
Browse files

Delete SimpleGANs-checkpoint.ipynb

parent 80e97704
%% Cell type:markdown id: tags:
##### ref https://realpython.com/generative-adversarial-networks/#your-first-gan
%% Cell type:markdown id: tags:
## Simple GANs Example
This notebook give an example of GANs where we train neural networks to generating a sine function $f(x) = sin(x)$ given a series of random number
%% Cell type:markdown id: tags:
#### Import the essentail pacakge
First we import the essentail package, including:
1. pytorch which is the mainly package to build neural structure
2. matplotlib is used to visualize the data
%% Cell type:code id: tags:
``` python
import torch
from torch import nn
import math
import matplotlib.pyplot as plt
```
%% Cell type:markdown id: tags:
Set a random seed to reproduce the result
%% Cell type:code id: tags:
``` python
torch.manual_seed(111)
```
%%%% Output: execute_result
<torch._C.Generator at 0x7f912002f210>
%% Cell type:markdown id: tags:
#### Create dataset
%% Cell type:markdown id: tags:
Create the training set. The training set a series of point on sine function $f(x) = sin(x)$ where $x \in [0, 2\pi]$.
%% Cell type:code id: tags:
``` python
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [(train_data[i], train_labels[i]) for i in range(train_data_length)]
```
%% Cell type:markdown id: tags:
Visualize the plot
%% Cell type:code id: tags:
``` python
plt.plot(train_data[:, 0], train_data[:, 1], ".")
```
%%%% Output: execute_result
[<matplotlib.lines.Line2D at 0x7f9093ee76d8>]
%%%% Output: display_data
![]()
%% Cell type:markdown id: tags:
#### Build the model
%% Cell type:markdown id: tags:
The GANs consists of two neural networks, one called Discriminator and the other called Generator. The role of the generator is to estimate the probability distribution of the real samples in order to provide generated samples resembling real data. The discriminator, in turn, is trained to estimate the probability that a given sample came from the real data rather than being provided by the generator.
%% Cell type:code id: tags:
``` python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x):
output = self.model(x)
return output
```
%% Cell type:code id: tags:
``` python
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.ReLU(),
nn.Linear(32, 2),
)
def forward(self, x):
output = self.model(x)
return output
```
%% Cell type:markdown id: tags:
Instantiation
%% Cell type:code id: tags:
``` python
discriminator = Discriminator()
generator = Generator()
```
%% Cell type:markdown id: tags:
#### Set hyperparameters
%% Cell type:code id: tags:
``` python
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True
)
```
%% Cell type:markdown id: tags:
Set the optimization object
%% Cell type:code id: tags:
``` python
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
```
%% Cell type:markdown id: tags:
Training process:
1. Train the discriminator using real(1) data and fake(0) data. The discriminator will need to distinguish between the fake and real data points
2. Train the generator using random numbers as input and let discriminator to do the classify output from generator. The loss is the difference of the classification result and true labels, all of which is of value 1
%% Cell type:code id: tags:
``` python
def train(generator, discriminator, num_epochs, batch_size, train_loader):
for epoch in range(num_epochs):
for n, (real_samples, _) in enumerate(train_loader):
# Data for training the discriminator
real_samples_labels = torch.ones((batch_size, 1))
latent_space_samples = torch.randn((batch_size, 2))
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1))
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels)
)
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels)
loss_discriminator.backward()
optimizer_discriminator.step()
# Data for training the generator
latent_space_samples = torch.randn((batch_size, 2))
# Training the generator
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels
)
loss_generator.backward()
optimizer_generator.step()
# Show loss
if epoch % 10 == 0 and n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
```
%% Cell type:code id: tags:
``` python
train(generator, discriminator, num_epochs, batch_size, train_loader)
```
%% Cell type:markdown id: tags:
#### Prediction
%% Cell type:markdown id: tags:
Generate a series of random point and let the generator performs on these data poins and show the generated results
%% Cell type:code id: tags:
``` python
latent_space_samples = torch.randn(500, 2)
plt.plot(latent_space_samples[:, 0], latent_space_samples[:, 1], ".")
generated_samples = generator(latent_space_samples)
```
%%%% Output: display_data
![]()
%% Cell type:code id: tags:
``` python
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
```
%%%% Output: execute_result
[<matplotlib.lines.Line2D at 0x7f9069e33b70>]
%%%% Output: display_data
![]()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment