import os
import numpy as np
import keras
import cv2
import imutils
import math
from PIL import Image

input_shape = (28, 28, 1)

absolute_path = os.path.dirname(__file__)
model = keras.models.load_model(os.path.join(absolute_path, 'model.keras'))

### Preprocessing functions ###

def split_image(img, num_of_parts):
    images = []

    if num_of_parts == 1:
        images.append(img)
        return images
    
    (height, width) = img.shape[:2]
    part_width = int(width/num_of_parts)

    for i in range(0,num_of_parts):
        images.append(img[0:height, i*part_width:(i+1)*part_width])

    return images

def sort_contours(cnts, method="left-to-right"):
    reverse = False
    i = 0
    if method == "right-to-left" or method == "bottom-to-top":
        reverse = True
    if method == "top-to-bottom" or method == "bottom-to-top":
        i = 1
    boundingBoxes = [cv2.boundingRect(c) for c in cnts]
    (cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes),
    key=lambda b:b[1][i], reverse=reverse))
    # return the list of sorted contours and bounding boxes
    return (cnts, boundingBoxes)

def get_contours(img):
    # contour finder excpects the object in the image to be white on a black background
    inverted = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV)[1]

    #find contours
    contours = cv2.findContours(inverted.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = imutils.grab_contours(contours)
    return sort_contours(contours, method="left-to-right")[0]

def trim(img):  
    threshold = 3
    keepout = 10

    width = img.shape[1]
    col_avg = cv2.reduce(img, dst=None, dim=0, rtype=cv2.REDUCE_SUM, dtype=cv2.CV_32F)
    col_avg = cv2.normalize(col_avg, None, 0, 255, cv2.NORM_MINMAX)

    x_start = 0
    for i in range(0, width-1):
        if abs(col_avg[0][i] - col_avg[0][i+1]) > threshold:
            x_start = i
            break

    x_end = width-1
    for i in reversed(range(1, width-1)):
        if abs(col_avg[0][i] - col_avg[0][i-1]) > threshold:
            x_end = i
            break

    if x_start >= x_end:
        x_start = 0
        x_end = width-1

    x_start = max(x_start-keepout, 0)
    x_end = min(x_end+keepout, width-1)

    img = img[:, x_start:x_end]

    height = img.shape[0]
    row_avg = cv2.reduce(img, dst=None, dim=1, rtype=cv2.REDUCE_SUM, dtype=cv2.CV_32F)
    row_avg = cv2.normalize(row_avg, None, 0, 255, cv2.NORM_MINMAX)

    y_start = 0
    for i in range(0, width-1):
        if abs(row_avg[i][0] - row_avg[i+1][0]) > threshold:
            y_start = i
            break

    y_end = height-1
    for i in reversed(range(1, height-1)):
        if abs(row_avg[i][0] - row_avg[i-1][0]) > threshold:
            y_end = i
            break

    if y_start >= y_end:
        y_start = 0
        y_end = height-1

    y_start = max(y_start-keepout, 0)
    y_end = min(y_end+keepout, height-1)

    img = img[y_start:y_end, :]
    
    return img

def preprocessing(img):
    (height, width) = img.shape[:2]
    normalized_height = 250
    resized = cv2.resize(img, (int(width/height*normalized_height), normalized_height), interpolation = cv2.INTER_CUBIC)

    # filtering
    filtered = cv2.fastNlMeansDenoisingColored(resized, None, h=25, hColor=10, templateWindowSize=3, searchWindowSize=51)
    #filtered = cv2.fastNlMeansDenoisingColored(img, None, h=30, hColor=10, templateWindowSize=3, searchWindowSize=11)

    # convert to grayscale
    gray = cv2.cvtColor(filtered, cv2.COLOR_RGB2GRAY)

    # trim the text so that we dont treshold noise
    trimmed = trim(cv2.normalize(gray, None, 0, 255, cv2.NORM_MINMAX))
    
    # thresholding in multiple parts
    # normalize first to get maximum range
    thres = trimmed.copy()

    (height, width) = thres.shape[:2]
    num_of_parts = max(1, int(width/height))
    part_width = int(width/num_of_parts)

    for i in range(0,num_of_parts):
        croppedImage = thres[0:height, i*part_width:(i+1)*part_width]
        thres[0:height, i*part_width:(i+1)*part_width] = cv2.threshold(croppedImage,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)[1]

    # blur the image to get rid of pixelation
    preprocessed = cv2.blur(thres,(2,2))

    return filtered, trimmed, preprocessed

def segmentation(img):
    min_contour = 500

    #find contours
    contours = get_contours(img)

    # loop over the contours
    segmented = []
    for c in contours:
        #if cv2.contourArea(c) > min_contour:
        (x, y, w, h) = cv2.boundingRect(c)

        segment = img[y:y + h, x:x + w]
        (height, width) = segment.shape[:2]

        if (img.shape[1]-width) <= 10:
            continue

        aspect = width/height

        #segmented.append(segment)
        
        splits = split_image(segment, int(aspect/1.3)+1)

        for split in splits:
            
            contours2 = get_contours(split)
            for c2 in contours2:
                

                #print("CA: " + str(cv2.contourArea(c2)))
                if cv2.contourArea(c2) > min_contour:
                    
                    (x, y, w, h) = cv2.boundingRect(c2)
                    segmented.append(split[y:y + h, x:x + w])
            

    return segmented

def size_normalization(img, aspect):

    # map aspect ratio with one of the methods below
    # old_size is in (height, width) format
    (old_height, old_width) = img.shape[:2]
    old_aspect = old_width/old_height

    if aspect == 0:
        new_aspect = old_aspect
    elif aspect == 1:
        new_aspect = old_aspect**(1./2.)
    elif aspect == 2:
        new_aspect = old_aspect**(1./3.)
    elif aspect == 3:
        new_aspect = math.sin(math.pi/2*old_aspect)**(1./2.)
    elif aspect == 4:
        new_aspect = 1

    new_width = int(new_aspect*old_height)
    aspect_normalized = cv2.resize(img, (new_width, old_height), interpolation = cv2.INTER_CUBIC)
    '''
    (height, width) = aspect_normalized.shape[:2]
    M = cv2.getRotationMatrix2D((width/2,height/2), -15, 1)
    aspect_normalized = 255-cv2.warpAffine(255-aspect_normalized, M, (width,height))
    '''
    # make the image square by adding padding to not distort the aspect ratio
    # size is in (height, width) format
    old_size = aspect_normalized.shape[:2]
    new_size = max(old_size[0], old_size[1]) # new size is the side length of the new image
    # how much padding we need in each direction
    delta_h = new_size - old_size[0]
    delta_w = new_size - old_size[1]
    # split the padding to the sides of the original image
    top, bottom = delta_h//2, delta_h-(delta_h//2)
    left, right = delta_w//2, delta_w-(delta_w//2)
    # square the image by padding one of the sides with white color
    color = [255, 255, 255]
    squared = cv2.copyMakeBorder(aspect_normalized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

    # downsample to the cnn input size
    downsampled = cv2.resize(squared, input_shape[:2], interpolation = cv2.INTER_CUBIC)

    # add padding equal to the original dataset
    padding = 2
    padded = cv2.copyMakeBorder(downsampled, padding, padding, padding, padding, cv2.BORDER_CONSTANT, value=color)
    padded = cv2.resize(padded, input_shape[:2], interpolation = cv2.INTER_CUBIC)

    # apply thickening
    normalized = 255-cv2.dilate(255-padded, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)))

    return normalized

def pass_to_cnn(img):
    prepared = img.astype("float32") / 255
    prepared = np.expand_dims(prepared, axis=-1)
    prepared = prepared.reshape(1,input_shape[0],input_shape[1],1)
    prediction = model.predict(prepared, verbose=0)
    return prediction.argmax(), prediction.max()*100

### Segmentation + recognition from an image ###

def recognize(img):
    segmented = segmentation(preprocessing(img)[2])

    cnn_in = []
    recognized = []

    for seg in segmented:
        guesses = []
        normalized = 0

        for aspect in range(2,5):
            normalized = size_normalization(seg, aspect)
            guesses.append(pass_to_cnn(normalized)[0])

        cnn_in.append(normalized)
        recognized.append(max(guesses, key=guesses.count))

    return cnn_in, recognized

### Get serial number (ignores the first # character)

def get_serial_number(img, *args):
    if len(args) == 0:
        numbers = recognize(img)[1]
    else:
        numbers = args[0]
        
    result = 0
    for i in range(0, len(numbers)):
        if not i == 0:
            multiplier = 10**(len(numbers)-i-1)
            result = result + multiplier*numbers[i]
    return result
