/*
 * strace.c
 * 
 * Copyright (c) 2000, BindView Corporation.
 *
 * See LICENSE file.
 *
 */

#include <tchar.h>
#include <windows.h>
#include <winnt.h>

#include <winioctl.h>
#include <stdio.h>

#include "../driver/syscalls.h"
#include "types.h"
#include "../driver/version.h"
#include "../driver/ioctlcmd.h"

#include "getopt.h"

extern BOOL LoadDeviceDriver (const TCHAR * Name,
                              const TCHAR * Path,
                              HANDLE * lphDevice);
BOOL UnloadDeviceDriver( const TCHAR * Name );
int enable_debug_priv (void);

const char *lookup_status (unsigned int status);

FILE *output;

#define MagicFoo(call,args) #call,
char *syscall_names[] = {
#include "../driver/syscallx.h"
};
#undef MagicFoo
#define NUM_SYSCALLS (sizeof (syscall_names) / sizeof (syscall_names[0]))

void
Usage ()
{
    fprintf (stderr, "Usage:\n\
    strace -p <pid>\n\
    strace <cmdline>\n");
    exit (1);
}

static void
verify_driver_local (char *pathname)
{
    char drive[4];
    int drive_type;

    strncpy (drive, pathname, 3);
    drive[3] = 0;

    drive_type = GetDriveType (drive);
    
    if (drive_type == DRIVE_REMOTE) {
        fprintf (stderr, "Strace must be on a local drive.\nPlease copy "
                 "the files locally and try again.\n");
        exit (1);
    }
}



int
load_driver (HANDLE *pdev)
{
    char *p;
    char file_name[MAX_PATH+1];
    
    if (GetModuleFileName (GetModuleHandle (NULL),
                           file_name, sizeof (file_name)-1)) {
        p = strrchr (file_name, '\\');
        p++;
        strncpy (p, "strace.sys", sizeof (file_name) - (p - file_name));
        file_name[sizeof (file_name)-1] = 0;
    }
    verify_driver_local (file_name);
    return LoadDeviceDriver ("Strace", file_name, pdev);
}

static unsigned long pid = -1;
static unsigned char *cmd;
static unsigned char *filter;

int
parse_cmdline (int argc, char *argv[])
{
    int c;
    while ((c = getopt (argc, argv, "p:e:o:")) != EOF) {
        switch (c) {
        case 'p':
            pid = atoi (optarg);
            break;

        case 'e':
            filter = optarg;
            break;

        case 'o':
            output = fopen (optarg, "w");
            if (!output) {
                fprintf (stderr, "Failed to open %s for writing.\nExiting.\n",
                         optarg);
                exit (1);
            }
            break;

        default:
            Usage ();
            break;
        }
    }

    if (optind < argc) {
        cmd = GetCommandLine ();
        /* FIXME: this can definitely be wrong. */
        cmd = strstr (cmd, argv[optind]);
    }

    if ((cmd && (pid != -1))
        || (!cmd && (pid == -1))) {
        /* can't specify both pid and prog to run */
        Usage ();
    }

    return 0;
}

int
set_filter (HANDLE dev, int filter)
{
    int dummy;
    return DeviceIoControl (dev, STRACE_setfilter,
                            &filter, 4, NULL, 0, &dummy, NULL);
}

int
set_ignore (HANDLE dev, int pid)
{
    int dummy;
    return DeviceIoControl (dev, STRACE_setignore,
                            &pid, 4, NULL, 0, &dummy, NULL);
}

struct {
    char *name;
    int  filt;
} filters[] = {
#define FOO(a) { #a, FILTER_GROUP | FILTER_##a },
FOO (ntos)
FOO (win32k)
FOO (system)
FOO (object)
FOO (memory)
FOO (section)
FOO (thread)
FOO (process)
FOO (job)
FOO (token)
FOO (synch)
FOO (time)
FOO (profile)
FOO (port)
FOO (file)
FOO (key)
FOO (security)
FOO (misc)
FOO (ntuser)
FOO (ntgdi)
#undef FOO
};
#define NUM_FILTERS (sizeof (filters) / sizeof (filters[0]))


void
apply_filter (HANDLE dev, char *filter)
{
    int i, setting;
    char *filt;

    if (*filter == '!') {
        filter++;
        setting = FILTER_OFF;
    } else {
        setting = FILTER_ON;
        /* turn everything off for starters */
        set_filter (dev, FILTER_OFF | FILTER_GROUP | FILTER_all);
    }

    for (filt = strtok (filter, ",");
         filt;
         filt = strtok (NULL, ",")) {

        for (i=0; i<NUM_FILTERS; i++) {
            if (strcmp (filt, filters[i].name) == 0) {
                set_filter (dev, setting | filters[i].filt);
                break;
            }
        }
        if (i != NUM_FILTERS)
            continue;

        /* check for a specific call */
        for (i=0; i<NUM_SYSCALLS; i++) {
            if (strcmp (filt, syscall_names[i]) == 0) {
                set_filter (dev, setting | i);
            }
            /*
             * don't break out of the loop, because some calls have
             * multiple entries due to differing arguement numbers in
             * different versions.
             */
        }

    }
}

