###############################################################################
#  pcca.py: Module for experiments with Prefix Colouring Crossproduct Algorithm
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Michal Kajan <ikajan@fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id: pcca.py 573 2010-12-11 15:40:40Z ipus $

"""
Module for experiments with Prefix Colouring Crossproduct Algorithm (temporary codename)
"""

import copy
import sys

import bclassification
from netbench.classification.prefix_coloured import *
from netbench.classification.ruleset import *
from netbench.classification.parsers.simplenificparser import *
from netbench.classification.rule import *
from netbench.classification.common import *
from netbench.common.packetheader import *
from math import log, ceil
import random

class PCCA(bclassification.BClassification):
    """
    Prefix Colouring Crosspoduct Algorithm
    """

    def __init__(self):
        """
        Constructor
        """
        
        # Ruleset
        self._rules = None

        # Perfect hash will search in these rules only
        self._search_rules = None

        # create dictionary with unique prefixes for every dimension
        self._prefixes = {}

        # Universal rule (rule that match all packets)
        self._universalrule = None

        # create empty dictionary for colours settings
        self._colours_info = {}

        # Table of weights of graph nodes
        self._nodetable = []

        # Constant for hash function
        self._hash_const = 1

        # Rules handled by TCAM (generate most of the pseudorules)
        self._spoilers = []

        # Names of prefixes
        self._prefix_cond_names = []


    def load_ruleset(self, ruleset, remove_spoilers=8, colors=8):
        """
        Process input set of rules, generate all necessary structures
        """

        self._rules = RuleSet()

        print "Number of input rules: %d" % \
            ruleset.count_rules()

        # copy rules (this operation will remove duplicates and rules
        #             covered by other rules)
        for rule in ruleset.get_rules():
            self._rules.add_rule(rule, check=True)

        self._orig_rules_count = self._rules.count_rules()

        print "Number of rules after input filtering: %d" % \
            self._orig_rules_count

        # If there is an universal rule, store it and delete it from rulesets
        self._universalrule = None
        if self._rules.get_rule(-1).is_universal():
            print "Universal rule detected:",
            self._rules.get_rule(-1).display()
            self._universalrule = self._rules.get_rule(-1)
            self._rules.del_rule(-1)

#        # Get names of fields used in the ruleset
#        self._prefix_cond_names = self._rules.get_field_names("PrefixSet")

        self._prefix_cond_names = ['srcipv4','dstipv4','src_port','dst_port','protocol']

        self._rules.add_universal_prefixes(self._prefix_cond_names)

        # expand prefixsets for all rules in order to have rules only with
        # one prefix in each dimension
        self._rules.expand_prefixsets()

        print "Number of rules after range expansion: %d" % \
            self._rules.count_rules()

        # remove "worst" rules
        if (self._rules.count_rules() > remove_spoilers):
            self._spoilers = self._rules.remove_spoilers(remove_spoilers)

        print "Number of rules after spoilers removal: %d" % \
            self._rules.count_rules()

        # Perfect hash will search in these rules only
        self._search_rules = self._rules

        # Gather all prefixes from each dimension separately
        for r in self._rules.get_rules(): # For each rule
            for c, p in r.get_all_conditions().iteritems(): # For each cond.
                if c not in self._prefixes: # If we meet this cond. for first time
                    self._prefixes[c] = [p]         # Create the list

                else:                       # If we met this cond. before
                    if not p in self._prefixes[c]: # If this prefixset not in list
                        self._prefixes[c].append(p) # Add to the list

        # _prefixes is now dictionary with one list for each condition
        # Each list contains all prefixsets found in that dimension.

        self._prefixsets = self._rules.get_prefixes(self._prefix_cond_names)
        # _prefixsets is a dictionary of PrefixSets

        self.set_max_colouring(colors)
        self.expand_pseudorules_with_colouring()

        print "Number of rules after pseudorules expansion: %d" % \
            self._rules.count_rules()

        #for r in self._rules._rules:
            #r.display()

        self.create_graph()
        print "Number of graph vertices: %d" % len(self._nodetable)


 
    def set_colouring(self, colours):
        """
        Set number of colours in each dimension and assign colour to every
        prefix. Colours is a dictionary which specifies number of colours used
        in given dimension.
        """
        self._colours_info = {}

        # it is necessary to assign bitmaps and colours to every prefix
        # TBD      
       

    def set_max_colouring(self, colours=4):
        """
        Assigns the same number of colours to each dimension
        Also fill color bitmaps
        """
        
        for d in self._prefix_cond_names:
            self._colours_info[d] = colours

        ##### Creating empty bitmaps #####
        for cond, p in self._prefixes.iteritems(): # for each condition and each prefix

            bitmap_set = {}

            for c in self._prefixes.keys():
                # assign empty bitmap for each neighbouring dimension 
                if (c != cond):
                    #bitmap_set[c] = [0]*len(self._prefixes[c])
                    bitmap_set[c] = [0]*colours

       # this code is for maximum possible colouring --------------------------
