//////////////////////////////////////////////////////////////
//                                                          //
//  DLLInjection.cpp                                        //
//  Version 1.0                                             //
//  Greg Jenkins, November 2007                             //
//  Ring3 Circus (www.ring3circus.com)                      //
//  Creative Commons Attribution 3.0 Unported License       //
//                                                          //
//  User-mode DLL-based code injection class implementation //
//  for 32-bit Windows (XP and above)                       //
//                                                          //
//////////////////////////////////////////////////////////////

#include "DLLInjection.h"

struct DLLHook {
	void* remote_address;
	unsigned char original_opcodes[5];
};

void* DLLInjection::GetRemoteProcAddress(const char* proc_name) {
	// Assume injected DLL
	void* local_address = GetProcAddress(local_dll_handle, proc_name);
	if (local_address == NULL) return NULL;
	void* remote_address = reinterpret_cast<void*> (reinterpret_cast<DWORD> (local_address) - reinterpret_cast<DWORD> (local_dll_handle) + reinterpret_cast<DWORD> (remote_base));
	return remote_address;
}

void* DLLInjection::GetRemoteProcAddress(const char* proc_name, const char* module_path) {
	// Arbitrary existing remote DLL
	HMODULE local_module = LoadLibraryA(module_path);
	if (local_module == 0) return 0;
	loaded_dlls.push_back(local_module);

	HMODULE remote_module_base = GetRemoteModuleHandle(local_module);

	void* local_address = GetProcAddress(local_module, proc_name);
	if (local_address == NULL) return NULL;
	void* remote_address = reinterpret_cast<void*> (reinterpret_cast<DWORD> (local_address) - reinterpret_cast<DWORD> (local_module) + reinterpret_cast<DWORD> (remote_module_base));
	return remote_address;
}

bool DLLInjection::CallThreadProc(const char* thread_proc_name, void* parameter, DWORD timeout_ms, DWORD &exit_code) {
	LPTHREAD_START_ROUTINE start_address = reinterpret_cast<LPTHREAD_START_ROUTINE> (GetRemoteProcAddress(thread_proc_name));

	DWORD thread_id = 0;
	HANDLE thread = CreateRemoteThread(process_handle, NULL, 0, start_address, parameter, 0, &thread_id);
	if (thread == NULL || thread_id == 0) return false;
	
	bool success = true;
	if (WaitForSingleObject(thread, timeout_ms) != 0) success = false;
	if (GetExitCodeThread(thread, &exit_code) == 0) success = false;
	CloseHandle(thread);

	return success;
}

DWORD DLLInjection::GetProcessIDFromWindow(const char* class_name, const char* window_name) {
	HWND window = FindWindowA(class_name, window_name);
	if (window == NULL) return 0;
	DWORD process_id = 0;
	GetWindowThreadProcessId(window, &process_id);
	return process_id;
}

DLLInjection::DLLInjection(const char* dll_path) :
	local_dll_handle(NULL),
	dll_path(dll_path),
	remote_base(NULL),
	process_handle(NULL)
{
	// We don't want 0 to be a valid index into the hook vector, so here's a dummy entry
	DLLHook empty_hook;
	empty_hook.remote_address = NULL;
	hooks.push_back(empty_hook);

	local_dll_handle = LoadLibraryA(dll_path);
	if (local_dll_handle == NULL) throw E_FAIL;
}

DLLInjection::~DLLInjection() {
	RemoveDLL();

	CloseHandle(process_handle);
	process_handle = NULL;

	if (local_dll_handle) FreeLibrary(local_dll_handle);
	local_dll_handle = NULL;
}

HMODULE DLLInjection::GetRemoteModuleHandle(HMODULE local_handle) {
	// Excuse the C-style code, here, but std::string's case-insensitive comparison isn't worth the effort

	char local_module_name[MAX_PATH];
	GetModuleFileNameA(local_handle, local_module_name, sizeof(local_module_name) / sizeof(TCHAR));

	const int MAX_MODULES = 1024;
	HMODULE modules[MAX_MODULES];
    DWORD bytes_needed = 0;

	if (!EnumProcessModules(process_handle, modules, sizeof(modules), &bytes_needed)) return NULL;
	int num_modules = (bytes_needed / sizeof(HMODULE));
	for (int i = 0; i < num_modules; ++i) {
		char module_name[MAX_PATH];
		if (GetModuleFileNameExA(process_handle, modules[i], module_name, sizeof(module_name) / sizeof(TCHAR))) {
			if (_stricmp(local_module_name, module_name) == 0) {
				// Match found
				return modules[i];
			}
		}
	}
	return NULL;
}

