预测所有的数字,现在我们将mnist.py文件里面从加载标签以后的代码做以修改。
def encode_digit(Y, digit):
encoded_Y = np.zeros_like(Y)
n_labels = Y.shape[0]
for i in range(n_labels):
if Y[i] == digit:
encoded_Y[i][0] = 1
return encoded_Y
TRAINING_LABELS = load_labels("../data/mnist/train-labels-idx1-ubyte.gz")
TEST_LABELS = load_labels("../data/mnist/t10k-labels-idx1-ubyte.gz")
Y_train = []
Y_test = []
for digit in range(10):
Y_train.append(encode_digit(TRAINING_LABELS, digit))
Y_test.append(encode_digit(TEST_LABELS, digit))
修改分类算法里的代码,如下:
# iterations为迭代次数
def train(X, Y, iterations, lr):
w = np.zeros((X.shape[1], 1)) # 初始化权重w
for i in range(iterations):
# print("Iteration %4d => Loss: %.20f" % (i, loss(X, Y, w)))
w -= gradient(X, Y, w) * lr
return w
def test(X, Y, w, digit):
total_examples = X.shape[0]
correct_results = np.sum(classify(X, w) == Y)
success_percent = correct_results * 100 / total_examples
print("Correct classifications for digit %d: %d/%d (%.2f%%)" %
(digit, correct_results, total_examples, success_percent))
for digit in range(10):
# 训练模型
w = train(mi.X_train, mi.Y_train[digit], iterations=100, lr=1e-5)
# 模型测试
test(mi.X_test, mi.Y_test[digit], w, digit)
运行后,得到如下结果:
我们发现预测数字8的效果显然没有其它数字的效果好,进一步地说明我们的训练算法还有很大地改进空间,不过这样的结果还能接受!
参考文献:
Programming Machine Learning: Form Coding to Deep Learning.[M],Paolo Perrotta,2021.6.
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)