[GAN] DCGAN 실습및 코드리뷰


귀여운 고양이를 생성한다.

의의

stargan을 보고 stylegan을 구현하려고 시도하는 중 실제로 구현하는 레벨이 부족하다고 여겨 기초가 되는 dcgan을 구현해 보기로 결심함 코드는 거의 따라 보고 쓴거지만 내가 직접 generator모델과 discriminator모델을 만들어보고 이미지데이터를 바꿔서 학습시켜본 것 또한 그에따른 이미지 전처리 과정에 대한 이해를 도움

실습 코드

import os.path
import os
import numpy as np
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from PIL import Image
import keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
K.set_image_data_format('channels_last')

import pathlib
data_dir = "/content/drive/My Drive/GAN project/repo/cats"
data_dir = pathlib.Path(data_dir)

img_width = 32
img_height = 32
batch_size =100

image_count = len(list(data_dir.glob('*/*.jpg')))
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*.jpg'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # resize the image to the desired size
  return tf.image.resize(img, [img_width, img_height])

def process_path(file_path):
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img

list_data = list_ds.map(process_path)

google colab에서 datadir에있는 고양이 그림들을 불러온다 이떄 불러온 방식은 tf.data.Dataset모듈을 이용하여 이미지데이터를 다루었다 이때 data_dir = pathlib.Path는 data_dir경로를 객체화시켜 glob모듈을 사용할수 있게 해주었다 불러온 데이터를 shuffle시키고 decode_img함수를 통해 이용할수 있게 만들었다.

위에 만든 함수는 map모듈로 전부 적용시켰다

batch_data = list_data.batch(100)
it = iter(batch_data)

데이터를 100단위의 batch_data로 만들어 주고 iter를 통해 반복적으로 사용할 수 있게 하였다.

class Gan:
  def __init__ (self , img_data):
    img_size = 32
    channel = 3
  
    self.img_data = img_data
    self.input_shape = (img_size, img_size , channel)
    self.img_rows = img_size
    self.img_cols = img_size
    self.channel = channel
    self.noise_size = 100

    #creat D ang G
    self.create_d()
    self.create_g()

    #Build model to train D.
    optimizer = Adam(lr = 0.0008) #이떄 lr은 학습률이다
    self.D.compile(loss = 'binary_crossentropy' , optimizer = optimizer)

    #Build model to train G.
    optimizer = Adam(lr = 0.0004)
    self.D.trainable = False
    self.AM = Sequential()
    self.AM.add(self.G)
    self.AM.add(self.D)
    self.AM.compile(loss='binary_crossentropy' , optimizer = optimizer) 

  def create_d(self):
    self.D = Sequential()
    depth = 64
    dropout = 0.4
    self.D.add(Conv2D(depth*1 , 5, strides = 2, input_shape = self.input_shape , padding = 'same'))
    self.D.add(LeakyReLU(alpha = 0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*2,5,strides =2,padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
    self.D.add(LeakyReLU(alpha=0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Conv2D(depth*8 , 5 , strides = 1 , padding = 'same'))
    self.D.add(LeakyReLU(alpha = 0.2))
    self.D.add(Dropout(dropout))
    self.D.add(Flatten())
    self.D.add(Dense(1))
    self.D.add(Activation('sigmoid'))
    self.D.summary()

    return self.D 

  def create_g(self):
    self.G = Sequential()
    depth = 256
    dropout = 0.4
    dim = 8

    self.G.add(Dense(dim*dim*depth , input_dim = self.noise_size))
    self.G.add(BatchNormalization(momentum = 0.9))
    self.G.add(Activation('relu'))

    self.G.add(Reshape((dim, dim, depth)))#8 8 256
    self.G.add(Dropout(dropout))

    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))#16 16 128
    self.G.add(Activation('relu'))

    self.G.add(UpSampling2D())
    self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))#32 32 64
    self.G.add(Activation('relu'))


    self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same'))
    self.G.add(BatchNormalization(momentum=0.9))#64 64 32
    self.G.add(Activation('relu'))

    self.G.add(Conv2DTranspose(3, 5, padding='same')) # 64 64 3 
    self.G.add(Activation('sigmoid'))
    self.G.summary()
    
    return self.G #하나의 함수라고 생각

  def train(self, batch_size = 100):
    images_train = self.img_data
    images_train = images_train.batch(batch_size)
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
    images_fake = self.G.predict(noise)

    train_one_image = next(it)
    x = concatenate((train_one_image, images_fake) , axis = 0)

    #train D
    y = np.ones([2*batch_size , 1])
    y[batch_size:  , :] = 0
    self.D.trainable = True
    d_loss = self.D.train_on_batch(x , y)

    #train G
    y = np.ones([batch_size , 1])
    noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
    self.D.trainable = False
    a_loss = self.AM.train_on_batch(noise , y)

    return d_loss , a_loss , images_fake

  def save(self):
    self.G.save_weights('gan_g_weight.h5')
    self.D.save_weights('gan_d_weight.h5')

  def load(self):
    if os.path.isfile('gan_g_weights.h5'):
      self.G.load_weights('gan_g_weights.h5')
      print("Load G from file.")

    if os.path.isfile("gan_d_weights.h5"):
      self.D.load_weights('gan_d_weights.h5')
      print("load D from file")

메인이라고 할 수 있는 gan의 몸체이다 keras를 통해 DCGAN 모델을 구성하였고 train 함수를 반복시킴으로 GAN모델이 학습한다 .


gan = Gan(list_data)

print("save")

gan.load()

epochs = 100
sample_size = 10
batch_size = 100
train_per_epoch =  len(list_data) // batch_size

for epoch in range(0, epochs):
  print("Epochs :" ,epoch )
  total_d_loss = 0.0
  total_a_loss = 0.0
  imgs = None

  for batch in range(0, train_per_epoch):
    d_loss , a_loss, t_imgs = gan.train(batch_size)
    print("현재 loss:", total_d_loss , total_a_loss)
    total_d_loss += d_loss
    total_a_loss += a_loss
  
    if imgs is None:
      imgs = t_imgs

  if epochs % 5 == 0:
    total_d_loss /= train_per_epoch
    total_a_loss /= train_per_epoch

    print("Epoch: {}, D Loss: {}, AM Loss: {}" .format(epoch, total_d_loss, total_a_loss))

    #show generated images
    fig, ax = plt.subplots(1 , sample_size , figsize = (sample_size , 1))
    for i in range(0 , sample_size):
      ax[i].set_axis_off()
      ax[i].imshow(imgs[i].reshape((gan.img_rows , gan,img_clos , gan.channel)),
                   interpolation = 'nearest');
                   
    plt.show()
    plt.close(figs);

    gan.save         

실제 돌아가는 코드이다

실행결과

Model: "sequential_33"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_48 (Conv2D)           (None, 16, 16, 64)        4864      
_________________________________________________________________
leaky_re_lu_48 (LeakyReLU)   (None, 16, 16, 64)        0         
_________________________________________________________________
dropout_57 (Dropout)         (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_49 (Conv2D)           (None, 8, 8, 128)         204928    
_________________________________________________________________
leaky_re_lu_49 (LeakyReLU)   (None, 8, 8, 128)         0         
_________________________________________________________________
dropout_58 (Dropout)         (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 4, 4, 256)         819456    
_________________________________________________________________
leaky_re_lu_50 (LeakyReLU)   (None, 4, 4, 256)         0         
_________________________________________________________________
dropout_59 (Dropout)         (None, 4, 4, 256)         0         
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 4, 4, 512)         3277312   
_________________________________________________________________
leaky_re_lu_51 (LeakyReLU)   (None, 4, 4, 512)         0         
_________________________________________________________________
dropout_60 (Dropout)         (None, 4, 4, 512)         0         
_________________________________________________________________
flatten_12 (Flatten)         (None, 8192)              0         
_________________________________________________________________
dense_24 (Dense)             (None, 1)                 8193      
_________________________________________________________________
activation_65 (Activation)   (None, 1)                 0         
=================================================================
Total params: 4,314,753
Trainable params: 4,314,753
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_34"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_25 (Dense)             (None, 16384)             1654784   
_________________________________________________________________
batch_normalization_44 (Batc (None, 16384)             65536     
_________________________________________________________________
activation_66 (Activation)   (None, 16384)             0         
_________________________________________________________________
reshape_12 (Reshape)         (None, 8, 8, 256)         0         
_________________________________________________________________
dropout_61 (Dropout)         (None, 8, 8, 256)         0         
_________________________________________________________________
up_sampling2d_25 (UpSampling (None, 16, 16, 256)       0         
_________________________________________________________________
conv2d_transpose_41 (Conv2DT (None, 16, 16, 128)       819328    
_________________________________________________________________
batch_normalization_45 (Batc (None, 16, 16, 128)       512       
_________________________________________________________________
activation_67 (Activation)   (None, 16, 16, 128)       0         
_________________________________________________________________
up_sampling2d_26 (UpSampling (None, 32, 32, 128)       0         
_________________________________________________________________
conv2d_transpose_42 (Conv2DT (None, 32, 32, 64)        204864    
_________________________________________________________________
batch_normalization_46 (Batc (None, 32, 32, 64)        256       
_________________________________________________________________
activation_68 (Activation)   (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_transpose_43 (Conv2DT (None, 32, 32, 32)        51232     
_________________________________________________________________
batch_normalization_47 (Batc (None, 32, 32, 32)        128       
_________________________________________________________________
activation_69 (Activation)   (None, 32, 32, 32)        0         
_________________________________________________________________
conv2d_transpose_44 (Conv2DT (None, 32, 32, 3)         2403      
_________________________________________________________________
activation_70 (Activation)   (None, 32, 32, 3)         0         
=================================================================
Total params: 2,799,043
Trainable params: 2,765,827
Non-trainable params: 33,216
_________________________________________________________________
save
Epochs : 0
현재 loss: 0.0 0.0
현재 loss: 2.4333198070526123 0.08901489526033401
현재 loss: 3.8582282066345215 0.11727375723421574
현재 loss: 5.525717496871948 0.2411381397396326
현재 loss: 6.388717770576477 1.3069357071071863
현재 loss: 52.22373592853546 1.309405385516584
현재 loss: 53.9788624048233 1.3094076057691382
현재 loss: 57.95309817790985 1.3094076181380645
현재 loss: 64.15115678310394 1.3094076193748463
현재 loss: 72.402636885643 1.3094076203083223
현재 loss: 82.22621476650238 1.309407631680446
현재 loss: 92.13610780239105 1.3094097675175673
현재 loss: 100.37168347835541 1.3102262324655696
현재 loss: 105.34046542644501 1.400463745650117
현재 loss: 106.92031037807465 2.149145708140199
현재 loss: 122.47436845302582 2.1581681004727526
현재 loss: 125.78560173511505 2.158222034880086
현재 loss: 133.05226385593414 2.158222961635974
현재 loss: 143.85323584079742 2.158223001018963

결과가 좋지 않다. 어째서인지 loss손실함수가 하늘끝까지 승천하려고 한다……… 정확한 이유는 알 수 없고 어떻게 해야할지도 모르겠다 다음에는 더 안전한 예시로 다시 공부해보자