pytorch自定义Dataset实现数据集迭代器

    今天实践一个小功能,完成pytorch读取文件夹中的wav语音文件来迭代输出,作为神经网络的数据集dataset。再前期使用tensorflow来读取大量wav文件时发现要自己封装,过程较为复杂,接口也较为乱,转到pytorch后发现框架更加pythonic。

    在pytorch中,提供了一种十分方便的数据读取机制,即使用torch.utils.data.Dataset与Dataloader组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增强操作。

    torch.utils.data.Dataset: 所有的类都应该是此类的子类(也就是说应该继承该类), 所有的子类都要重写(override) __len()__, __getitem()__ 这两个方法。

    __len()__此方法应该提供数据集的大小(容量)。

    __getitem()__此方法应该提供支持下标索方式引访问数据集,还可以在__getitem__时对数据进行预处理。

    torch.utils.data.DataLoader:对数据集进行包装,可以设置批次大小batch_size、是否打乱数据shuffle等。

    接下来就是笔者测试将wav文件夹下的wav文件的梅尔频谱图作为数据集迭代器,为后面大数据集输入神经网络做准备。

    1、定义AudioDataset类,重写 __len()__和  __getitem()__方法。__getitem()__方法中通过librosa提取wav文件的melspectrogram作为数据输出。同时pytorchaudio也可以加载wav文件并提取melspectrogram。

    2、实例化AudioDataset,操作"./_assets//vhf_wave_split_output"文件夹。

    3、输出AudioDataset迭代结果,得到(128, 47)的梅尔频谱图。

4、DataLoader对数据集迭代器AudioDataset进行封装,批次大小为2,并且打乱数据集。DataLoader之后的数据集完成数据特征工程,可以作为神经网络的输入。

版权声明:
作者:congcong
链接:https://www.techfm.club/p/47995.html
来源:TechFM
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
< <上一篇
下一篇>>