42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@File : qr_verify.py
|
|
@Contact : zpyovo@hotmail.com
|
|
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
@Description :
|
|
|
|
@Modify Time @Author @Version @Desciption
|
|
------------ ------- -------- -----------
|
|
2022/3/3 11:38 AM Pengyu Zhang 1.0 None
|
|
'''
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from qr_verify_Tool.network import FaceModel
|
|
|
|
|
|
class QR_verify_dev(nn.Module):
|
|
def __init__(self):
|
|
super(QR_verify_dev, self).__init__()
|
|
if torch.cuda.is_available():
|
|
print('cuda available:{}'.format(torch.cuda.is_available()))
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.checkpoint = torch.load('./model/train_qr_varification_FaceModel_3qp_256_Affine_lr1e-3_far-1_80.pth',
|
|
map_location=self.device)
|
|
self.model = FaceModel(256, num_classes=640, pretrained=True)
|
|
self.model.load_state_dict(self.checkpoint['model'])
|
|
self.model.eval()
|
|
|
|
def forward(self,input):
|
|
# input = torch.from_numpy(np.array(input).transpose((2, 0, 1))).unsqueeze(0)
|
|
# input = input.cuda().float()
|
|
output = self.model(input)
|
|
|
|
return output |