除了提供您问题的答案外,我还将使代码更像
TF2.0。如果您有任何疑问/需要澄清,请在下面发表评论。1.加载数据
我建议使用Tensorflow
Datasets库。如果可以在一行中完成数据,则绝对不需要加载数据
numpy并将其转换为
tf.data.Dataset:
import tensorflow_datasets as tfdsdataset = tfds.load("mnist", as_supervised=True, split=tfds.Split.TRAIN)
上面的行只会返回
TRAINsplit(在此处了解更多信息)。2.定义扩充和摘要
为了保存图像,必须在每次遍历期间保留tf.summary.SummaryWriter对象。
我创建了一个带有
__call__方法的便捷包装类,可轻松使用
tf.data.Dataset的
map功能:
import tensorflow as tfclass ExampleAugmentation: def __init__(self, logdir: str, max_images: int, name: str): self.file_writer = tf.summary.create_file_writer(logdir) self.max_images: int = max_images self.name: str = name self._counter: int = 0 def __call__(self, image, label): augmented_image = tf.image.random_flip_left_right( tf.image.random_flip_up_down(image) ) with self.file_writer.as_default(): tf.summary.image( self.name, augmented_image, step=self._counter, max_outputs=self.max_images, ) self._counter += 1 return augmented_image, label
name将是保存图像各部分的名称。您可能会问哪一部分-由定义的部分
max_outputs。
说
image在
__call__将具有形状
(32, 28, 28,1),其中第一维是分批,第二宽度,第三高度和最后通道(MNIST的情况下,只有1升但需要在此维度
tf.image扩充)。此外,假设
max_outputs指定为
4。在这种情况下,批次中仅保存4张第一张图像。默认值为
3,因此您可以将其设置为
BATCH_SIZE保存每张图像。
在中
Tensorboard,每个图像将是一个单独的样本,您可以在其上最后进行迭代。
_counter是必需的,这样图像 就不会 被覆盖(我想,不是很确定,其他人的澄清会很好)。
重要提示: 您可能希望将此类重命名为类似的 *** 作,例如
ImageSaver,进行更严重的商务活动并将增强移动到单独的函子/
lambda函数时。我猜这足以满足演示目的。3.设置全局变量
请不要混合使用函数声明,全局变量,数据加载和其他内容
(例如加载数据并随后创建函数)。我知道
TF1.0鼓励这种类型的编程,但是他们正设法摆脱这种编程,您可能想跟上潮流。
我在下面定义了一些全局变量,这些变量将在接下来的部分中使用,我想这很容易解释:
4.数据集扩充BATCH_SIZE = 32DATASET_SIZE = 60000EPOCHS = 5LOG_DIR = "/logs/images"AUGMENTATION = ExampleAugmentation(LOG_DIR, max_images=4, name="Images")
与您的相似,但略有不同:
dataset = ( dataset.map( lambda image, label: ( tf.image.convert_image_dtype(image, dtype=tf.float32), label, ) ) .batch(BATCH_SIZE) .map(AUGMENTATION) .repeat(EPOCHS))
repeat
因为加载的数据集是生成器,所以需要tf.image.convert_image_dtype
-比显式tf.cast
混合除法更好和更具可读性的选项255
(并确保正确的图像格式)- 为了演示起见,在增强之前完成了批处理
几乎与您在示例中所做的一样,但是我提供了额外的
steps_per_epoch,因此
fit知道构成一个纪元的批次数:
model = tf.keras.models.Sequential( [ tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation="softmax"), ])model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])model.fit( dataset, epochs=EPOCHS, steps_per_epoch=DATASET_SIZE // BATCH_SIZE, callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR)],)
除了我认为的以外,没有太多其他解释。
6.运行Tensorboard由于
TF2.0可以在colab中使用进行 *** 作
%tensorboard --logdir/logs/images,因此只想为可能访问此问题的其他人添加此功能。随心所欲地做,无论如何,您肯定会知道如何做。
图像应在内部,
IMAGES并且每个由命名的样本都应
name提供给
AUGMENTATION对象。7.完整代码(使每个人的生活更轻松)
import tensorflow as tfimport tensorflow_datasets as tfdsclass ExampleAugmentation: def __init__(self, logdir: str, max_images: int, name: str): self.file_writer = tf.summary.create_file_writer(logdir) self.max_images: int = max_images self.name: str = name self._counter: int = 0 def __call__(self, image, label): augmented_image = tf.image.random_flip_left_right( tf.image.random_flip_up_down(image) ) with self.file_writer.as_default(): tf.summary.image( self.name, augmented_image, step=self._counter, max_outputs=self.max_images, ) self._counter += 1 return augmented_image, labelif __name__ == "__main__": # Global settings BATCH_SIZE = 32 DATASET_SIZE = 60000 EPOCHS = 5 LOG_DIR = "/logs/images" AUGMENTATION = ExampleAugmentation(LOG_DIR, max_images=4, name="Images") # Dataset dataset = tfds.load("mnist", as_supervised=True, split=tfds.Split.TRAIN) dataset = ( dataset.map( lambda image, label: ( tf.image.convert_image_dtype(image, dtype=tf.float32), label, ) ) .batch(BATCH_SIZE) .map(AUGMENTATION) .repeat(EPOCHS) ) # Model and training model = tf.keras.models.Sequential( [ tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation="softmax"), ] ) model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] ) model.fit( dataset, epochs=EPOCHS, steps_per_epoch=DATASET_SIZE // BATCH_SIZE, callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR)], )
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)