Commit 8e027ff9 authored by Jacky Lin's avatar Jacky Lin
Browse files

a

parent 6d5cdf72
......@@ -49,11 +49,11 @@
 
%% Cell type:code id: tags:
 
``` python
# Root directory for dataset
dataroot = 'anime-faces/'
dataroot = 'data/anime-faces/'
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
This diff is collapsed.
%% Cell type:markdown id: tags:
# MNIST with Simple GANs
In this notebook, we use the GANs network train with MNIST dataset. So that the neural network could generate the hand writting digits.
%% Cell type:markdown id: tags:
#### Import packagees
%% Cell type:code id: tags:
``` python
%matplotlib widget
import torch
from torch import nn
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
torch.manual_seed(111)
```
%%%% Output: execute_result
<torch._C.Generator at 0x7f18bd351630>
%% Cell type:markdown id: tags:
#### Set the GPU
%% Cell type:code id: tags:
``` python
device = ""
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
```
%% Cell type:markdown id: tags:
#### Dataset
We load the data from `torchvision.datasets.MNIST`, and transform it to tensor with normalization. Then we save it to the loader
%% Cell type:code id: tags:
``` python
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_set = torchvision.datasets.MNIST(
root="./data/", train=True, transform=transform
)
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True
)
```
%% Cell type:code id: tags:
``` python
# Some example of MNIST
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
ax = plt.subplot(4, 4, i + 1)
plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")
plt.xticks([])
plt.yticks([])
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
#### Build the NN
%% Cell type:code id: tags:
``` python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(x.size(0), 784)
output = self.model(x)
return output
```
%% Cell type:code id: tags:
``` python
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh(),
)
def forward(self, x):
output = self.model(x)
output = output.view(x.size(0), 1, 28, 28)
return output
```
%% Cell type:code id: tags:
``` python
discriminator = Discriminator().to(device=device)
generator = Generator().to(device=device)
```
%% Cell type:markdown id: tags:
#### Set hyperparameter and Optimizer
%% Cell type:code id: tags:
``` python
lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
animation_samples = torch.randn(batch_size, 100).to(device=device)
animation_results = np.array([])
```
%% Cell type:markdown id: tags:
#### Training the dataset
%% Cell type:code id: tags:
``` python
for epoch in range(num_epochs):
animation_results = np.append(animation_results, generator(animation_samples).cpu().detach())
for n, (real_samples, mnist_labels) in enumerate(train_loader):
# Data for training the discriminator
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(
device=device
)
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device
)
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1)).to(
device=device
)
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels)
)
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels
)
loss_discriminator.backward()
optimizer_discriminator.step()
# Data for training the generator
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device
)
# Training the generator
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels
)
loss_generator.backward()
optimizer_generator.step()
# Show loss
if n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
```
%% Cell type:markdown id: tags:
#### Prediction
%% Cell type:markdown id: tags:
We generate a series of random points and let generator handle these points to transform them into hand written digits
%% Cell type:code id: tags:
``` python
a = animation_results.reshape(num_epochs,batch_size,28,28)
```
%% Cell type:code id: tags:
``` python
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fps = 30
nSeconds = 5
# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure( figsize=(8,4) )
im = []
for i in range(a.shape[1]):
ax = plt.subplot(4, 8, i + 1)
imi = plt.imshow(a[0][i], interpolation='none', aspect='auto', vmin=0, vmax=1, cmap='gray')
im.append(imi)
plt.xticks([])
plt.yticks([])
fig.suptitle('After 0 epoch')
def animate_func(i):
if i % fps == 0:
print( '.', end ='' )
fig.suptitle('After '+str(i)+' epoch')
for j in range(a.shape[1]):
ax = plt.subplot(4, 8, i + 1)
im[j].set_array(a[i][j])
im.append(imi)
plt.xticks([])
plt.yticks([])
return [im]
anim = animation.FuncAnimation(fig, animate_func,
frames = nSeconds * fps,
interval = 10000 / fps)
```
%%%% Output: display_data
%% Cell type:code id: tags:
``` python
import pickle
pickle.dump(a, open( "animation.dat", "wb" ) )
```
%% Cell type:code id: tags:
``` python
import pickle
a = pickle.load( open( "animation.dat", "rb" ) )
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import time
k = np.random.uniform(size = 10)
fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = plt.plot([], [], 'ro')
def init():
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
return ln,
def update(i):
xdata.append(k[i])
ydata.append(k[i])
ln.set_data(k[i],k[i])
time.sleep(1)
return ln,
ani = FuncAnimation(fig, update, frames = len(k) + 1,
init_func=init, blit=True)
plt.show()
```
%%%% Output: display_data
%% Cell type:code id: tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment