import argparse import cv2 import glob import os from basicsr.archs.rrdbnet_arch import RRDBNet import torch.nn as nn from .realesrgan import Real_ESRGANer # from realesrgan.archs.srvgg_arch import SRVGGNetCompact class RealsrGan(nn.Module): def __init__(self, num_in_ch=3, scale=4,model_path=None,device='0'): super(RealsrGan, self).__init__() self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) self.netscale = 4 self.model_path = model_path self.device = device # restorer self.upsampler = Real_ESRGANer( scale=self.netscale, device=self.device, model_path=self.model_path, model=self.model, tile=0, tile_pad=10, pre_pad=0, half=False) def forward(self,input): # img = cv2.imread(input, cv2.IMREAD_UNCHANGED) # if len(img.shape) == 3 and img.shape[2] == 4: # img_mode = 'RGBA' # else: # img_mode = None try: output, _ = self.upsampler.enhance(input, outscale=4) except RuntimeError as error: print('Error', error) print('If you encounter CUDA out of memory, try to set --tile with a smaller number.') # else: # if img_mode == 'RGBA': # RGBA images should be saved in png format # extension = 'png' # save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}') # cv2.imwrite(save_path, output) return output