/*
 * If running on WinXP or better, check the memoryprotection value
 */
static void
check_compatibility (void)
{
    OSVERSIONINFO ver = { sizeof (ver) };

    if (!GetVersionEx (&ver)) {
        fprintf (stderr, "Unable to determine windows version.  Exiting\n");
        exit (1);
    }
    if ((ver.dwMajorVersion >= 5)
        && (ver.dwMinorVersion >= 1)) {
        /* This is XP or better.  Check the necessary registry value is set */
        HKEY key = 0;
        DWORD err;
        DWORD type = 0;
        DWORD value;
        DWORD size = sizeof (value);

        err = RegOpenKeyEx (HKEY_LOCAL_MACHINE,
                            "SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Memory Management",
                            0, KEY_READ, &key);
        if (err != ERROR_SUCCESS) {
            fprintf (stderr, "Unabled to open HKLM\\SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Memory Management to check write protection.  Exiting.\n");
            exit (1);
        }
        err = RegQueryValueEx (key, "EnforceWriteProtection", NULL,
                               &type, (LPBYTE)&value, &size);
        if ((err != ERROR_SUCCESS)
            || (type != REG_DWORD)
            || (value != 0)) {
            fprintf (stderr, "\
For Windows XP, the registry value:\n\
HKLM\\SYSTEM\\CurrentControlSet\\Control\\Session Manager\\Memory Management\\EnforceWriteProtection\n\
must be set to 0 (type REG_DWORD) for strace to work.\n\
Please read the included documentation to determine whether you want to\n\
disable write protection, and, if so, set the value appropriately and\n\
reboot.");
            exit (1);
        }
    }
}

