###############################################################################
#  ruleset.py: Module for operations with packet classification rulesets
#  Copyright (C) 2009 Brno University of Technology, ANT @ FIT
#  Author(s): Viktor Pus <ipus@fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module providing operations with packet classification ruletets.
"""

import sys
from rule import *

class RuleSet(object):
    """Class for the set of Rules"""

    def __init__(self):
        """Constructor"""
        self._rules = []

    def display(self):
        """
        Print all rules in a human-readable format.
        """
        for r in self._rules:
            r.display()
            print "--------"

    def add_rule(self, rule, check=False):
        """
        Add Rule into RuleSet.

        rule: Instance of the Rule class.
        check: Add only rules that are not covered by any of existing rules.
        """
        # List of rules is kept ordered, lowest numbers first.
        for i, r in enumerate(self._rules):
            if check and r.covers(rule):
                #print "Rule",rule.get_priority(),"is covered by rule",r.get_priority()
                return
            
            if r._priority > rule._priority:
                self._rules.insert(i, rule)
                return

        self._rules.append(rule)

    def get_rule(self, index):
        """
        Get the rule at the given index.

        index: Index of rule, rules are ordered by priority, highest priority first
        """
        return self._rules[index]

    def is_rule_with_given_priority(self, priority):
        """
        Return true if rule with given priority exists in the ruleset
        """
        for r in self._rules:
           if (r.get_priority() == priority):
               return True
        
    def get_index_of_rule(self, priority):
        """
        Return rule with given priority
        """
        self._index = 0
        for r in self._rules:
           if (r.get_priority() == priority):
               return self._index
           else:
               self._index = self._index+1
        
    def get_rules(self):
        """
        Get list of all rules.
        """
        return self._rules

    def count_rules(self):
        """
        Return number of rules in the ruleset.
        """
        return len(self._rules)

    def del_rule(self, index):
        """
        Delete rule at the given index.

        index: Index of rule, rules are ordered by priority, highest priority first
        """
        del self._rules[index]

    def clear(self):
        """
        Delete all rules.
        """
        self._rules = []

    def classify(self, packetheader):
        """
        Return list of rules matching the packet.
        The list is sorted by priority, higest priority is first.
        If no rule matches the packet, empty list is returned.

        packetheader: instance of PacketHeader class.
        """
        list = []
        for rule in self._rules:
            if rule.match(packetheader):
                list.append(rule)

        return list

    def classify_first(self, packetheader):
        """
        Return the first matching rule. Return None if no rule
        matches.

        packetheader: instance of PacketHeader class.
        """
        for rule in self._rules:
            if rule.match(packetheader):
                return rule

    def classify_bool(self, packetheader):
        """
        Return True if packet matches any rule, False if it matches
        no rule.

        packetheader: instance of PacketHeader class.
        """
        for rule in self._rules:
            if rule.match(packetheader):
                return True

        return False

    def get_field_names(self, fieldtype=""):
        """
        Return list of names of all fields used in the ruleset.
        
        fieldtype: Return only fields of specified type (given as string,
                   eg. "RuleSet")
        """
        field_names = []
        for r in self._rules:
            for c, v in r.get_all_conditions().iteritems():
                if (fieldtype == "" or (str(type(v))[7:-2].split(".")[-1]) == fieldtype):
                    if (c not in field_names):
                        field_names.append(c)
        
        return field_names
    
    def add_universal_prefixes(self, fields = None):
        """
        Add zero-length prefix for every unused field in all rules, so every
        rule will define a condition for all fields.
        
        fields: you can provide a list of field names to use, otherwise it's
                get by calling get_field_names() method
        """
        if (fields == None):
            fields = self.get_field_names()
        
        for r in self._rules:
            for f in fields:
                if f not in r.get_all_conditions().keys():
                    # TODO: It'll be better to determine domain_size
                    p = Prefix(0, 0, 0)
                    ps = PrefixSet()
                    ps.add_prefix(p)
                    r.set_condition(f,ps)
    
    def get_prefixes(self, fields = None):
        """
        Create a list of all prefixes that appears in the ruleset.
        
        Return dictionary containing a PrefixSet for every field.
        
        fields: you can provide a list of field names to use, otherwise it's
                get by calling get_field_names() method
        """
        if (fields == None):
            fields = self.get_field_names()
        
        prefixes = {}
        for f in fields:
            prefixes[f] = PrefixSet()
        
        for r in self._rules:
            for f in fields:
                if f in r.get_all_conditions().keys():
                    prefixes[f].add_prefixes(r.get_condition(f).get_prefixes())
        
        return prefixes

    def remove_duplicates(self):
        """
        If there are more rules same but the priority, keep only the
        one with highest priority (lowest number)
        """
        print "remove_duplicates: **** Now there is the remove_covered() method, you probably want to use this *****"
        newrules = []
        
        for r in self._rules:
            dup = False
            for p in newrules:
                if r.same_but_priority(p):
                    dup = True
                    #print "Duplicates found:"
                    #r.display()
                    #p.display()
                    break
            if not dup:
                newrules.append(r)

        self._rules = newrules
    
    def remove_covered(self):
        """
        Remove rules that are fully covered by some other rule with higher
        or same priority (it also removes duplicates).
        """
        oldrules = self._rules
        self.clear()
        for r in oldrules:
            self.add_rule(r, check=True)
        
    
    def expand_prefixsets(self):
        """
        Convert all rules into the form where they contains only single 
        prefix in each condition. (No sets of prefixes)
        """
        newrules = []
        for r in self._rules:
            newrules.extend(r.expand_prefixsets())

        self._rules = newrules

    def remove_spoilers(self, count):
        """
        Detect rules that generate excessive number of pseudorules.
        These rules are removed from the ruleset, and returned as a list.

        count: Number of rules to be removed.
        """
        prefixes = {}
        # Gather all prefixes from each dimension separately
        for r in self._rules: # For each rule
            for c, p in r.get_all_conditions().iteritems(): # For each cond.
                if c not in prefixes: # If we meet this cond. for first time
                    prefixes[c] = [p]         # Create the list
                else:                 # If we met this cond. before
                    if not p in prefixes[c]: # If this prefixset not in list
                        prefixes[c].append(p) # Add to the list

        # prefixes is now dictionary with one list for each condition
        # Each list contains all prefixsets found in that dimension.

        score = []
        # For each rule, select its more specific prefixes
        for r in self._rules: # For each rule
            cover = {}  # Prepare new dictionary
            for c, p_my in r.get_all_conditions().iteritems(): # For each cond.
                cover[c] = filter(p_my.covers, prefixes[c])
            
            # Now cover is a dicitonary, where for every dimension there
            # is a list with all same or more specific prefixsets.
            # (cover is subset of prefixes.)

            cross = 1
            for i in cover.keys():
                cross = cross*len(cover[i])
            # cross is a rough estimation of number of pseudorules
            score.append([cross, r.get_priority()])

        score.sort(reverse=True)
        #print score

        removed = []

        for i in range(count):
            index = self.get_index_of_rule(score[i][1])
            #print "Rule with ", score[i][0], \
                #" pseudos has priority ", score[i][1], " and index ", index
            removed.append(self._rules[index])
            self.del_rule(index)

        return removed


    def expand_pseudorules(self):
        """
        Add rules into ruleset, such that every possible LPM combination
        is covered by some rule.

        Every rule should define a condition for all fields
        (add_universal_prefixes() method).

        Ruleset should already contain only one-prefix conditions
        (expand_prefixsets() method). 

        Also user will probably want to remove universal rule from the ruleset
        before calling this method.
        """
        
        print "Generating pseudorules:"
        
        prefixes = {}
        
        # Gather all prefixes from each dimension separately
        for r in self._rules: # For each rule
            for c, p in r.get_all_conditions().iteritems(): # For each cond.
                if c not in prefixes: # If we meet this cond. for first time
                    prefixes[c] = [p]         # Create the list
                else:                 # If we met this cond. before
                    if not p in prefixes[c]: # If this prefixset not in list
                        prefixes[c].append(p) # Add to the list

        # prefixes is now dictionary with one list for each condition
        # Each list contains all prefixsets found in that dimension.

        candidates = []
        
        # For each rule, select its more specific prefixes
        for r in self._rules: # For each rule
            cover = {}  # Prepare new dictionary
            for c, p_my in r.get_all_conditions().iteritems(): # For each cond.
                cover[c] = filter(p_my.covers, prefixes[c])

            # Now cover is a dicitonary, where for every dimension there
            # is a list with all same or more specific prefixsets.
            # (cover is subset of prefixes.)

            conditions = []
            prefixlists = []
            for c, plist in cover.iteritems():
                conditions.append(c)
                prefixlists.append(plist)

            # conditions is a list of condition names
            # prefixlists is a list of lists

            # Now, let the crossproduct begin!
            cross = cross_lists(prefixlists)

            # Create candidates to new rules (both rules and pseudorules)
            for pr in cross: # For each crossproduct
                newr = Rule() # Prepare new rule
                newr.set_priority(r.get_priority()) # with the same priority
                for i, pset in enumerate(pr): # For each condition
                    newr.set_condition(conditions[i], pset) # set
                if (newr != r): # If it's not the existing one
                    newr.set_pseudo(True) # it's a pseudorule
                    newr.set_target(r) # Reference to original rule
                candidates.append(newr)

        
        print "There are %i candidates to new rules, filtering ..." % len(candidates),
        sys.stdout.flush()
        
        newrules = []
        n = len(candidates)
        i = 0
        lastp = -1
        
        for pr in candidates:
            if (int(i*100/n) != lastp):
                lastp = int(i*100/n)
                print ("%3d"%int(i*100/n))+"%",
                sys.stdout.flush()
                print "\b\b\b\b\b\b",
            i += 1
            
            add = True
            for r in self._rules:
                if r._priority >= pr._priority:
                    break
                for cn, c in r._conditions.iteritems():
                    if not c.covers(pr.get_condition(cn)):
                        break
                else:
                    add = False
                    break
            if add:
                newrules.append(pr)
            
        self._rules = newrules
        print "done"


                   
def cond2field(cond):
    """
    Return array of tuples containing name of header and field, where a 
    value of given condition can be found in packet, and it's type
    """
    if (cond == "srcipv4"):
        return [("ipv4", "src_addr", "ipv4str")]
    elif (cond == "dstipv4"):
        return [("ipv4", "dst_addr", "ipv4str")]
    elif (cond == "src_port"):
        return [("tcp", "src_port", "int16"),
                ("udp", "src_port", "int16")]
    elif (cond == "dst_port"):
        return [("tcp", "dst_port", "int16"),
                ("udp", "dst_port", "int16")]
    elif (cond == "protocol"):
        return [("ipv4", "protocol", "int8")]
    # TODO
    elif (cond == "srcmac"):
        return [("???", "???", "macstr")]
    elif (cond == "dstmac"):
        return [("???", "???", "macstr")]
    
    elif (cond == "tcpflags"):
        return [("tcp", "???", "int8")]
    else:
        return None

