2024-09-01 21:51:50 +01:00

47 lines
1.6 KiB
Python

import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
import torch.nn as nn
from .realesrgan import Real_ESRGANer
# from realesrgan.archs.srvgg_arch import SRVGGNetCompact
class RealsrGan(nn.Module):
def __init__(self, num_in_ch=3, scale=4,model_path=None,device='0'):
super(RealsrGan, self).__init__()
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.netscale = 4
self.model_path = model_path
self.device = device
# restorer
self.upsampler = Real_ESRGANer(
scale=self.netscale,
device=self.device,
model_path=self.model_path,
model=self.model,
tile=0,
tile_pad=10,
pre_pad=0,
half=False)
def forward(self,input):
# img = cv2.imread(input, cv2.IMREAD_UNCHANGED)
# if len(img.shape) == 3 and img.shape[2] == 4:
# img_mode = 'RGBA'
# else:
# img_mode = None
try:
output, _ = self.upsampler.enhance(input, outscale=4)
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
# else:
# if img_mode == 'RGBA': # RGBA images should be saved in png format
# extension = 'png'
# save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
# cv2.imwrite(save_path, output)
return output