SRGAN是一种用于图像超分辨率(SR)的生成对抗网络(GAN),能够推断四倍放大因子的照片般逼真的自然图像。
文章来源:
2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)
下载链接:
前言
在GAN领域中,超分辨率复原一直是计算机视觉领域一个十分热门的研究方向,在商业上也有着很大的用武之地,随着2014年lan J. Goodflew那篇惊世骇俗的GAN发表出来,GAN伴随着CNN一起,可谓是乘风破浪,衍生出来琳琅满目的各种应用。
网络结构
SRGAN网络结构如下图(SRGAN还是用SRRESNET来进行超分工作 但增加了一个对抗网络来判断生成的图片是原图还是超分出来的图):
def SRGAN_g(t_image): # Input-Conv-Relu n = fluid.layers.conv2d(input=t_image, num_filters=64, filter_size=3, stride=1, padding='SAME', name='n64s1/c', data_format='NCHW') n = fluid.layers.batch_norm(n, momentum=0.99, epsilon=0.001) n = fluid.layers.relu(n, name=None) temp = n # B residual blocks # Conv-BN-Relu-Conv-BN-Elementwise_add for i in range(16): nn = fluid.layers.conv2d(input=n, num_filters=64, filter_size=3, stride=1, padding='SAME', name='n64s1/c1/%s' % i, data_format='NCHW') nn = fluid.layers.batch_norm(nn, momentum=0.99, epsilon=0.001, name='n64s1/b1/%s' % i) nn = fluid.layers.relu(nn, name=None) log = 'conv%2d' % (i+1) nn = fluid.layers.conv2d(input=nn, num_filters=64, filter_size=3, stride=1, padding='SAME', name='n64s1/c2/%s' % i, data_format='NCHW') nn = fluid.layers.batch_norm(nn, momentum=0.99, epsilon=0.001, name='n64s1/b2/%s' % i) nn = fluid.layers.elementwise_add(n, nn, act=None, name='b_residual_add/%s' % i) n = nn n = fluid.layers.conv2d(input=n, num_filters=64, filter_size=3, stride=1, padding='SAME', name='n64s1/c/m', data_format='NCHW') n = fluid.layers.batch_norm(n, momentum=0.99, epsilon=0.001, name='n64s1/b2/%s' % i) n = fluid.layers.elementwise_add(n, temp, act=None, name='add3') # B residual blacks end # Conv-Pixel_shuffle-Conv-Pixel_shuffle-Conv n = fluid.layers.conv2d(input=n, num_filters=256, filter_size=3, stride=1, padding='SAME', name='n256s1/1', data_format='NCHW') n = fluid.layers.pixel_shuffle(n, upscale_factor=2) n = fluid.layers.relu(n, name=None) n = fluid.layers.conv2d(input=n, num_filters=256, filter_size=3, stride=1, padding='SAME', name='n256s1/2', data_format='NCHW') n = fluid.layers.pixel_shuffle(n, upscale_factor=2) n = fluid.layers.relu(n, name=None) n = fluid.layers.conv2d(input=n, num_filters=3, filter_size=1, stride=1, padding='SAME', name='out', data_format='NCHW') n = fluid.layers.tanh(n, name=None) return n
def SRGAN_d(input_images): # Conv-Leaky_Relu net_h0 = fluid.layers.conv2d(input=input_images, num_filters=64, filter_size=4, stride=2, padding='SAME', name='h0/c', data_format='NCHW') net_h0 = fluid.layers.leaky_relu(net_h0, alpha=0.2, name=None) # h1 Cnov-BN-Leaky_Relu net_h1 = fluid.layers.conv2d(input=net_h0, num_filters=128, filter_size=4, stride=2, padding='SAME', name='h1/c', data_format='NCHW') net_h1 = fluid.layers.batch_norm(net_h1, momentum=0.99, epsilon=0.001, name='h1/bn') net_h1 = fluid.layers.leaky_relu(net_h1, alpha=0.2, name=None) # h2 Cnov-BN-Leaky_Relu net_h2 = fluid.layers.conv2d(input=net_h1, num_filters=256, filter_size=4, stride=2, padding='SAME', name='h2/c', data_format='NCHW') net_h2 = fluid.layers.batch_norm(net_h2, momentum=0.99, epsilon=0.001, name='h2/bn') net_h2 = fluid.layers.leaky_relu(net_h2, alpha=0.2, name=None) # h3 Cnov-BN-Leaky_Relu net_h3 = fluid.layers.conv2d(input=net_h2, num_filters=512, filter_size=4, stride=2, padding='SAME', name='h3/c', data_format='NCHW') net_h3 = fluid.layers.batch_norm(net_h3, momentum=0.99, epsilon=0.001, name='h3/bn') net_h3 = fluid.layers.leaky_relu(net_h3, alpha=0.2, name=None) # h4 Cnov-BN-Leaky_Relu net_h4 = fluid.layers.conv2d(input=net_h3, num_filters=1024, filter_size=4, stride=2, padding='SAME', name='h4/c', data_format='NCHW') net_h4 = fluid.layers.batch_norm(net_h4, momentum=0.99, epsilon=0.001, name='h4/bn') net_h4 = fluid.layers.leaky_relu(net_h4, alpha=0.2, name=None) # h5 Cnov-BN-Leaky_Relu net_h5 = fluid.layers.conv2d(input=net_h4, num_filters=2048, filter_size=4, stride=2, padding='SAME', name='h5/c', data_format='NCHW') net_h5 = fluid.layers.batch_norm(net_h5, momentum=0.99, epsilon=0.001, name='h5/bn') net_h5 = fluid.layers.leaky_relu(net_h5, alpha=0.2, name=None) # h6 Cnov-BN-Leaky_Relu net_h6 = fluid.layers.conv2d(input=net_h5, num_filters=1024, filter_size=4, stride=2, padding='SAME', name='h6/c', data_format='NCHW') net_h6 = fluid.layers.batch_norm(net_h6, momentum=0.99, epsilon=0.001, name='h6/bn') net_h6 = fluid.layers.leaky_relu(net_h6, alpha=0.2, name=None) # h7 Cnov-BN-Leaky_Relu net_h7 = fluid.layers.conv2d(input=net_h6, num_filters=512, filter_size=4, stride=2, padding='SAME', name='h7/c', data_format='NCHW') net_h7 = fluid.layers.batch_norm(net_h7, momentum=0.99, epsilon=0.001, name='h7/bn') net_h7 = fluid.layers.leaky_relu(net_h7, alpha=0.2, name=None) #修改原论文网络 net = fluid.layers.conv2d(input=net_h7, num_filters=128, filter_size=1, stride=1, padding='SAME', name='res/c', data_format='NCHW') net = fluid.layers.batch_norm(net, momentum=0.99, epsilon=0.001, name='res/bn') net = fluid.layers.leaky_relu(net, alpha=0.2, name=None) net = fluid.layers.conv2d(input=net_h7, num_filters=128, filter_size=3, stride=1, padding='SAME', name='res/c2', data_format='NCHW') net = fluid.layers.batch_norm(net, momentum=0.99, epsilon=0.001, name='res/bn2') net = fluid.layers.leaky_relu(net, alpha=0.2, name=None) net = fluid.layers.conv2d(input=net_h7, num_filters=512, filter_size=3, stride=1, padding='SAME', name='res/c3', data_format='NCHW') net = fluid.layers.batch_norm(net, momentum=0.99, epsilon=0.001, name='res/bn3') net = fluid.layers.leaky_relu(net, alpha=0.2, name=None) net_h8 = fluid.layers.elementwise_add(net_h7, net, act=None, name='res/add') net_h8 = fluid.layers.leaky_relu(net_h8, alpha=0.2, name=None) #net_ho = fluid.layers.flatten(net_h8, axis=0, name='ho/flatten') net_ho = fluid.layers.fc(input=net_h8, size=1024, name='ho/fc') net_ho = fluid.layers.leaky_relu(net_ho, alpha=0.2, name=None) net_ho = fluid.layers.fc(input=net_h8, size=1, name='ho/fc2') # return # logits = net_ho net_ho = fluid.layers.sigmoid(net_ho, name=None) return net_ho # , logits
def conv_block(input, num_filter, groups, name=None): conv = input for i in range(groups): conv = fluid.layers.conv2d( input=conv, num_filters=num_filter, filter_size=3, stride=1, padding=1, act='relu', param_attr=fluid.param_attr.ParamAttr( name=name + str(i + 1) + "_weights"), bias_attr=False) return fluid.layers.pool2d( input=conv, pool_size=2, pool_type='max', pool_stride=2) def vgg19(input, class_dim=1000): layers = 19 vgg_spec = { 11: ([1, 1, 2, 2, 2]), 13: ([2, 2, 2, 2, 2]), 16: ([2, 2, 3, 3, 3]), 19: ([2, 2, 4, 4, 4]) } assert layers in vgg_spec.keys(), \ "supported layers are {} but input layer is {}".format(vgg_spec.keys(), layers) nums = vgg_spec[layers] conv1 = conv_block(input, 64, nums[0], name="vgg19_conv1_") conv2 = conv_block(conv1, 128, nums[1], name="vgg19_conv2_") conv3 = conv_block(conv2, 256, nums[2], name="vgg19_conv3_") conv4 = conv_block(conv3, 512, nums[3], name="vgg19_conv4_") conv5 = conv_block(conv4, 512, nums[4], name="vgg19_conv5_") fc_dim = 4096 fc_name = ["fc6", "fc7", "fc8"] fc1 = fluid.layers.fc( input=conv5, size=fc_dim, act='relu', param_attr=fluid.param_attr.ParamAttr( name=fc_name[0] + "_weights"), bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset")) fc1 = fluid.layers.dropout(x=fc1, dropout_prob=0.5) fc2 = fluid.layers.fc( input=fc1, size=fc_dim, act='relu', param_attr=fluid.param_attr.ParamAttr( name=fc_name[1] + "_weights"), bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset")) fc2 = fluid.layers.dropout(x=fc2, dropout_prob=0.5) out = fluid.layers.fc( input=fc2, size=class_dim, param_attr=fluid.param_attr.ParamAttr( name=fc_name[2] + "_weights"), bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset")) return out, conv5
损失函数
论文中还给出了生成器和判别器的损失函数的形式:
判别器的损失函数为:
训练策略
# init t_image = fluid.layers.data(name='t_image',shape=[96, 96, 3],dtype='float32') t_target_image = fluid.layers.data(name='t_target_image',shape=[384, 384, 3],dtype='float32') vgg19_input = fluid.layers.data(name='vgg19_input',shape=[224, 224, 3],dtype='float32') step_num = int(len(train_hr_imgs) / batch_size) # initialize G for epoch in range(0, n_epoch_init + 1): epoch_time = time.time() np.random.shuffle(train_hr_imgs) # real sample_imgs_384 = random_crop(train_hr_imgs, 384) sample_imgs_standardized_384 = standardized(sample_imgs_384) # input sample_imgs_96 = im_resize(sample_imgs_384,96,96) sample_imgs_standardized_96 = standardized(sample_imgs_96) # vgg19 sample_imgs_224 = im_resize(sample_imgs_384,224,224) sample_imgs_standardized_224 = standardized(sample_imgs_224) # loss total_mse_loss, n_iter = 0, 0 for i in tqdm.tqdm(range(step_num)): step_time = time.time() imgs_384 = sample_imgs_standardized_384[i * batch_size:(i + 1) * batch_size] imgs_384 = np.array(imgs_384, dtype='float32') imgs_96 = sample_imgs_standardized_96[i * batch_size:(i + 1) * batch_size] imgs_96 = np.array(imgs_96, dtype='float32') # vgg19 data imgs_224 = sample_imgs_standardized_224[i * batch_size:(i + 1) * batch_size] imgs_224 = np.array(imgs_224, dtype='float32') mse_loss_n = exe.run(SRGAN_g_program, feed={'t_image': imgs_96, 't_target_image': imgs_384, 'vgg19_input':imgs_224}, fetch_list=[mse_loss])[0] total_mse_loss += mse_loss_n n_iter += 1 log = "[*] Epoch_init: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter) print(log) if (epoch != 0) and (epoch % 10 == 0): out = exe.run(SRGAN_g_program, feed={'t_image': imgs_96, 't_target_image': imgs_384, 'vgg19_input':imgs_224}, fetch_list=[test_im])[0][0] # generate img im_G = np.array((out+1)*127.5, dtype=np.uint8) im_96 = np.array((imgs_96[0]+1)*127.5, dtype=np.uint8) im_384 = np.array((imgs_384[0]+1)*127.5, dtype=np.uint8) cv2.imwrite('./output/epoch_init_{}_G.jpg'.format(epoch), cv2.cvtColor(im_G, cv2.COLOR_RGB2BGR)) cv2.imwrite('./output/epoch_init_{}_96.jpg'.format(epoch), cv2.cvtColor(im_96, cv2.COLOR_RGB2BGR)) cv2.imwrite('./output/epoch_init_{}_384.jpg'.format(epoch), cv2.cvtColor(im_384, cv2.COLOR_RGB2BGR)) # train GAN (SRGAN) for epoch in range(0, n_epoch + 1): ## update learning rate epoch_time = time.time() # real sample_imgs_384 = random_crop(train_hr_imgs, 384) sample_imgs_standardized_384 = standardized(sample_imgs_384) # input sample_imgs_96 = im_resize(sample_imgs_384,96,96) sample_imgs_standardized_96 = standardized(sample_imgs_96) # vgg19 sample_imgs_224 = im_resize(sample_imgs_384,224,224) sample_imgs_standardized_224 = standardized(sample_imgs_224) # loss total_d_loss, total_g_loss, n_iter = 0, 0, 0 for i in tqdm.tqdm(range(step_num)): step_time = time.time() imgs_384 = sample_imgs_standardized_384[i * batch_size:(i + 1) * batch_size] imgs_384 = np.array(imgs_384, dtype='float32') imgs_96 = sample_imgs_standardized_96[i * batch_size:(i + 1) * batch_size] imgs_96 = np.array(imgs_96, dtype='float32') # vgg19 data imgs_224 = sample_imgs_standardized_224[i * batch_size:(i + 1) * batch_size] imgs_224 = np.array(imgs_224, dtype='float32') ## update D errD = exe.run(SRGAN_d_program, feed={'t_image': imgs_96, 't_target_image': imgs_384}, fetch_list=[d_loss])[0] ## update G errG = exe.run(SRGAN_g_program, feed={'t_image': imgs_96, 't_target_image': imgs_384, 'vgg19_input':imgs_224}, fetch_list=[g_loss])[0] total_d_loss += errD total_g_loss += errG n_iter += 1 log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter, total_g_loss / n_iter) print(log) if (epoch != 0) and (epoch % 10 == 0): out = exe.run(SRGAN_g_program, feed={'t_image': imgs_96, 't_target_image': imgs_384, 'vgg19_input':imgs_224}, fetch_list=[test_im])[0][0] # generate img im_G = np.array((out + 1) * 127.5, dtype=np.uint8) im_96 = np.array((imgs_96[0] + 1) * 127.5, dtype=np.uint8) im_384 = np.array((imgs_384[0] + 1) * 127.5, dtype=np.uint8) cv2.imwrite('./output/epoch_{}_G.jpg'.format(epoch), cv2.cvtColor(im_G, cv2.COLOR_RGB2BGR)) cv2.imwrite('./output/epoch_{}_96.jpg'.format(epoch), cv2.cvtColor(im_96, cv2.COLOR_RGB2BGR)) cv2.imwrite('./output/epoch_{}_384.jpg'.format(epoch), cv2.cvtColor(im_384, cv2.COLOR_RGB2BGR)) # save model # d_models save_pretrain_model_path_d = 'models/d_models/' # delete old model files shutil.rmtree(save_pretrain_model_path_d, ignore_errors=True) # mkdir os.makedirs(save_pretrain_model_path_d) fluid.io.save_persistables(executor=exe, dirname=save_pretrain_model_path_d, main_program=SRGAN_g_program) # g_models save_pretrain_model_path_g = 'models/g_models/' # delete old model files shutil.rmtree(save_pretrain_model_path_g, ignore_errors=True) # mkdir os.makedirs(save_pretrain_model_path_g) fluid.io.save_persistables(executor=exe, dirname=save_pretrain_model_path_g, main_program=SRGAN_g_program)
结果展示
import os from PIL import Image import matplotlib.pyplot as plt img0 = Image.open('./output/epoch_1780_96.jpg') img1 = Image.open('./output/epoch_1780_384.jpg') img2 = Image.open('./output/epoch_1780_G.jpg') plt.figure("Image Completion Result",dpi=384) # dpi = 384 显示的是原图大小 plt.subplot(2,3,1) plt.imshow(img0) plt.title('Low resolution',fontsize='xx-small',fontweight='heavy') plt.axis('off') plt.subplot(2,3,2) plt.imshow(img1) plt.title('Hing resolution',fontsize='xx-small',fontweight='heavy') plt.axis('off') plt.subplot(2,3,3) plt.imshow(img2) plt.title('Generate',fontsize='xx-small',fontweight='heavy') plt.axis('off') plt.show()
心得体会
在此篇文章之前,CNN网络在传统的单帧超分辨率重建上就取得了非常好的效果,但是当图像下采样倍数较高时,重建的得到的图片会过于平滑,丢失细节。此篇文章提出的利用GAN来进行超分辨率重建的方法,是第一个能恢复四倍下采样图像的框架。SRGAN这个网络的最大贡献就是使用了生成对抗网络(Generative adversarial network)来训练SRResNet,使其产生的HR图像看起来更加自然,有更好的视觉效果,更接近自然HR图像。