pytorch自定义Dataset实现数据集迭代器
今天实践一个小功能,完成pytorch读取文件夹中的wav语音文件来迭代输出,作为神经网络的数据集dataset。再前期使用tensorflow来读取大量wav文件时发现要自己封装,过程较为复杂,接口也较为乱,转到pytorch后发现框架更加pythonic。
在pytorch中,提供了一种十分方便的数据读取机制,即使用torch.utils.data.Dataset与Dataloader组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个batch数据,并能在输出时对数据进行相应的预处理或数据增强操作。
torch.utils.data.Dataset: 所有的类都应该是此类的子类(也就是说应该继承该类), 所有的子类都要重写(override) __len()__, __getitem()__ 这两个方法。
__len()__此方法应该提供数据集的大小(容量)。
__getitem()__此方法应该提供支持下标索方式引访问数据集,还可以在__getitem__时对数据进行预处理。
torch.utils.data.DataLoader:对数据集进行包装,可以设置批次大小batch_size、是否打乱数据shuffle等。
接下来就是笔者测试将wav文件夹下的wav文件的梅尔频谱图作为数据集迭代器,为后面大数据集输入神经网络做准备。
1、定义AudioDataset类,重写 __len()__和 __getitem()__方法。__getitem()__方法中通过librosa提取wav文件的melspectrogram作为数据输出。同时pytorchaudio也可以加载wav文件并提取melspectrogram。
2、实例化AudioDataset,操作"./_assets//vhf_wave_split_output"文件夹。
3、输出AudioDataset迭代结果,得到(128, 47)的梅尔频谱图。
4、DataLoader对数据集迭代器AudioDataset进行封装,批次大小为2,并且打乱数据集。DataLoader之后的数据集完成数据特征工程,可以作为神经网络的输入。
共有 0 条评论