import matplotlib.pyplot as plt
def img_tile(imgs, row_col = None, transpose = False, channel_first=True, channels=3):
'''
tile a list of images to a large grid.
imgs: iterable of images to use
row_col: None (automatic), or tuple of (#rows, #columns)
transpose: Wheter to stitch the list of images row-first or column-first
channel_first: if true, assume images with CxWxH, else WxHxC
channels: 3 or 1, number of color channels
'''
if row_col == None:
sqrt = np.sqrt(len(imgs))
rows = np.floor(sqrt)
delt = sqrt - rows
cols = np.ceil(rows + 2*delt + delt**2 / rows)
rows, cols = int(rows), int(cols)
else:
rows, cols = row_col
if channel_first:
h, w = imgs[0].shape[1], imgs[0].shape[2]
else:
h, w = imgs[0].shape[0], imgs[0].shape[1]
show_im = np.zeros((rows*h, cols*w, channels)) #plt.imsave 需要的是h*w*c的形式!
if transpose:
def iterator():
for i in range(rows):
for j in range(cols):
yield i, j
else:
def iterator():
for j in range(cols):
for i in range(rows):
yield i, j
k = 0
for i, j in iterator():
im = imgs[k]
if channel_first:
im = np.transpose(im, (1, 2, 0))
show_im[h*i:h*i+h, w*j:w*j+w] = im #按照列的方式来组合这个大的矩阵
k += 1
if k == len(imgs):
break
return np.squeeze(show_im)
rev_imgs = torch.clamp(rev_imgs, 0., 1.)
imgs = [rev_imgs[i].cpu().data.numpy() for i in range(c.batch_size)]
imgs = img_tile(imgs, (8, 8), transpose=False, channel_first=True, channels=3)
plt.imsave(F'./training_images/random_samples_{i_epoch}.png', imgs, vmin=0, vmax=1, dpi=300)
方案二也可以用Image.fromarray来实现保存单个图片,要点都是先转化成numpy的形式然后维度C*H*W转换成H*W*C的形式
def save_img(img, img_name, save_dir):
create_dir(save_dir)
img_path = os.path.join(save_dir, img_name)
img_pil = Image.fromarray(img.astype(np.uint8))
img_pil.save(img_path)
at_images_np = at_images.detach().cpu().numpy()#转化成numpy的形式
adv_img = at_images_np[0]
adv_img = np.moveaxis(adv_img, 0, 2) #这里是将chw变成hwc的格式,也可以用np.transpose()
adv_dir = os.path.join(save_dir, str(q_size))
img_name = "adv_{}.jpg".format(i)
save_img(adv_img, img_name, adv_dir)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)