Python实现EM算法解GMM问题

Python实现EM算法解GMM问题,第1张

1. 高斯混合模型与EM算法的推导

详细推导过程见:https://www.jianshu.com/p/c990abda8059
核心算法如下:

2. Python实现 2.1 单变量GMM
# 设置参数
N = 300
K = 2
epoches = 20


# 构造训练集
def init(N):
    mu1 = 1
    sigma1 = 0.5
    x1 = np.sort(np.random.normal(mu1, sigma1, N // 2))
    mu2 = 5
    sigma2 = 1.0
    x2 = np.sort(np.random.normal(mu2, sigma2, N // 2))
    x = [x1, x2]
    return x


def show_data(x,arg):
    plt.title('pdf of GMM')
    plt.xlabel("x")
    plt.ylabel("pdf")
    plt.scatter(x[0], [0 for i in range(N // 2)], s=1, c='r')
    y1 = np.exp(-1 * (x[0] - 1) ** 2 / (2 * (0.5 ** 2))) / (math.sqrt(2 * np.pi) * 0.5)
    plt.plot(x[0], y1, label='Sample,mu=1,sigma=0.5', c='r')

    plt.scatter(x[1], [0 for i in range(N // 2)], s=1, c='g')
    y2 = np.exp(-1 * (x[1] - 5) ** 2 / 2) / (math.sqrt(2 * np.pi))
    plt.plot(x[1], y2, label='Sample,mu=5,sigma=1', c='g')
    mu,sigma = arg[0],arg[1]
    color = ['black','b']
    for i in range(2):
        plt.plot(x[i],np.exp(-1 * (x[i] - mu[i]) ** 2 / (2 * (sigma[i] ** 2))) / (math.sqrt(2 * np.pi) * sigma[i]),\
                 label='GMM,mu={:.2f},sigma={:.2f}'.format(mu[i],sigma[i]), c = color[i])
    plt.legend()
    plt.grid()
    plt.show()

def GMMs(x, K, epoches):
    w = np.full((K,), 1 / K)  # 每个模型的权重
    mu = np.random.rand(2, )
    sigma = np.random.rand(2, )
    data = np.concatenate([x[0], x[1]])
    for epoch in range(epoches):
        # E-step
        R = np.zeros((N, K))
        for i in range(K):
            gaussPdf = norm.pdf(data, mu[i], sigma[i])
            R[:, i] = w[i] * gaussPdf
        R /= np.sum(R, axis=1, keepdims=True)
        # M-step
        for i in range(K):
            mu[i] = np.sum(R[:, i]*data, axis=0) / np.sum(R[:, i],axis=0)
            sigma[i] = math.sqrt(np.sum(np.dot(R[:,i],(data-mu[i])**2))/np.sum(R[:,i],axis=0))
            w[i] = np.sum(R[:,1],axis=0)/N
    if mu[0] > mu[1]:
        mu[0],mu[1] = mu[1],mu[0]
        sigma[0],sigma[1] = sigma[1],sigma[0]
    return mu,sigma

def show_model(arg,K):
    mu,sigma = arg[0],arg[1]
    for i in range(K):
        print("第{:}个子模型的mu为{:.2f},sigma为{:.2f}"\
              .format(i+1,mu[i],sigma[i]))

def main():
    x = init(N)
    arg = GMMs(x, K, epoches)
    show_data(x,arg)
    show_model(arg,K)

if __name__ == '__main__':
    main()
2.2 多变量GMM 

github: UCAS-MachineLearning-homework/GMM.py at master · 0809zheng/UCAS-MachineLearning-homework · GitHub

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/738060.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-28
下一篇 2022-04-28

发表评论

登录后才能评论

评论列表(0条)

保存