# batch_parsing.py
#
# Author: Vojtěch Dvořák
#
# Script for pytest framework, that tests yaramod python bindings. Tests in
# this file are focused on batch parsing of YARA files.
#
# Requirements: pytest, yaramod_python


import pytest
import math
import os
import yaramodv4 as y


TMP_FILENAMES = [os.path.abspath("__test_file1__.tmp"), os.path.abspath("__test_file2__.tmp")]


def test_parsing():
    """Basic parsing of trivial string should not throw any error"""

    try:
        yaramod = y.Yaramod()
        yaramod.parse_string("rule a { condition: true }")
    except:
        pytest.fail("Parsing throw unexpected error!")



def test_config_error_mode():
    """Error mode of yaramod parser should be configurable"""
    
    yaramod = y.Yaramod()

    try:
        yaramod.parse_string("rule a { condition: true ")
    except:
        pytest.fail("Unexpected error was thrown!")


    yaramod.config.error_mode = y.ErrorMode.Strict

    try:
        yaramod.parse_string("rule a { condition: true ")
    except y.YaraSyntaxError:
        pass # Expected error was thrown
    except:
        pytest.fail("Another type of error was expected to be thrown!")
    else:
        pytest.fail("Error was excepted, but it was not thrown!")



def test_yarasource():
    """Tests of interface of YaraSource class"""

    yara_string = """
rule a { 
    condition: true
}

rule b { 
    condition: true
}   
    """

    file = open(TMP_FILENAMES[0], "w")
    file.write(yara_string)
    file.close()

    yaramod = y.Yaramod()

    yara_src = yaramod.parse_file(TMP_FILENAMES[0])

    assert(yara_src.entry_file.name == TMP_FILENAMES[0])
    assert(yara_src.has_file(TMP_FILENAMES[0]))
    assert(yara_src.get_file(TMP_FILENAMES[0]) is yara_src.entry_file)

    yara_src.remove_file(TMP_FILENAMES[0])

    assert(len(yara_src.files) == 0)



def test_yarafile():
    """Tests of interface of YaraFile class"""

    yara_string = """
import "cuckoo"

include "__test.tmp__"

//Hello

rule a { 
    condition: true
}

/*
World
*/

rule b { 
    condition: true
}   
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    
    entry_file = yara_src.entry_file

    assert(len(entry_file.rules) == 2)
    
    assert(entry_file.get_local_rule_by_row(8).id == "a")
    assert(entry_file.get_local_rule_by_row(10) is None)
    assert(entry_file.get_local_rule_by_row(16).id == "b")

    assert(len(entry_file.comments) == 2)
    assert(len(entry_file.local_file_includes) == 1)
    assert(len(entry_file.imports) == 1)
    assert(len(entry_file.semantic_errors) == 1)



def test_rule():
    """Tests of interface of Rule class"""

    yara_string = """
import "cuckoo"

include "__test.tmp__"

//Hello

rule a : tag1 tag2 { 
    strings:
        $a = "abc"
    condition: 
        true
}

/*
World
*/

global rule b { 
    condition: true
}   
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    
    entry_file = yara_src.entry_file
    rules = entry_file.local_rules
    imports = entry_file.imports
    includes = entry_file.local_file_includes

    assert(rules[0].id == "a")

    assert(len(rules[0].strings) == 1)
    assert(rules[0].strings[0].id == "a")

    assert(len(rules[0].tags) == 2)
    assert(rules[0].tags[0] == "tag1")
    assert(rules[0].tags[1] == "tag2")

    assert(rules[1].id == "b")
    assert(rules[1].is_global())

    assert(imports[0].module_name == "cuckoo")
    assert(includes[0].path == "__test.tmp__")
    assert(entry_file.comments[0].text == "//Hello")



def test_strings():
    """Tests of interface of String class"""

    yara_string = """

rule a : tag1 tag2 { 
    strings:
        $ = /He{1,3}llo/i
        $b = { AA BB CC } 
        $a = "abc" xor(1-4) base64
        $c = "def" base64("0123456789012345678901234567890123456789012345678901234567890123")
    condition: 
        true
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    
    entry_file = yara_src.entry_file
    strings = entry_file.rules[0].strings
    
    string_cnt = 4
    assert(len(strings) == string_cnt)
    assert(strings[0].is_anonymous())
    assert(strings[0].content == "/He{1,3}llo/i")

    assert(strings[1].id == "b")

    assert(strings[2].id == "a")
    assert(strings[2].content == "abc")
    assert(strings[2].modifiers.get(y.StringModifierType.Xor).get_arg()[0] == 1)
    assert(strings[3].modifiers.get(y.StringModifierType.Base64).get_arg() == list("0123456789012345678901234567890123456789012345678901234567890123"))

    # Tets if strings are iterable
    cnt = 0
    for s in strings:
        assert(s is not None)
        cnt += 1

    assert(string_cnt == cnt)



def test_multiple_files():
    """Test of batch parsing of included files and access to included files"""

    yara_string0 = f"""
