深入探究生成对抗网络(GAN):原理与代码分析

news/2024/7/5 3:44:48

文章目录

  • 1. 应用领域
    • 1.1 图像生成
    • 1.2 图像编辑和重建
    • 1.3 视频生成
    • 1.4 文本生成
    • 1.5 音乐生成
    • 1.1 虚拟现实增强
  • 2. GAN的原理
    • 2.1 核心概念
    • 2.2 网络结构
    • 2.3 损失函数
    • 2.4 训练过程
  • 3. GAN图像生成任务应用

生成对抗网络(Generative Adversarial Network, GAN)是一种强大的深度学习模型,由生成器和判别器两个神经网络组成。GAN的目标是让生成器网络生成逼真的样本,以尽可能欺骗判别器网络,同时判别器网络要尽可能准确地区分真实样本和生成样本。

1. 应用领域

1.1 图像生成

GAN在图像生成领域非常流行。通过训练生成器网络来生成与训练数据集相似的逼真图像。GAN可以生成各种类型的图像,如人脸、风景、动物等。

1.2 图像编辑和重建

GAN图像编辑和重建。通过对生成器网络进行操纵,可以修改图像的特定属性,如颜色、纹理等,实现图像编辑的效果。此外,GAN还可以从损坏或不完整的图像中进行重建,填补缺失的部分,达到修复的效果。

1.3 视频生成

GAN视频生成。通过对时间序列数据进行建模,生成器网络可以生成逼真的连续帧,从而实现视频生成。

1.4 文本生成

GAN文本生成。通过训练生成器网络,生成具有逼真语义和语法结构的文本,如自动生成故事、对话模型、自动摘要等。

1.5 音乐生成

GAN通过对音乐序列进行建模,生成器网络可以生成新颖且具有艺术性的音乐作品。

1.1 虚拟现实增强

GAN虚拟现实(VR)和增强现实(AR),生成逼真的虚拟场景和物体。

2. GAN的原理

2.1 核心概念

GAN的核心概念是生成器和判别器。生成器负责生成逼真的样本,而判别器则用于区分真实样本和生成样本。生成器和判别器通过对抗训练的方式相互竞争,最终达到生成逼真样本的目标。

2.2 网络结构

生成器和判别器通常采用深度神经网络。生成器将一个随机向量作为输入,通过一系列的神经网络层逐步生成逼真样本。判别器接收生成样本和真实样本作为输入,并输出一个概率值来判断样本的真实性。

2.3 损失函数

GAN使用了两个损失函数:生成器损失和判别器损失。生成器损失衡量生成样本与真实样本之间的差异,鼓励生成器生成更逼真的样本。判别器损失衡量判别器对生成样本和真实样本的分类准确性,鼓励判别器准确区分这两类样本。

2.4 训练过程

GAN的训练过程是一个交替的优化过程。在每次迭代中,首先固定生成器,通过最小化判别器损失来更新判别器网络参数;然后固定判别器,通过最小化生成器损失来更新生成器网络参数。这种交替的训练过程使得生成器和判别器逐渐提升性能,直至达到平衡状态。

3. GAN图像生成任务应用

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

# 定义生成器网络
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    return model

# 定义判别器网络
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# 定义生成器和判别器
generator = make_generator_model()
discriminator = make_discriminator_model()

# 定义损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 定义生成器损失
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# 定义判别器损失
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

# 定义优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 定义训练步骤
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

# 训练模型
EPOCHS = 100
BATCH_SIZE = 128

for epoch in range(EPOCHS):
    for batch in range(len

(train_images) // BATCH_SIZE):
        images = train_images[batch * BATCH_SIZE : (batch + 1) * BATCH_SIZE]
        train_step(images)

    # 每个epoch结束后生成一张示例图片
    noise = tf.random.normal([1, 100])
    generated_image = generator(noise, training=False)
    # 保存生成的图片或展示在可视化界面中

# 保存生成器和判别器模型
generator.save('generator_model.h5')
discriminator.save('discriminator_model.h5')

通过反复迭代训练,生成器逐渐生成逼真的手写数字图像,判别器逐渐提高对真实图像和生成图像的辨别能力。


http://lihuaxi.xjx100.cn/news/1243659.html

相关文章

极致呈现系列之:Echarts柱状图的创意设计与数字美学的完美平衡

先看下最终效果 目录 数字之美:Echarts柱状图的基础应用形色俱佳:Echarts柱状图的样式美化与创意设计独具匠心:Echarts柱状图的柱体形状自定义动感十足:Echarts柱状图的交互动画实现数字排序的艺术:Echarts柱状图的数…

安卓自动化

又python客户端 --------> Appium Server ------------------> Java 先安装: 安装2的时候 添加一个环境变量 第三步添加环境变量 第四步添加环境变量,在系统变量path中添加

version `GLIBCXX_3.4.14‘ not found

./Gate: /usr/lib64/libstdc.so.6: version GLIBCXX_3.4.14 not found (required by ./Gate) 本人测试gcc-8.3.0装不上,可考虑7.30亲测可装, 4.81也测试过了,可以装但是应该不支持3.414 查看支持的版本列表 strings /lib64/libc.so.6 | grep GLIBC 8.30安装 wget http://ftp…

系统码的编译码与汉明码

本专栏包含信息论与编码的核心知识,按知识点组织,可作为教学或学习的参考。markdown版本已归档至【Github仓库:https://github.com/timerring/information-theory 】或者公众号【AIShareLab】回复 信息论 获取。 文章目录 系统码的编译码线性…

如何录制声音?推荐这2款电脑录音软件!

案例:怎么录制电脑上的声音?在电脑上怎么录制自己的声音?有没有小伙伴知道操作的步骤。 【我想录制语音会议,还想录制自己的歌声,在电脑上如何录制声音?求一个简单易懂的教程,在线等&#xff0…

Optional简述(Java8新特性)

Optional类是Java8为了解决null值判断问题,借鉴google guava类库的Optional类而引入的一个同名Optional类,使用Optional类可以避免显式的null值判断(null的防御性检查),避免null导致的NPE(NullPointerExcep…

高速视觉筛选机PCI Express实时运动控制卡XPCIE1028

产品导读 正运动技术的PCI Express总线运动控制卡XPCIE1028,具备位置锁存、多维高速硬件位置比较输出PSO、同步跟随、精准触发的运动控制和I/O控制功能。 配合正运动技术MotionRT7实时内核使用,可高度满足高速视觉筛选机应用所需的运动控制需求。 XPC…

linuxOPS基础_yum详解

yum是如何安装软件的 yum仓库(也称yum源)用于存放各种rpm的软件包以及软件包之间的依赖关系(repodata目录)需要安装软件的计算机连接到指定yum仓库来安装软件包 yum源作用 软件包管理器,类似Windows下的软件管家 yu…