实战使用keras的mnist数据集,使用GAN生成手写数据图片
使用生成器模型生成手写图片
使用判别器模型对真图片进行真判断对生成的图片进行假判断
定义生成器模型
生成器用来将随机数生成手写数据图片
生成器使用三层结构
输入层和中间层使用BatchNormalization进行标准化,使用LeakyReLU进行激活
输出层使用tanh激活
最后将输出数据改为28 * 28 * 1 的形状
line_number: true1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| def generator_model(): model = tf.keras.Sequential() model.add(keras.layers.Dense(256, input_shape = (100,), use_bias = False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(512, use_bias = False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(28 * 28 * 1, use_bias = False, activation = 'tanh')) model.add(keras.layers.BatchNormalization())
model.add(keras.layers.Reshape((28, 28, 1)))
return model
|
定义判别器模型
判别器用来对输入图片进行判别
判别器使用三层网络结构
第一层将图片数据延展并输入到一个全连接层,使用BatchNormalization标准化,使用LeakyReLU激活
第二层是个全连接层,同样使用使用BatchNormalization标准化,使用LeakyReLU激活
第三层输出判别的结果
line_number: true1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| def discriminator_model(): model = keras.Sequential() model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, use_bias = False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(256, use_bias = False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU())
model.add(keras.layers.Dense(1))
return model
|
GAN 生成对抗网络
应用领域
* 图像生成
* 图像增强
* 风格化
* 艺术的图像创作
GAN定义
GAN 包含两部分生成器generator与判别器discriminator,
* 生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器
* 判别器对接收的图片进行真假判断
GAN 设计
自编码器
基本去燥自编ma
卷积去燥自编码器
GAN自定义