MindSpore中对于常见数据及已经有现成API来进行处理,常见数据集包括:CelebA、Cifar100、Cifar10、Coco、ImageNet、Minist、VOC
下面以Cifar10数据集作为例子展示一下接口调用及数据的图片的展示
以下为官网提供的API的接口调用及解释:
class mindspore.dataset.Cifar10Dataset(dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None)
A source dataset for reading and parsing Cifar10 dataset. This api only supports parsing Cifar10 file in binary version now. #用于读取和解析 Cifar10 数据集的源数据集,这个api现在只支持解析二进制版本的Cifar10文件
The generated dataset has two columns [image, label]. The tensor of column image is of the uint8 type. The tensor of column label is a scalar of the uint32 type # 生成的数据集有两列[image,label]。列图像的张量为uint8类型。列标签的张量是uint32类型的标量
接口调用及图片展示:
import mindspore.dataset as ds from PIL import Image import matplotlib.pyplot as plt sampler = ds.SequentialSampler(num_samples=6) dataset = ds.Cifar10Dataset(data_dir, sampler=sampler) # 在数据集上创建迭代器,检索到的数据将是字典数据类型 for i, data in enumerate(dataset.create_dict_iterator()): print("Image shape: {}".format(data['image'].shape), ", Label {}".format(data['label'])) image = data['image'] image = image.asnumpy() # mindspore.Tensor to numpy image = Image.fromarray(image) # plt plt.subplot(2, 3, i + 1) plt.imshow(image) plt.title(f"{i + 1}", fontsize=6) plt.xticks([]) plt.yticks([]) plt.show()
结果展示:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)