【人工智能基础】GAN与WGAN实验

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random

h_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # input[b, 2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            # output[b,2]
        )

    def forward(self, z):
        output = self.net(z)
        return output

判别网络定义

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

数据集生成函数

def data_generator():
    # 生成中心点
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x,y in centers] 
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * 0.02
            # 随机选取一个中心点
            center = random.choice(centers)
            # 把刚刚随机到的高斯分布点根据center进行移动
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset

可视化函数

将图片生成到visdom

import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1,2))

    with torch.no_grad():
        points = torch.Tensor(points).cpu()
        disc_map = D(points).cpu().numpy()
    x = y = np.linspace(-RANGE,RANGE,N_POINTS)
    cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1,fontsize=10)

    with torch.no_grad():
        z = torch.randn(batchsz, 2).cpu()
        samples = G(z).cpu().numpy()
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')

    viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            loss_D = lossr + lossf

            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

def gradient_penalty(D,xr,xf):
    # [b,1]
    t = torch.rand(batchsz, 1).cpu()
    # 扩展为[b, 2]
    t = t.expand_as(xr)
    # 插值
    mid = t * xr + (1 - t) * xf
    # 设置需要的倒数信息
    mid.requires_grad_()

    pred = D(mid)
    grads = autograd.grad(outputs=pred, 
                          inputs=mid,
                          grad_outputs=torch.ones_like(pred),
                          create_graph=True,
                          retain_graph=True,
                          only_inputs=True)[0]
    gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
    return gp

修改运行函数

def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            # 梯度惩罚值
            gp = gradient_penalty(D,xr,xf.detach())
            loss_D = lossr + lossf + 0.2 * gp
            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行

run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/606759.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决uniapp软键盘弹起导致页面fixed定位元素被顶上去

在移动端开发中通常导航栏需要固定在页面的最顶端,但当页面中有输入框且dom元素较多时,点击输入框弹出软键盘会促使导航栏往上移 正常情况如图一所示,软键盘弹起如图二所示 图一 图二 解决办法 1)给输入框添加 :adjust-position…

李飞飞团队 AI4S 最新洞察:16 项创新技术汇总,覆盖生物/材料/医疗/问诊……

不久前,斯坦福大学 Human-Center Artificial Intelligence (HAI) 研究中心重磅发布了《2024年人工智能指数报告》。 作为斯坦福 HAI 的第七部力作,这份报告长达 502 页,全面追踪了 2023 年全球人工智能的发展趋势。相比往年,扩大了…

node.js 下载安装 配置环境变量

1 官网下载 需要的版本https://nodejs.org/dist 下载 .msi的文件 2 根据安装向导,安装 3 检查安装 是否成功,winr 输入cmd,输入node --version 回车,查看版本 4 配置换进变量 node路径是 安装时 的安装路径 5 vscode 启动项目…

HTML(3)——常用标签3

引用标签 1.<blockquote>和<q> 两者都是对文本的解释引用&#xff0c;<blockquote>是用较大的段落进行解释&#xff0c;<q>是对较小的段落进行解释。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UT…

【小浩算法 BST与其验证】

BST与其验证 前言我的思路思路一 中序遍历判断数组无重复递增思路二 递归边界最大值最小值的传递 我的代码测试用例1测试用例2 前言 BST是二叉树一个经典应用&#xff0c;我们常常将其用于数据的查找以及构建平衡二叉树等。今天我所做的题目是验证一颗二叉树是否为二叉搜索树&…

基于PHP高考志愿填报系统搭建私有化部署源码

金秋志愿高考志愿填报系统是一款为高中毕业生提供志愿填报服务的在线平台。该系统旨在帮助学生更加科学、合理地选择自己的大学专业和学校&#xff0c;从而为未来的职业发展打下坚实的基础。 该系统的主要功能包括:报考信息查询、志愿填报数据指导、专业信息查询、院校信息查询…

学习CSS3动画教程:手把手教你绘制跑跑卡丁车

学习之前&#xff0c;请先听一段音乐&#xff1a;等登&#xff0c;等登&#xff0c;等登等登等登&#xff01;没错&#xff0c;这就是我们当年玩的跑跑卡丁车的背景音乐&#xff0c;虽然后来有了QQ飞车&#xff0c;但还是更喜欢跑跑卡丁车&#xff0c;从最初的基础板车&#xf…

深入入IAEA底层LinkedList

✅作者简介&#xff1a;大家好&#xff0c;我是再无B&#xff5e;U&#xff5e;G&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;再无B&#xff5e;U&#xff5e;G-CSDN博客 目标&#xff1a; 1.掌握LinkedList 2.…