#            colour_index = 0;
            for prefix_set_actual in self._prefixes[cond]:   # for each prefix in current condition, index 0 - only prefix in the prefixset


                for p_actual in prefix_set_actual: 

    
                    # cast original prefix to new coloured prefix
                    p_actual.__class__ = PrefixColoured
    
#                    p_actual.set_colour(colour_index)
    
                    p_actual.set_bitmaps(copy.deepcopy(bitmap_set))

                    p_actual.set_condition(cond)
    
#                    colour_index = colour_index + 1

        ##### Assigning colors #####
        for cond, prefixset in self._prefixsets.iteritems():
            colour_index = 0;
            prefixes_list = prefixset.get_prefixes()
            for i in range(len(prefixes_list)):

                # Sequential by length, with saturation
                # start with shortest prefixes
                #prefixes_list[-(i+1)].set_colour(colour_index)
                #if colour_index < colours-1:
                    #colour_index = colour_index + 1

                # Random
                prefixes_list[i].set_colour(random.randint(0, colours-1))

                # Sequential by length, no saturation (just overflow)
                #prefixes_list[-(i+1)].set_colour(colour_index)
                #if colour_index == colours-1:
                    #colour_index = 0
                #else:
                    #colour_index = colour_index + 1

                # Sequential by tree order, with saturation (Used in paper)
                #nest = prefixset.get_nesting(prefixes_list[i])
                #if nest > colours-1:
                    #prefixes_list[i].set_colour(colours-1)
                #else:
                    #prefixes_list[i].set_colour(nest)

                #print prefixes_list[i]


        ##### Bitmaps filling #####
        # now, fill in values to bitmaps for every rule - very nasty code
        for r in self._rules.get_rules():  # for each rule
           for cond_main, prefixset_main in r.get_all_conditions().iteritems(): # for each condition

               _prefix_index = self._prefixes[cond_main].index(prefixset_main)    # get index of prefix for current condition

               _bitmaps_current_prefix = self._prefixes[cond_main][_prefix_index][0].get_bitmaps()

               for cond_adjacent, prefixset_adjacent in r.get_all_conditions().iteritems(): # for each adjacent condition
                   if (cond_adjacent != cond_main): # skip equal condition
                       prefix_adjacent_index = self._prefixes[cond_adjacent].index(prefixset_adjacent)
                       _bitmaps_current_prefix[cond_adjacent][self._prefixes[cond_adjacent][prefix_adjacent_index][0].get_colour()] = 1

                   else:
                       continue


               # set '1' to bitmaps of all more specific prefixes of current prefix
               for prefix_tmp in self._prefixes[cond_main]:
                   #print "domain_size %d = " % prefix_tmp[0].get_domain_size()

                   # find more specific prefix
                   if (prefixset_main.covers(prefix_tmp)):
                       _bitmaps_tmp = prefix_tmp[0].get_bitmaps()

                       # set '1' to the bitmaps for all condition
                       for cond_tmp, bmp_tmp in _bitmaps_tmp.iteritems():
                           for i in range(len(bmp_tmp)):
                               _bitmaps_tmp[cond_tmp][i] = _bitmaps_tmp[cond_tmp][i] | _bitmaps_current_prefix[cond_tmp][i]
                       
                  

               
        for cond, prefixsets_list in self._prefixes.iteritems():
            for prefixsets in prefixsets_list:

                bitmaps = prefixsets[0].get_bitmaps()
              

