Coding-torch-数据集-Datasets

torch内置datasets的使用方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torchvision
import torchvision.transforms as transforms

# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 图像归一化
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

适用于图片分类的datasets(数据放在不同的文件夹下表示不同类别)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import glob  # 导入用于文件路径匹配的模块
from torchvision import transforms # 导入图像转换模块
from torch.utils import data # 导入PyTorch数据工具模块
from PIL import Image # 导入PIL图像处理库

# 标准化数据
transforms = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Resize((256, 256)), # 调整图像大小为256x256
transforms.Normalize(mean=0.5, std=0.5) # 标准化图像数据
])

class my_dataset(data.Dataset):
def __init__(self, imgs_path, annos_path):
self.imgs_path = imgs_path # 图像文件路径
self.annos_path = annos_path # 标签文件路径

def __getitem__(self, index):
img_path = self.imgs_path[index] # 获取图像路径
pil_img = Image.open(img_path) # 使用PIL打开图像
pil_img = transforms(pil_img) # 对图像进行预处理

anno_path = self.annos_path[index] # 获取标签路径
anno_img = Image.open(anno_path) # 使用PIL打开标签图像
pil_anno = transforms(anno_img) # 对标签图像进行预处理

return pil_img, pil_anno

# 创建训练和验证数据集
train_dataset = CustomDataset(train_data_path, transform=transform)
val_dataset = CustomDataset(val_data_path, transform=transform)

适用于图片分割,目标检测的datasets(数据和标签都是图像)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import glob  # 导入用于文件路径匹配的模块
from torchvision import transforms # 导入图像转换模块
from torch.utils import data # 导入PyTorch数据工具模块
from PIL import Image # 导入PIL图像处理库

# 标准化数据
transforms = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Resize((256, 256)), # 调整图像大小为256x256
transforms.Normalize(mean=0.5, std=0.5) # 标准化图像数据
])

class my_dataset(data.Dataset):
def __init__(self, imgs_path, annos_path):
self.imgs_path = imgs_path # 图像文件路径
self.annos_path = annos_path # 标签文件路径

def __getitem__(self, index):
img_path = self.imgs_path[index] # 获取图像路径
pil_img = Image.open(img_path) # 使用PIL打开图像
pil_img = transforms(pil_img) # 对图像进行预处理

anno_path = self.annos_path[index] # 获取标签路径
anno_img = Image.open(anno_path) # 使用PIL打开标签图像
pil_anno = transforms(anno_img) # 对标签图像进行预处理

return pil_img, pil_anno

def __len__(self):
return len(self.imgs_path) # 返回数据集的长度

# 训练数据集导入
imgs_path = glob.glob('facade/train_picture/*.png') # 匹配训练图像文件路径
label_path = glob.glob('facade/train_label/*.jpg') # 匹配训练标签文件路径

# 测试数据集导入
test_imgs_path = glob.glob('facade/test_picture/*.png') # 匹配测试图像文件路径
test_label_path = glob.glob('facade/test_label/*.jpg') # 匹配测试标签文件路径

# 对数据和标签排序,确保一一对应
imgs_path = sorted(imgs_path)
label_path = sorted(label_path)

test_imgs_path = sorted(test_imgs_path)
test_label_path = sorted(test_label_path)

train_dataset = my_dataset(imgs_path, label_path)
test_dataset = my_dataset(test_imgs_path, test_label_path) # 创建测试数据集对象
train_loader = data.DataLoader(train_dataset, batch_size=4, shuffle=True) # 创建训练数据加载器
test_loader = data.DataLoader(test_dataset, batch_size=4, shuffle=False) # 创建测试数据加载器

子数据集划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.utils.data import Subset
# 加载数据集,并分成两个子集
train_data = dsets.ImageFolder(root='data_self/train', transform=transform)
test_data = dsets.ImageFolder(root='data_self/test', transform=transform)

# 创建训练数据集的索引列表
train_indices1 = list(range(0, len(train_data), 2))
train_indices2 = list(range(1, len(train_data), 2))

# 创建训练子集1和训练子集2
train_data1 = Subset(train_data, train_indices1)
train_data2 = Subset(train_data, train_indices2)

# 创建测试数据集的索引列表
test_indices1 = list(range(0, len(test_data), 2))
test_indices2 = list(range(1, len(test_data), 2))

# 创建测试子集1和测试子集2
test_data1 = Subset(test_data, test_indices1)
test_data2 = Subset(test_data, test_indices2)

读入一个文件夹下的图片制作数据集(不分类)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 自定义数据集类
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = [f for f in os.listdir(root_dir) if f.endswith('.jpg')]

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_path).convert("RGB") # 确保图片是RGB模式

if self.transform:
image = self.transform(image)

return image

# 图像预处理
transform = transforms.Compose([
transforms.Resize((opt.img_size, opt.img_size)), # 确保这与模型输入尺寸匹配
transforms.ToTensor(), # 将PIL Image转换为Tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])

# 创建数据集实例
dataset = CustomImageDataset(root_dir='/home/yeshixin/work/work-generate/data', transform=transform)

# 创建DataLoader实例
dataloader = DataLoader(
dataset,
batch_size=opt.batch_size, # 使用您之前定义的批处理大小
shuffle=True, # 在每个epoch开始时打乱数据
num_workers=opt.n_cpu, # 使用指定数量的CPU线程加载数据
)