###############################################################################
#  phca.py: Module for experiments with Perfect Hashing Crossproduct Algorithm
#  Copyright (C) 2010 Brno University of Technology, ANT @ FIT
#  Author(s): Vaclav Bartos <xbarto11@stud.fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module for experiments with Perfect Hashing Crossproduct Algorithm
"""

import sys
import copy
import bclassification
from netbench.classification.rule import *
from netbench.classification.ruleset import *
from netbench.classification.prefix import *
from netbench.classification.prefixset import *
from netbench.classification.common import *
from netbench.common.packetheader import *

class PHCA(bclassification.BClassification):
    """
    Perfect Hashing Crossproduct Algorithm.
    """

    def __init__(self, verbose=True):
        """Constructor"""
        
        # Given ruleset
        self._ruleset = None
        
        # List of all condition names (dimensions) used in ruleset
        self._cond_names = []
        
        # List of all prefixes (one PrefixSet for each dimension)
        self._prefixes = {}
        
        # Universal rule (rule that match all packets)
        self._universalrule = None
        
        # Table of weights of graph nodes
        self._nodetable = []
        
        # Constant for hash function
        self._hash_const = 1
        
        # Rules that have been removed from ruleset (and in real application 
        # it will be processed by TCAM)
        self._spoilers = []
        
        # Set verbosity
        self._verbose = verbose

    def load_ruleset(self, ruleset):
        """
        Load rules and generate nodetable (table of weights of graph nodes
        used for perfect hashing)
        """
        vrbs = self._verbose
        
        orig_count = ruleset.count_rules()
        
        if (orig_count == 0):
            print "ERROR: Ruleset is empty"
            return
        
        # Copy ruleset (and remove duplicates and rules covered by other rules)
        rs = RuleSet()      # Will be stored to self._ruleset
        prs = RuleSet()     # Will be expanded to pseudorules
        for rule in ruleset.get_rules():
            rs.add_rule(rule, check=True) # This copies only useful rules
        for rule in rs.get_rules():
            prs.add_rule(rule)
        
        print "Number of rules:", prs.count_rules(),
        print "(%d loaded, %d filtered out)"%(orig_count,orig_count-prs.count_rules())
        
        # If there is an universal rule, store it and delete it from rulesets
        self._universalrule = None
        if rs.get_rule(-1).is_universal():
            print "Universal rule detected:",
            rs.get_rule(-1).display()
            self._universalrule = rs.get_rule(-1)
            rs.del_rule(-1)
            prs.del_rule(-1)
        self._ruleset = rs
        
        # Get names of fields used in the ruleset
        #self._cond_names = rs.get_field_names()
        self._cond_names = ['srcipv4','dstipv4','src_port','dst_port','protocol']
        
        # Generate and add pseudorules to the ruleset
        prs.add_universal_prefixes(self._cond_names)
        prs.expand_prefixsets()
        print "After expanding prefixsets:",prs.count_rules()
        if (prs.count_rules() > 8):
            print "Removing 8 spoilers (rules, that generates most pseudorules) ...",
            sys.stdout.flush()
            self._spoilers = prs.remove_spoilers(8)
            print "done"
            print "Removed rules priorities:", [r.get_priority() for r in self._spoilers]

        print "Final number of rules:",prs.count_rules()
        print
        prs.expand_pseudorules()
        print "Final number of (pseudo)rules:",prs.count_rules()
        
        
        # Create a list of all prefixes (dict with one PrefixSet for every dimension)
        self._prefixes = prs.get_prefixes(self._cond_names)
        
        #----------------------------------------------------------------------
        # Create graph
        
        # For each node there is an array of its edges - (node, weight) pairs
        # [[(node1, index_of_rule),(node2, index_of_rule2)],...]
        graph = []
        num_of_keys = prs.count_rules()
        size = 1.5
        hc = 1    # Hash constant
        print
        print "Trying to create acyclic graph..."
        last = [0,0,0]
        looped = False
        while (True):
            N = int(num_of_keys * size)
            graph = [ [] for i in range(N)]
            if vrbs: print "Size of graph:", N, "...",
            sys.stdout.flush()
            
            # For every pseudorule, create a key, compute its hashes h1 and h2
            # and add an edge to the graph
            for r in prs.get_rules(): # For each rule
                
                # Put prefixes in all dimensions into one list
                key = []
                for c in self._cond_names:  # For all dimensions
                    # Get data of the first (and only) prefix in prefixset
                    v, m = r.get_condition(c)[0].get_data()
                    key.append(v + (v >> 6)^hc)
                    key.append(m)
                
                # Compute hash functions
                h1 = hash(tuple(key)) % N
                key.append(hc)
                h2 = hash(tuple(key)) % N
                
                dedge = False
                for (h,p) in graph[h1]:
                    if (h == h2): # This edge already exists in the graph
                        if vrbs: print "Double edge (rules:",p,r.get_priority(),") ->",
                        dedge = True
                if dedge:
                    break
                
                # Add edge between h1 and h2 to graph
                graph[h1].append((h2, r.get_priority()))
                graph[h2].append((h1, r.get_priority()))
            
            else:
                # Determine, if graph is acyclic
                if acyclic(graph):
                    if vrbs: print "OK"
                    break
            
            if vrbs: print "cycle detected"
            hc += 1
            size *= 1.05
            #if (size > 100):
            #    print "ERROR: Can't create acyclic graph."
            #    return
            #print size
        
        # Set weights of graph nodes in such way that sum of two nodes  
        # is weight of edge between them.
        weights = [None for i in range(N)] # Array of graph nodes weights
        stack = []
        for start_node in range(N):
            if (weights[start_node] != None):
                continue
            weights[start_node] = 0
            stack = [start_node]
            while (len(stack) > 0):
                cur_node = stack.pop()
                for n, w in graph[cur_node]:
                    if (weights[n] == None):
                        weights[n] = w - weights[cur_node]
                        stack.append(n)
        
        print
        if vrbs: print "Final number of (pseudo)rules:", len(prs.get_rules())
        print "Final size of nodetable:", len(weights)
        #print "Nodetable:",weights
        
        self._pseudorules_count = len(prs.get_rules())
        self._nodetable = weights
        self._hash_const = hc


    def classify(self, packetheader):
        """
        Classify packet by the algorithm.
        Return the list containing a highest priority rule that given packet
        matches (or empty list if it matches no rule) 
        
        packetheader: Instance of the PacketHeader class.
        """
        
        vrbs = self._verbose
        
        if self._nodetable == []:
            print "ERROR: No rules are loaded, call load_ruleset() method first."
            return []
        
        # Do LPM on every header field and concatenate results to one key
        key = []
        hc = self._hash_const
        for c in self._cond_names: # For every dimension (condition)
            
            try:
                # Get value from packet and convert it to integer
                field = extract_field(packetheader, c)
            except FieldNotFoundError:
                raise BadPacketError("Field '%s' not found."%c)
            
            # Find longest prefix
            prefix = self._prefixes[c].match_longest(field)
            # Add it to key
            if (prefix):
                v, m = prefix.get_data()
                key.append(v + (v >> 6)^hc)
                key.append(m)
#                 if vrbs:
#                     print field,
#                     prefix.display()
#                     print
            else:
                key.append(0)
#                 if vrbs:
#                     print field, "None"
        
        # Compute two different hash functions
        h1 = hash(tuple(key)) % len(self._nodetable)
        key.append(hc)
        h2 = hash(tuple(key)) % len(self._nodetable)
        
        # Read weights of nodes (given by results of hash functions) from
        # the table - their sum is priority of the rule
        rulepriority = self._nodetable[h1] + self._nodetable[h2]
        

        # Find rule with given priority
        rule = None
        for r in self._ruleset.get_rules():
            if (r.get_priority() == rulepriority):
                rule = r
                break
        
        # Check if found rule really matches the packet
        if rule is not None:
            if not rule.match(packetheader):
                rule = None
        
        # Check if some of the removed spoilers matches the packet
        for r in self._spoilers:
            if (r.match(packetheader)):
                if (rule is None) or (r.get_priority() < rule.get_priority()):
                    rule = r
        
        if rule is not None:
            return [rule]
        elif self._universalrule is not None:
            return [self._universalrule]
        else:
            return []
        

    def report_memory(self):
        """
        Print detailed info about algorithm memory requirements.
        """
        
        print "========== PHCA memory report =========="
        print "Number of pseudorules:", self._pseudorules_count
        print "Size of node table:", len(self._nodetable)



#################
# Other methods #
#################

def acyclic(graph):
    """
    Returns True, if given graph is acyclic.
    
    graph: It must be undirected graph given as list of lists of connected 
    nodes and edge weights (weights are ignored):
      [[(node1,weight1),(node2,weight2),...],...]
    """
    
    visits = [0 for i in range(len(graph))] # Mark all nodes as not yet visited
    next_nodes = [] # List of (next node to visit, previous node)
    
    for start_node in range(len(graph)):
        if (visits[start_node] > 0):
            continue
        next_nodes = [(start_node, None)]
        while (len(next_nodes) > 0):
            cur_node, prev_node = next_nodes.pop(0)
            if (visits[cur_node] > 0): # If node has been visited already
                return False           # there must be a cycle
            visits[cur_node] += 1 # Count this visit
            # For all nodes connected to current node 
            for n, w in graph[cur_node]:
                if (n != prev_node):
                    next_nodes.append((n,cur_node)) # Add node to list of nodes to visit
            
    return True # All nodes has been visited and no cycle has been detected - graph is acyclic

