#needs pytorch and matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
def savehm2(img,hm, outname, q=100):
#usage:
# hm is your heatmap in shape (1,3,h,w), NOTE: it is expected to have 3 channels,
# if it does not have, bcs it is a gray heatmap, then create an axis at dim=0 and repeat the heatmap 3 times along the new dim=0
#see thr line " hm = hm.cpu().squeeze().sum(dim=0).numpy()" on how it gets processed
#img is (1,3,h,w) torch tensor containing the image as it comes out from the dataloader with standard processing
# see invert_normalize with undoes the standardizatiom
#outname is the name with path where the heatmap is saved to . should be a .jpg or png as ending
ts=invert_normalize(img.cpu().squeeze())
a=ts.data.numpy().transpose((1, 2, 0))
plt.imshow(a, cmap='gray')
hm = hm.cpu().squeeze().sum(dim=0).numpy()
clim = np.percentile(np.abs(hm), q)
hm = hm / clim
#hm = gregoire_black_firered(hm)
#axs[1].imshow(hm)
plt.imshow(hm, cmap="seismic", clim=(-1, 1),alpha=0.5)
plt.axis('off')
plt.savefig(outname,bbox_inches='tight')
def invert_normalize(ten, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
print(ten.shape)
s=torch.tensor(np.asarray(std,dtype=np.float32)).unsqueeze(1).unsqueeze(2)
m=torch.tensor(np.asarray(mean,dtype=np.float32)).unsqueeze(1).unsqueeze(2)
res=ten*s+m
return res