int
main (int argc, char *argv[])
{
    HANDLE dev = 0;
    HANDLE hThread = 0;
    HANDLE hProcess = 0;
    DWORD bytes_out;
    int i;
    char Stats[MAX_DATA];
    DWORD StatsLen;
    int done = 0;
    unsigned int last_dir=1, last_seq=0;

    /* default to writing to stdout */
    output = stdout;
    parse_cmdline (argc, argv);

    init_arg_info ();

    enable_debug_priv ();

    check_compatibility ();

    if (load_driver (&dev)) {
        /* verify the versions match */
        DWORD ver;
        if (!DeviceIoControl (dev, STRACE_getver, NULL, 0,
                              &ver, sizeof (ver), &bytes_out, NULL)
            || (ver != STRACE_VERSION)) {
            fprintf (stderr, "app/driver version mismatch.\n");
            fprintf (stderr, "app version: %x\n", STRACE_VERSION);
            fprintf (stderr, "drv version: %x\n", ver);
            fprintf (stderr, "Please be sure you have the same versions of the app and driver.\n"
                    "If you upgraded recently, verify that the new driver replaced the old one\n"
                    "and that you rebooted to load the new driver.\n");
            exit (1);
        }

        /* turn everything on by default */
        set_filter (dev, FILTER_ON | FILTER_GROUP | FILTER_all);
        set_ignore (dev, GetCurrentProcessId ());
        if (filter) {
            apply_filter (dev, filter);
        }

        if (cmd) {
            STARTUPINFO si;
            PROCESS_INFORMATION pi;

            memset (&si,0, sizeof (si));
            si.cb = sizeof (si);

            if (CreateProcess (NULL, cmd, NULL, NULL, FALSE,
                               CREATE_SUSPENDED, NULL, NULL, &si, &pi)) {
                hThread = pi.hThread;
                hProcess = pi.hProcess;
                pid = pi.dwProcessId;
            } else {
                fprintf (stderr, "CreateProcess failed: %x\n", GetLastError ());
                exit (1);
            }
        }
        
        if (!DeviceIoControl (dev, STRACE_hook,
                               &pid, 4, NULL, 0, &bytes_out, NULL))
        {
            fprintf (stderr, "ioctl - start failed 0x%x\n", GetLastError ());
            goto exit;
        }

        if (hThread) {
            ResumeThread (hThread);
            CloseHandle (hThread);
        }

        for ( ; ; ) {
            ENTRY *ent;
            DWORD bytes_used;
            do {
                StatsLen = 0;
                if (!DeviceIoControl (dev, STRACE_getdata, NULL, 0,
                                      &Stats, sizeof Stats,
                                      &StatsLen, NULL))
                {
                    fprintf (stderr, "getstats failed: 0x%x\n", GetLastError ());
                    goto exit;
                }
                bytes_used = 0;
                ent = (ENTRY *) Stats;
                while (bytes_used < StatsLen) {
                    int arg_bytes_used;

                    if (ent->direction == 0) {
                        if (last_dir == 0) {
                            fprintf (output, "\n");
                        }
                        fprintf (output, "%d%s %d %d %s (",
                                ent->seq,
                                ent->prev_mode ? "" : "*",
                                ent->pid,
                                ent->tid,
                                syscall_names[ent->call_num]);

                        arg_bytes_used = 0;
                        for (i=0; i<all_arg_info[ent->call_num]->num; i++) {
                            if ((all_arg_info[ent->call_num]->args[i].dir == DIR_IN)
                                || (all_arg_info[ent->call_num]->args[i].dir == DIR_IN_OUT)) {

                                all_types[all_arg_info[ent->call_num]->args[i].type].copy_arg (ent, (DWORD) &arg_bytes_used);
                            }
                        }

                        fprintf (output, "... ");
                        last_dir = 0;
                        last_seq = ent->seq;
                    } else {
                        if ((last_seq != ent->seq)
                            && (last_dir != 1)) {
                            fprintf (output, "\n");
                        }
                        if (last_seq != ent->seq) {
                            fprintf (output, "%d%s %d %d %s ... ",
                                    ent->seq,
                                    ent->prev_mode ? "" : "*",
                                    ent->pid,
                                    ent->tid,
                                    syscall_names[ent->call_num]);
                        }
                        /*
                         * only decode out arguments on success.  This
                         * isn't totally correct, but doing more is tough
                         * to get right.
                         */
                        if ((signed long)ent->result >= 0) {
                            arg_bytes_used = 0;
                            for (i=0; i<all_arg_info[ent->call_num]->num; i++) {
                                if ((all_arg_info[ent->call_num]->args[i].dir == DIR_OUT)
                                    || (all_arg_info[ent->call_num]->args[i].dir == DIR_IN_OUT)) {

                                    all_types[all_arg_info[ent->call_num]->args[i].type].copy_arg (ent, (DWORD) &arg_bytes_used);
                                }
                            }
                        }
                        fprintf (output, ") == %s\n", lookup_status (ent->result));
                        last_dir = 1;
                        last_seq = ent->seq;
                    }


                    bytes_used += ENT_SIZE (ent);
                    ent = (ENTRY *)(((char *)ent)+ENT_SIZE (ent));
                }
            } while (StatsLen != 0);

            fflush (output);
            if (done)
                break;
            if (hProcess) {
                if (WaitForSingleObject (hProcess, 1000) == WAIT_OBJECT_0) {
                    done = 1;
                }
            } else {
                Sleep (1000);
            }
        }
        if (!DeviceIoControl (dev, STRACE_unhook,
                              NULL, 0, NULL, 0, &bytes_out, NULL))
        {
            fprintf (stderr, "ioctl - start failed 0x%x\n", GetLastError ());
            goto exit;
        }
    } else {
        fprintf (stderr, "Failed to load device driver: %d\n",
                 GetLastError ());
    }



 exit:
    if (dev)
        CloseHandle (dev);
    if (hProcess)
        CloseHandle (hProcess);
    if (output && (output != stdout)) {
        fflush (output);
        fclose (output);
        output = 0;
    }

   /* FIXME: unloading causes page faults currently, so don't unload */
    /* UnloadDeviceDriver (_T ("Strace")); */
}


/*
 * Try to enable the debug privilege
 */
int
enable_debug_priv (void)
{
    HANDLE hToken = 0;
    DWORD dwErr = 0;
    TOKEN_PRIVILEGES newPrivs;

    if (!OpenProcessToken (GetCurrentProcess (),
                           TOKEN_ADJUST_PRIVILEGES,
                           &hToken))
    {
        dwErr = GetLastError ();
        fprintf (stderr, "Unable to open process token: %d\n", dwErr);
        goto exit;
    }

    if (!LookupPrivilegeValue (NULL, SE_DEBUG_NAME,
                               &newPrivs.Privileges[0].Luid))
    {
        dwErr = GetLastError ();
        fprintf (stderr, "Unable to lookup privilege: %d\n", dwErr);
        goto exit;
    }

    newPrivs.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED;
    newPrivs.PrivilegeCount = 1;
    
    if (!AdjustTokenPrivileges (hToken, FALSE, &newPrivs, 0, NULL, NULL))
    {
        dwErr = GetLastError ();
        fprintf (stderr, "Unable to adjust token privileges: %d\n", dwErr);
        goto exit;
    }

 exit:
    if (hToken)
        CloseHandle (hToken);

    return dwErr;
}
