当时看老师留的作业的时候发现要求读取历史数据,但网上有没找到,我自己找了手册啥的,查到了几个参数,就分享下,希望各位大佬不要看不上哈...
那就先把程序从零实现:
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import keras from tensorflow.keras import * gpus = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_visible_devices(gpus[0], True) file_data = np.load(r'E:/mnist.npz') x_train= file_data['x_train'] y_train= file_data['y_train'] x_test= file_data['x_test'] y_test= file_data['y_test']
先读取一些包,然后设置自己的gpu,如果预先下载好数据集的话可以放在自己固定放程序的文件夹内,修改上面的绝对路径就好。
下面是参数的初始化,将输入变成0.0-1.0之间浮点数,便于运算和加速收敛(当然,计算性能现今提升了好多,这个小程序没啥影响。)当时我们同学有人整个程序运行完会报错,提示Dim不匹配,如果不匹配的话就把下面注释的取消掉,reshape一下就解决了。
X_train = tf.cast(x_train, dtype=tf.float32)/ 255.0 X_test = tf.cast(x_test, dtype=tf.float32)/ 255.0 # X_train = X_train.reshaper((60000, 28, 28, 1)) # X_test = X_train.reshaper((10000, 28, 28, 1)) print (X_train.shape)
下面就是固定的那些格式了,没什么东西,都API什么的,主要是Model.fit返回一个参数,用来可视化,他是保存历史数据的,ACC和Loss,程序里面的那些超参数都可以修改,当然,那些优化方法我用的是'adam',loss是‘sparse_categorical_crossentropy’,同理都可以修改的。
model = tf.keras.Sequential([ tf.keras.layers.Conv2D(16, kernel_size= (3,3), padding = 'same', activation= tf.nn.relu, input_shape= (28, 28, 1)), tf.keras.layers.MaxPool2D(pool_size = (2,2)), tf.keras.layers.Conv2D(16, kernel_size= (3,3), padding = 'same',activation= tf.nn.relu), tf.keras.layers.MaxPool2D(pool_size= (2,2)), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation= 'relu'), tf.keras.layers.Dense(10, activation= 'softmax') ]) print (model.summary()) model.compile (optimizer= 'adam', loss='sparse_categorical_crossentropy', metrics=["sparse_categorical_accuracy"]) history = model.fit(X_train, y_train, batch_size = 100, epochs =5, validation_split=0.20, shuffle=True) model.evaluate(X_test, y_test,verbose= 2)
History返回了4个参数,分别是['loss', 'sparse_categorical_accuracy', 'val_loss', 'val_sparse_categorical_accuracy'],就是训练集的损失和精确度,和测试集的损失和精度。
下面是用matplotlib可视化,这个就没什么东西了,显示一下训练的结果。
plt.figure() plt.subplot(2,2,1) plt.title('Loss') plt.plot(history.history['loss'],label='Validation bit_error') plt.subplot(2,2,2) plt.title('Sparse Categorical Accuracy') plt.plot(history.history['sparse_categorical_accuracy'],label='Validation bit_error') plt.subplot(2,2,3) plt.title('Val Loss') plt.plot(history.history['val_loss'],label='Validation bit_error') plt.subplot(2,2,4) plt.title('Val Sparse Categorical Accuracy') plt.plot(history.history['val_sparse_categorical_accuracy'],label='Validation bit_error') plt.show()
下面是显示的结果和图形。
好了,这里就结束了,谢谢观看,嘿嘿
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)