###############################################################################
#  rule.py: Module for operations with packet classification rules
#  Copyright (C) 2009 Brno University of Technology, ANT @ FIT
#  Author(s): Viktor Pus <ipus@fit.vutbr.cz>
###############################################################################
#
#  LICENSE TERMS
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions
#  are met:
#  1. Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#  2. Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in
#     the documentation and/or other materials provided with the
#     distribution.
#  3. All advertising materials mentioning features or use of this software
#     or firmware must display the following acknowledgement:
#
#       This product includes software developed by the University of
#       Technology, Faculty of Information Technology, Brno and its
#       contributors.
#
#  4. Neither the name of the Company nor the names of its contributors
#     may be used to endorse or promote products derived from this
#     software without specific prior written permission.
#
#  This software or firmware is provided ``as is'', and any express or implied
#  warranties, including, but not limited to, the implied warranties of
#  merchantability and fitness for a particular purpose are disclaimed.
#  In no event shall the company or contributors be liable for any
#  direct, indirect, incidental, special, exemplary, or consequential
#  damages (including, but not limited to, procurement of substitute
#  goods or services; loss of use, data, or profits; or business
#  interruption) however caused and on any theory of liability, whether
#  in contract, strict liability, or tort (including negligence or
#  otherwise) arising in any way out of the use of this software, even
#  if advised of the possibility of such damage.
#
#  $Id$

"""
Module providing operations with packet classification rules.
"""

import sys
from netbench.common.packetheader import *
from prefix import *
from prefixset import *

