47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
import argparse
|
|
import cv2
|
|
import glob
|
|
import os
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
import torch.nn as nn
|
|
from .realesrgan import Real_ESRGANer
|
|
# from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
|
|
class RealsrGan(nn.Module):
|
|
def __init__(self, num_in_ch=3, scale=4,model_path=None,device='0'):
|
|
super(RealsrGan, self).__init__()
|
|
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
|
self.netscale = 4
|
|
self.model_path = model_path
|
|
self.device = device
|
|
# restorer
|
|
self.upsampler = Real_ESRGANer(
|
|
scale=self.netscale,
|
|
device=self.device,
|
|
model_path=self.model_path,
|
|
model=self.model,
|
|
tile=0,
|
|
tile_pad=10,
|
|
pre_pad=0,
|
|
half=False)
|
|
|
|
def forward(self,input):
|
|
|
|
# img = cv2.imread(input, cv2.IMREAD_UNCHANGED)
|
|
# if len(img.shape) == 3 and img.shape[2] == 4:
|
|
# img_mode = 'RGBA'
|
|
# else:
|
|
# img_mode = None
|
|
|
|
try:
|
|
output, _ = self.upsampler.enhance(input, outscale=4)
|
|
except RuntimeError as error:
|
|
print('Error', error)
|
|
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
|
# else:
|
|
# if img_mode == 'RGBA': # RGBA images should be saved in png format
|
|
# extension = 'png'
|
|
# save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
|
# cv2.imwrite(save_path, output)
|
|
return output
|