2024-09-01 21:51:50 +01:00

42 lines
1.4 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : qr_verify.py
@Contact : zpyovo@hotmail.com
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
@Description :
@Modify Time @Author @Version @Desciption
------------ ------- -------- -----------
2022/3/3 11:38 AM Pengyu Zhang 1.0 None
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
from qr_verify_Tool.network import FaceModel
class QR_verify_dev(nn.Module):
def __init__(self):
super(QR_verify_dev, self).__init__()
if torch.cuda.is_available():
print('cuda available:{}'.format(torch.cuda.is_available()))
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.checkpoint = torch.load('./model/train_qr_varification_FaceModel_3qp_256_Affine_lr1e-3_far-1_80.pth',
map_location=self.device)
self.model = FaceModel(256, num_classes=640, pretrained=True)
self.model.load_state_dict(self.checkpoint['model'])
self.model.eval()
def forward(self,input):
# input = torch.from_numpy(np.array(input).transpose((2, 0, 1))).unsqueeze(0)
# input = input.cuda().float()
output = self.model(input)
return output