class Rule(object):
    """
    Class representing one classification rule.
    """

    def __init__(self):
        """Constructor"""

        # Dictionary of 2-tuples (name:string, value:PrefixSet or :int)
        self._conditions = {}

        # Rule priority. Lower number means higher priority, 0 is the highest of all.
        self._priority = 0

        # Determines whether the rule is pseudorule (default no)
        self._pseudo = False

        # Points to the target rule in case this is pseudorule
        self._target = None

        self._cond_hash = 0

    def display(self):
        """
        Print rule in a human-readable format.
        """
        print "Rule priority: " + str(self._priority),
        if self._pseudo:
            print "Pseudorule target: ", self._target.get_priority(),


        for c, p in self._conditions.iteritems():
            print " " + c + ":",
            p.display()
        
        print ""


    def match(self, packetheader):
        """
        Return True if packet satisfies all conditions of the rule.

        packetheader: instance of the PacketHeader class.
        """

        # Get interesting headers
        ethhdr = packetheader.get_header("ethernet")
        ipv4hdr = packetheader.get_header("ipv4")
        tcphdr = packetheader.get_header("tcp")
        udphdr = packetheader.get_header("udp")

        # c is condition name, 
        # p is condition value (PrefixSet)
        for c, p in self._conditions.iteritems():
            if (c == "srcipv4"):
                if ipv4hdr:
                    ip = ipv4hdr.get_field("src_addr")
                    if not p.match(ip):
                        return False
                else:
                    return False

            elif (c == "dstipv4"):
                if ipv4hdr:
                    ip = ipv4hdr.get_field("dst_addr")
                    if not p.match(ip):
                        return False
                else:
                    return False

            elif (c == "protocol"):
                if ipv4hdr:
                    proto = ipv4hdr.get_field("protocol")
                    if not p.match(proto):
                        return False
                else:
                    return False

            elif (c == "src_port" or c == "dst_port"):
                if ipv4hdr:
                    proto = ipv4hdr.get_field("protocol")
                    if (proto == 6): # TCP
                        if tcphdr:
                            if not p.match(tcphdr.get_field(c)):
                                return False
                        else: 
                            # Strange: TCP, but no header found
                            return False
                    elif (proto == 17): # UDP
                        if udphdr:
                            if not p.match(udphdr.get_field(c)):
                                return False
                        else:
                            # Strange: UDP, but no header found
                            return False
                    else: # Unknown protocol, but condition on port?
                        return False

            else:
                print "Unknown condition name: " + c
                return False


        # No non-matching condition found -> packet matches the rule
        return True


    def set_condition(self, name, value):
        """
        Set the rule condition for one packet header field.

        name: Name of the packet header field.
        Currently supported header fields: srcipv4, dstipv4,
        protocol, src_port, dst_port

        value: Condition on the packet header field. 
        It should be PrefixSet.
        """
        self._conditions[name] = value

    def get_condition(self, name):
        """
        Get the condition for one packet header field. 
        Return None if the condition for the given field doesn't exist.

        name: Name of packet header field.
        """
        return self._conditions[name]

    def get_all_conditions(self):
        """
        Get all conditions of the rule in the dictionary.
        """
        return self._conditions

    def set_priority(self, priority):
        """
        Set the priority of the rule. Lower number means higher priority, 
        0 is the highest of all.
        """
        self._priority = priority

    def get_priority(self):
        """
        Get the priority of the rule.
        """
        return self._priority

    def set_pseudo(self, val):
        """
        Set whether rule is pseudorule.
        """
        self._pseudo = val

    def get_pseudo(self):
        """
        Return True if rule is pseudorule.
        """
        return self._pseudo

    def set_target(self, target):
        """
        Set pointer to correct rule in case this is pseudorule.
        """
        self._target = target

    def get_target(self):
        """
        Get pointer to correct rule in case this is pseudorule.
        """
        return self._target
    
    def is_universal(self):
        """
        Return True if rule is universal (matches all packets)
        """
        for v in self._conditions.itervalues():
            if not v.is_universal():
                return False
        return True

    def __eq__(self, other):
        """
        Rules are equal if they have the same priority, 
        and all of their conditions are equal.
        """
        if (self.__class__ != other.__class__):
            return False
        
        if (self._priority != other._priority):
            return False

        if (self._pseudo != other._pseudo):
            return False

        if (self._target != other._target):
            return False

        if (self._conditions != other._conditions):
            return False

        return True

    def __ne__(self, other):
        """
        Rules are inequal if they are not equal.
        """
        return not self == other


    def expand_prefixsets(self):
        """
        Return list of rules, covering the same area as the original rule,
        but each rule contains only one prefix in each condition.
        (No sets of prefixes)
        """
        rules = []
        condition_names = []
        prefixes = []

        # Gather all conditions and their prefixes
        for c, p in self._conditions.iteritems():
            condition_names.append(c)
            if (type(p) == PrefixSet):
                prefixes.append(p.get_prefixes())
            else:
                prefixes.append([p])

        # Evaluate crossproduct
        expanded = cross_lists(prefixes)

        # Create new rule for each crossproduct word
        for plist in expanded:
            r = Rule()
            r.set_priority(self._priority)
            # Set conditions one by one
            for i, cond in enumerate(condition_names):
                p = PrefixSet()                      # TODO predelat at se dava jen Prefix/MaskedInt a ne PrefixSet
                p.add_prefix(plist[i])
                r.set_condition(cond, p)
            rules.append(r)

        return rules

    def same_but_priority(self, other):
        """
        Return true if two rules are same, except for they may have the same
        priority.
        """
        if sorted(self._conditions.keys()) != sorted(other._conditions.keys()):
            return False
        
        for cond, spref in self._conditions.iteritems():
            if other._conditions[cond] != spref:
                return False
        
        return True

    def compute_conditions_hash(self):
        """
        Precompute and store hash of rule's conditions (not priority).
        """
        for c, p in self._conditions.iteritems():
            self._cond_hash = self._cond_hash ^ hash(p)

    def covers(self, other, allow_same_priority=True, has_same_conds=False):
        """
        Return True if rule fully covers another rule.

        other: Rule that we want to compare

        allow_same_priority: Do not require different priorities

        has_same_conds: set to True to skip check for same rule conditions
        """
        # Check if covering rule has higher priority (0=highest)
        if self._priority > other._priority:
            return False

        if not allow_same_priority:
            if self._priority == other._priority:
                return False

        # Check if all mine non-universal conditions are in other
        if (not has_same_conds):
            # Get all non-universal conditions
            for c, p in self._conditions.iteritems():
                if p.is_universal() or (other._conditions.has_key(c) and p.covers(other._conditions[c])):
                    continue
                else:
                    return False
        else:
            # Check all coditions one by one
            for c, p in self._conditions.iteritems():
                if not p.covers(other.get_condition(c)):
                    return False
        
        return True


#############################
# Other rule module methods #
#############################

def cross_lists(lists):
    """
    Recursive function returns crossproduct of all items from all lists.
    Result is a list of lists.

    lists: list of lists of items to be crossproducted.
    """

    result = []
    if len(lists) == 1:
        for i in lists[0]:
            result.append([i])
    else:
        lower = cross_lists(lists[1:])
        for i in lists[0]: # All items from lists[0] will be prepended
            for j in lower: # to all lists from previous recursive calls.
                t = []
                t = j[:] # Create new list as a copy
                t.insert(0, i) # Prepend
                result.append(t) # Add to results list

    return result