import "cuckoo"

include "{TMP_FILENAMES[1]}"

rule malware1 {{
    strings:
        $abc = "abc"
    condition:
        true
}}

  
    """

    yara_string1 = """
rule a { 
    condition: true
}

rule b { 
    condition: true
}   
    """

    file = open(TMP_FILENAMES[0], "w")
    file.write(yara_string0)
    file.close()

    file = open(TMP_FILENAMES[1], "w")
    file.write(yara_string1)
    file.close()

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_file(TMP_FILENAMES[0])
    
    entry_file = yara_src.entry_file
    included_file = entry_file.local_file_includes[0].file

    assert(len(entry_file.rules) == 3)
    assert(entry_file.rules[2].id == "malware1")

    assert(included_file.name == os.path.abspath(TMP_FILENAMES[1]))
    assert(len(included_file.local_rules) == 2)
    assert(included_file.local_rules[0].id == "a")
    assert(included_file.local_rules[1].id == "b")



def test_get_text_formatted():
    """Tests of formatting interface of Printable class"""

    yara_string = """

rule a : tag1 tag2 { 
    strings:
        $ = /He{1,3}llo/i
        $b = { AA BB CC } 
        $a = "abc" xor(1-4) base64
        $c = "def" base64("0123456789012345678901234567890123456789012345678901234567890123")
    condition: 
        true

        and

        true
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    
    entry_file = yara_src.entry_file

    expected_text = """rule a : tag1 tag2 {
\tstrings:
\t\t$ = /He{1,3}llo/i
\t\t$b = { AA BB CC }
\t\t$a = "abc" xor(1-4) base64
\t\t$c = "def" base64("0123456789012345678901234567890123456789012345678901234567890123")
\tcondition:
\t\ttrue and true
}

"""
    
    assert(entry_file.get_text_formatted() == expected_text)



def test_repetetive_parsing():
    """Results of repetetive batch parsing should be consistent"""

    yara_string = "rule a { condition: true "

    yaramod = y.Yaramod()

    src1 = yaramod.parse_string(yara_string)
    src2 = yaramod.parse_string(yara_string)
    src3 = yaramod.parse_string(yara_string)

    syntax_errors1 = src1.entry_file.syntax_errors
    syntax_errors2 = src2.entry_file.syntax_errors
    syntax_errors3 = src3.entry_file.syntax_errors

    assert(len(syntax_errors1) == 1)
    assert(len(syntax_errors1) == len(syntax_errors2))
    assert(len(syntax_errors2) == len(syntax_errors3))

    assert(syntax_errors1.data[0].description == syntax_errors3.data[0].description)
    assert(syntax_errors1.data[0].offset == syntax_errors3.data[0].offset)
    assert(syntax_errors1.data[0].len == syntax_errors3.data[0].len)

    assert(src1.entry_file.rules[0].id ==  src2.entry_file.rules[0].id)
    assert(src2.entry_file.rules[0].id ==  src3.entry_file.rules[0].id)



def test_metas():
    """Tests of interface of Meta class"""

    yara_string = """

rule very_bad_sw { 
    meta:
        author = "John Doe"
        level = 10
        md5 = "d746582b12adbf358f67e04b202a32b3"
        critical = true
    condition: 
        true
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["very_bad_sw"]

    assert(len(rule.metas) == 4)
    assert(rule.metas.get("author").value.str == "John Doe")
    assert(rule.metas.get("level").value.int == 10)
    assert(rule.metas.get("critical").value.bool == True)




def test_internal_variables():
    """Tests of interface of IntVariable class (Avast feature)"""

    yara_string = """

rule rule_with_internal_variables {
    variables:
        a = "abc"
        b = false
        c = 10
        d = 0o100
    condition: 
        b or c
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["rule_with_internal_variables"]

    condition = rule.condition

    var_values = {
        "a" : "abc", 
        "b" : False, 
        "c" : 10,
        "d" : 64
    }

    for v in rule.vars:
        assert(var_values[v.id] == v.value.literal.get_value())

    assert(isinstance(condition, y.OrExpression))
    
    assert(condition.right.get_type() == y.ExpressionType.Int)
    assert(condition.right.id == "c")
    


def test_aritmetic_expression():
    """Tests of interface of Expression class and its subclasses"""

    yara_string = """

rule a { 
    condition: 
        3 * (2 + 3) + 0 \\ 4 - 1.25
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["a"]

    condition = rule.condition

    assert(condition.is_valid()[0])
    assert(condition.get_type() == y.ExpressionType.Float)
    assert(isinstance(condition, y.SubExpression))


    left = condition.left

    assert(left.get_type() == y.ExpressionType.Int)
    assert(isinstance(left, y.AddExpression))

    left_right = left.right

    assert(left_right.get_type() == y.ExpressionType.Int)
    assert(isinstance(left_right, y.DivExpression))


    right = condition.right

    assert(right.get_type() == y.ExpressionType.Float)
    assert(isinstance(right, y.LiteralExpression))
    assert(right.literal.is_float())
    assert(math.isclose(right.literal.float, 1.25))



def test_stringref_expression():
    """Tests of interface of Expression class and its subclasses"""

    yara_string = """

rule a {
    strings:
        $a = "first"
        $b = "second"
    condition: 
        $a and not $b
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["a"]

    condition = rule.condition

    assert(condition.is_valid()[0])
    assert(condition.get_type() == y.ExpressionType.Bool)
    assert(isinstance(condition, y.AndExpression))

    left = condition.left
    assert(isinstance(left, y.StringExpression))
    assert(left.id == "a")

    right = condition.right
    assert(isinstance(right, y.NotExpression))
    
    right_op = right.operand
    assert(isinstance(right_op, y.StringExpression))
    assert(right_op.id == "b")



def test_module_ref_expression():
    """Tests of interface of Expression class and its subclasses"""

    yara_string = """
import "cuckoo"

rule a {
    condition: 
        cuckoo.network.dns_lookup(/verybadsite\\.com/)
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["a"]

    condition = rule.condition

    assert(isinstance(condition, y.FunctionCallExpression))
    assert(condition.args[0].get_type() == y.ExpressionType.Regexp)
    assert(condition.args[0].content == "/verybadsite\\.com/")

    assert(isinstance(condition.function, y.StructExpression))



def test_for_expression():
    """Tests of interface of Expression class and its subclasses"""

    yara_string = """
import "cuckoo"

rule malware_sample {
    strings:
        $a = "a"
        $b = "b"
        $c = "c"
    condition: 
        for all of ($a, $b, $c) : ( @ > 10 )
}
  
    """

    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["malware_sample"]

    condition = rule.condition
    assert(isinstance(condition, y.ForExpression))
    assert(isinstance(condition.quantifier, y.AllExpression))
    assert(isinstance(condition.set, y.StringSetExpression))

    for_set = condition.set
    assert(for_set.elements[2].id == "c")

    inner_expression = condition.inner_expression
    assert(isinstance(inner_expression, y.GtExpression))



def test_observing_visitor():
    """Tests of interface of ObservingVisitor"""

    yara_string = """

rule malware_sample {
    strings:
        $a1 = "a"
        $a2 = "a"
    condition: 
        123 == 321 and (431 != 654) or true or not false and for all of ($a*) : ( @ != 10 )
}
  
    """

    class CounterVisitor(y.ObservingVisitor):
        def __init__(self):
            y.ObservingVisitor.__init__(self)
            self.eq_expression_cnt = 0
            self.bool_literal_cnt = 0
            self.neq_expression_cnt = 0

        def visit_Eq(self, e):
            self.eq_expression_cnt += 1
            e.left.accept(self)
            e.right.accept(self)

        def visit_Neq(self, e):
            self.neq_expression_cnt += 1
            e.left.accept(self)
            e.right.accept(self)

        def visit_Literal(self, e):
            literal = e.literal
            if literal.is_bool():
                self.bool_literal_cnt += 1


    yaramod = y.Yaramod()
    yara_src = yaramod.parse_string(yara_string)
    entry_file = yara_src.entry_file
    rule = entry_file.local_rules["malware_sample"]

    visitor = CounterVisitor()

    rule.condition.accept(visitor)

    assert(visitor.eq_expression_cnt == 1)
    assert(visitor.neq_expression_cnt == 2)
    assert(visitor.bool_literal_cnt == 2)