#        print "&&&&&&&&&&&&&&&&&&&&&&&&&&&"
#        print "prefixes after colouring:"
#        for cond, prefixset in self._prefixsets.iteritems():
#            prefixes_list = prefixset.get_prefixes()
#            print cond
#            for i in range(len(prefixes_list)):
#                   prefixes_list[-(i+1)].display(), " "
#                   print prefixes_list[-(i+1)].get_colour()
   
#            print
#        print "&&&&&&&&&&&&&&&&&&&&&&&&&&&"



    def classify(self,packetheader):
        """
        Classifies packet using prefix colouring algorithm. Returns rule with highest
        priority or None if no match occured.
        """
        
        if (self._nodetable == []):
            print "ERROR: No rules are loaded, call load_ruleset() method first()."
            return []

        # Do LPM search on every packet header field and concatenate results to one key
        key = []
        hc = self._hash_const

        matching_prefixes = {}

        for c in self._prefix_cond_names: # For every dimension (condition)

            try:
                # Get value from packet and convert it to integer
                field = extract_field(packetheader, c)
            except FieldNotFoundError:
                raise BadPacketError("Field '%s' not found."%c)

            matching_prefixes[c] = self._prefixsets[c].match(field)

        # Check if some of the removed spoilers matches the packet
        spoiler = None
        for r in self._spoilers:
            if (r.match(packetheader)):
                if (spoiler is None) or \
                   (r.get_priority() < spoiler.get_priority()):
                    spoiler = r

        # check if there is some prefix in every condition
        # check if every condition has been matched, if some has no match,
        # packet matches either universal rule (if defined), spoiler or no rule
        for cond in matching_prefixes.keys():
            if (len(matching_prefixes[cond]) == 0):
                # found condition with no matching prefix
                if spoiler is not None: # Spoiler has higher priority than univ
                    return [spoiler]
                else:
                    if (self._universalrule):
                        return [self._universalrule]

                    else:
                        return []

        

        # bitmap of found prefixes for each dimension - from LPM stage
        # created according to the "colour set"
        bitmap_of_found_prefixes = {}
        
        # prepare empty bitmap
        for cond in self._prefixes.keys():
            bitmap_of_found_prefixes[cond] = [0]*self._colours_info[cond]
            

        # fill in bitmap
        for cond in self._prefixes.keys():
            for prefix in matching_prefixes[cond]:
                bitmap_of_found_prefixes[cond][prefix.get_colour()] = 1


        # compute logical AND through all bitmaps of found prefixes 
        # (with usage of longest prefixes
        # from other dimensions)
        for cond in matching_prefixes.keys():
            for cond2 in matching_prefixes.keys():
                if (cond != cond2):
                    for bit_position in range(len(bitmap_of_found_prefixes[cond])):
                        #print bit_position
                        #print bitmap_of_found_prefixes
                        bitmap_of_found_prefixes[cond][bit_position] = bitmap_of_found_prefixes[cond][bit_position] & (matching_prefixes[cond2][0].get_bitmaps())[cond][bit_position]


        # check if resulting bitmap in some condition is not zero - if yes - packet does not match any rule
        some_zeros = False
        for cond, prefixes in bitmap_of_found_prefixes.iteritems():
            all_zeros = True
            for bit_position in prefixes:
                if (bit_position == 1):
                    all_zeros = False
                    break

            # found bitmap with all zeros - we can end searching for 
            # the matching rule in perfect hash
            if all_zeros:
                some_zeros = True
                break

        
        rule = None
        if not some_zeros:
            #print bitmap_of_found_prefixes
            # create key tuple for perfect hash search
            for cond in self._prefix_cond_names:
                # search for longest allowed matching prefix
                for prefix in matching_prefixes[cond]:
                    if bitmap_of_found_prefixes[cond][prefix.get_colour()] == 1:
                        #print cond, prefix
                        key.append(self._prefixsets[cond].get_prefixes().index(prefix))
                        break
            key.append(hc)

            # Compute two different hash functions
            h1 = hash(tuple(key)) % len(self._nodetable)
            key.append(hc)
            h2 = hash(tuple(key)) % len(self._nodetable)

            #print key, h1, h2

            # Read weights of nodes (given by results of hash functions) from
            # the table - their sum is priority of the rule
            rulepriority = self._nodetable[h1] + self._nodetable[h2]
            #print rulepriority, self._nodetable[h1], self._nodetable[h2]

            # Find the rule and check if it really matches the values
            # from header. If it does, return the rule, if not return universal
            # rule or empty list.

            # Find rule with given priority
            for r in self._search_rules.get_rules():
                if (r.get_priority() == rulepriority):
                    rule = r
                    break

            #print rule
            #rule.display()

            # Check if found rule really matches the packet
            if rule is not None:
                if not rule.match(packetheader):
                    rule = None
            
            #print rule
        
        if spoiler is not None:
            if rule is not None:
                if spoiler.get_priority() < rule.get_priority():
                    rule = spoiler
            else:
                rule = spoiler
        
        if rule is not None:
            return [rule]
        elif self._universalrule is not None:
            return [self._universalrule]
        else:
            return []


    def print_parsed_rules(self):
        """Prints parsed rules"""
        self._rules.display()


    def get_dimensions_list(self):
        """
        Get dimensions list for given ruleset
        """
        return self._prefixes.keys()


    def prefixes_count(self, dimension):
        """
        Return number of prefixes in given dimension
        If nonexisting dimension is requested, None is returned
        """
        if (dimension in self._prefixes):
            return len(self._prefixes[dimension])
        else:
            return None


    def expand_pseudorules_with_colouring(self):
        """
        Add rules into ruleset, such that necessary LPM combinations that must
        be covered according to the colouring are covered by some rules.

        Every rule should define a condition for all fields and prefixes MUST be
        coloured prior to this step.
        """

        pseudos = []
 
        # For each rule, select its more specific prefixes
        for r in self._rules.get_rules(): # For each rule
            cover = {}  # Prepare new dictionary
            for c, p_my in r._conditions.iteritems(): # For each cond.
                cover[c] = filter(p_my.covers, self._prefixes[c])

            # Now cover is a dicitonary, where for every dimension there
            # is a list with all same or more specific prefixsets.
            # (cover is subset of prefixes.)

            conditions = []
            prefixsetslists = []
            for c, plist in cover.iteritems():
                conditions.append(c)
                prefixsetslists.append(plist)

            # conditions is a list of condition names
            # prefixsetslists is a list of lists
            cross = cross_sublists_coloured(prefixsetslists, conditions)

            # create pseudorules
            for pr in cross: # For each crossproduct
                newr = Rule() # Prepare new rule
                newr.set_priority(r._priority) # with the same priority
                for i, pset in enumerate(pr): # For each condition
