0%

tf2-GAN

实战使用keras的mnist数据集,使用GAN生成手写数据图片

使用生成器模型生成手写图片
使用判别器模型对真图片进行真判断对生成的图片进行假判断

定义生成器模型

生成器用来将随机数生成手写数据图片
生成器使用三层结构
输入层和中间层使用BatchNormalization进行标准化,使用LeakyReLU进行激活
输出层使用tanh激活
最后将输出数据改为28 * 28 * 1 的形状

line_number: true
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def generator_model():
model = tf.keras.Sequential()
model.add(keras.layers.Dense(256, input_shape = (100,), use_bias = False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())

model.add(keras.layers.Dense(512, use_bias = False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())

model.add(keras.layers.Dense(28 * 28 * 1, use_bias = False, activation = 'tanh'))
model.add(keras.layers.BatchNormalization())

model.add(keras.layers.Reshape((28, 28, 1)))

return model

定义判别器模型

判别器用来对输入图片进行判别
判别器使用三层网络结构
第一层将图片数据延展并输入到一个全连接层,使用BatchNormalization标准化,使用LeakyReLU激活
第二层是个全连接层,同样使用使用BatchNormalization标准化,使用LeakyReLU激活
第三层输出判别的结果

line_number: true
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def discriminator_model():
model = keras.Sequential()
model.add(keras.layers.Flatten())

model.add(keras.layers.Dense(512, use_bias = False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())

model.add(keras.layers.Dense(256, use_bias = False))
model.add(keras.layers.BatchNormalization())
model.add(keras.layers.LeakyReLU())

model.add(keras.layers.Dense(1))

return model

GAN 生成对抗网络

应用领域

* 图像生成
* 图像增强
* 风格化
* 艺术的图像创作

GAN定义

GAN 包含两部分生成器generator与判别器discriminator,
* 生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器
* 判别器对接收的图片进行真假判断

GAN 设计

  • 生成器网络

  • 判别器网络(例如 5层CNN)

自编码器
基本去燥自编ma
卷积去燥自编码器

GAN自定义