#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import numpy as np
import json
import cv2
import pickle
import detectron2
from detectron2.data import DatasetCatalog
from detectron2.data import MetadataCatalog
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer
import sys

"""Module for training Mask-RCNN convolutional network. 
   Usage:
   > python TrainMaskRCNN.py <trainDataSetJSON> <evaluateDataSetJSON> <imagesPath> <outputFileName>
   
   Input arguments:
   * <trainDataSetJSON> - Json file in detectron2 format with training data set.
   * <evaluateDataSetJSON> - Json file in detectron2 format with evaluation data set.
   * <imagesPath> - Path to a folder containg the data sets images.
   * <outputFileName> - Name for the output binary file.

"""

__author__ = "Ondrej Klima"
__copyright__ = "Copyright 2020"
__credits__ = ["Ondrej Klima"]
__email__ = "iklima@fit.vutbr.cz"
__license__ = "BUT"
__version__ = "1.0"
__maintainer__ = "Ondrej Klima"

def main():
    # Parsing input arguments 
    argv = sys.argv
    
    try:
        trainFileName = argv[1]
    except IndexError:
        raise IndexError('File with training data set must be supplied as an argument')
    if not path.isfile(trainFileName):
        raise ValueError('File "%s" does not exist!' % trainFileName)
    
    try:
        evaluationFileName = argv[2]
    except IndexError:
        raise IndexError('File with evaluation data set must be supplied as an argument')
    if not path.isfile(limitsFileName):
        raise ValueError('File "%s" does not exist!' % limitsFileName)   
    
    try:
        imagesPath = argv[3]
    except IndexError:
        raise IndexError('Image path must be supplied as an argument')
        
    try:
        outputFileName = argv[4]
    except IndexError:
        raise IndexError('Output file name must be supplied as an argument')    
       
    with open(trainFileName) as f:
        train = json.load(f)
    for i in range(len(train)):
        train[i]['file_name'] = os.path.join(imagesPath, train[i]['file_name'])
    
    with open(evaluationFileName) as f:
        evaluate = json.load(f)
    for i in range(len(evaluate)):
        evaluate[i]['file_name'] = os.path.join(imagesPath, evaluate[i]['file_name'])
       
    DatasetCatalog.register("train", lambda: train)
    DatasetCatalog.register("evaluate", lambda: evaluate)
    
    MetadataCatalog.get("train").set(thing_classes=["product", "package"])
    MetadataCatalog.get("evaluate").set(thing_classes=["product", "package"])
    
    cfg = get_cfg()
    cfg.OUTPUT_DIR = os.path.dirname(outputFileName)
    cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
    
    cfg.DATASETS.TRAIN = ("train",)     
    cfg.DATASETS.TEST = ("evaluate", )
    
    cfg.DATALOADER.NUM_WORKERS = 2    
    
    cfg.SOLVER.IMS_PER_BATCH = 2    
    cfg.SOLVER.BASE_LR = 0.00025     
    cfg.SOLVER.MAX_ITER = 900        
    
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128     
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2  
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")   
    
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    
    trainer = DefaultTrainer(cfg)
    trainer.resume_or_load(resume=False)
    trainer.train()
    
    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   
        
    with open(outputFileName, 'wb') as f:
         pickle.dump(cfg, f)
     
if __name__ == "__main__":
    main()