Coding-SRGAN-超分辨生成对抗网络-基于RGB图片数据

摘要

srgan.py用于训练网络
inference.py用于使用训练好的模型文件推理
图片数据放在/home/myself/work/work-generate/data下,所有图片在一个文件夹里

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

import argparse
import os
import numpy as np
import math
import itertools
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

from models import *
from datasets import *

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("/home/myself/work/work-generate/srgan/images", exist_ok=True)
os.makedirs("/home/myself/work/work-generate/srgan/saved_models", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="data", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--hr_height", type=int, default=512, help="high res. image height")
parser.add_argument("--hr_width", type=int, default=512, help="high res. image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving image samples")
parser.add_argument("--checkpoint_interval", type=int, default=20, help="interval between model checkpoints")
opt = parser.parse_args()
print(opt)

cuda = torch.cuda.is_available()

hr_shape = (opt.hr_height, opt.hr_width)

# Initialize generator and discriminator
generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape))
feature_extractor = FeatureExtractor()

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()

if opt.epoch != 0:
# Load pretrained models
generator.load_state_dict(torch.load("/home/myself/work/work-generate/srgan/saved_models/generator_%d.pth"))
discriminator.load_state_dict(torch.load("/home/myself/work/work-generate/srgan/saved_models/discriminator_%d.pth"))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

dataloader = DataLoader(
ImageDataset("/home/myself/work/work-generate/%s" % opt.dataset_name, hr_shape=hr_shape),

batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)

# ----------
# Training
# ----------
def main():
for epoch in range(opt.epoch, opt.n_epochs):
for i, imgs in enumerate(dataloader):

# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))

# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

# ------------------
# Train Generators
# ------------------

optimizer_G.zero_grad()

# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)

# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr)
loss_content = criterion_content(gen_features, real_features.detach())

# Total loss
loss_G = loss_content + 1e-3 * loss_GAN

loss_G.backward()
optimizer_G.step()

# ---------------------
# Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

# Total loss
loss_D = (loss_real + loss_fake) / 2

loss_D.backward()
optimizer_D.step()

# --------------
# Log Progress
# --------------

sys.stdout.write(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item())
)

batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
# Save image grid with upsampled inputs and SRGAN outputs
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_lr, gen_hr), -1)
save_image(img_grid, "/home/myself/work/work-generate/srgan/images/%d.png" % batches_done, normalize=False)

if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(generator.state_dict(), "/home/myself/work/work-generate/srgan/saved_models/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "/home/myself/work/work-generate/srgan/saved_models/discriminator_%d.pth" % epoch)

if __name__ == '__main__':
main()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# 导入所需库和模块
import torch.nn as nn # PyTorch神经网络模块
import torch.nn.functional as F # 附加功能,如激活函数等
import torch # 主PyTorch库
from torchvision.models import vgg19 # VGG19模型,虽然未直接使用,但可能是参考或预留
import math # 数学运算库
from PIL import Image # 处理图像的库
import torchvision.transforms as transforms # 图像转换工具
from torchvision.utils import save_image # 保存图像工具
import numpy as np # 数值处理库

# 定义残差块,用于构建生成器网络中的残差网络结构
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
# 定义一系列卷积、批归一化、激活函数操作
self.conv_block = nn.Sequential(
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, momentum=0.8), # 动量参数设置为0.8
nn.PReLU(), # 参数为默认值的PReLU激活函数
nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(in_features, momentum=0.8),
)

def forward(self, x):
# 实现残差连接,输入x加上经过卷积块处理后的x
return x + self.conv_block(x)

# 定义生成器网络,基于残差网络结构
class GeneratorResNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
super(GeneratorResNet, self).__init__()

# 第一层卷积和PReLU激活
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())

# 重复定义n_residual_blocks次数的残差块
res_blocks = [ResidualBlock(64) for _ in range(n_residual_blocks)]
self.res_blocks = nn.Sequential(*res_blocks) # 将所有残差块放入一个Sequential容器中

# 残差块之后的第二个卷积层和批归一化
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, momentum=0.8))

# 上采样层,这里使用PixelShuffle进行上采样
upsampling = [
nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.PixelShuffle(upscale_factor=2), # 使用PixelShuffle上采样,因子为2
nn.PReLU(), # 激活函数
] * 2 # 重复两次上述操作,进行两轮上采样
self.upsampling = nn.Sequential(*upsampling)

# 输出层,最终的卷积层和Tanh激活
self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())

def forward(self, x):
# 前向传播过程,包括残差结构、上采样等
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2) # 这里直接相加可能有误,应是残差结构的正确应用
out = self.upsampling(out)
out = self.conv3(out)
return out

# 反标准化函数,将模型输出转换为可显示的格式
def unnormalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
tensor = tensor.clone() # 防止原地修改,先复制张量
for i in range(tensor.size(1)): # 遍历每个通道
tensor[:, i, :, :] = tensor[:, i, :, :] * std[i] + mean[i] # 反标准化操作
return tensor

# 主程序入口
if __name__ == '__main__':
# 设置均值和标准差,用于图像标准化
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# 图像转换管道,包含转Tensor和标准化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std),
])

# 初始化并加载预训练的生成器模型
g = GeneratorResNet()
g.load_state_dict(torch.load('generator_100.pth'))

# 加载并处理图像
input= Image.open("C:\\Users\\hahag\\OneDrive\\WORK\\work-generate\\6.jpg").convert("RGB")
input = transform(input) # 转换图像

# 添加batch维度
input = input.unsqueeze(0)

# 通过生成器生成输出
output= g(input)

# 反标准化
output = unnormalize(output)

# 保存生成的图像
save_image(output, 'output.png', normalize=False)