tensor格式的cifar10图片保存8*8的组合图片png

tensor格式的cifar10图片保存8*8的组合图片png,第1张

import matplotlib.pyplot as plt               

def img_tile(imgs, row_col = None, transpose = False, channel_first=True, channels=3):

    '''

    tile a list of images to a large grid.

    imgs:       iterable of images to use

    row_col:    None (automatic), or tuple of (#rows, #columns)

    transpose:  Wheter to stitch the list of images row-first or column-first

    channel_first: if true, assume images with CxWxH, else WxHxC

    channels:   3 or 1, number of color channels

    '''

    if row_col == None:

        sqrt = np.sqrt(len(imgs))

        rows = np.floor(sqrt)

        delt = sqrt - rows

        cols = np.ceil(rows + 2*delt + delt**2 / rows)

        rows, cols = int(rows), int(cols)

    else:

        rows, cols = row_col

    if channel_first:

        h, w = imgs[0].shape[1], imgs[0].shape[2]

    else:

        h, w = imgs[0].shape[0], imgs[0].shape[1]

    show_im = np.zeros((rows*h, cols*w, channels))  #plt.imsave 需要的是h*w*c的形式!

    if transpose:

        def iterator():

            for i in range(rows):

                for j in range(cols):

                    yield i, j

    else:

        def iterator():

            for j in range(cols):

                for i in range(rows):

                    yield i, j

    k = 0

    for i, j in iterator():

            im = imgs[k]

            if channel_first:

                im = np.transpose(im, (1, 2, 0))

            show_im[h*i:h*i+h, w*j:w*j+w] = im #按照列的方式来组合这个大的矩阵

            k += 1

            if k == len(imgs):

                break

    return np.squeeze(show_im)

rev_imgs = torch.clamp(rev_imgs, 0., 1.)

imgs = [rev_imgs[i].cpu().data.numpy() for i in range(c.batch_size)]

imgs  = img_tile(imgs, (8, 8), transpose=False, channel_first=True, channels=3)

plt.imsave(F'./training_images/random_samples_{i_epoch}.png', imgs, vmin=0, vmax=1, dpi=300)

方案二

也可以用Image.fromarray来实现保存单个图片,要点都是先转化成numpy的形式然后维度C*H*W转换成H*W*C的形式

def save_img(img, img_name, save_dir):

    create_dir(save_dir)

    img_path = os.path.join(save_dir, img_name)

    img_pil = Image.fromarray(img.astype(np.uint8))

    img_pil.save(img_path)

at_images_np = at_images.detach().cpu().numpy()#转化成numpy的形式

adv_img = at_images_np[0]

adv_img = np.moveaxis(adv_img, 0, 2) #这里是将chw变成hwc的格式,也可以用np.transpose()

adv_dir = os.path.join(save_dir, str(q_size))

img_name = "adv_{}.jpg".format(i)

save_img(adv_img, img_name, adv_dir)

 

欢迎分享,转载请注明来源:内存溢出

原文地址: http://outofmemory.cn/langs/734874.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-27
下一篇 2022-04-27

发表评论

登录后才能评论

评论列表(0条)

保存