HMODULE DLLInjection::InjectDLL(DWORD process_id) {
	// Open Process
	process_handle = OpenProcess(PROCESS_ALL_ACCESS, false, process_id);
	if (process_handle == 0) return NULL;

	// Allocate space for string to contain the DLL Path
	SIZE_T path_length = dll_path.size() + 1;
	void* remote_buffer = VirtualAllocEx(process_handle, NULL, path_length * sizeof(char), MEM_COMMIT, PAGE_READWRITE);

	bool success = false;
	DWORD exit_code = WAIT_FAILED;
	if (remote_buffer != NULL) {
		SIZE_T bytes_written = 0;
		WriteProcessMemory(process_handle, remote_buffer, dll_path.c_str(), path_length, &bytes_written);
		if (bytes_written == path_length) {
			DWORD thread_id = 0;
			HMODULE kernel32 = GetModuleHandleA("Kernel32");
			LPTHREAD_START_ROUTINE remote_lla = reinterpret_cast<LPTHREAD_START_ROUTINE> (GetProcAddress(kernel32, "LoadLibraryA"));
			if (remote_lla != NULL) {
				HANDLE thread = CreateRemoteThread(process_handle, NULL, 0, remote_lla, reinterpret_cast<void*> (remote_buffer), 0, &thread_id);
				if (thread != NULL && thread_id != 0) {
					WaitForSingleObject(thread, 5000);
					GetExitCodeThread(thread, &exit_code);
					CloseHandle(thread);
				}
			}
		}
		VirtualFreeEx(process_handle, remote_buffer, dll_path.size() + 1, MEM_RELEASE);
	}

	if (exit_code == WAIT_FAILED || exit_code == WAIT_ABANDONED || exit_code == WAIT_TIMEOUT) {
		// Thread didn't complete
		success = false;
	} else if (exit_code < 0x1000) {
		// LoadLibraryFailed
		success = false;
	} else {
		success = true;
		remote_base = reinterpret_cast<void*> (exit_code);
	}

	if (success) return reinterpret_cast<HMODULE> (remote_base);
	return NULL; // Fail
}

HDLLHOOK DLLInjection::InstallHookByOffset(void* remote_address, DWORD offset) {
	HDLLHOOK return_value = 0;

	// Set access
	DWORD new_protect = PAGE_EXECUTE_READWRITE;
	DWORD old_protect;
	if (VirtualProtectEx(process_handle, remote_address, 5, new_protect, &old_protect) == FALSE) return 0;

	// Store original opcodes
	DLLHook hook;
	hook.remote_address = remote_address;

	DWORD num_bytes_read = 0;
	bool success = (ReadProcessMemory(process_handle, remote_address, hook.original_opcodes, 5, &num_bytes_read) != FALSE);
	if (!success || num_bytes_read != 5) {
		// Failed
		return_value = 0;
	} else {
		// Install the patch
		unsigned char patch[5];

		patch[0] = 0xE9; // JMP Opcode
		DWORD* offset_ptr = reinterpret_cast<DWORD*> (patch + 1);
		*offset_ptr = offset; // Relative jump

		// Write patch
		bool success = false;
		DWORD num_bytes_written = 0;
		success = (WriteProcessMemory(process_handle, remote_address, patch, 5, &num_bytes_written) != FALSE);
		if (!success || num_bytes_written != 5) {
			// Failed
		} else {
			// Save hook in database
			hooks.push_back(hook);
			return_value = hooks.size() - 1;
		}
	}

	// Reset access
	VirtualProtectEx(process_handle, remote_address, 5, old_protect, &new_protect);

	return return_value;
}

