自定义数据集读取_CodingPark编程公园

基础知识

def getitem( )只有在 读内容 的时候才会走它的里面

在这里插入图片描述
pytorch虽然简单易用,但是其高度的封装使得初学者难以理解数据是如何读入的。

对于自己的任务,很可能pytorch提供的数据读取机制难以完全满足任务要求。

所以我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程。

样例

class Dataset(object):
    """此处省略"""
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

可以看出我们只需要实现“getitem(self, index)”方法即可,这个__getitem__()方法的index参数看起来有些让人困惑,查阅python官方文档发现这个方法是object的方法:
在这里插入图片描述

所以这个index的值就是索引值,比如我们想要数据集中第二张图片索引就是2


自定义SingeClassDataset

我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:

  • 初始化图像路径模块__init__()
    由于目的是继承Dataset类,所以应该采用__init__()来存储数据路径,在存储之前要先检查输入的路径是不是正确的路径,防止图片读取失败。

  • 图像转Tensor模块_read_convert_image()
    读取图像采用python内置的PIL提供的Image类型,这也是pytorch支持的核心类型。读取Image类型的图片后可以直接通过torchvision提供的变换,转为pytorch需要的Tensor类型。

  • 数据索引模块__getitem__()
    根据图像的存储方式不同可以采用多种读取策略,常见的情况有两种:图像在一个文件夹中图像在多个文件夹中

下面实现的__getitem__()方法针对于所有的图像在一个文件夹内的情况。

在这里插入图片描述

定义了__getitem__()方法的类就可以通过“索引”获得数据,下面来看一下数据是正确的读入了,可视化采用的是matplotlib

由于之前pytorch的版本还必须实现"len()"方法用于返回数据集的长度,所以下面的代码实现了它,但是当前版本的pytorch已经不再强制实现这个函数了

完整代码

# -*- encoding: utf-8 -*-
"""
@File    :   Dataset_diy.py    
@Contact :   ag@team-ag.club
@License :   (C)Copyright 2019-2020, CodingPark

@Modify Time      @Author    @Version    @Desciption
------------      -------    --------    -----------
2020-10-27 13:03   AG         1.0        自定义数据集读取
"""

'''
Explain:
【1】我们需要学习如何使用pytorch提供的torch.utils.data.Dataset来自定义数据读取流程
【2】我们需要实现“__getitem__(self, index)”方法,其中index的值就是索引值,比如我们想要数据集中第二张图片索引就是2
'''

'''
我在自定义Dataset时共分为3个步骤,目的是以后能够好的进行功能扩展:
初始化图像路径模块__init__()
图像转Tensor模块_read_convert_image()
数据索引模块__getitem__()

'''

from torch.utils.data import Dataset
import os
from PIL import Image
import torchvision.transforms as T


class SingleClassDataset(Dataset):
    """
    This Dataset only work for a folder that contains one class image!!!
    """

    def __init__(self, file_path):
        # 保证输入的是正确的路径
        if not os.path.isdir(file_path):
            raise ValueError("input file_path is not a dir")
        self.file_path = file_path
        # 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据
        self.image_list = os.listdir(file_path)
        # 将PIL的Image转为Tensor
        self.transforms = T.ToTensor()

    def __getitem__(self, index):
        # 根据index获取图片完整路径
        image_path = os.path.join(self.file_path, self.image_list[index])
        # 都图片并转为Tensor
        image = self._read_convert_image(image_path)
        return image

    def _read_convert_image(self, image_name):
        image = Image.open(image_name)
        image = self.transforms(image).float()
        return image

    def __len__(self):
        return len(self.image_list)


import matplotlib.pyplot as plt
MyDataset = SingleClassDataset(file_path="pikachu/")
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(MyDataset[i].numpy().transpose(1, 2, 0))

plt.show()



结果展示
在这里插入图片描述

在这里插入图片描述

评论将由博主筛选后显示,对所有人可见 | 还能输入1000个字符 “速评一下”
©️2020 CSDN 皮肤主题: 鲸 设计师:meimeiellie 返回首页
实付 79.90元
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值