详细推导过程见:https://www.jianshu.com/p/c990abda8059
核心算法如下:
# 设置参数
N = 300
K = 2
epoches = 20
# 构造训练集
def init(N):
mu1 = 1
sigma1 = 0.5
x1 = np.sort(np.random.normal(mu1, sigma1, N // 2))
mu2 = 5
sigma2 = 1.0
x2 = np.sort(np.random.normal(mu2, sigma2, N // 2))
x = [x1, x2]
return x
def show_data(x,arg):
plt.title('pdf of GMM')
plt.xlabel("x")
plt.ylabel("pdf")
plt.scatter(x[0], [0 for i in range(N // 2)], s=1, c='r')
y1 = np.exp(-1 * (x[0] - 1) ** 2 / (2 * (0.5 ** 2))) / (math.sqrt(2 * np.pi) * 0.5)
plt.plot(x[0], y1, label='Sample,mu=1,sigma=0.5', c='r')
plt.scatter(x[1], [0 for i in range(N // 2)], s=1, c='g')
y2 = np.exp(-1 * (x[1] - 5) ** 2 / 2) / (math.sqrt(2 * np.pi))
plt.plot(x[1], y2, label='Sample,mu=5,sigma=1', c='g')
mu,sigma = arg[0],arg[1]
color = ['black','b']
for i in range(2):
plt.plot(x[i],np.exp(-1 * (x[i] - mu[i]) ** 2 / (2 * (sigma[i] ** 2))) / (math.sqrt(2 * np.pi) * sigma[i]),\
label='GMM,mu={:.2f},sigma={:.2f}'.format(mu[i],sigma[i]), c = color[i])
plt.legend()
plt.grid()
plt.show()
def GMMs(x, K, epoches):
w = np.full((K,), 1 / K) # 每个模型的权重
mu = np.random.rand(2, )
sigma = np.random.rand(2, )
data = np.concatenate([x[0], x[1]])
for epoch in range(epoches):
# E-step
R = np.zeros((N, K))
for i in range(K):
gaussPdf = norm.pdf(data, mu[i], sigma[i])
R[:, i] = w[i] * gaussPdf
R /= np.sum(R, axis=1, keepdims=True)
# M-step
for i in range(K):
mu[i] = np.sum(R[:, i]*data, axis=0) / np.sum(R[:, i],axis=0)
sigma[i] = math.sqrt(np.sum(np.dot(R[:,i],(data-mu[i])**2))/np.sum(R[:,i],axis=0))
w[i] = np.sum(R[:,1],axis=0)/N
if mu[0] > mu[1]:
mu[0],mu[1] = mu[1],mu[0]
sigma[0],sigma[1] = sigma[1],sigma[0]
return mu,sigma
def show_model(arg,K):
mu,sigma = arg[0],arg[1]
for i in range(K):
print("第{:}个子模型的mu为{:.2f},sigma为{:.2f}"\
.format(i+1,mu[i],sigma[i]))
def main():
x = init(N)
arg = GMMs(x, K, epoches)
show_data(x,arg)
show_model(arg,K)
if __name__ == '__main__':
main()
2.2 多变量GMM
github: UCAS-MachineLearning-homework/GMM.py at master · 0809zheng/UCAS-MachineLearning-homework · GitHub
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)