#!/usr/bin/env python # -*- encoding: utf-8 -*- ''' @File : qr_verify.py @Contact : zpyovo@hotmail.com @License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT @Description : @Modify Time @Author @Version @Desciption ------------ ------- -------- ----------- 2022/3/3 11:38 AM Pengyu Zhang 1.0 None ''' from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import numpy as np from qr_verify_Tool.network import FaceModel class QR_verify_dev(nn.Module): def __init__(self): super(QR_verify_dev, self).__init__() if torch.cuda.is_available(): print('cuda available:{}'.format(torch.cuda.is_available())) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.checkpoint = torch.load('./model/train_qr_varification_FaceModel_3qp_256_Affine_lr1e-3_far-1_80.pth', map_location=self.device) self.model = FaceModel(256, num_classes=640, pretrained=True) self.model.load_state_dict(self.checkpoint['model']) self.model.eval() def forward(self,input): # input = torch.from_numpy(np.array(input).transpose((2, 0, 1))).unsqueeze(0) # input = input.cuda().float() output = self.model(input) return output