#!/usr/bin/env python

# Copyright (c) 2016 - Digital Operatives LLC
# All rights reserved
#
# Written by Evan Sultanik, Ph.D.
# evan@sultanik.com
# 
# This source code is licensed under the GNU GENERAL PUBLIC LICENSE
# Version 2.
# 
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import ctypes
import z3

def int_overflow(val, int_bits = 32):
    maxint = 2**int_bits - 1
    if not -maxint-1 <= val <= maxint:
        val = (val + (maxint + 1)) % (2 * (maxint + 1)) - maxint - 1
    return val

def rotr(value, shift, bits = 32):
    shift &= bits - 1
    if shift == 0:
        return value
    return (value >> shift) | ((value << (bits - shift)) & (2**bits - 1))

def hash_it_metasploit(string):
    h = 0
    
    # At each byte of a NULL-terminated string (including the terminating
    # NULL byte) circularly shift the hash right by 0xD (13) places, then add
    # the new byte.
    for c in string:
	h = ctypes.c_uint32(int_overflow(rotr(h, 0x0D) + ord(c))).value
        
    h = ctypes.c_uint32(rotr(h, 0x0D)).value # one last time for the NULL byte
    return h

def build_problem(solver, string_length, prev_hash = None, _char_vars = None):
    if _char_vars is None:
        _char_vars = []
    if string_length <= 0:
        h = z3.BitVec("hash", 32)
        solver.add(h == z3.LShR(prev_hash, 0x0D) | (prev_hash << (32 - 0x0D)))
        return h, _char_vars
    else:
        c = z3.BitVec("c" + str(string_length), 32)
        # Constrain the character to only upper-case ASCII:
        solver.add(c >= ord('A'))
        solver.add(c <= ord('Z'))
        _char_vars.append(c)
        if prev_hash is None:
            return build_problem(solver, string_length - 1, c, _char_vars)
        else:
            h = z3.BitVec("hash" + str(string_length), 32)
            solver.add(h == (z3.LShR(prev_hash, 0x0D) | (prev_hash << (32 - 0x0D))) + c)
            return build_problem(solver, string_length - 1, h, _char_vars)

def solve():
    i = 1
    yielded = False
    while not yielded:
        print "Trying string length %d..." % i

        solver = z3.Solver()

        h, char_vars = build_problem(solver, i)

        solver.add(h == 0)
        
        while solver.check() == z3.sat:
            m = solver.model()
            yield "".join(map(chr, map(lambda c : m[c].as_long(), char_vars)))
            yielded = True
            asmts = []
            for c in char_vars:
                asmts.append(c != m[c])
            solver.add(z3.Or(*asmts)) # prevent next model from using the same assignment as a previous model
            
        i += 1

if __name__ == "__main__":
    import sys
    for collision in solve():
        print "Collision: %s\t(hash = %d)" % (collision, hash_it_metasploit(collision))
        sys.stdout.flush()
