/*
 * strace.c
 * 
 * Copyright (c) 2000, BindView Corporation.
 *
 * See LICENSE file.
 *
 * Hook system calls of a process (or all processes)
 */

#include "ntddk.h"
#include "stdarg.h"
#include "stdio.h"
#include "ioctlcmd.h"
#include "syscalls.h"
#include "driver.h"
#include "version.h"

#define DBG_TAS 1
#ifdef DBG_TAS
#define DbgPrint(arg) DbgPrint arg
#else
#define DbgPrint(arg) 
#endif

/*
 * Format of kernel call table
 */
struct srv_table {
    void **service_table;
    unsigned long low_call;        
    unsigned long hi_call;
    void **arg_table;
};

/*
 * The kernel call tables
 */
static struct srv_table *service_table;
static struct srv_table *shadow_table;

extern struct srv_table *KeServiceDescriptorTable;

static struct srv_table *find_shadow_table ();
__declspec(dllimport) KeAddSystemServiceTable (ULONG, ULONG, ULONG, ULONG, ULONG);

/*
 * Data buffer types and vars
 */
struct data_buf {
    unsigned long len;
    struct data_buf *next;
    struct data_buf *prev;
    char data[MAX_DATA];
};

static struct data_buf *data_head = NULL;
static struct data_buf *data_tail = NULL;
static KMUTEX DataLock;
static unsigned long num_bufs = 0;
#define MAX_BUFS 20


static PDEVICE_OBJECT device;

static int hooked = FALSE;


unsigned long traced_pid;
unsigned long ignored_pid; /* used for ignoring calls made by the app itself */

unsigned long call_count = 0;

/*
 * Should probably be a rw-lock.  This isn't really used right now.
 */
unsigned long calls_in_progress = 0;


/*
 * Data on the different versions of system call maps.
 */
struct map_info {
    struct syscall_map *map;
    unsigned int *size;
} all_maps[] = {
    { nt4sp3_map, &nt4sp3_map_size },
    { nt4sp4_map, &nt4sp4_map_size },
    { nt5_map,    &nt5_map_size },
    { xp_map,     &xp_map_size },
    { nt4termsp4_map, &nt4termsp4_map_size },
};

struct map_info all_win32k_maps[] = {
    { nt4sp3_win32k_map, &nt4sp3_win32k_map_size },
    { nt4sp45_win32k_map, &nt4sp45_win32k_map_size },
    { nt4sp6_win32k_map, &nt4sp6_win32k_map_size },
    { nt5_win32k_map, &nt5_win32k_map_size },
    { nt5_q328310_win32k_map, &nt5_q328310_win32k_map_size },
    { xp_win32k_map, &xp_win32k_map_size },
    { xpsp1_win32k_map, &xpsp1_win32k_map_size },
    { nt4termsp4_win32k_map, &nt4termsp4_win32k_map_size },
    { nt4termsp6_win32k_map, &nt4termsp6_win32k_map_size },
};

static struct syscall_map *map_in_use;
static unsigned int map_in_use_size;
static struct syscall_map *win32k_map_in_use;
static unsigned int win32k_map_in_use_size;


void
pick_map (void)
{
    unsigned int i;

    for (i=0; i< sizeof (all_maps) / sizeof (all_maps[0]); i++) {
        if (*all_maps[i].size
            == (service_table->hi_call - service_table->low_call)) {
            DbgPrint (("Using map %d", i));
            map_in_use = all_maps[i].map;
            map_in_use_size = *all_maps[i].size;
            break;
        }
    }
    for (i=0; i< sizeof (all_win32k_maps) / sizeof (all_win32k_maps[0]); i++) {
        if (*all_win32k_maps[i].size
            == (shadow_table[1].hi_call - shadow_table[1].low_call)) {
            DbgPrint (("Using win32k map %d", i));
            win32k_map_in_use = all_win32k_maps[i].map;
            win32k_map_in_use_size = *all_win32k_maps[i].size;
            break;
        }
    }
}


