diff --git a/data/butterfly_GT_srcnn_x3.bmp b/data/butterfly_GT_srcnn_x3.bmp index 2883a78..57337f0 100644 Binary files a/data/butterfly_GT_srcnn_x3.bmp and b/data/butterfly_GT_srcnn_x3.bmp differ diff --git a/data/ppt3_srcnn_x3.bmp b/data/ppt3_srcnn_x3.bmp index 9536036..569e6ec 100644 Binary files a/data/ppt3_srcnn_x3.bmp and b/data/ppt3_srcnn_x3.bmp differ diff --git a/data/zebra_srcnn_x3.bmp b/data/zebra_srcnn_x3.bmp index db70b47..4f72d51 100644 Binary files a/data/zebra_srcnn_x3.bmp and b/data/zebra_srcnn_x3.bmp differ diff --git a/test.py b/test.py index dfa5cdf..e02a2bf 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,4 @@ import argparse -import os import torch import torch.backends.cudnn as cudnn @@ -7,7 +6,7 @@ import PIL.Image as pil_image from models import SRCNN -from utils import convert_rgb_to_y, calc_psnr +from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr if __name__ == '__main__': @@ -38,25 +37,25 @@ image = image.resize((image_width, image_height), resample=pil_image.BICUBIC) image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC) image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC) - - _, cb, cr = image.convert('YCbCr').split() - image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale))) image = np.array(image).astype(np.float32) - image = convert_rgb_to_y(image) - image /= 255. - image = torch.from_numpy(image).to(device) - image = image.unsqueeze(0).unsqueeze(0) + ycbcr = convert_rgb_to_ycbcr(image) + + y = ycbcr[..., 0] + y /= 255. + y = torch.from_numpy(y).to(device) + y = y.unsqueeze(0).unsqueeze(0) with torch.no_grad(): - preds = model(image).clamp(0.0, 1.0) + preds = model(y).clamp(0.0, 1.0) - psnr = calc_psnr(image, preds) + psnr = calc_psnr(y, preds) print('PSNR: {:.2f}'.format(psnr)) - preds = preds.mul(255.0).byte().cpu().numpy().squeeze(0).squeeze(0) + preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) - y = pil_image.fromarray(preds) - output = pil_image.merge('YCbCr', (y, cb, cr)).convert('RGB') + output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0]) + output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) + output = pil_image.fromarray(output) output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale))) diff --git a/utils.py b/utils.py index 29df88d..3a4540a 100644 --- a/utils.py +++ b/utils.py @@ -4,11 +4,45 @@ def convert_rgb_to_y(img): if type(img) == np.ndarray: - return 16. + (64.738 * img[:, :, 0] + 129.052 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. elif type(img) == torch.Tensor: if len(img.shape) == 4: img = img.squeeze(0) - return 16. + (64.738 * img[0, :, :] + 129.052 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + else: + raise Exception('Unknown Type', type(img)) + + +def convert_rgb_to_ycbcr(img): + if type(img) == np.ndarray: + y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. + cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. + return np.array([y, cb, cr]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256. + cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256. + return torch.cat([y, cb, cr], 0).permute(1, 2, 0) + else: + raise Exception('Unknown Type', type(img)) + + +def convert_ycbcr_to_rgb(img): + if type(img) == np.ndarray: + r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 + g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 + b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 + return np.array([r, g, b]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921 + g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576 + b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836 + return torch.cat([r, g, b], 0).permute(1, 2, 0) else: raise Exception('Unknown Type', type(img))