35 lines
1.3 KiB
Python
Executable File
35 lines
1.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.models as models
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import os
|
|
import argparse
|
|
import random
|
|
|
|
def load_model(model_path):
|
|
model = models.resnet18(pretrained=True)
|
|
num_ftrs = model.fc.in_features
|
|
model.fc = nn.Linear(num_ftrs, 2)
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 加载模型权重
|
|
model.eval()
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = model.to(device)
|
|
return model
|
|
|
|
def preprocess_image(image_path):
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(), # 转换为Tensor
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
|
|
])
|
|
image = Image.open(image_path).convert('RGB') # 打开图像并转换为 RGB
|
|
image = transform(image).unsqueeze(0) # 增加 batch 维度
|
|
return image
|
|
|
|
def predict(model, image_tensor):
|
|
with torch.no_grad(): # 禁用梯度计算
|
|
output = model(image_tensor)
|
|
_, predicted = torch.max(output, 1) # 获取预测类别
|
|
probabilities = torch.nn.functional.softmax(output, dim=1) # 计算概率
|
|
return predicted.item(), probabilities.squeeze().tolist() |