###############################################################################
#  clark_nfa.py: Module for PATTERN MATCH
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Vlastimil Kosar <ikosar@fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

import nfa_data
import b_nfa 
import b_state
import copy
import re
import msfm_parser
import sym_char_class
import sym_char

class clark_nfa(b_nfa.b_nfa):
    """A base class for NFA automata."""
    def __init__(self):
        b_nfa.b_nfa.__init__(self)
        self.width = 8
        self.template = "templates/vhdl/clark_nfa.vhd"
        self._statistic = dict()
        self._useBram = False
        self._LUTInputs = 6
        
    def  get_HDL(self, useBram = False):
        """Return HDL description of NFA unit implemented in Clark's approach."""
        f = open(self.template, "rb");    #Opens file automat
        blob = f.read()
        tmp = re.split("%\$%", blob)
        
        self._useBram = useBram
        self._luts = 0
        
        self.remove_epsilons()
        
        chdec = self._get_char_dec_HDL()
        logic = self._get_logic_HDL()
        final = self._get_final_HDL()
        dataSignal = chdec[0] + logic[0] + final[0]
        dataImplementation = chdec[1] + logic[1] + final[1]
        
        textSignal = "-- --------------- GENERATED BY CLARK_NFA.PY ----------------\n"
        textImplementation = "-- --------------- GENERATED BY CLARK_NFA.PY ----------------\n"
        
        for element in dataSignal:
            textSignal += element
        for element in dataImplementation:
            textImplementation += element
        textSignal += "-- --------------- END ----------------\n"
        textImplementation += "-- --------------- END ----------------\n"
        
        result = tmp[0] + str(self.width) + tmp[1] + str(self._fStateNum) + tmp[2] + textSignal + tmp[3] + textImplementation + tmp[4]
        return result
        
    def _get_char_dec_HDL(self, useBram = False):
        """ Return HDL description of shared char decoder as (signals, description). """
        # TODO: implement it in BRAM
        if useBram == True:
            raise Exception
        else:
            allChars = set()
            
            for symbol in self._automaton.alphabet.keys():
                if isinstance(self._automaton.alphabet[symbol], sym_char_class.b_Sym_char_class):
                    for char in self._automaton.alphabet[symbol].charClass:
                        allChars.add(char)
                if isinstance(self._automaton.alphabet[symbol], sym_char.b_Sym_char):
                    allChars.add(self._automaton.alphabet[symbol].char)
                    
            signalList = list()
            descriptionList = list()
            
            for char in allChars:
                signalList.append("    signal char_" + str(ord(char)) + " : std_logic;\n")
                chrDec  = "    VD_" + str(ord(char)) + ": entity work.VALUE_DECODER\n" 
                chrDec += "    generic map(\n"
                chrDec += "        DATA_WIDTH => " + str(self.width) + ",\n"
                chrDec += "        VALUE      => " + str(ord(char)) + "\n"
                chrDec += "    )\n"
                chrDec += "    port map(\n"
                chrDec += "        INPUT  => DATA,\n"
                chrDec += "        OUTPUT => char_" + str(ord(char)) + "\n"
                chrDec += "    );\n\n"
                descriptionList.append(chrDec)
                self._luts += self._countLUTsForInpunts(self.width) 
                  
            for symbol in self._automaton.alphabet.keys():
                signalList.append("    signal symbol_" + str(symbol) + " : std_logic;\n")
                if isinstance(self._automaton.alphabet[symbol], sym_char_class.b_Sym_char_class):
                    text = "    symbol_" + str(symbol) + " <= "
                    first = True
                    self._luts += self._countLUTsForInpunts(len(self._automaton.alphabet[symbol].charClass))
                    for char in self._automaton.alphabet[symbol].charClass:
                        if first == True:
                            text += "char_" + str(ord(char))
                            first = False
                        else:
                            text += " or char_" + str(ord(char))
                    text += ";\n"
                    descriptionList.append(text)
                if isinstance(self._automaton.alphabet[symbol], sym_char.b_Sym_char):
                    descriptionList.append("    symbol_" + str(symbol) + " <= char_" + str(ord(self._automaton.alphabet[symbol].char)) + ";\n")
             
            return (signalList, descriptionList)
                
    def _get_logic_HDL(self):
        """ Return HDL description of states and transitions as (signals, description). """
        signalList = list()
        descriptionList = list()
        
        for state in self._automaton.states.keys():
            signalList.append("    signal state_in_" + str(state) + " : std_logic;\n")
            signalList.append("    signal state_out_" + str(state) + " : std_logic;\n")
            if state == self._automaton.start:
                text  =  "    STATE_" + str(state) + ": entity work.STATE\n"
                text +=  "    generic map(\n"
                text +=  "        DEFAULT     => '1'\n"
                text +=  "    )\n"
                text +=  "    port map(\n"
                text +=  "        CLK    => CLK,\n"
                text +=  "        RESET  => local_reset,\n"
                text +=  "        INPUT  => '0',\n"
                text +=  "        WE     => we,\n"
                text +=  "        OUTPUT => state_out_" + str(state) + "\n"
                text +=  "    );\n\n"
            else:
                text  =  "    STATE_" + str(state) + ": entity work.STATE\n"
                text +=  "    generic map(\n"
                text +=  "        DEFAULT     => '0'\n" 
                text +=  "    )\n"
                text +=  "    port map(\n"
                text +=  "        CLK    => CLK,\n"
                text +=  "        RESET  => local_reset,\n"
                text +=  "        INPUT  => state_in_" + str(state) + ",\n"
                text +=  "        WE     => we,\n"
                text +=  "        OUTPUT => state_out_" + str(state) + "\n"
                text +=  "    );\n\n"
            descriptionList.append(text)
        
        inputTransitions = dict()
        for transition in self._automaton.transitions:
            if inputTransitions.has_key(transition[2]) == True:
                inputTransitions[transition[2]].add((transition[0], transition[1]))
            else:
                inputTransitions[transition[2]] = set()
                inputTransitions[transition[2]].add((transition[0], transition[1]))
                                
        for state in inputTransitions.keys():
            if len(inputTransitions[state]) == 1:
                data = inputTransitions[state].pop()
                descriptionList.append("    state_in_" + str(state) + " <= state_out_" + str(data[0]) + " and symbol_" + str(data[1]) + ";\n")
                self._luts += self._countLUTsForInpunts(2)
            else:
                text = str()
                first = True
                self._luts += self._countLUTsForInpunts(2*len(inputTransitions[state]))
                for transition in inputTransitions[state]:
                    if first == True:
                        text += "    state_in_" + str(state) + " <= (state_out_" + str(transition[0]) + " and symbol_" + str(transition[1]) + ")"
                        first = False
                    else:
                        text += " or (state_out_" + str(transition[0]) + " and symbol_" + str(transition[1]) + ")"
                text += ";\n"
                descriptionList.append(text)
        return (signalList, descriptionList)
                        
    def _get_final_HDL(self):
        """ Return HDL description of interconection of final states as (signals, description). """
        signalList = list()
        descriptionList = list()
        
        sameFinal = dict()
        for fstate in self._automaton.final:
            if sameFinal.has_key(self._automaton.states[fstate].get_regexp_number()) == True:
                sameFinal[self._automaton.states[fstate].get_regexp_number()].add(fstate)
            else:
                sameFinal[self._automaton.states[fstate].get_regexp_number()] = set()
                sameFinal[self._automaton.states[fstate].get_regexp_number()].add(fstate)
        
        self._fStateNum = len(sameFinal)
        
        signalList.append("    signal bitmap_in : std_logic_vector(" + str(len(sameFinal)) + " - 1 downto 0);\n")
        for final in sameFinal.keys():
            signalList.append("    signal final_" + str(final) + " : std_logic;\n")
            if len(sameFinal[final]) == 1:
                state = sameFinal[final].pop()
                descriptionList.append("    final_" + str(final) + " <= state_out_" + str(state) + ";\n")
            else:
                first = True
                self._luts += self._countLUTsForInpunts(len(sameFinal[final]))
                for pfinal in sameFinal[final]:
                    if first == True:
                        text = "    final_" + str(final) + " <= state_out_" + str(pfinal)
                        first = False
                    else:
                        text += " or state_out_" + str(pfinal)
                text += ";\n"
                descriptionList.append(text)
                
        sfKeys = sameFinal.keys()
        for i in range(0, len(sfKeys)):
            descriptionList.append("    bitmap_in(" + str(i) + ") <= final_" + str(sfKeys[i]) + ";\n")
                
        return (signalList, descriptionList)

    def report_logic(self):
        ffSize = len(self._automaton.states) + self._fStateNum + 1
        if self._useBram == False:
            bramSize = 0
        else:
            raise Exception
           
        return (self._luts, ffSize, bramSize)
        
    def _countLUTsForInpunts(self, inputs):
        i = inputs
        res = 0;
        if i <= self._LUTInputs:
            res = 1;
        else:
            while i > self._LUTInputs:
                if i/self._LUTInputs == round(i/self._LUTInputs):
                    res = res + int(round((i / self._LUTInputs), 0))
                    i = int(round((i / self._LUTInputs), 0))
                else:
                    res = res + int(round((i / self._LUTInputs) + 0.5, 0))
                    i = int(round((i / self._LUTInputs) + 0.5, 0))
            res = res + 1;
        return res
                                        