好用无广告的快捷回复软件

在现在的工作环境中&#xff0c;时间就是金钱。对于客服人员来说&#xff0c;能够快速而准确地回复客户的问题&#xff0c;是提高工作效率和客户满意度的关键。因此&#xff0c;一个实用的快捷回复工具是必不可少的。今天&#xff0c;我想向大家推荐一款好用且无广告的客服快捷…

三勾软件 / 三勾点餐系统门店系统,java+springboot+vue3

项目介绍 三勾点餐系统基于javaspringbootelement-plusuniapp打造的面向开发的小程序商城&#xff0c;方便二次开发或直接使用&#xff0c;可发布到多端&#xff0c;包括微信小程序、微信公众号、QQ小程序、支付宝小程序、字节跳动小程序、百度小程序、android端、ios端。 在…

2024年受欢迎的主流待办事项提醒软件推荐

随着科技的飞速发展&#xff0c;2024年的今天&#xff0c;众多优秀软件如雨后春笋般涌现&#xff0c;极大地便利了我们的生活与工作。其中&#xff0c;待办事项提醒软件尤为受欢迎&#xff0c;它们不仅能记录日常待办任务&#xff0c;还能在关键时刻提醒我们及时处理&#xff0…

1707jsp电影视频网站系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 校园商城派送系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统采用web模式&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数…

使用 SSH 连接 GitHub Action 服务器

前言 Github Actions 是 GitHub 推出的持续集成 (Continuous integration&#xff0c;简称 CI) 服务它提供了整套虚拟服务器环境&#xff0c;基于它可以进行构建、测试、打包、部署项目&#xff0c;如果你的项目是开源项目&#xff0c;可以不限时使用服务器硬件规格&#xff1…

(1)AB_PLC Studio 5000软件与固件版本升级

AB_PLC Studio 5000软件与固件版本升级 1. 软件版本升级2. 固件版本升级1. 软件版本升级 使用将老程序从19版本升级到33版本。 step1:双击程序.ACD文件,打开界面如下。 step2:点击更改Controller,选择我们的新CPU的型号和版本号。点击确定 step3:点击确定,等待。 st…

echart 多表联动value为null时 tooltip 显示问题

两个图表&#xff0c;第一个有tooltip,第二个隐藏掉 两个图表的 series 如下 // ----- chart1 ----series: [{name: Union Ads,type: line,stack: Total,data: [320, 282, 391, 334, null, null, null],},{name: Email,type: line,stack: Total,data: [220, 232, 221, 234, 29…

[YOLOv8] 用YOLOv8实现指针式圆形仪表智能读数(三)

最近研究了一个项目&#xff0c;利用python代码实现指针式圆形仪表的自动读数&#xff0c;并将读数结果进行输出&#xff0c;若需要完整数据集和源代码可以私信。 目录 &#x1f353;&#x1f353;1.yolov8实现圆盘形仪表智能读数 &#x1f64b;&#x1f64b;2.表盘智能读数…

华普检测温湿度监测系统建设方案

一、项目背景 随着医疗行业的蓬勃发展&#xff0c;药品、试剂和血液的储存安全直接关系到患者的健康。根据《药品存储管理规范》、《医疗器械冷链&#xff08;运输、贮存&#xff09;管理指南》、《疫苗储存和运输管理规范》和《血液存储要求》等相关法规&#xff0c;医院药剂…

Satellite Communications Symposium(WCSP2022)

1.Power Allocation for NOMA-Assisted Integrated Satellite-Aerial-Terrestrial Networks with Practical Constraints(具有实际约束的 NOMA 辅助天地一体化网络的功率分配) 摘要&#xff1a;天地一体化网络和非正交多址接入被认为是下一代网络的关键组成部分&#xff0c;为…

流畅的python-学习笔记_前言+第一部分

前言 标准库doctest 测试驱动开发&#xff1a;先写测试&#xff0c;推动开发 obj[key]实际调用实例的__getitem__方法 python数据模型 特殊方法 特殊方法一般自己定义&#xff0c;供py解释器调用&#xff0c;不推荐自己手动调用。 对于py内置类型&#xff0c;调用特殊方…

八股文(C#篇)

C#中的数值类型 堆和栈 值类型的数据被保存在栈&#xff08;stack)上&#xff0c;而引用类型的数据被保存在堆&#xff08;heap&#xff09;上&#xff0c;当值类型作为参数传递给函数时&#xff0c;会将其复制到新的内存空间中&#xff0c;因此在函数中对该值类型的修改不会影…
最新文章