stone

教你花式读数据(keras&&pytorch)
对于一个DPer来说,数据的预处理是一件充满恶意的事情,正确的数据读取方式可以让程序跑得飞起,而且对于一些特殊的模...
扫描右侧二维码阅读全文
27
2018/11

教你花式读数据(keras&&pytorch)

对于一个DPer来说,数据的预处理是一件充满恶意的事情,正确的数据读取方式可以让程序跑得飞起,而且对于一些特殊的模型,多输入多输出的模型,则更为棘手。很多人跟着教程写了几个深度学习的例子,用框架提供的几个读取数据的函数可以很轻易获取到那些经典的数据集,但是大多数情况下,数据集都是自己,做的,而且比较大,肯定不能直接加载到内存中。

所以!!

正确掌握数据的读取,真的可以为所欲为的!

最近在keras和pytorch之间游走,下面从这两个框架的角度来看看怎么高效地读取数据。

首先说明一下基本的数据集储存结构,一般为三级目录,如下:

/dataset-root
├── 1
├── 2
├── 3
├── 4
└── 5

每一个样本按照对应的类别放到文件中,同时建议在每一个类别或者root目录下用一个list文件储存样本的名字,有了这个list文件,我们可以动态选择样本或者做什么其他的处理,这都很舒服的。

读取数据的时候,要注意怎么异步加载数据,提高读取的性能。

Keras

keras中的基本类ImageDataGenerator就不细说了,主要看看Sequence这个大杀器。

Sequence主要重写getitem和len两个方法,getitem直接返回一个batch,结构为x,y,如果有多个输入输出,则用list来包装一下,不要用tuple,这个不能识别。这里有两个坑,一个是getitem获取的id是batch的id,不是sample的id,另外len返回的是总共有多少个batch,要向上取整,这个会影响迭代的次数的。

另外,自己定义的Sequence没有自动transform的,要自己做归一化和数据增强的,总得来看是没有pytorch的数据接口来得方便,不知道keras后面会不会优化一下这些问题。

配合Sequence类,使用fit_generater,其中use_multiprocessing,据说是可以开启异步加载数据的,具体性能没有测过,但是GPU利用率没有pytorch高。

class MyDataset(keras.utils.Sequence):
    def __init__(self, data_path, mode='train', batch_size=32):
        self.x = []
        self.y = []
        self.batch_size = batch_size
        self.mode = mode
        with open(data_path, 'r') as f:
            for line in f:
                xx, yy = line.strip().split(',')
                self.x.append(xx)
                self.y.append(int(yy))

        self.size = len(self.y)

    def __getitem__(self, index):
        begin_idx = index*self.batch_size
        end_idx = (index+1)*self.batch_size
        img_names = self.x[begin_idx:end_idx]
        labels = self.y[begin_idx:end_idx]

        wsi_imgs = []
        gaussian_imgs = []
        sobel_imgs = []
        for img_name in img_names:
            a, b, c = self.__get_single_image(img_name)
            wsi_imgs.append(a)
            gaussian_imgs.append(b)
            sobel_imgs.append(c)

        return [np.array(wsi_imgs), np.array(gaussian_imgs), np.array(sobel_imgs)], np.array(labels)

    def __get_single_image(self, img_name):
        pass

    def __len__(self):
        # return int(len(self.x))
        return int(np.ceil(len(self.x) / float(self.batch_size)))
        
train_dataset = MyDataset(
    "/path/a", mode='train', batch_size=20)
valid_dataset = MyDataset(
    "/path/b", mode='valid', batch_size=20)

callbacks = [
    keras.callbacks.EarlyStopping(patience=30),
    keras.callbacks.ReduceLROnPlateau(patience=10, min_lr=1e-7, verbose=1),
    keras.callbacks.TensorBoard(log_dir="./log/base", update_freq='batch'),
    # keras.callbacks.ModelCheckpoint("./saved_model/base.h5")
]

model.fit_generator(train_dataset,
                    validation_data=valid_dataset,
                    workers=10,
                    epochs=500, use_multiprocessing=True, callbacks=callbacks)

Pytorch

强力安利pytorch,整体用下来非常的舒服。特别是数据读取部分和模型定义部分。

自带的图片读取如下

transform = transforms.Compose(
    [transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize((0.64639061, 0.56044774, 0.61909978), (0.1491973, 0.17535066, 0.12751725))])

train_data = datasets.ImageFolder(
    "/path/a", transform=transform)
train_loader = DataLoader(train_data, batch_size=64,
                          shuffle=True, pin_memory=True, num_workers=10)

for inputs,labels in train_loader:
    inputs = inputs.to('cuda:0',non_blocking=False)
    labels = labels.to('cuda:0',non_blocking=False)

值得注意的是pin_memorynon_blocking,这两个参数可以开启数据的异步加载,效果显著,基本可以将GPU的性能压榨干。pin_memory是在显存中预留空间。

同样,对于一些复杂的情况,也是要自己定义数据读取的类,pytorch是利用torch.utils.data.Dataset来实现的,例子如下。关键也是两个方法getitem和len。和keras不同的是每次只返回一个sample,index就是每一个sample的id。数据归一化和数据增强也是要自己做的。

class MyDataset(Dataset):
    def __init__(self, data_path, mode='train'):
        self.x = []
        self.y = []
        self.mode = mode
        with open(data_path, 'r') as f:
            for line in f:
                xx, yy = line.strip().split(',')
                self.x.append(xx)
                self.y.append(int(yy))

        self.size = len(self.y)

    def __getitem__(self, index):
        img_name = self.x[index]
        wsi_img_path = os.path.join(base_path % ('wsi', self.mode), img_name)
        gaussian_img_path = os.path.join(
            base_path % ('gaussian', self.mode), img_name)
        sobel_img_path = os.path.join(
            base_path % ('sobel', self.mode), img_name)

        #..........
        # 省略若干行
        #..........
        return wsi_img, gaussian_img, sobel_img, self.y[index]

    def __len__(self):
        return self.size

valid_data = MyDataset(
    "/path/a", mode='valid')
valid_loader = DataLoader(valid_data, batch_size=128,
                          shuffle=True, pin_memory=True, num_workers=10)

summary

基本上,掌握了自定义数据加载方式的方法,就天下无敌了,随便你怎么加载数据,加载多少数据,要对数据做什么奇怪的操作都可以。

正确掌握数据的读取,真的可以为所欲为的!

Last modification:November 27th, 2018 at 09:50 pm
If you think my article is useful to you, please feel free to appreciate

Leave a Comment