2024-09-01 21:51:50 +01:00

60 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File : affinity.py
@Contact : zpyovo@hotmail.com
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
@Description :
@Modify Time @Author @Version @Desciption
------------ ------- -------- -----------
2022/3/12 9:18 PM Pengyu Zhang 1.0 None
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
import numpy as np
import cv2
def eightway_total_diff(arr):
# 对输入数组进行填充每边填充1个元素
padded_arr = np.pad(arr, pad_width=1, mode='edge')
# 获取填充后数组的尺寸
h, w = padded_arr.shape
# 初始化输出数组,用于存储总差值
diff_array = np.zeros((h-2, w-2))
# 遍历每个元素(不包括填充的边界)
for y in range(1, h-1):
for x in range(1, w-1):
total_diff = 0 # 初始化当前元素的总差值
# 计算每个方向的差值
for dy in range(-1, 2):
for dx in range(-1, 2):
if dy == 0 and dx == 0:
# 排除中心点自身
continue
# 直接使用填充后的坐标计算差值
diff = abs(padded_arr[y, x] - padded_arr[y+dy, x+dx])
# 如果差值小于0则置为0
total_diff += max(diff, 0)
# 将总差值存储在输出数组中
diff_array[y-1, x-1] = total_diff
return diff_array.astype(int)
def roi_affinity_siml(X_ter,X_str):
X_ter = np.array(X_ter)
X_str = np.array(X_str)
X_ter_diffs = eightway_total_diff(X_ter / X_ter.max())
X_str_diffs = eightway_total_diff(X_str / X_str.max())
# 计算差异
difference = cv2.absdiff(X_ter_diffs, X_str_diffs)
mean_diff = np.mean(difference)
# 计算相似度
similarity = max((1 - mean_diff) * 100, 0)
return similarity