#                    print "i: %d" % i
                    newr.set_condition(conditions[i], pset) # set
                if (newr != r): # Original rule will not be pseudorule
                    newr._pseudo = True
                    newr._target = r # Reference to original rule
                pseudos.append(newr)
                #newr.display()

        # Add pseudos into ruleset, if no higher-priority (pseudo)rule exists
        print "Candidate pseudorules: %d" % len(pseudos)

        #for p in pseudos:
            #p.display()

        pseudolist = []
        newrules = RuleSet()
        n = len(pseudos)
        i = 0
        lastp = -1
        hash_set = set()
        
        for pr in pseudos: # for candidate pseudorules
            if (int(i*100/n) != lastp):
                lastp = int(i*100/n)
                print ("%3d"%int(i*100/n))+"%",
                sys.stdout.flush()
                print "\b\b\b\b\b\b",
            i += 1
            
            #add = True
            #for r in self._rules._rules:
                #if r._priority >= pr._priority:
                    #break
                #for cn, c in r._conditions.iteritems():
                    #if not c.covers(pr._conditions[cn]):
                        #break
                #else:
                    #add = False
                    #break
            #if add:
                #newrules.add_rule(pr, False)

            pr.compute_conditions_hash()

            if (pr._cond_hash in hash_set): # If hash is in the set,
                # then search for matching previous pseudorule
                for epr in pseudolist: # for pseudorules already added
                    # Check if epr is same as pr
                    same = True
                    if pr._cond_hash == epr._cond_hash: # Matching hash does
                        # not tell us for sure, we have to compare prefixes
                        for c in self._prefix_cond_names:
                            if epr._conditions[c] is not pr._conditions[c]:
                                same = False
                    else: # Non-matching hash tells for sure:they are different
                        same = False

                    if same:                        # Do not add the same, only
                        if pr._priority < epr._priority: # increase priority 
                            epr._priority = pr._priority # if possible
                        break
                else: # If for cycle didn't break (no same pseudorule found)
                    pseudolist.append(pr) # Simply append
                    hash_set.add(pr._cond_hash)
            else: # No same hash found in set
                pseudolist.append(pr) # Simply append
                hash_set.add(pr._cond_hash)

        newrules._rules = sorted(pseudolist, key=Rule.get_priority)

        self._rules = newrules




    def report_memory(self):
        """
        Print statistics information - number of pseudorules,
        number of prefixes in each dimension,
        number of colours in each dimension,
        size of perfect hash table.
        """

        print "== PCCA Memory Report =="
        print "Number of spoilers: %d" % len(self._spoilers)
        print "Number of prefixes in each dimension (no spoilers):"
        for cond, prefixsets in self._prefixes.iteritems():
            print cond + ": %d" % len(prefixsets),
        print
        print "Number of colors:"
        print self._colours_info
        # Compute memory added to LPMs
        lpm_added = {}
        lpm_added_sum = 0
        for dim in self._prefix_cond_names: # For each dimension
            lpm_added[dim] = ceil(log(self._colours_info[dim], 2))
                # Store own color
            for odim in self._prefix_cond_names: # For all other dimensions
                if dim != odim:
                    lpm_added[dim] += self._colours_info[odim] # Store bitmap
            lpm_added[dim] *= len(self._prefixsets[dim])
                # Times number of prefixes
            lpm_added_sum += lpm_added[dim]

        print "Memory added to LPMs [bits]: %d" % lpm_added_sum
        print lpm_added

        print "Perfect hash table: %d items" % len(self._nodetable)
        print 'Perfect hash table: {0} bits supposing one vertex takes {1} bits'\
            .format(len(self._nodetable) * ceil(log(self._orig_rules_count, 2)),
            ceil(log(self._orig_rules_count, 2)))

        print 'Perfect hash table: {0} bits supposing one vertex takes 16 bits'\
            .format(len(self._nodetable) * 16)

        bits_per_rule = 64
        print 'Rule table supposing one rule occupies {0} bits: {1} bits'\
            .format(bits_per_rule, bits_per_rule*self._orig_rules_count)

        
 
        # add printing number of rules and pseudorules and number of colours in each dimension
    
    def print_debug_info(self):
        """
        Print debug info
        """
        
        # for every dimension print its unique prefixes
        for keys in self._prefixes.get_keys():
            print keys + ":"
            for p in self._prefixes[keys]:
               print p
                     

    def create_graph(self):
        """
        Creates acyclic graph for perfect hash function
        """

        #----------------------------------------------------------------------
        # Create graph

        # For each node there is an array of its edges - (node, weight) pairs
        # [[(node1, index_of_rule),(node2, index_of_rule2)],...]
        graph = []
        num_of_keys = len(self._rules.get_rules())
        size = 1.5
        hc = 1    # Hash constant
        print "Trying to create acyclic graph"
        last = [0,0,0]
        looped = False
        while (True):
            N = int(num_of_keys * size)
            graph = [ [] for i in range(N)]
            print "Size of graph:", N, "...",

            # For every pseudorule, create a key, compute its hashes h1 and h2
            # and add an edge to the graph
            for r in self._rules.get_rules(): # For each rule
                # Put prefixes in all dimensions into one list
                key = []
                for c in self._prefix_cond_names:  # For all dimensions
                    # Get data of the first (and only) prefix in prefixset
                    #v, m = r.get_condition(c)[0].get_data()
                    #key.append(v + (v >> 6)^hc)
                    #key.append(m)
                    key.append(self._prefixsets[c].get_prefixes().index(r.get_condition(c)[0]))
                key.append(hc)

                # Compute hash functions
                h1 = hash(tuple(key)) % N
                key.append(hc)
                h2 = hash(tuple(key)) % N

                #print key, h1, h2, r._priority
                
                dedge = False
                for (h,p) in graph[h1]:
                    if (h == h2): # This edge already exists in the graph
                        print "Double edge (rules:",p,r.get_priority(),") ->",
                        dedge = True
                if dedge:
                    break
                
                # Add edge between h1 and h2 to graph
                graph[h1].append((h2, r.get_priority()))
                graph[h2].append((h1, r.get_priority()))
            
            else:
                # Determine, if graph is acyclic
                if acyclic(graph):
                    print "OK"
                    break
            
            print "cycle detected"
            hc += 1
            size *= 1.05

        # Set weights of graph nodes in such way that sum of two nodes
        # is weight of edge between them.
        weights = [None for i in range(N)] # Array of graph nodes weights
        stack = []
        for start_node in range(N):
            if (weights[start_node] != None):
                continue
            weights[start_node] = 0
            stack = [start_node]
            while (len(stack) > 0):
                cur_node = stack.pop()
                for n, w in graph[cur_node]:
                    if (weights[n] == None):
                        weights[n] = w - weights[cur_node]
                        stack.append(n)

        self._pseudorules_count = len(self._rules.get_rules())
        self._nodetable = weights
        self._hash_const = hc


