import imageio
import matplotlib.pyplot as plt
from histogram_matching import ExactHistogramMatcher
import General_function as tools
from skimage.exposure import match_histograms


############################################################################################################
def histogram_matching_rgb(target_img,reference_img):

    reference_histogram = ExactHistogramMatcher.get_histogram(reference_img)
    new_target_img = ExactHistogramMatcher.match_image_to_histogram(target_img, reference_histogram)
    return new_target_img


def histogram_matching_grey_values(target_img,reference_img):

    reference_histogram = ExactHistogramMatcher.get_histogram(reference_img)
    new_target_img = ExactHistogramMatcher.match_image_to_histogram(target_img, reference_histogram)
    return new_target_img
    
def exact_histogram_specification_function(target_img,reference_img):

    if len(target_img.shape) == 3:
        
        result_image=histogram_matching_rgb(target_img,reference_img)
    else:
        
        result_image=histogram_matching_grey_values(target_img,reference_img)
    
    result_image=tools.float_to_uint8(result_image)
    return result_image
############################################################################################################
def default_histogram_matching (target_img,reference_img):

    if len(target_img.shape) == 3:
        
        result_image=match_histograms(target_img, reference_img, multichannel=True)
    else:
        
        result_image=match_histograms(target_img, reference_img, multichannel=False)
    
    
    return result_image

############################################################################################################


target_img = imageio.imread('images/sub-10228_anat_sub-10228_T1w.nii.gz.png')
reference_img = imageio.imread('images/TE_16.8.png')

exact_hist_spec_image=exact_histogram_specification_function(target_img,reference_img)

default_histogram_matching_image=default_histogram_matching(target_img,reference_img)


fig, axs = plt.subplots(2, 2)
axs[0, 0].imshow(target_img,cmap='gray')
axs[0, 0].set_title('target_img')
axs[0, 1].imshow(reference_img,cmap='gray')
axs[0, 1].set_title('reference_img')
axs[1, 0].imshow(exact_hist_spec_image,cmap='gray')
axs[1, 0].set_title('exact_hist_spec_image')
axs[1, 1].imshow(default_histogram_matching_image,cmap='gray')
axs[1, 1].set_title('default_histogram_matching_image')


fig2=plt.figure(2)

subplot1=fig2.add_subplot(1,2,1)
plt.hist(exact_hist_spec_image.flatten(),256,[0,256], color = 'r')
plt.xlim([0,256])
subplot1.set_title('exact_hist_spec_image')

subplot1=fig2.add_subplot(1,2,2)
plt.hist(default_histogram_matching_image.flatten(),256,[0,256], color = 'r')
plt.xlim([0,256])
subplot1.set_title('default_histogram_matching_image')