if __name__ == '__main__':
    cn = clark_nfa()
    Test0 = msfm_parser.msfm_parser()
    Test0.load_file("rules/Moduly/backdoor.rules.pcre")
    #Test0.load_file("rules/l7/selected/selected.pcre")
    #Test0.set_text("/abc[d-f]/")
    #nfa0 = Test0.get_nfa()
    #cn.create_from_nfa_data(nfa0)
    cn.create_by_parser(Test0)
    cn.remove_epsilons()
    textl = str(len(cn._automaton.states)) + ";" + str(len(cn._automaton.transitions)) + ";"
    res = cn.get_HDL()
    data = cn.report_logic()
    textl += str(data[0]) + ";" + str(data[1]) + ";"
    cn._removeCharClasses()
    cn.reduce_alphabet()
    cn._createCharClasses()
    textl += str(len(cn._automaton.states)) + ";" + str(len(cn._automaton.transitions)) + ";"
    res = cn.get_HDL()
    data = cn.report_logic()
    textl += str(data[0]) + ";" + str(data[1]) + ";"
    #print(cn._automaton.transitions)
    print(textl)
    print(cn._statistic)
    
    #print(cn.get_HDL())
    #cn._automaton.Show("test.dot", " size=\12,8\"\n")
    #print(cn.report_logic())
    #print(cn.search("abcd4595494"))
    #print(cn.search("9655659"))
    #print(cn._automaton)