从零搭建音乐识别系统(三)音乐分类模型

从零搭建音乐识别系统(三)音乐分类模型,第1张

从零搭建音乐识别系统(三)音乐分类模型


        经过上一篇的介绍,我们已经获得了AudioSet开源数据集,并且将数据集分为两类:音乐类和非音乐类。将每段10秒的音频提取得到[64, 1001]大小的特征矩阵。接下来就是使用这些特征矩阵来训练一个二分类模型,用来识别音频片段是否是音乐片段。

        参考自成熟的图像分类网络模型,我们采用Pytorch框架训练ResNet分类模型,分类模型的训练主要包括以下几个核心步骤:

1、训练数据预处理

        由于我们的每条训练数据是保存在npy文件里面的,所以我们需要自定义一个预处理npy数据文件的方法。

        npy数据文件读取:

class MusicNpyDataset(Dataset):
    def __init__(self, filenames, labels, transform):
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        image = np.load(self.filenames[idx]).astype(np.float32)
        # 分类模型使用[64, 251]大小的特征矩阵进行训练
        image = image[:, ::4]
        image = self.transform(image)
        return image, self.labels[idx]

        数据归一化、数据增强:

from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import torchvision


class UpDownFlipTransform(object):
    """
    数据增强,对输入矩阵进行上下翻转
    """
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        random_prob = np.random.uniform(0, 1)
        if 1.0 - random_prob < self.p:
            img = img[::-1, :].copy()
        return img


class LeftRightFlipTransform(object):
    """
    数据增强,对输入矩阵进行左右翻转
    """
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        random_prob = np.random.uniform(0, 1)
        if 1.0 - random_prob < self.p:
            img = img[:, ::-1].copy()
        return img


class ZeronormalizeTransform(object):
    """
    数据增强,对输入矩阵进行零均值归一化
    """
    def __call__(self, img):
        img_mean = img.mean()
        img = img - img_mean
        return img


train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        LeftRightFlipTransform(p=0.5),
        ZeronormalizeTransform()
    ]
)
val_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        ZeronormalizeTransform()
    ]
)

2、训练数据加载

        定义好了数据预处理和数据增强方法之后,使用DataLoader加载训练数据。

def fetch_data(data_path, ratio, batch_size, train_transforms, val_transforms, num_workers=4, seed=100):
    random.seed(seed)

    # 使用torchvision解析磁盘上按类别存放好的npy文件
    dataset = torchvision.datasets.DatasetFolder(data_path, loader=np.load, extensions=("npy",))
    classes = dataset.classes
    character = [[] for _ in range(len(classes))]
    random.shuffle(dataset.samples)

    for x, y in dataset.samples:
        character[y].append(x)

    # 查看每个类别的总样本数量
    for i in range(len(classes)):
        print("{}: {}".format(classes[i], len(character[i])))

    # 根据ratio比例,划分训练集、验证集、测试集
    train_inputs, val_inputs, test_inputs = [], [], []
    train_labels, val_labels, test_labels = [], [], []
    for i, data in enumerate(character):
        num_sample_train = int(len(data) * ratio[0])
        num_sample_val = int(len(data) * ratio[1])
        num_val_index = num_sample_train + num_sample_val

        # 训练集
        for x in data[:num_sample_train]:
            train_inputs.append(str(x))
            train_labels.append(i)
        # 验证集
        for x in data[num_sample_train:num_val_index]:
            val_inputs.append(str(x))
            val_labels.append(i)
        # 测试集
        for x in data[num_val_index:]:
            test_inputs.append(str(x))
            test_labels.append(i)

    train_dataloader = DataLoader(MusicNpyDataset(train_inputs, train_labels, train_transforms), batch_size=batch_size,
                                  num_workers=num_workers, drop_last=True, shuffle=True)
    val_dataloader = DataLoader(MusicNpyDataset(val_inputs, val_labels, val_transforms), batch_size=batch_size,
                                drop_last=False, shuffle=False, num_workers=num_workers)
    test_dataloader = DataLoader(MusicNpyDataset(test_inputs, test_labels, val_transforms), batch_size=batch_size,
                                 drop_last=False, shuffle=False, num_workers=num_workers)
    return train_dataloader, val_dataloader, test_dataloader, classes

3、网络模型构建

        数据处理好了之后,下面就是要定义网络模型,将数据输入到模型中进行训练,我们使用ResNet18作为分类模型,由于原始的ResNet18模型是处理三通道大小为224 x 224的图像数据的,而我们的输入数据是一通道的64 x 251的矩阵,所以要对ResNet18网络模型进行一些改造,主要有以下三个改造点:

        1、把ResNet18的第一层卷积层更换成可以处理一通道输入的卷积层

        2、把ResNet18模型最后的平均池化层更换更自适应平局池化层(这一步仅在Pytorch版本过低的时候需要),从而使得网络模型能够处理64 x 251大小的输入数据

        3、修改网络模型的输出层,将输出层的类别数修改为2

        ResNet18网络代码如下:

def build_model(classes, pretrained=True):
    model = torchvision.models.resnet18(pretrained=pretrained)
    model.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(in_features=512, out_features=len(classes))

    nn.init.kaiming_normal_(model.conv1.weight, mode="fan_in", nonlinearity="relu")
    nn.init.kaiming_normal_(model.fc.weight, mode="fan_in", nonlinearity="relu")
    return model

4、定义验证集、测试集性能评价指标

        有了网络模型之后,我们还需要定义模型训练的评价指标,这里我们采用准确率来衡量。

def evaluate_accuracy(data_loader, model):
    acc_sum = 0.0
    n = 0

    for X, y in data_loader:
        X = X.cuda()
        y = y.cuda()
        y_pred = model(X)
        acc_sum += (y_pred.argmax(dim=1) == y).float().mean().item()
        n += 1
    return acc_sum / n

5、误差分析

        上述流程基本上涵盖了训练一个完整分类器的所有步骤,但是呢?吴恩达大佬不止一次告诫我们,在训练模型时要注重误差分析,要熟悉了解自己的数据,所以我们这里也对模型在训练和验证过程中的误差进行分析,这里分析过程主要是分析模型在训练集和验证集上预测出错的数据样本,下面的代码会记录训练集和验证集中预测出错的样本,也就是误差的来源,方便我们定位问题。

def error_analysis(X, y, y_pred, classes):
    """
    误差分析
    :param X:
    :param y:
    :param y_pred:
    :param classes:
    :return:
    """
    # 计算预测出错的样本
    failure_index = y_pred.argmax(dim=1) != y
    failure_data = X[failure_index]
    failure_label = y_pred.argmax(dim=1)[failure_index]
    true_label = y[failure_index]

    failure_classes = []
    true_classes = []
    for label1, label2 in zip(failure_label.cpu().numpy(), true_label.cpu().label()):
        failure_classes.append(classes[label1])
        true_classes.append(classes[label2])
    print(failure_classes)
    print(true_classes)

        本篇用到的相关方法具体可以参考使用pytorch开发一套完整的图像分类finetune代码(包括模型训练、测试、误差分析),后续也会将整理好的完整代码发出来。

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5700582.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-17
下一篇 2022-12-17

发表评论

登录后才能评论

评论列表(0条)

保存