# calculator.py
#
# Simple Python script, that uses yaramodv4 module and ObservingVisitor
# interface. It takes one argument - path to file with rules, that have only
# arithmetic conditions (they must contain only \, *, +, -, () 
# and floats/integers as operands). Conditions of these rules are evaluated and
# results printed to the stdout. Be careful, if there is something else in 
# conditions it may result in undefined behavior!
#
# Requirements: yaramodv4
#
#
# USAGE:
# python3 calculator.py <path>
#
#
# EXAMPLE:
#
# test.tmp:
# ```
# rule a { condition: 1 + 3 }
# rule b { condition: 1 + 2 * 4 }
# ```
#
# Output:
# ```
# a: Result is 4
# b: Result is 9
# ```
#
# Author: Vojtěch Dvořák


import sys
import yaramodv4 as y


class CalcVisitor(y.ObservingVisitor):
    """Concrete visitor class, that performs evaluation of expressions"""

    def __init__(self):
        y.ObservingVisitor.__init__(self)
        self.stack_ = []


    def is_arithmetic(self, expression):
        """Checks if type of expression is arithmetic"""

        arithmetic_types = (
            y.MulExpression,
            y.AddExpression,
            y.SubExpression,
            y.DivExpression,
            y.ParenthesesExpression,
            y.UnaryMinusExpression,
            y.LiteralExpression
        )

        return isinstance(expression, arithmetic_types) 
    

    def visit_binary_operator(self, operator, func):
        """Method for processing binary arithmetic operator"""

        if self.is_arithmetic(operator.left):
            operator.left.accept(self)
        
        if self.is_arithmetic(operator.right):
            operator.right.accept(self)

        if len(self.stack_) >= 2:
            right = self.stack_.pop()
            left = self.stack_.pop()
            result = func(left, right)
            self.stack_.append(result)
            return True
        
        else:
            self.stack_.clear() # Clear stack if there are not enough operands
            return False
        

    def visit_unary_operator(self, operator, func):
        """Method for processing unary arithmetic operator"""

        if self.is_arithmetic(operator.operand):
            operator.operand.accept(self)

        if len(self.stack_) >= 1:
            operand = self.stack_.pop()
            result = func(operand)
            self.stack_.append(result)
            return True
        
        else:
            self.stack_.clear() # Clear stack if there are not enough operands
            return False


    def visit_Add(self, e):
        self.visit_binary_operator(e, lambda x, y: x + y)


    def visit_Sub(self, e):
        self.visit_binary_operator(e, lambda x, y: x - y)


    def visit_Mul(self, e):
        self.visit_binary_operator(e, lambda x, y: x * y)


    def visit_Div(self, e):
        left_is_int = e.left.get_type() == y.ExpressionType.Int
        right_is_int = e.right.get_type() == y.ExpressionType.Int

        if left_is_int and right_is_int:
            self.visit_binary_operator(e, lambda x, y: x // y)
        else:
            self.visit_binary_operator(e, lambda x, y: x / y)


    def visit_UnaryMinus(self, e):
        self.visit_unary_operator(e, lambda x: -x)


    def visit_Literal(self, e):
        literal = e.literal
        if literal.is_int() or literal.is_float():
            self.stack_.append(literal.get_value())
        else:
            self.stack_.clear()


    def get_result(self):
        """Returns value on the of the stack (if expression was only arithmetic, it should be the result)"""

        return self.stack_.pop()



if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Missing path!")
        exit(1)

    parser = y.Yaramod()
    yara_src = parser.parse_file(sys.argv[1])

    if yara_src.all_semantic_errors or yara_src.all_syntax_errors:
        print("Errors found in the file!")
        exit(1)

    calcVisitor = CalcVisitor()

    for rule in yara_src.all_rules:
        is_int = rule.condition.get_type() == y.ExpressionType.Int
        is_float = rule.condition.get_type() == y.ExpressionType.Float

        if is_int or is_float:
            calcVisitor.start(rule.condition)
            print(f"{rule.id}: Result is {calcVisitor.get_result()}")
        else:
            print(f"{rule.id}: Condition expr. has non-arithm. type!")

    

