[GAN] DCGAN 실습및 코드리뷰
in DeepLearning on Code
귀여운 고양이를 생성한다.
의의
stargan을 보고 stylegan을 구현하려고 시도하는 중 실제로 구현하는 레벨이 부족하다고 여겨 기초가 되는 dcgan을 구현해 보기로 결심함 코드는 거의 따라 보고 쓴거지만 내가 직접 generator모델과 discriminator모델을 만들어보고 이미지데이터를 바꿔서 학습시켜본 것 또한 그에따른 이미지 전처리 과정에 대한 이해를 도움
실습 코드
import os.path
import os
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from PIL import Image
import keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
K.set_image_data_format('channels_last')
import pathlib
data_dir = "/content/drive/My Drive/GAN project/repo/cats"
data_dir = pathlib.Path(data_dir)
img_width = 32
img_height = 32
batch_size =100
image_count = len(list(data_dir.glob('*/*.jpg')))
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*.jpg'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=3)
# resize the image to the desired size
return tf.image.resize(img, [img_width, img_height])
def process_path(file_path):
img = tf.io.read_file(file_path)
img = decode_img(img)
return img
list_data = list_ds.map(process_path)
google colab에서 datadir에있는 고양이 그림들을 불러온다 이떄 불러온 방식은 tf.data.Dataset모듈을 이용하여 이미지데이터를 다루었다 이때 data_dir = pathlib.Path는 data_dir경로를 객체화시켜 glob모듈을 사용할수 있게 해주었다 불러온 데이터를 shuffle시키고 decode_img함수를 통해 이용할수 있게 만들었다.
위에 만든 함수는 map모듈로 전부 적용시켰다
batch_data = list_data.batch(100)
it = iter(batch_data)
데이터를 100단위의 batch_data로 만들어 주고 iter를 통해 반복적으로 사용할 수 있게 하였다.
class Gan:
def __init__ (self , img_data):
img_size = 32
channel = 3
self.img_data = img_data
self.input_shape = (img_size, img_size , channel)
self.img_rows = img_size
self.img_cols = img_size
self.channel = channel
self.noise_size = 100
#creat D ang G
self.create_d()
self.create_g()
#Build model to train D.
optimizer = Adam(lr = 0.0008) #이떄 lr은 학습률이다
self.D.compile(loss = 'binary_crossentropy' , optimizer = optimizer)
#Build model to train G.
optimizer = Adam(lr = 0.0004)
self.D.trainable = False
self.AM = Sequential()
self.AM.add(self.G)
self.AM.add(self.D)
self.AM.compile(loss='binary_crossentropy' , optimizer = optimizer)
def create_d(self):
self.D = Sequential()
depth = 64
dropout = 0.4
self.D.add(Conv2D(depth*1 , 5, strides = 2, input_shape = self.input_shape , padding = 'same'))
self.D.add(LeakyReLU(alpha = 0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*2,5,strides =2,padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
self.D.add(LeakyReLU(alpha=0.2))
self.D.add(Dropout(dropout))
self.D.add(Conv2D(depth*8 , 5 , strides = 1 , padding = 'same'))
self.D.add(LeakyReLU(alpha = 0.2))
self.D.add(Dropout(dropout))
self.D.add(Flatten())
self.D.add(Dense(1))
self.D.add(Activation('sigmoid'))
self.D.summary()
return self.D
def create_g(self):
self.G = Sequential()
depth = 256
dropout = 0.4
dim = 8
self.G.add(Dense(dim*dim*depth , input_dim = self.noise_size))
self.G.add(BatchNormalization(momentum = 0.9))
self.G.add(Activation('relu'))
self.G.add(Reshape((dim, dim, depth)))#8 8 256
self.G.add(Dropout(dropout))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))#16 16 128
self.G.add(Activation('relu'))
self.G.add(UpSampling2D())
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))#32 32 64
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
self.G.add(BatchNormalization(momentum=0.9))#64 64 32
self.G.add(Activation('relu'))
self.G.add(Conv2DTranspose(3, 5, padding='same')) # 64 64 3
self.G.add(Activation('sigmoid'))
self.G.summary()
return self.G #하나의 함수라고 생각
def train(self, batch_size = 100):
images_train = self.img_data
images_train = images_train.batch(batch_size)
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
images_fake = self.G.predict(noise)
train_one_image = next(it)
x = concatenate((train_one_image, images_fake) , axis = 0)
#train D
y = np.ones([2*batch_size , 1])
y[batch_size: , :] = 0
self.D.trainable = True
d_loss = self.D.train_on_batch(x , y)
#train G
y = np.ones([batch_size , 1])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
self.D.trainable = False
a_loss = self.AM.train_on_batch(noise , y)
return d_loss , a_loss , images_fake
def save(self):
self.G.save_weights('gan_g_weight.h5')
self.D.save_weights('gan_d_weight.h5')
def load(self):
if os.path.isfile('gan_g_weights.h5'):
self.G.load_weights('gan_g_weights.h5')
print("Load G from file.")
if os.path.isfile("gan_d_weights.h5"):
self.D.load_weights('gan_d_weights.h5')
print("load D from file")
메인이라고 할 수 있는 gan의 몸체이다 keras를 통해 DCGAN 모델을 구성하였고 train 함수를 반복시킴으로 GAN모델이 학습한다 .
gan = Gan(list_data)
print("save")
gan.load()
epochs = 100
sample_size = 10
batch_size = 100
train_per_epoch = len(list_data) // batch_size
for epoch in range(0, epochs):
print("Epochs :" ,epoch )
total_d_loss = 0.0
total_a_loss = 0.0
imgs = None
for batch in range(0, train_per_epoch):
d_loss , a_loss, t_imgs = gan.train(batch_size)
print("현재 loss:", total_d_loss , total_a_loss)
total_d_loss += d_loss
total_a_loss += a_loss
if imgs is None:
imgs = t_imgs
if epochs % 5 == 0:
total_d_loss /= train_per_epoch
total_a_loss /= train_per_epoch
print("Epoch: {}, D Loss: {}, AM Loss: {}" .format(epoch, total_d_loss, total_a_loss))
#show generated images
fig, ax = plt.subplots(1 , sample_size , figsize = (sample_size , 1))
for i in range(0 , sample_size):
ax[i].set_axis_off()
ax[i].imshow(imgs[i].reshape((gan.img_rows , gan,img_clos , gan.channel)),
interpolation = 'nearest');
plt.show()
plt.close(figs);
gan.save
실제 돌아가는 코드이다
실행결과
Model: "sequential_33"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_48 (Conv2D) (None, 16, 16, 64) 4864
_________________________________________________________________
leaky_re_lu_48 (LeakyReLU) (None, 16, 16, 64) 0
_________________________________________________________________
dropout_57 (Dropout) (None, 16, 16, 64) 0
_________________________________________________________________
conv2d_49 (Conv2D) (None, 8, 8, 128) 204928
_________________________________________________________________
leaky_re_lu_49 (LeakyReLU) (None, 8, 8, 128) 0
_________________________________________________________________
dropout_58 (Dropout) (None, 8, 8, 128) 0
_________________________________________________________________
conv2d_50 (Conv2D) (None, 4, 4, 256) 819456
_________________________________________________________________
leaky_re_lu_50 (LeakyReLU) (None, 4, 4, 256) 0
_________________________________________________________________
dropout_59 (Dropout) (None, 4, 4, 256) 0
_________________________________________________________________
conv2d_51 (Conv2D) (None, 4, 4, 512) 3277312
_________________________________________________________________
leaky_re_lu_51 (LeakyReLU) (None, 4, 4, 512) 0
_________________________________________________________________
dropout_60 (Dropout) (None, 4, 4, 512) 0
_________________________________________________________________
flatten_12 (Flatten) (None, 8192) 0
_________________________________________________________________
dense_24 (Dense) (None, 1) 8193
_________________________________________________________________
activation_65 (Activation) (None, 1) 0
=================================================================
Total params: 4,314,753
Trainable params: 4,314,753
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_34"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_25 (Dense) (None, 16384) 1654784
_________________________________________________________________
batch_normalization_44 (Batc (None, 16384) 65536
_________________________________________________________________
activation_66 (Activation) (None, 16384) 0
_________________________________________________________________
reshape_12 (Reshape) (None, 8, 8, 256) 0
_________________________________________________________________
dropout_61 (Dropout) (None, 8, 8, 256) 0
_________________________________________________________________
up_sampling2d_25 (UpSampling (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_transpose_41 (Conv2DT (None, 16, 16, 128) 819328
_________________________________________________________________
batch_normalization_45 (Batc (None, 16, 16, 128) 512
_________________________________________________________________
activation_67 (Activation) (None, 16, 16, 128) 0
_________________________________________________________________
up_sampling2d_26 (UpSampling (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_transpose_42 (Conv2DT (None, 32, 32, 64) 204864
_________________________________________________________________
batch_normalization_46 (Batc (None, 32, 32, 64) 256
_________________________________________________________________
activation_68 (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv2d_transpose_43 (Conv2DT (None, 32, 32, 32) 51232
_________________________________________________________________
batch_normalization_47 (Batc (None, 32, 32, 32) 128
_________________________________________________________________
activation_69 (Activation) (None, 32, 32, 32) 0
_________________________________________________________________
conv2d_transpose_44 (Conv2DT (None, 32, 32, 3) 2403
_________________________________________________________________
activation_70 (Activation) (None, 32, 32, 3) 0
=================================================================
Total params: 2,799,043
Trainable params: 2,765,827
Non-trainable params: 33,216
_________________________________________________________________
save
Epochs : 0
현재 loss: 0.0 0.0
현재 loss: 2.4333198070526123 0.08901489526033401
현재 loss: 3.8582282066345215 0.11727375723421574
현재 loss: 5.525717496871948 0.2411381397396326
현재 loss: 6.388717770576477 1.3069357071071863
현재 loss: 52.22373592853546 1.309405385516584
현재 loss: 53.9788624048233 1.3094076057691382
현재 loss: 57.95309817790985 1.3094076181380645
현재 loss: 64.15115678310394 1.3094076193748463
현재 loss: 72.402636885643 1.3094076203083223
현재 loss: 82.22621476650238 1.309407631680446
현재 loss: 92.13610780239105 1.3094097675175673
현재 loss: 100.37168347835541 1.3102262324655696
현재 loss: 105.34046542644501 1.400463745650117
현재 loss: 106.92031037807465 2.149145708140199
현재 loss: 122.47436845302582 2.1581681004727526
현재 loss: 125.78560173511505 2.158222034880086
현재 loss: 133.05226385593414 2.158222961635974
현재 loss: 143.85323584079742 2.158223001018963
결과가 좋지 않다. 어째서인지 loss손실함수가 하늘끝까지 승천하려고 한다……… 정확한 이유는 알 수 없고 어떻게 해야할지도 모르겠다 다음에는 더 안전한 예시로 다시 공부해보자