pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合
自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:
- from torch.utils.data import Dataset
-
-
- class AudioDataset(Dataset):
- def __init__(self, ...):
- """类的初始化"""
- pass
-
- def __getitem__(self, item):
- """每次怎么读数据,返回数据和标签"""
- return data, label
-
- def __len__(self):
- """返回整个数据集的长度"""
- return total
注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本
案例:
文件目录结构
目的:读取p225文件夹中的音频数据
- class AudioDataset(Dataset):
- def __init__(self, data_folder, sr=16000, dimension=8192):
- self.data_folder = data_folder
- self.sr = sr
- self.dim = dimension
-
- # 获取音频名列表
- self.wav_list = []
- for root, dirnames, filenames in os.walk(data_folder):
- for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
- self.wav_list.append(os.path.join(root, filename))
-
- def __getitem__(self, item):
- # 读取一个音频文件,返回每个音频数据
- filename = self.wav_list[item]
- wb_wav, _ = librosa.load(filename, sr=self.sr)
-
- # 取 帧
- if len(wb_wav) >= self.dim:
- max_audio_start = len(wb_wav) - self.dim
- audio_start = np.random.randint(0, max_audio_start)
- wb_wav = wb_wav[audio_start: audio_start + self.dim]
- else:
- wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
-
- return wb_wav, filename
-
- def __len__(self):
- # 音频文件的总数
- return len(self.wav_list)
注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
实例化 Dataset 对象
- Dataset= AudioDataset("./p225", sr=16000)
如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作
- # 实例化AudioDataset对象
- train_set = AudioDataset("./p225", sr=16000)
-
- for i, data in enumerate(train_set):
- wb_wav, filname = data
- print(i, wb_wav.shape, filname)
-
- if i == 3:
- break
- # 0 (8192,) ./p225\p225_001.wav
- # 1 (8192,) ./p225\p225_002.wav
- # 2 (8192,) ./p225\p225_003.wav
- # 3 (8192,) ./p225\p225_004.wav
如果想要通过batch读取数据,需要使用DataLoader进行包装
为何要使用DataLoader?
pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。
- DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
参数:
返回:数据加载器
案例:
- # 实例化AudioDataset对象
- train_set = AudioDataset("./p225", sr=16000)
- train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
-
- for (i, data) in enumerate(train_loader):
- wav_data, wav_name = data
- print(wav_data.shape) # torch.Size([8, 8192])
- print(i, wav_name)
- # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
- # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
我们来吃几个栗子消化一下:
这个例子就是本文一直举例的,栗子1只是合并了一下而已
文件目录结构
目的:读取p225文件夹中的音频数据
- import fnmatch
- import os
- import librosa
- import numpy as np
- from torch.utils.data import Dataset
- from torch.utils.data import DataLoader
-
-
- class Aduio_DataLoader(Dataset):
- def __init__(self, data_folder, sr=16000, dimension=8192):
- self.data_folder = data_folder
- self.sr = sr
- self.dim = dimension
-
- # 获取音频名列表
- self.wav_list = []
- for root, dirnames, filenames in os.walk(data_folder):
- for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
- self.wav_list.append(os.path.join(root, filename))
-
- def __getitem__(self, item):
- # 读取一个音频文件,返回每个音频数据
- filename = self.wav_list[item]
- print(filename)
- wb_wav, _ = librosa.load(filename, sr=self.sr)
-
- # 取 帧
- if len(wb_wav) >= self.dim:
- max_audio_start = len(wb_wav) - self.dim
- audio_start = np.random.randint(0, max_audio_start)
- wb_wav = wb_wav[audio_start: audio_start + self.dim]
- else:
- wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
-
- return wb_wav, filename
-
- def __len__(self):
- # 音频文件的总数
- return len(self.wav_list)
-
-
- train_set = Aduio_DataLoader("./p225", sr=16000)
- train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
-
-
- for (i, data) in enumerate(train_loader):
- wav_data, wav_name = data
- print(wav_data.shape) # torch.Size([8, 8192])
- print(i, wav_name)
- # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
- # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
注意事项:
- 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
- 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意
相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。
我给出以下几种建议:
建议一:
如果你模型需要读取的不是简单的音频,而是经过较复杂特征处理后的数据,特征处理还挺需要时间的,我建议你用这种方法
先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。
具体实现代码:
第一步:创建一个H5_generation脚本,读取语音并进行特征处理,最后将特征转换为h5格式文件。(大家根据自己的研究领域进行相应的特征提取,我这个是语音频带扩展的窄带和宽带特征提取代码,你们能看懂我想要表达的思想就行):
H5_generation.py
第二步:通过Dataset从h5格式文件中读取数据
- import numpy as np
- from torch.utils.data import Dataset
- from torch.utils.data import DataLoader
- import h5py
-
- def load_h5(h5_path):
- # load training data
- with h5py.File(h5_path, 'r') as hf:
- print('List of arrays in input file:', hf.keys())
- X = np.array(hf.get('data'), dtype=np.float32)
- Y = np.array(hf.get('label'), dtype=np.float32)
- return X, Y
-
-
- class AudioDataset(Dataset):
- """数据加载器"""
- def __init__(self, data_folder):
- self.data_folder = data_folder
- self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)
-
- def __getitem__(self, item):
- # 返回一个音频数据
- X = self.X[item]
- Y = self.Y[item]
-
- return X, Y
-
- def __len__(self):
- return len(self.X)
-
-
- train_set = AudioDataset("./speaker225_resample_train.h5")
- train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
-
-
- for (i, wav_data) in enumerate(train_loader):
- X, Y = wav_data
- print(i, X.shape)
- # 0 torch.Size([64, 8192, 1])
- # 1 torch.Size([64, 8192, 1])
- # ...
我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了
建议二:
如果你的模型输入就是语音波形,或者特征处理非常简单,我强烈建议你一步到位,不要去什么生成h5文件,
- import os
- import time
- import numpy as np
- from torch.utils.data import Dataset, DataLoader
- import librosa
-
-
- class AudioData(Dataset):
- def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"):
- super(AudioData, self).__init__()
- self.dimension = dimension
- self.stride = stride
- self.scale = scale
- self.fs = fs
- self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)]
- self.wb_list = []
-
- self.split()
-
- def get_nb(self, wb_wav):
- nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采样率 16000-->8000
- nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采样率 8000-->16000,并不恢复高频部分
- return nb_wav
-
- def split(self):
- for wav_path in self.wavs_path:
- wav, _ = librosa.load(path=wav_path, sr=self.fs)
- wav_length = len(wav) # 音频长度
- if wav_length < self.stride: # 如果语音长度小于4096
- continue
- if wav_length < self.dimension: # 如果语音长度小于8192
- diffe = self.dimension - wav_length
- wb_wav = np.pad(wav, (0, diffe), mode="constant")
- self.wb_list.append(wb_wav)
- else: # 如果音频大于 8192
- start_index = 0
- while True:
- if start_index + self.dimension > wav_length:
- break
-
- wb_frame = wav[start_index:start_index + self.dimension]
- self.wb_list.append(wb_frame)
- start_index += self.stride
-
- def __len__(self):
- return len(self.wb_list)
-
- def __getitem__(self, index):
- return self.wb_list[index], self.get_nb(self.wb_list[index])
-
-
- if __name__ == "__main__":
- start_time = time.time()
- data = AudioData()
-
- print(len(data)) # 3420
- train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True)
- end_time = time.time()
- print("用了%d的时间" % (end_time-start_time)) # 24秒
-
- for wb, nb in train_loader:
- print("宽带", wb.shape) # torch.Size([32, 8192])
- print("窄带", nb.shape) # torch.Size([32, 8192])
- break
建议二的低效版:
看完了建议二,不看这个版本也行,但是为了让大家思考如果更加高效的
- # Author:凌逆战
- # -*- coding:utf-8 -*-
- """
- 作用:
- """
- import os
- import time
-
- import numpy as np
- from torch.utils.data import Dataset, DataLoader
- import librosa
-
-
- class AudioData(Dataset):
- def __init__(self, dimension=8192, stride=4096, fs=16000, scale=2, data_path="./train"):
- super(AudioData, self).__init__()
- self.dimension = dimension
- self.stride = stride
- self.scale = scale
- self.fs = fs
- self.wavs_path = [os.path.join(data_path, wav_name) for wav_name in os.listdir(data_path)]
- self.wb_list = []
- self.nb_list = []
-
- self.preprocess()
-
-
- def get_nb(self, wb_wav):
- nb_wav = librosa.core.resample(wb_wav, self.fs, self.fs / self.scale) # 下采样率 16000-->8000
- nb_wav = librosa.core.resample(nb_wav, self.fs / self.scale, self.fs) # 上采样率 8000-->16000,并不恢复高频部分
- return nb_wav
-
- def preprocess(self):
- for wav_path in self.wavs_path:
- wav, _ = librosa.load(path=wav_path, sr=self.fs)
- wav_length = len(wav) # 音频长度
- if wav_length < self.stride: # 如果语音长度小于4096
- continue
- if wav_length < self.dimension: # 如果语音长度小于8192
- diffe = self.dimension - wav_length
-
- wb_wav = np.pad(wav, (0, diffe), mode="constant")
- nb_wav = self.get_nb(wb_wav)
-
- self.wb_list.append(wb_wav)
- self.nb_list.append(nb_wav)
- else: # 如果音频大于 8192
- start_index = 0
- while True:
- if start_index + self.dimension > wav_length:
- break
-
- wb_frame = wav[start_index:start_index + self.dimension]
- nb_frame = self.get_nb(wb_frame)
-
- self.wb_list.append(wb_frame)
- self.nb_list.append(nb_frame)
- start_index += self.stride
-
- def __len__(self):
- return len(self.wb_list)
-
- def __iter__(self):
- for index in range(len(self.wb_list)):
- yield self.wb_list[index], self.nb_list[index]
-
- def __getitem__(self, index):
- return self.wb_list[index], self.nb_list[index]
-
-
- if __name__ == "__main__":
- start_time = time.time()
- data = AudioData()
-
- print(len(data)) # 3420
- train_loader = DataLoader(data, batch_size=32, shuffle=True, drop_last=True)
- end_time = time.time()
- print("用了%d的时间" % (end_time-start_time)) # 61秒
-
- for wb, nb in train_loader:
- print("宽带", wb.shape)
- print("窄带", nb.shape)
- break
这个方法用了61秒完成数据读取,原因是什么大家可以自己去思考,不建议用这个方法