脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - Pytorch搭建SRGAN平台提升图片超分辨率

Pytorch搭建SRGAN平台提升图片超分辨率

2022-12-12 11:08Bubbliiiing Python

这篇文章主要为大家介绍了Pytorch搭建SRGAN平台提升图片超分辨率,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

网络构建

一、什么是SRGAN

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

Pytorch搭建SRGAN平台提升图片超分辨率


该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。

二、生成网络的构建

Pytorch搭建SRGAN平台提升图片超分辨率


生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。:

SRGAN的生成网络由三个部分组成。

1、低分辨率图像进入后会经过一个卷积+RELU函数。

2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。

3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升。

前两个部分用于特征提取,第三部分用于提高分辨率。

import math
import torch
from torch import nn
class ResidualBlock(nn.Module):
  def __init__(self, channels):
      super(ResidualBlock, self).__init__()
      self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
      self.bn1 = nn.BatchNorm2d(channels)
      self.prelu = nn.PReLU(channels)
      self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
      self.bn2 = nn.BatchNorm2d(channels)
  def forward(self, x):
      short_cut = x
      x = self.conv1(x)
      x = self.bn1(x)
      x = self.prelu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      return x + short_cut
class UpsampleBLock(nn.Module):
  def __init__(self, in_channels, up_scale):
      super(UpsampleBLock, self).__init__()
      self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
      self.pixel_shuffle = nn.PixelShuffle(up_scale)
      self.prelu = nn.PReLU(in_channels)
  def forward(self, x):
      x = self.conv(x)
      x = self.pixel_shuffle(x)
      x = self.prelu(x)
      return x
class Generator(nn.Module):
  def __init__(self, scale_factor, num_residual=16):
      upsample_block_num = int(math.log(scale_factor, 2))
      super(Generator, self).__init__()
      self.block_in = nn.Sequential(
          nn.Conv2d(3, 64, kernel_size=9, padding=4),
          nn.PReLU(64)
      )
      self.blocks = []
      for _ in range(num_residual):
          self.blocks.append(ResidualBlock(64))
      self.blocks = nn.Sequential(*self.blocks)
      self.block_out = nn.Sequential(
          nn.Conv2d(64, 64, kernel_size=3, padding=1),
          nn.BatchNorm2d(64)
      )
      self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
      self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
      self.upsample = nn.Sequential(*self.upsample)
  def forward(self, x):
      x = self.block_in(x)
      short_cut = x
      x = self.blocks(x)
      x = self.block_out(x)
      upsample = self.upsample(x + short_cut)
      return torch.tanh(upsample)

三、判别网络的构建

Pytorch搭建SRGAN平台提升图片超分辨率


判别网络的构成如上图所示:

SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。

判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。

判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。

实现代码如下:

class Discriminator(nn.Module):
  def __init__(self):
      super(Discriminator, self).__init__()
      self.net = nn.Sequential(
          nn.Conv2d(3, 64, kernel_size=3, padding=1),
          nn.LeakyReLU(0.2),
          nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
          nn.BatchNorm2d(64),
          nn.LeakyReLU(0.2),
          nn.Conv2d(64, 128, kernel_size=3, padding=1),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(0.2),
          nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(0.2),
          nn.Conv2d(128, 256, kernel_size=3, padding=1),
          nn.BatchNorm2d(256),
          nn.LeakyReLU(0.2),
          nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
          nn.BatchNorm2d(256),
          nn.LeakyReLU(0.2),
          nn.Conv2d(256, 512, kernel_size=3, padding=1),
          nn.BatchNorm2d(512),
          nn.LeakyReLU(0.2),
          nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
          nn.BatchNorm2d(512),
          nn.LeakyReLU(0.2),
          nn.AdaptiveAvgPool2d(1),
          nn.Conv2d(512, 1024, kernel_size=1),
          nn.LeakyReLU(0.2),
          nn.Conv2d(1024, 1, kernel_size=1)
      )
  def forward(self, x):
      batch_size = x.size(0)
      return torch.sigmoid(self.net(x).view(batch_size))

 

训练思路

SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。

一、判别器的训练

在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。

因此判别器的训练步骤如下:

1、随机选取batch_size个真实高分辨率图片。

2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。

3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。

Pytorch搭建SRGAN平台提升图片超分辨率

二、生成器的训练

在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。

因此生成器的训练步骤如下:

1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。

2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss

Pytorch搭建SRGAN平台提升图片超分辨率

 

利用SRGAN生成图片

SRGAN的库整体结构如下:

Pytorch搭建SRGAN平台提升图片超分辨率

一、数据集的准备

在训练前需要准备好数据集,数据集保存在datasets文件夹里面。

Pytorch搭建SRGAN平台提升图片超分辨率

二、数据集的处理

打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。

Pytorch搭建SRGAN平台提升图片超分辨率

三、模型训练

在完成数据集处理后,运行train.py即可开始训练。

Pytorch搭建SRGAN平台提升图片超分辨率


训练过程中,可在results文件夹内查看训练效果:

Pytorch搭建SRGAN平台提升图片超分辨率

以上就是Pytorch搭建SRGAN平台提升图片超分辨率的详细内容,更多关于Pytorch搭建SRGAN图片超分辨率的资料请关注服务器之家其它相关文章!

原文链接:https://blog.csdn.net/weixin_44791964/article/details/121628982

延伸 · 阅读

精彩推荐
  • Pythonpython DataFrame获取行数、列数、索引及第几行第几列的值方法

    python DataFrame获取行数、列数、索引及第几行第几列的值方法

    下面小编就为大家分享一篇python DataFrame获取行数、列数、索引及第几行第几列的值方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看...

    小白九九23632021-01-29
  • Python浅谈Python类的单继承相关知识

    浅谈Python类的单继承相关知识

    本文给大家介绍面向对象三要素之一继承Inheritance的相关知识,通过示例代码给大家介绍了继承、猫类、狗类不用写代码,直接继承了父类的属性和方法,...

    Amae6392021-10-29
  • PythonPython Pandas基础操作详解

    Python Pandas基础操作详解

    这篇文章主要介绍了Python使用Pandas库常见操作,结合实例形式详细分析了Python Pandas模块的功能、原理、数据对象创建、查看、选择等相关操作技巧与注意事...

    冲浪的长颈鹿V5732022-02-19
  • PythonOpenCV半小时掌握基本操作之傅里叶变换

    OpenCV半小时掌握基本操作之傅里叶变换

    这篇文章主要介绍了OpenCV基本操作之傅里叶变换,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...

    我是小白呀11032021-12-27
  • PythonTensorFlow tensor的拼接实例

    TensorFlow tensor的拼接实例

    今天小编就为大家分享一篇TensorFlow tensor的拼接实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧 ...

    Eric_LH4672020-04-14
  • Pythonpython一行sql太长折成多行并且有多个参数的方法

    python一行sql太长折成多行并且有多个参数的方法

    今天小编就为大家分享一篇python一行sql太长折成多行并且有多个参数的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    sy_y3702021-03-19
  • PythonDjango drf请求模块源码解析

    Django drf请求模块源码解析

    APIView中的dispatch是整个请求生命过程的核心方法,包含了请求模块,权限验证,异常模块和响应模块,我们先来介绍请求模块,对Django drf请求模块源码相关...

    Silent丿丶黑羽11732021-11-26
  • PythonPython实现位图分割的效果

    Python实现位图分割的效果

    目前网络上大多为用C++或者Matlab编写实现位图分割,所以本文将使用Python实现位图分割这一效果,代码简单易懂,感兴趣的小伙伴可以关注一下...

    小斌斌_Plus3562022-03-04