import cv2
import matplotlib.pyplot as plt
import imutils
from PIL import Image
import os
dataset_path = 'C:/Users/Radim/Desktop/Image processing/dataset/'
dataset_used = 'numbers-JC-ProgJava/'
#dataset_used = 'numbers-kensanata/'

#test_img = cv2.cvtColor(cv2.imread(dataset_path + dataset_used + '0027_PL3M/9/number-2.png'), cv2.COLOR_BGR2RGB)

"""**Preprocessing - filter and threshold, Centering - center and standardize size**"""

def preprocessing(img):

  # denoising
  size = img.shape[0]
  filtered = cv2.fastNlMeansDenoisingColored(img, None, h=size*4, hColor=10, templateWindowSize=3, searchWindowSize=int(size/10))
  
  # convert to black and white
  gray = cv2.cvtColor(filtered, cv2.COLOR_RGB2GRAY)
  thres = cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX)
  thres = cv2.threshold(thres, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]

  return gray, thres

def centering(img):

  # contour finder excpects the object in the image to be white on a black background
  inverted = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)[1]

  #find contours
  contours = cv2.findContours(inverted.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  contours = imutils.grab_contours(contours)
  #contours = sort_contours(contours, method="left-to-right")[0]

  # loop over the contours
  segmented = []
  for c in contours:
    if cv2.contourArea(c) > 20:
      (x, y, w, h) = cv2.boundingRect(c)
      segmented.append(img[y:y + h, x:x + w])

  success = True
  if len(segmented) != 1:
    success = False
    return segmented, success

  # square the image
  color = [255, 255, 255]
  # old_size is in (height, width) format
  old_size = segmented[0].shape[:2]
  new_size = max(old_size[0], old_size[1]) # new size is the side length of the new image
  # how much padding we need in each direction
  delta_h = new_size - old_size[0]
  delta_w = new_size - old_size[1]
  # split the padding to the sides of the original image
  top, bottom = delta_h//2, delta_h-(delta_h//2)
  left, right = delta_w//2, delta_w-(delta_w//2)
  squared = squared = cv2.copyMakeBorder(segmented[0], top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

  # resize to a standard 28x28
  resized = cv2.resize(squared, (24, 24), interpolation = cv2.INTER_CUBIC)
  padded = cv2.copyMakeBorder(resized, 2, 2, 2, 2, cv2.BORDER_CONSTANT, value=color)
  padded = cv2.blur(padded,(2,2))

  return padded, success

#plt.imshow(preprocessing(test_img)[1], cmap="gray")
#plt.imshow(centering(preprocessing(test_img)[1])[0], cmap="gray")

"""**Process dataset**"""

number_dirs = ["0","1","2","3","4","5","6","7","8","9"]

for dataset in os.listdir(dataset_path + dataset_used):
#for dataset in ["0026_CH5M","0027_PL3M","0028_PL3F","0029_PL3M"]:
  errorcount = 0
  imgcount = 0

  for number in number_dirs:
    print("processing: " + str(dataset) + " number " + str(number))
    source_path = dataset_path + dataset_used + dataset + "/" + number + "/"
    destination_path = dataset_path + 'numbers-processed/' + dataset + "/" + number + "/"

    for f in os.listdir(source_path):
      img_in_path = source_path + f

      if os.path.isfile(img_in_path):
        #img = cv2.cvtColor(cv2.imread(img_in_path), cv2.COLOR_BGR2RGB)
        #img = cv2.cvtColor(cv2.imread(img_in_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
        img = cv2.imread(img_in_path, cv2.IMREAD_UNCHANGED)
        
        if(len(img.shape) >= 3):
          if img.shape[2] == 4:     # we have an alpha channel
            a1 = ~img[:,:,3]        # extract and invert that alpha
            img = cv2.add(cv2.merge([a1,a1,a1,a1]), img)   # add up values (with clipping)
            img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)    # strip alpha channel

        [processed, success] = centering(preprocessing(img)[1])

        if success:
          if not os.path.exists(destination_path):
            os.makedirs(destination_path)
          img_out_path = destination_path + f
          output_file = open(img_out_path, 'w+')
          cv2.imwrite(img_out_path, processed)
          output_file.close()
          imgcount = imgcount+1
        else:
          errorcount = errorcount+1

  print("errors: " + str(errorcount))
  print("success: " + str(imgcount))

print("DONE")