54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
|
|
import os,argparse
|
||
|
|
import numpy as np
|
||
|
|
from PIL import Image
|
||
|
|
from models import *
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torchvision.transforms as tfs
|
||
|
|
import torchvision.utils as vutils
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
from torchvision.utils import make_grid
|
||
|
|
abs=os.getcwd()+'/'
|
||
|
|
# def tensorShow(tensors,titles=['haze']):
|
||
|
|
# fig=plt.figure()
|
||
|
|
# for tensor,tit,i in zip(tensors,titles,range(len(tensors))):
|
||
|
|
# img = make_grid(tensor)
|
||
|
|
# npimg = img.numpy()
|
||
|
|
# ax = fig.add_subplot(221+i)
|
||
|
|
# ax.imshow(np.transpose(npimg, (1, 2, 0)))
|
||
|
|
# ax.set_title(tit)
|
||
|
|
# plt.show()
|
||
|
|
|
||
|
|
parser=argparse.ArgumentParser()
|
||
|
|
parser.add_argument('--task',type=str,default='its',help='its or ots')
|
||
|
|
parser.add_argument('--test_imgs',type=str,default='test_imgs',help='Test imgs folder')
|
||
|
|
opt=parser.parse_args()
|
||
|
|
dataset=opt.task
|
||
|
|
gps=3
|
||
|
|
blocks=19
|
||
|
|
img_dir=abs+opt.test_imgs+'/'
|
||
|
|
# output_dir=abs+f'pred_FFA_{dataset}/'
|
||
|
|
output_dir=abs+'pred_imgs/'
|
||
|
|
if not os.path.exists(output_dir):
|
||
|
|
os.mkdir(output_dir)
|
||
|
|
model_dir=abs+f'trained_models/{dataset}_train_ffa_{gps}_{blocks}.pk'
|
||
|
|
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
|
ckp=torch.load(model_dir,map_location=device)
|
||
|
|
net=FFA(gps=gps,blocks=blocks)
|
||
|
|
net=nn.DataParallel(net)
|
||
|
|
net.load_state_dict(ckp['model'])
|
||
|
|
net.eval()
|
||
|
|
for im in os.listdir(img_dir):
|
||
|
|
print(f'\r {im}',end='',flush=True)
|
||
|
|
haze = Image.open(img_dir+im)
|
||
|
|
haze1= tfs.Compose([
|
||
|
|
tfs.ToTensor(),
|
||
|
|
tfs.Normalize(mean=[0.64, 0.6, 0.58],std=[0.14,0.15, 0.152])
|
||
|
|
])(haze)[None,::]
|
||
|
|
haze_no=tfs.ToTensor()(haze)[None,::]
|
||
|
|
with torch.no_grad():
|
||
|
|
pred = net(haze1)
|
||
|
|
ts=torch.squeeze(pred.clamp(0,1).cpu())
|
||
|
|
# tensorShow([haze_no,pred.clamp(0,1).cpu()],['haze','pred'])
|
||
|
|
vutils.save_image(ts,output_dir+im.split('.')[0]+f'_{dataset}.png')
|