Commit 9cf69179 authored by Jacky Lin's avatar Jacky Lin
Browse files

add model saver

parent bf034f91
......@@ -677,6 +677,37 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Save/Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save Model\n",
"torch.save(netG, \"saved_models/generator.pth\")\n",
"torch.save(netD, \"saved_models/discriminator.pth\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load Model\n",
"netG = torch.load(\"saved_models/generator.pth\")\n",
"netD = torch.load(\"saved_models/discriminator.pth\")\n",
"netG.eval()\n",
"netD.eval()"
]
}
],
"metadata": {
This diff is collapsed.
......@@ -17,7 +17,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
......@@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -55,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
......@@ -78,13 +78,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e159505ae8b147fdab136d2c8e71539f",
"model_id": "66ab8e3d491f40b6a1eb73ad62c1ca3e",
"version_major": 2,
"version_minor": 0
},
......
......@@ -47,3 +47,4 @@ pip install pytorch numpy pandas sklearn torchvision matplotlib plotly PIL tqdm
4. [DCGAN](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
5. [Context-Encoder GAN for Image Inpainting](https://www.kaggle.com/balraj98/context-encoder-gan-for-image-inpainting-pytorch/output)
6. [Tiny ImageNet Kaggle](https://www.kaggle.com/akash2sharma/tiny-imagenet)
7. [GANs-Application](https://github.com/nashory/gans-awesome-applications)
%% Cell type:markdown id: tags:
##### ref https://realpython.com/generative-adversarial-networks/#your-first-gan
%% Cell type:markdown id: tags:
## Simple GANs Example
%% Cell type:markdown id: tags:
#### Introduction
This project focuses on the exploration of generative adversarial network (GAN). A [generative adversarial network (GAN)](https://en.wikipedia.org/wiki/Generative_adversarial_network) is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. Two neural networks contest with each other in a game in the form of a zero-sum game, where one agent's gain is another agent's loss. In this project, we explore the basic concepts and structures behind the generative adversarial networks. We also perform this neural network on different datasets to learn some interesting application of this model. All the code we provided is written in python3 in jupyter notebook.
%% Cell type:markdown id: tags:
This notebook give an example of GANs where we train neural networks to generating a sine function $f(x) = sin(x)$ given a series of random number
%% Cell type:markdown id: tags:
### Import the essentail pacakge
First we import the essentail package, including:
### Import the essential pacakge
First we import the essential package, including:
1. pytorch which is the mainly package to build neural structure
2. matplotlib is used to visualize the data
%% Cell type:code id: tags:
``` python
import torch
from torch import nn
import math
import matplotlib.pyplot as plt
```
%% Cell type:markdown id: tags:
Set a random seed to reproduce the result
%% Cell type:code id: tags:
``` python
torch.manual_seed(111)
```
%%%% Output: execute_result
<torch._C.Generator at 0x7f912002f210>
%% Cell type:markdown id: tags:
#### Create dataset
%% Cell type:markdown id: tags:
Create the training set. The training set a series of point on sine function $f(x) = sin(x)$ where $x \in [0, 2\pi]$.
%% Cell type:code id: tags:
``` python
train_data_length = 1024
train_data = torch.zeros((train_data_length, 2))
train_data[:, 0] = 2 * math.pi * torch.rand(train_data_length)
train_data[:, 1] = torch.sin(train_data[:, 0])
train_labels = torch.zeros(train_data_length)
train_set = [(train_data[i], train_labels[i]) for i in range(train_data_length)]
```
%% Cell type:markdown id: tags:
Visualize the plot
%% Cell type:code id: tags:
``` python
plt.plot(train_data[:, 0], train_data[:, 1], ".")
```
%%%% Output: execute_result
[<matplotlib.lines.Line2D at 0x7f9093ee76d8>]
%%%% Output: display_data
![]()
%% Cell type:markdown id: tags:
#### Build the model
%% Cell type:markdown id: tags:
The GANs consists of two neural networks, one called Discriminator and the other called Generator. The role of the generator is to estimate the probability distribution of the real samples in order to provide generated samples resembling real data. The discriminator, in turn, is trained to estimate the probability that a given sample came from the real data rather than being provided by the generator.
![Struct](img/GANsStruct.jpg)
![s](./img/GANsStruct.jpg)
%% Cell type:code id: tags:
``` python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, x):
output = self.model(x)
return output
```
%% Cell type:code id: tags:
``` python
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.ReLU(),
nn.Linear(32, 2),
)
def forward(self, x):
output = self.model(x)
return output
```
%% Cell type:markdown id: tags:
Instantiation
%% Cell type:code id: tags:
``` python
discriminator = Discriminator()
generator = Generator()
```
%% Cell type:markdown id: tags:
#### Set hyperparameters
%% Cell type:code id: tags:
``` python
lr = 0.001
num_epochs = 300
loss_function = nn.BCELoss()
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True
)
```
%% Cell type:markdown id: tags:
#### Set the optimization object
What do optimizer do in neural network?
![op](img/optim.png)
`Adam` combines the best properties of the AdaGrad and RMSProp algorithms to provide an optimization algorithm that can handle sparse gradients on noisy problems.
%% Cell type:code id: tags:
``` python
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
```
%% Cell type:markdown id: tags:
Training process:
1. Train the discriminator using real(1) data and fake(0) data. The discriminator will need to distinguish between the fake and real data points
2. Train the generator using random numbers as input and let discriminator to do the classify output from generator. The loss is the difference of the classification result and true labels, all of which is of value 1
%% Cell type:code id: tags:
``` python
def train(generator, discriminator, num_epochs, batch_size, train_loader):
for epoch in range(num_epochs):
for n, (real_samples, _) in enumerate(train_loader):
# Data for training the discriminator
real_samples_labels = torch.ones((batch_size, 1))
latent_space_samples = torch.randn((batch_size, 2))
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1))
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels)
)
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels)
loss_discriminator.backward()
optimizer_discriminator.step()
# Data for training the generator
latent_space_samples = torch.randn((batch_size, 2))
# Training the generator
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels
)
loss_generator.backward()
optimizer_generator.step()
# Show loss
if epoch % 10 == 0 and n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
```
%% Cell type:code id: tags:
``` python
train(generator, discriminator, num_epochs, batch_size, train_loader)
```
%%%% Output: stream
Epoch: 0 Loss D.: 0.19172890484333038
Epoch: 0 Loss G.: 2.205277919769287
Epoch: 10 Loss D.: 0.6455698013305664
Epoch: 10 Loss G.: 0.8941397666931152
Epoch: 20 Loss D.: 0.643601655960083
Epoch: 20 Loss G.: 0.8925288915634155
Epoch: 30 Loss D.: 0.6730794906616211
Epoch: 30 Loss G.: 0.9668323993682861
Epoch: 40 Loss D.: 0.6405140161514282
Epoch: 40 Loss G.: 0.7298990488052368
Epoch: 50 Loss D.: 0.6809470057487488
Epoch: 50 Loss G.: 0.7750263214111328
Epoch: 60 Loss D.: 0.6326340436935425
Epoch: 60 Loss G.: 0.8059306144714355
Epoch: 70 Loss D.: 0.5179269313812256
Epoch: 70 Loss G.: 1.0993049144744873
Epoch: 80 Loss D.: 0.5927025675773621
Epoch: 80 Loss G.: 0.8768778443336487
Epoch: 90 Loss D.: 0.6255234479904175
Epoch: 90 Loss G.: 0.7707381248474121
Epoch: 100 Loss D.: 0.5345417857170105
Epoch: 100 Loss G.: 0.8872535228729248
Epoch: 110 Loss D.: 0.6842091083526611
Epoch: 110 Loss G.: 0.6737083196640015
Epoch: 120 Loss D.: 0.618989109992981
Epoch: 120 Loss G.: 1.0244462490081787
Epoch: 130 Loss D.: 0.651526689529419
Epoch: 130 Loss G.: 0.7835967540740967
Epoch: 140 Loss D.: 0.6564826965332031
Epoch: 140 Loss G.: 0.736600935459137
Epoch: 150 Loss D.: 0.7063291072845459
Epoch: 150 Loss G.: 0.8015179634094238
Epoch: 160 Loss D.: 0.7304964661598206
Epoch: 160 Loss G.: 0.6758403182029724
Epoch: 170 Loss D.: 0.6666960716247559
Epoch: 170 Loss G.: 0.7535879611968994
Epoch: 180 Loss D.: 0.6781765818595886
Epoch: 180 Loss G.: 0.7857789993286133
Epoch: 190 Loss D.: 0.6380088925361633
Epoch: 190 Loss G.: 0.7979979515075684
Epoch: 200 Loss D.: 0.6476839780807495
Epoch: 200 Loss G.: 0.7968341112136841
Epoch: 210 Loss D.: 0.6286032795906067
Epoch: 210 Loss G.: 1.0192309617996216
Epoch: 220 Loss D.: 0.6657217741012573
Epoch: 220 Loss G.: 0.7269636392593384
Epoch: 230 Loss D.: 0.7014310359954834
Epoch: 230 Loss G.: 0.7147063612937927
Epoch: 240 Loss D.: 0.7104332447052002
Epoch: 240 Loss G.: 0.7164992690086365
Epoch: 250 Loss D.: 0.7514128088951111
Epoch: 250 Loss G.: 1.0906519889831543
Epoch: 260 Loss D.: 0.5960716009140015
Epoch: 260 Loss G.: 1.1555901765823364
Epoch: 270 Loss D.: 0.6859428882598877
Epoch: 270 Loss G.: 0.682672917842865
Epoch: 280 Loss D.: 0.6769641041755676
Epoch: 280 Loss G.: 0.7675719857215881
Epoch: 290 Loss D.: 0.6696629524230957
Epoch: 290 Loss G.: 0.764472246170044
%% Cell type:markdown id: tags:
#### Prediction
%% Cell type:markdown id: tags:
Generate a series of random point and let the generator performs on these data poins and show the generated results
%% Cell type:code id: tags:
``` python
latent_space_samples = torch.randn(500, 2)
plt.plot(latent_space_samples[:, 0], latent_space_samples[:, 1], ".")
generated_samples = generator(latent_space_samples)
```
%%%% Output: display_data
![]()
%% Cell type:code id: tags:
``` python
generated_samples = generated_samples.detach()
plt.plot(generated_samples[:, 0], generated_samples[:, 1], ".")
```
%%%% Output: execute_result
[<matplotlib.lines.Line2D at 0x7f9069e33b70>]
%%%% Output: display_data
![]()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment