Commit 8e027ff9 authored by Jacky Lin's avatar Jacky Lin
Browse files

a

parent 6d5cdf72
......@@ -75,7 +75,7 @@
"outputs": [],
"source": [
"# Root directory for dataset\n",
"dataroot = 'anime-faces/'\n",
"dataroot = 'data/anime-faces/'\n",
"# Number of workers for dataloader\n",
"workers = 2\n",
"# Batch size during training\n",
This diff is collapsed.
%% Cell type:markdown id: tags:
# MNIST with Simple GANs
In this notebook, we use the GANs network train with MNIST dataset. So that the neural network could generate the hand writting digits.
%% Cell type:markdown id: tags:
#### Import packagees
%% Cell type:code id: tags:
``` python
%matplotlib widget
import torch
from torch import nn
import numpy as np
import math
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
torch.manual_seed(111)
```
%%%% Output: execute_result
<torch._C.Generator at 0x7f18bd351630>
%% Cell type:markdown id: tags:
#### Set the GPU
%% Cell type:code id: tags:
``` python
device = ""
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(device)
```
%%%% Output: stream
cuda
%% Cell type:markdown id: tags:
#### Dataset
We load the data from `torchvision.datasets.MNIST`, and transform it to tensor with normalization. Then we save it to the loader
%% Cell type:code id: tags:
``` python
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_set = torchvision.datasets.MNIST(
root="./data/", train=True, transform=transform
)
batch_size = 32
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, shuffle=True
)
```
%% Cell type:code id: tags:
``` python
# Some example of MNIST
real_samples, mnist_labels = next(iter(train_loader))
for i in range(16):
ax = plt.subplot(4, 4, i + 1)
plt.imshow(real_samples[i].reshape(28, 28), cmap="gray_r")
plt.xticks([])
plt.yticks([])
```
%%%% Output: display_data
%% Cell type:markdown id: tags:
#### Build the NN
%% Cell type:code id: tags:
``` python
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(x.size(0), 784)
output = self.model(x)
return output
```
%% Cell type:code id: tags:
``` python
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh(),
)
def forward(self, x):
output = self.model(x)
output = output.view(x.size(0), 1, 28, 28)
return output
```
%% Cell type:code id: tags:
``` python
discriminator = Discriminator().to(device=device)
generator = Generator().to(device=device)
```
%% Cell type:markdown id: tags:
#### Set hyperparameter and Optimizer
%% Cell type:code id: tags:
``` python
lr = 0.0001
num_epochs = 50
loss_function = nn.BCELoss()
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
animation_samples = torch.randn(batch_size, 100).to(device=device)
animation_results = np.array([])
```
%% Cell type:markdown id: tags:
#### Training the dataset
%% Cell type:code id: tags:
``` python
for epoch in range(num_epochs):
animation_results = np.append(animation_results, generator(animation_samples).cpu().detach())
for n, (real_samples, mnist_labels) in enumerate(train_loader):
# Data for training the discriminator
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(
device=device
)
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device
)
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1)).to(
device=device
)
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels)
)
# Training the discriminator
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels
)
loss_discriminator.backward()
optimizer_discriminator.step()
# Data for training the generator
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device
)
# Training the generator
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels
)
loss_generator.backward()
optimizer_generator.step()
# Show loss
if n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
```
%%%% Output: stream
Epoch: 0 Loss D.: 0.5106401443481445
Epoch: 0 Loss G.: 0.6140938401222229
Epoch: 1 Loss D.: 0.016402263194322586
Epoch: 1 Loss G.: 5.503405570983887
Epoch: 2 Loss D.: 0.00035893276799470186
Epoch: 2 Loss G.: 8.898468971252441
Epoch: 3 Loss D.: 0.004426785744726658
Epoch: 3 Loss G.: 5.379899978637695
Epoch: 4 Loss D.: 0.015691719949245453
Epoch: 4 Loss G.: 4.864607334136963
Epoch: 5 Loss D.: 0.08068899810314178
Epoch: 5 Loss G.: 3.1070806980133057
Epoch: 6 Loss D.: 0.2025149017572403
Epoch: 6 Loss G.: 2.2945780754089355
Epoch: 7 Loss D.: 0.39154988527297974
Epoch: 7 Loss G.: 2.11862850189209
Epoch: 8 Loss D.: 0.3579307198524475
Epoch: 8 Loss G.: 2.383777379989624
Epoch: 9 Loss D.: 0.2735465168952942
Epoch: 9 Loss G.: 2.620600461959839
Epoch: 10 Loss D.: 0.24209283292293549
Epoch: 10 Loss G.: 1.9028700590133667
Epoch: 11 Loss D.: 0.44531941413879395
Epoch: 11 Loss G.: 1.6573467254638672
Epoch: 12 Loss D.: 0.36328646540641785
Epoch: 12 Loss G.: 1.7773442268371582
Epoch: 13 Loss D.: 0.46487748622894287
Epoch: 13 Loss G.: 1.3897985219955444
Epoch: 14 Loss D.: 0.40781110525131226
Epoch: 14 Loss G.: 1.9483462572097778
Epoch: 15 Loss D.: 0.4692971110343933
Epoch: 15 Loss G.: 1.429753303527832
Epoch: 16 Loss D.: 0.679309606552124
Epoch: 16 Loss G.: 1.1184732913970947
Epoch: 17 Loss D.: 0.46219807863235474
Epoch: 17 Loss G.: 1.1497530937194824
Epoch: 18 Loss D.: 0.5302788019180298
Epoch: 18 Loss G.: 1.298142433166504
Epoch: 19 Loss D.: 0.4652501344680786
Epoch: 19 Loss G.: 1.151010274887085
Epoch: 20 Loss D.: 0.4639967083930969
Epoch: 20 Loss G.: 1.4000627994537354
Epoch: 21 Loss D.: 0.49388980865478516
Epoch: 21 Loss G.: 1.1866282224655151
Epoch: 22 Loss D.: 0.5613424181938171
Epoch: 22 Loss G.: 1.0579705238342285
Epoch: 23 Loss D.: 0.5281556248664856
Epoch: 23 Loss G.: 1.0413790941238403
Epoch: 24 Loss D.: 0.5011868476867676
Epoch: 24 Loss G.: 1.2157130241394043
Epoch: 25 Loss D.: 0.5161645412445068
Epoch: 25 Loss G.: 1.0572420358657837
Epoch: 26 Loss D.: 0.4746711552143097
Epoch: 26 Loss G.: 0.9783821105957031
Epoch: 27 Loss D.: 0.5207068920135498
Epoch: 27 Loss G.: 1.0903351306915283
Epoch: 28 Loss D.: 0.6585843563079834
Epoch: 28 Loss G.: 1.0490336418151855
Epoch: 29 Loss D.: 0.5125542879104614
Epoch: 29 Loss G.: 1.0740063190460205
Epoch: 30 Loss D.: 0.6027552485466003
Epoch: 30 Loss G.: 0.9736583828926086
Epoch: 31 Loss D.: 0.5678924322128296
Epoch: 31 Loss G.: 1.10847806930542
Epoch: 32 Loss D.: 0.5240979194641113
Epoch: 32 Loss G.: 1.0584428310394287
Epoch: 33 Loss D.: 0.5240440368652344
Epoch: 33 Loss G.: 1.0047487020492554
Epoch: 34 Loss D.: 0.534977912902832
Epoch: 34 Loss G.: 0.9598945379257202
Epoch: 35 Loss D.: 0.4817778468132019
Epoch: 35 Loss G.: 1.1781553030014038
Epoch: 36 Loss D.: 0.5441155433654785
Epoch: 36 Loss G.: 1.1325626373291016
Epoch: 37 Loss D.: 0.5525763034820557
Epoch: 37 Loss G.: 0.9974243640899658
Epoch: 38 Loss D.: 0.5839608907699585
Epoch: 38 Loss G.: 1.0923521518707275
Epoch: 39 Loss D.: 0.5721290111541748
Epoch: 39 Loss G.: 1.0640000104904175
Epoch: 40 Loss D.: 0.5252882838249207
Epoch: 40 Loss G.: 1.144314169883728
Epoch: 41 Loss D.: 0.606524646282196
Epoch: 41 Loss G.: 1.018287181854248
Epoch: 42 Loss D.: 0.5141879320144653
Epoch: 42 Loss G.: 0.9381682872772217
Epoch: 43 Loss D.: 0.5589864253997803
Epoch: 43 Loss G.: 0.9547836184501648
Epoch: 44 Loss D.: 0.5677428245544434
Epoch: 44 Loss G.: 1.0904819965362549
Epoch: 45 Loss D.: 0.6925324201583862
Epoch: 45 Loss G.: 1.1372103691101074
Epoch: 46 Loss D.: 0.5809037685394287
Epoch: 46 Loss G.: 1.0041134357452393
Epoch: 47 Loss D.: 0.5931258201599121
Epoch: 47 Loss G.: 1.0651057958602905
Epoch: 48 Loss D.: 0.5238063335418701
Epoch: 48 Loss G.: 1.1413196325302124
Epoch: 49 Loss D.: 0.554250955581665
Epoch: 49 Loss G.: 1.0168068408966064
%% Cell type:markdown id: tags:
#### Prediction
%% Cell type:markdown id: tags:
We generate a series of random points and let generator handle these points to transform them into hand written digits
%% Cell type:code id: tags:
``` python
a = animation_results.reshape(num_epochs,batch_size,28,28)
```
%% Cell type:code id: tags:
``` python
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fps = 30
nSeconds = 5
# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure( figsize=(8,4) )
im = []
for i in range(a.shape[1]):
ax = plt.subplot(4, 8, i + 1)
imi = plt.imshow(a[0][i], interpolation='none', aspect='auto', vmin=0, vmax=1, cmap='gray')
im.append(imi)
plt.xticks([])
plt.yticks([])
fig.suptitle('After 0 epoch')
def animate_func(i):
if i % fps == 0:
print( '.', end ='' )
fig.suptitle('After '+str(i)+' epoch')
for j in range(a.shape[1]):
ax = plt.subplot(4, 8, i + 1)
im[j].set_array(a[i][j])
im.append(imi)
plt.xticks([])
plt.yticks([])
return [im]
anim = animation.FuncAnimation(fig, animate_func,
frames = nSeconds * fps,
interval = 10000 / fps)
```
%%%% Output: display_data
%% Cell type:code id: tags:
``` python
import pickle
pickle.dump(a, open( "animation.dat", "wb" ) )
```
%% Cell type:code id: tags:
``` python
import pickle
a = pickle.load( open( "animation.dat", "rb" ) )
```
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import time
k = np.random.uniform(size = 10)
fig, ax = plt.subplots()
xdata, ydata = [], []
ln, = plt.plot([], [], 'ro')
def init():
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
return ln,
def update(i):
xdata.append(k[i])
ydata.append(k[i])
ln.set_data(k[i],k[i])
time.sleep(1)
return ln,
ani = FuncAnimation(fig, update, frames = len(k) + 1,
init_func=init, blit=True)
plt.show()
```
%%%% Output: display_data
%% Cell type:code id: tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment