100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
'''
|
|
@File : common.py
|
|
@Contact : zpyovo@hotmail.com
|
|
@License : (C)Copyright 2018-2019, Lab501-TransferLearning-SCUT
|
|
@Description :
|
|
|
|
@Modify Time @Author @Version @Desciption
|
|
------------ ------- -------- -----------
|
|
2022/4/21 23:46 Pengyu Zhang 1.0 None
|
|
'''
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
|
|
import numpy as np
|
|
import cv2
|
|
from matplotlib.ticker import NullLocator
|
|
from PIL import Image
|
|
|
|
'''色彩增益加权的AutoMSRCR算法'''
|
|
def singleScaleRetinex(img, sigma):
|
|
retinex = np.log10(img) - np.log10(cv2.GaussianBlur(img, (0, 0), sigma))
|
|
|
|
return retinex
|
|
|
|
def multiScaleRetinex(img, sigma_list):
|
|
retinex = np.zeros_like(img)
|
|
for sigma in sigma_list:
|
|
retinex += singleScaleRetinex(img, sigma)
|
|
|
|
retinex = retinex / len(sigma_list)
|
|
|
|
return retinex
|
|
|
|
def automatedMSRCR(img, sigma_list):
|
|
img = np.float64(img) + 1.0
|
|
|
|
img_retinex = multiScaleRetinex(img, sigma_list)
|
|
|
|
for i in range(img_retinex.shape[2]):
|
|
unique, count = np.unique(np.int32(img_retinex[:, :, i] * 100), return_counts=True)
|
|
for u, c in zip(unique, count):
|
|
if u == 0:
|
|
zero_count = c
|
|
break
|
|
|
|
low_val = unique[0] / 100.0
|
|
high_val = unique[-1] / 100.0
|
|
for u, c in zip(unique, count):
|
|
if u < 0 and c < zero_count * 0.1:
|
|
low_val = u / 100.0
|
|
if u > 0 and c < zero_count * 0.1:
|
|
high_val = u / 100.0
|
|
break
|
|
|
|
img_retinex[:, :, i] = np.maximum(np.minimum(img_retinex[:, :, i], high_val), low_val)
|
|
|
|
img_retinex[:, :, i] = (img_retinex[:, :, i] - np.min(img_retinex[:, :, i])) / \
|
|
(np.max(img_retinex[:, :, i]) - np.min(img_retinex[:, :, i])) \
|
|
* 255
|
|
|
|
img_retinex = np.uint8(img_retinex)
|
|
|
|
return img_retinex
|
|
|
|
'''图像处理过程保存'''
|
|
def save_figure(fig_context,fig_name,res_code):
|
|
fig = plt.figure()
|
|
ax = fig.subplots(1)
|
|
ax.imshow(fig_context, aspect='equal')
|
|
result_path = os.path.join("tmp", '{}_{}_{}.jpg'.format(fig_name,res_code,time.time()))
|
|
plt.axis("off")
|
|
plt.gca().xaxis.set_major_locator(NullLocator())
|
|
plt.gca().yaxis.set_major_locator(NullLocator())
|
|
# filename = result_path.split("/")[-1].split(".")[0]
|
|
plt.savefig(result_path, quality=95, bbox_inches="tight", pad_inches=0.0)
|
|
plt.close()
|
|
return result_path
|
|
|
|
def cropped_image(img, points, shift, size):
|
|
x_min, y_min = points[0]
|
|
x_max, y_max = points[1]
|
|
|
|
x_min, y_min, x_max, y_max = int(x_min) - shift, int(y_min) - shift, int(x_max) + shift, int(y_max) + shift
|
|
quarter_width = (x_max - x_min) // size
|
|
quarter_height = (y_max - y_min) // size
|
|
|
|
# 裁剪图像
|
|
img = Image.fromarray(np.uint8(img))
|
|
cropped_im = img.crop((x_min, y_min, x_min + quarter_width, y_min + quarter_height))
|
|
|
|
|
|
return cropped_im |