#################
# Other methods #
#################

def cross_sublists_coloured(lists, conditions):
    """
    Crossproducts coloured prefixes.
    """
    result = []
    if len(lists) == 1:
        for i in lists[0]:
            result.append([i])
    else:
        lower = cross_sublists_coloured(lists[1:], conditions[1:])
        for i in lists[0]:
            for j in lower:
                cross_possible = True
                tmp_i = i._prefixes[0]
                for k in j:
                    tmp_k = k._prefixes[0]

                    if (((tmp_k._bitmap_set)[tmp_i._condition])[tmp_i._colour] & ((tmp_i._bitmap_set)[tmp_k._condition])[tmp_k._colour] != 1):
                       cross_possible = False

                if (cross_possible == True):  
                    t = []
                    t = j[:]
                    t.insert(0,i)
                    result.append(t)
    return result


def acyclic(graph):
    """
    Returns True, if given graph is acyclic.

    graph: It must be undirected graph given as list of lists of connected
    nodes and edge weights (weights are ignored):
      [[(node1,weight1),(node2,weight2),...],...]
    """

    visits = [0 for i in range(len(graph))] # Mark all nodes as not yet visited
    next_nodes = [] # List of (next node to visit, previous node)

    for start_node in range(len(graph)):
        if (visits[start_node] > 0):
            continue
        next_nodes = [(start_node, None)]
        while (len(next_nodes) > 0):
            cur_node, prev_node = next_nodes.pop(0)
            if (visits[cur_node] > 0): # If node has been visited already
                return False           # there must be a cycle
            visits[cur_node] += 1 # Count this visit
            # For all nodes connected to current node
            for n, w in graph[cur_node]:
                if (n != prev_node):
                    next_nodes.append((n,cur_node)) # Add node to list of nodes to visit

    return True # All nodes has been visited and no cycle has been detected - graph is acyclic