HDLLHOOK DLLInjection::InstallDLLHook(const char* existing_module_path, const char* existing_function_name, const char* hook_function_name) {
	// The nomenclature gets pretty hairy around here, as we have local and remote addresses
	// for both the existing and hook functions & modules
	
	// Load reference DLL
	HMODULE local_existing_module = LoadLibraryA(existing_module_path);
	if (local_existing_module == 0) return 0;
	loaded_dlls.push_back(local_existing_module);

	// Get remote module base
	HMODULE remote_dll_base = GetRemoteModuleHandle(local_existing_module);
	if (remote_dll_base == NULL) return 0;

	void* remote_function_address = GetRemoteProcAddress(existing_function_name, existing_module_path);
	if (remote_function_address == NULL) return 0;
	void* remote_hook_address = GetRemoteProcAddress(hook_function_name);
	if (remote_hook_address == NULL) return 0;

	// Calculate offset
	DWORD hook_int = reinterpret_cast<DWORD> (remote_hook_address);
	DWORD ex_int = reinterpret_cast<DWORD> (remote_function_address);
	DWORD offset = hook_int - ex_int - 5;

	return InstallHookByOffset(remote_function_address, offset);
}

HDLLHOOK DLLInjection::InstallCodeHook(void* existing_function_address, const char* hook_function_name) {
	void* hook_address = reinterpret_cast <void*> (GetProcAddress(local_dll_handle, hook_function_name));
	if (hook_address == NULL) return 0;

	// Calculate offset
	DWORD ex_int = reinterpret_cast<DWORD> (existing_function_address);
	DWORD hook_int = reinterpret_cast<DWORD> (hook_address);
	DWORD offset = hook_int - ex_int - 5;

	return InstallHookByOffset(existing_function_address, offset);
}

bool DLLInjection::RemoveAllHooks() {
	bool success = true;
	for (size_t i = 0; i < hooks.size(); ++i) {
		success &= RemoveHook(i);
	}
	hooks.clear();

	// Reinsert dummy hook
	DLLHook empty_hook;
	empty_hook.remote_address = NULL;
	hooks.push_back(empty_hook);

	return success;
}
bool DLLInjection::RemoveHook(HDLLHOOK handle) {
	void* remote_address = hooks[handle].remote_address;
	unsigned char* original_opcodes = hooks[handle].original_opcodes;

	if (remote_address == NULL) return false;

	// Set access
	DWORD new_protect = PAGE_EXECUTE_READWRITE;
	DWORD old_protect;
	if (VirtualProtectEx(process_handle, remote_address, 5, new_protect, &old_protect) == FALSE) return false;

	// Restore opcodes
	bool success = false;
	DWORD num_bytes_written = 0;
	success = (WriteProcessMemory(process_handle, reinterpret_cast<void*> (remote_address), original_opcodes, 5, &num_bytes_written) != FALSE);
	if (success != TRUE || num_bytes_written != 5) {
		// Failed
		success = false;
	} else {
		// Remove hook from database
		hooks[handle].remote_address = NULL;
		success = true;
	}

	// Reset access
	VirtualProtectEx(process_handle, remote_address, 5, old_protect, &new_protect);
	return success;
}

void DLLInjection::RemoveDLL() {
	if (process_handle == NULL) return;
	RemoveAllHooks();

	DWORD thread_id = 0;
	HMODULE kernel32 = GetModuleHandleA("Kernel32");
	LPTHREAD_START_ROUTINE remote_lla = reinterpret_cast<LPTHREAD_START_ROUTINE> (GetProcAddress(kernel32, "FreeLibrary"));
	if (remote_lla != NULL) {
		HANDLE thread = CreateRemoteThread(process_handle, NULL, 0, remote_lla, reinterpret_cast<void*> (remote_base), 0, &thread_id);
		CloseHandle(thread);
	}
}

void DLLInjection::UnloadAllTemporaryDLLs() {
	for (std::vector<HMODULE>::iterator it = loaded_dlls.begin(); it != loaded_dlls.end(); ++it) {
		FreeLibrary(*it);
	}
	loaded_dlls.clear();
}