unsigned long
inc (unsigned long *n)
{
    return InterlockedIncrement (n);
}

static void
free_all_bufs (void)
{
    struct data_buf *next;

    while (data_tail) {
        next = data_tail->next;
        ExFreePool (data_tail);
        data_tail = next;
    }
    data_tail = data_head = NULL;
}


void
start_new_buf (void)
{
    struct data_buf *buf;

    if (num_bufs == MAX_BUFS) {
        return; 
    }

    buf = ExAllocatePool (NonPagedPool, sizeof (*buf));
    if (buf) { 
        buf->len = 0;
        buf->next = 0;
        buf->prev = data_head;
        data_head->next = buf;
        data_head = buf;
        num_bufs++;
    }
}

struct data_buf *
get_first_buf (void)
{
    struct data_buf *rc;

    if (data_tail == data_head) {
        start_new_buf ();
    }
    if (data_tail == data_head) {
        return NULL;
    }

    rc = data_tail;
    
    data_tail = rc->next;
    data_tail->prev = NULL;
    rc->next = NULL;

    num_bufs--;
    return rc;
}


void
add_entry (ENTRY *sys_call)
{
    KeWaitForMutexObject (&DataLock, Executive, KernelMode, FALSE, NULL);

    if (data_head->len + ENT_SIZE (sys_call) > MAX_DATA) {
        start_new_buf ();
    }

    /*
     * check again, because may be out of memory and start_new_buf
     * might not do anything.
     */
    if (data_head->len + ENT_SIZE (sys_call) <= MAX_DATA) {
        ENTRY *e = (ENTRY *)(data_head->data+data_head->len);

        memcpy (e, sys_call, ENT_SIZE (sys_call));
        data_head->len += ENT_SIZE (sys_call);
    }

    KeReleaseMutex (&DataLock, FALSE);
}


void
hook (void)
{
    unsigned int i;

    if (!hooked) {
        DbgPrint (("hooking everything [in]"));
#ifndef DONT_HOOK
        for (i=0; i<map_in_use_size; i++) {
            all_syscalls[map_in_use[i].fn].real = service_table->service_table[i];
            service_table->service_table[i] = map_in_use[i].hook;
        }
        for (i=0; i<win32k_map_in_use_size; i++) {
            all_syscalls[win32k_map_in_use[i].fn].real = shadow_table[1].service_table[i];
            shadow_table[1].service_table[i] = win32k_map_in_use[i].hook;
        }
#endif
        DbgPrint (("hooking everything [out]"));
        hooked = TRUE;
    }
}


void
unhook (void)
{
    unsigned int i;

    if (hooked) {
        DbgPrint (("unhooking everything [in]"));
#ifndef DONT_HOOK
        for (i=0; i< map_in_use_size; i++) {
            service_table->service_table[i] = all_syscalls[map_in_use[i].fn].real;
        }
        for (i=0; i< win32k_map_in_use_size; i++) {
            shadow_table[1].service_table[i] = all_syscalls[win32k_map_in_use[i].fn].real;
        }
#endif
        DbgPrint (("unhooking everything [out]"));
        hooked = FALSE;
    }
}


