Skip to content

Commit

Permalink
Update test script and results
Browse files Browse the repository at this point in the history
  • Loading branch information
yjn870 committed Apr 22, 2019
1 parent b8fc136 commit 50c9185
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 16 deletions.
Binary file modified data/butterfly_GT_srcnn_x3.bmp
Binary file not shown.
Binary file modified data/ppt3_srcnn_x3.bmp
Binary file not shown.
Binary file modified data/zebra_srcnn_x3.bmp
Binary file not shown.
27 changes: 13 additions & 14 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import argparse
import os

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image

from models import SRCNN
from utils import convert_rgb_to_y, calc_psnr
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr


if __name__ == '__main__':
Expand Down Expand Up @@ -38,25 +37,25 @@
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)

_, cb, cr = image.convert('YCbCr').split()

image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

image = np.array(image).astype(np.float32)
image = convert_rgb_to_y(image)
image /= 255.
image = torch.from_numpy(image).to(device)
image = image.unsqueeze(0).unsqueeze(0)
ycbcr = convert_rgb_to_ycbcr(image)

y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)

with torch.no_grad():
preds = model(image).clamp(0.0, 1.0)
preds = model(y).clamp(0.0, 1.0)

psnr = calc_psnr(image, preds)
psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).byte().cpu().numpy().squeeze(0).squeeze(0)
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

y = pil_image.fromarray(preds)
output = pil_image.merge('YCbCr', (y, cb, cr)).convert('RGB')
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))
38 changes: 36 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,45 @@

def convert_rgb_to_y(img):
if type(img) == np.ndarray:
return 16. + (64.738 * img[:, :, 0] + 129.052 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
return 16. + (64.738 * img[0, :, :] + 129.052 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
else:
raise Exception('Unknown Type', type(img))


def convert_rgb_to_ycbcr(img):
if type(img) == np.ndarray:
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))


def convert_ycbcr_to_rgb(img):
if type(img) == np.ndarray:
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
return torch.cat([r, g, b], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))

Expand Down

0 comments on commit 50c9185

Please sign in to comment.