BOOLEAN
strace_ioctl (PFILE_OBJECT FileObject, BOOLEAN Wait,
              PVOID InputBuffer, ULONG InputBufferLength, 
              PVOID OutputBuffer, ULONG OutputBufferLength, 
              ULONG IoControlCode, PIO_STATUS_BLOCK IoStatus, 
              PDEVICE_OBJECT DeviceObject)
{
    struct data_buf *data;

    IoStatus->Status = STATUS_SUCCESS;
    IoStatus->Information = 0;

    switch (IoControlCode) {

    case STRACE_hook:
        DbgPrint (("Strace: hook\n"));
        if (InputBuffer && InputBufferLength == 4)
            traced_pid = *(ULONG *)InputBuffer;
        hook();
        break;

    case STRACE_unhook:
        DbgPrint (("Strace: unhook\n"));
        unhook ();
        break;

    case STRACE_getdata:
        DbgPrint (("Strace: get data\n"));
        KeWaitForMutexObject (&DataLock, Executive, KernelMode, FALSE, NULL);
        if (data_tail->len) {
            if (data_tail->len > OutputBufferLength) {
                KeReleaseMutex (&DataLock, FALSE);
                IoStatus->Status = STATUS_INVALID_PARAMETER;
                return FALSE;
            }

            data = get_first_buf ();
            KeReleaseMutex (&DataLock, FALSE);

            if (!data) {
                IoStatus->Status = STATUS_INSUFFICIENT_RESOURCES;
                return FALSE;
            }

            memcpy (OutputBuffer, data->data, data->len );
            IoStatus->Information = data->len;
            ExFreePool (data);
        } else {
            /*
             * no data
             */
            KeReleaseMutex (&DataLock, FALSE);
            DbgPrint (("Strace: get stats: no data\n"));
            IoStatus->Information = 0;
        }
        break;

#if 0
    case STRACE_getout:
        DbgPrint(("Strace: getout\n"));
        if (OutputBuffer && OutputBufferLength == 4) {
            *(ULONG *) OutputBuffer = calls_in_progress;
            IoStatus->Information = 4;
        }
        break;
#endif

    case STRACE_getver:
        DbgPrint(("Strace: getver\n"));
        if (OutputBuffer && OutputBufferLength == 4) {
            *(ULONG *) OutputBuffer = STRACE_VERSION;
            IoStatus->Information = 4;
        }
        break;

    case STRACE_setfilter:
        DbgPrint (("Strace: setfilter\n"));
        if (InputBuffer && InputBufferLength == 4) {
            set_filter (*(ULONG *)InputBuffer);
        } else {
            DbgPrint (("Strace: setfilter bogus params\n"));
        }
        break;

    case STRACE_setignore:
        DbgPrint (("Strace: setignore\n"));
        if (InputBuffer && InputBufferLength == 4) {
            ignored_pid = (*(ULONG *)InputBuffer);
        }
        break;

    default:
        DbgPrint (("strace: unknown ioctl\n"));
        IoStatus->Status = STATUS_INVALID_DEVICE_REQUEST;
        break;
    }
    return TRUE;
}

LUID dbg_priv = { SE_DEBUG_PRIVILEGE, 0 };

NTSTATUS
strace_dispatch (PDEVICE_OBJECT dev, PIRP Irp)
{
    PIO_STACK_LOCATION irp_stack;
    void *in_buf;
    void *out_buf;
    unsigned long in_buf_sz;
    unsigned long out_buf_sz;
    unsigned long ioctl;

    Irp->IoStatus.Status      = STATUS_SUCCESS;
    Irp->IoStatus.Information = 0;
    irp_stack = IoGetCurrentIrpStackLocation (Irp);

    in_buf     = Irp->AssociatedIrp.SystemBuffer;
    in_buf_sz  = irp_stack->Parameters.DeviceIoControl.InputBufferLength;
    out_buf    = Irp->AssociatedIrp.SystemBuffer;
    out_buf_sz = irp_stack->Parameters.DeviceIoControl.OutputBufferLength;
    ioctl      = irp_stack->Parameters.DeviceIoControl.IoControlCode;

    switch (irp_stack->MajorFunction) {
    case IRP_MJ_CREATE:
        DbgPrint(("Strace: IRP_MJ_CREATE\n"));
        if (!SeSinglePrivilegeCheck (dbg_priv, UserMode)) {
            Irp->IoStatus.Status = STATUS_ACCESS_DENIED;
            break;
        }
        call_count = 0;
        pick_map ();
        break;

    case IRP_MJ_CLOSE:
        DbgPrint (("Strace: IRP_MJ_CLOSE\n"));
        unhook ();
        break;

    case IRP_MJ_DEVICE_CONTROL:
        DbgPrint (("Strace: IRP_MJ_DEVICE_CONTROL\n"));

        if ((ioctl & 0x3) == METHOD_NEITHER) {
            out_buf = Irp->UserBuffer;
        }
        strace_ioctl (irp_stack->FileObject, TRUE, in_buf, in_buf_sz, 
                      out_buf, out_buf_sz, ioctl, &Irp->IoStatus, dev);
        break;
    }
    IoCompleteRequest (Irp, IO_NO_INCREMENT);
    return STATUS_SUCCESS;   
}


VOID
strace_unload (PDRIVER_OBJECT drv)
{
    UNICODE_STRING link_name;

    DbgPrint (("strace: unload [in]\n"));

    if (hooked)
        unhook ();

    while (calls_in_progress)
        /* refuse to unload.  Will BSOD otherwise */
        ;

    free_all_bufs ();
    RtlInitUnicodeString (&link_name, L"\\DosDevices\\Strace");
    IoDeleteSymbolicLink (&link_name);
    IoDeleteDevice (drv->DeviceObject);

    DbgPrint (("strace: unload [out]\n"));
}


NTSTATUS
DriverEntry (PDRIVER_OBJECT drv,
             PUNICODE_STRING unused)
{
    NTSTATUS rc;
    UNICODE_STRING name;
    UNICODE_STRING link_name;    

    DbgPrint (("strace: DriverEntry [in]\n"));

    RtlInitUnicodeString (&name, L"\\Device\\Strace");
    rc = IoCreateDevice (drv, 0, &name, FILE_DEVICE_STRACE,
                         0, TRUE, &device);
    if (NT_SUCCESS (rc)) {
        RtlInitUnicodeString (&link_name, L"\\DosDevices\\Strace");
        rc = IoCreateSymbolicLink (&link_name, &name);
        if (!NT_SUCCESS (rc)) {
            DbgPrint (("strace: IoCreateSymbolicLink failed: 0x%08x\n",
                       rc));
            if (device)
                IoDeleteDevice (device);
            return rc;
        }

        drv->MajorFunction[IRP_MJ_CREATE] = strace_dispatch;
        drv->MajorFunction[IRP_MJ_CLOSE] = strace_dispatch;
        drv->MajorFunction[IRP_MJ_DEVICE_CONTROL]  = strace_dispatch;
        drv->DriverUnload = strace_unload;
    }
    if (!NT_SUCCESS (rc)) {
        DbgPrint (("strace: Failed creating device: 0x%08x\n", rc));
        if (device)
            IoDeleteDevice (device);
        return rc;
    }

    service_table = KeServiceDescriptorTable;

    data_head = ExAllocatePool (NonPagedPool, sizeof (*data_head));
    if (!data_head) {
        DbgPrint (("strace: ExAllocatePool failed\n", rc));
        return STATUS_INSUFFICIENT_RESOURCES;
    }
    data_head->len = 0;
    data_head->next = NULL;
    data_head->prev = NULL;
    data_tail = data_head;
    num_bufs = 1;

    KeInitializeMutex (&DataLock, 0);

    calc_map_sizes ();
    init_arg_info ();

    shadow_table = find_shadow_table ();
    DbgPrint (("strace: shadow_table: 0x%x\n", shadow_table));

    DbgPrint (("strace: DriverEntry [out]\n"));
    return rc;
}


/*
 * Find the address of the KeServiceDescriptorTableShadow.
 * Uses technique from Undocumented Windows NT.
 */
static struct srv_table *
find_shadow_table ()
{
    unsigned char *check = (unsigned char *)KeAddSystemServiceTable;
    int i;
    struct srv_table *rc=0;

    for (i=0; i<100; i++) {
        __try {
            rc = *(struct srv_table**)check;
            if (!MmIsAddressValid (rc)
                || (rc == KeServiceDescriptorTable)
                || (memcmp (rc, KeServiceDescriptorTable, sizeof (*rc))
                    != 0)) {
                check++;
                rc = 0;
            }
        } __except (EXCEPTION_EXECUTE_HANDLER) {
            rc = 0;
        }
        if (rc)
            break;
    }
    return rc;
}
