//======================================================================
//
// Memtrack.c
//
// By Mark Russinovich
// http://www.sysinternals.com
//
// Functions for tracking pool usage in a driver and detection 
// some kinds of buffer over- and under-runs.
//
//======================================================================
#include "ntddk.h"
#include "memtrack.h"

//
// Only turn on our allocation functions if in DEBUG build
//
#if DBG || MEMDBG

KSPIN_LOCK MemLockAlloced;

//
// Define our allocation mask
//
#define ALLOC_XOR_MASK 0x5A5A5A5A

//
// These variables track our current pool usages and can 
// be easily viewed from inside a debugger
//
ULONG MemTotalNonpagedAlloced = 0;
ULONG MemTotalPagedAlloced = 0;

//
// These variables track the high-water mark for the 
// pool that we've used. These can be examined at the end
// of operations to see if our driver is inefficiently using
// resources or if its dangerously allocating too much memory.
//
ULONG MemMaxNonpagedAlloced = 0;
ULONG MemMaxPagedAlloced = 0;


//----------------------------------------------------------------------
//
// MemAllocatePoolWithTag
//
//----------------------------------------------------------------------
PVOID MemAllocatePoolWithTag( IN POOL_TYPE PoolType, IN ULONG NumberOfBytes,
							 IN ULONG Tag )
{
   PVOID Buffer;
   KIRQL prevIrql;

   //
   // Overallocate so that the following can be transparently added to the buffer:
   //  Pool type:      1 byte before the buffer
   //  Pre-signature:  1 ulong before the buffer
   //  Post-signature: 1 ulong after the buffer -> this is set to match pre-sig
   //
   Buffer = ExAllocatePoolWithTag( PoolType, NumberOfBytes + 2*sizeof(ULONG) + sizeof(CHAR), 
								   Tag );

   KeAcquireSpinLock( &MemLockAlloced, &prevIrql );

   //
   // Track pool usage
   //
   if( PoolType == NonPagedPool ) {

       MemTotalNonpagedAlloced += NumberOfBytes;
       if( MemTotalNonpagedAlloced > MemMaxNonpagedAlloced ) 
           MemMaxNonpagedAlloced = MemTotalNonpagedAlloced;

   } else {

       MemTotalPagedAlloced += NumberOfBytes;
       if( MemTotalPagedAlloced > MemMaxPagedAlloced ) 
           MemMaxPagedAlloced = MemTotalPagedAlloced;
   }
   KeReleaseSpinLock( &MemLockAlloced, prevIrql );

   //
   // If success, put a signature before and after the buffer
   //
   if( Buffer ) {
	  
      *(PULONG) Buffer = (CHAR) PoolType;
	  *(PULONG) ((PCHAR) Buffer + sizeof(CHAR)) = NumberOfBytes ^ ALLOC_XOR_MASK;
	  *(PULONG) ((PCHAR) Buffer + sizeof(ULONG) + sizeof(CHAR) + NumberOfBytes) = 
		 NumberOfBytes ^ ALLOC_XOR_MASK;
	  return (PVOID) ((PCHAR) Buffer + sizeof(ULONG ) + sizeof(CHAR));

   } else {
       
       //      
       // Its generally not a good thing to run out of pool
       //      
       KdPrint(("************** OUT OF %s POOL!!\n",
                PoolType == NonPagedPool ? "NONPAGED" : "PAGED" ));
       DbgBreakPoint();
       return NULL;
   }
}


//----------------------------------------------------------------------
//
//   MemAllocatePool
//
//----------------------------------------------------------------------
PVOID MemAllocatePool( IN POOL_TYPE PoolType, IN ULONG NumberOfBytes )
{
   PVOID Buffer;
   KIRQL prevIrql;

   //
   // Overallocate so that the following can be transparently added to the buffer:
   //  Pool type:      1 byte before the buffer
   //  Pre-signature:  1 ulong before the buffer
   //  Post-signature: 1 ulong after the buffer -> this is set to match pre-sig
   //
   Buffer = ExAllocatePool( PoolType, NumberOfBytes + 2*sizeof(ULONG) + sizeof(CHAR) );

   KeAcquireSpinLock( &MemLockAlloced, &prevIrql );

   if( PoolType == NonPagedPool ) {

       MemTotalNonpagedAlloced += NumberOfBytes;
       if( MemTotalNonpagedAlloced > MemMaxNonpagedAlloced ) 
           MemMaxNonpagedAlloced = MemTotalNonpagedAlloced;

   } else {

       MemTotalPagedAlloced += NumberOfBytes;
       if( MemTotalPagedAlloced > MemMaxPagedAlloced ) 
           MemMaxPagedAlloced = MemTotalPagedAlloced;
   }
   KeReleaseSpinLock( &MemLockAlloced, prevIrql );

   //
   // If success, put the size before and after the buffer
   //
   if( Buffer ) {
	  
      *(PULONG) Buffer = (CHAR) PoolType;
	  *(PULONG) ((PCHAR) Buffer + sizeof(CHAR)) = NumberOfBytes ^ ALLOC_XOR_MASK;
	  *(PULONG) ((PCHAR) Buffer + sizeof(ULONG) + sizeof(CHAR) + NumberOfBytes) = 
		 NumberOfBytes ^ ALLOC_XOR_MASK;
	  return (PVOID) ((PCHAR) Buffer + sizeof(ULONG ) + sizeof(CHAR));

   } else {

       KdPrint(("************** OUT OF %s POOL!!\n",
                PoolType == NonPagedPool ? "NONPAGED" : "PAGED" ));
      DbgBreakPoint();
	  return NULL;
   }
}


//----------------------------------------------------------------------
//
// MemFreePool
//
//----------------------------------------------------------------------
VOID MemFreePool( PVOID Buffer )
{
   ULONG numberOfBytesBefore, numberOfBytesAfter;   
   KIRQL prevIrql;
   POOL_TYPE poolType;

   //
   // Flag a problem if we're freeing a NULL buffer pointer
   //
   if( !Buffer ) {

	  KdPrint(("***************FREEING NULL POINTER\n"));
      DbgBreakPoint();
   }

   //
   // Make sure the signature is still there
   //
   numberOfBytesBefore = (*(PULONG) ((PCHAR) Buffer - sizeof(ULONG))) ^
	  ALLOC_XOR_MASK;
   numberOfBytesAfter  = (*(PULONG) ((PCHAR) Buffer + numberOfBytesBefore)) ^  
	  ALLOC_XOR_MASK;

   if( numberOfBytesBefore != numberOfBytesAfter ) {

      //
      // Someone corrupted the data past the end or before the beginning
      // of the buffer. This might be because the buffer is being freed
      // again, which we can conveniently detect.
      //
	  if( *(PULONG) ((PCHAR) Buffer + numberOfBytesBefore) == 0xDEAD ) {

		 KdPrint(("*******************BUFFER FREED TWICE: %x\n", Buffer ));

	  } else {

		 KdPrint(("*******************BUFFER CORRUPTION: %x\n", Buffer ));
	  }
	  DbgBreakPoint();
   }

   //
   // Reach back to get the pool type so that we can adjust the 
   // appropriate pool
   //
   KeAcquireSpinLock( &MemLockAlloced, &prevIrql );

   poolType = *(POOL_TYPE *)((PCHAR) Buffer-sizeof(ULONG)-sizeof(CHAR)) & 0xFF;

   if( poolType == NonPagedPool ) {

       MemTotalNonpagedAlloced -= numberOfBytesAfter;

   } else if( poolType == PagedPool ) {

       MemTotalPagedAlloced -= numberOfBytesAfter;

   } else {

       KdPrint(("*******************BUFFER CORRUPTION: %x\n", Buffer ));
       DbgBreakPoint();
   }       
   KeReleaseSpinLock( &MemLockAlloced, prevIrql );

   //
   // Zap bytes after so that we trap if someone tries to free
   // this same buffer
   //
   *(PULONG) ((PCHAR) Buffer + numberOfBytesBefore) = 0xDEAD; 
   
   ExFreePool( (PVOID) ((PCHAR)Buffer-sizeof(ULONG)-sizeof(CHAR)) );
   return;
}


//----------------------------------------------------------------------
//
// MemAllocateFromNPagedLookasideList
//
//----------------------------------------------------------------------
PVOID MemAllocateFromNPagedLookasideList( PNPAGED_LOOKASIDE_LIST  Lookaside )
{
   //
   // Always call the allocation function
   //
   return Lookaside->L.Allocate( Lookaside->L.Type, Lookaside->L.Size,
								   Lookaside->L.Tag );
}

//----------------------------------------------------------------------
//
// MemFreeToNPagedLookasideList
//
//----------------------------------------------------------------------
VOID MemFreeToNPagedLookasideList( PNPAGED_LOOKASIDE_LIST  Lookaside, PVOID Buffer )
{
   //
   // Always call the free function
   //
   Lookaside->L.Free( Buffer );
   return;
}

#endif // DBG || MEMDBG

//======================================================================
//======================================================================
// Functions below these lines are *always* enabled - debugging or not



//----------------------------------------------------------------------
//
//   MemInitializeNPagedLookasideList
//
// Description:
//   We have our own initialization function so that our lists
//   are not linked to the system-wide lookaside lists, which are
//   auto-tuned. We like to know what we're getting so we
//   keep our lookasides off the list.
//
//   NOTE: there is no need to implement private versions of
//   AllocateFromLookaside and FreeToLookaside since these
//   functions do no tuning and are actually in-line functions
//   defined in ntddk.h.
//
//----------------------------------------------------------------------
VOID
MemInitializeNPagedLookasideList( IN PNPAGED_LOOKASIDE_LIST Lookaside,
                                  IN PALLOCATE_FUNCTION Allocate,
                                  IN PFREE_FUNCTION Free,
                                  IN ULONG Flags,
                                  IN ULONG Size,
                                  IN ULONG Tag,
                                  IN USHORT Depth )
{
   RtlZeroMemory( Lookaside, sizeof(NPAGED_LOOKASIDE_LIST));
   Lookaside->L.Depth = Depth;
   Lookaside->L.Type  = NonPagedPool;
   Lookaside->L.Tag   = Tag;
   Lookaside->L.Size  = Size;

   if( Allocate ) 
      Lookaside->L.Allocate  = Allocate;
   else
#if DBG || MEMDBG
      Lookaside->L.Allocate =  MemAllocatePoolWithTag;
#else
      Lookaside->L.Allocate =  ExAllocatePoolWithTag;   
#endif

   if( Free ) 
      Lookaside->L.Free      = Free;
   else
#if DBG || MEMDBG
      Lookaside->L.Free      = MemFreePool;
#else
      Lookaside->L.Free      = ExFreePool;
#endif

   ExInitializeSListHead( &Lookaside->L.ListHead );
   KeInitializeSpinLock( &Lookaside->Lock );
}


//----------------------------------------------------------------------
//
// MemDeleteNPagedLookasideList
//
// See comments in MemInitializeNPagedLookasideList.
// MemDummyAllocate is part of this procedure.
//
//----------------------------------------------------------------------
PVOID MemDummyAllocate( POOL_TYPE PoolType, ULONG NumberOfBytes, ULONG Tag )
{
   return NULL;
}

VOID
MemDeleteNPagedLookasideList( PNPAGED_LOOKASIDE_LIST Lookaside )
{
   PVOID Entry;
   
   Lookaside->L.Allocate = MemDummyAllocate;
   while( (Entry = ExAllocateFromNPagedLookasideList( Lookaside )) != NULL ) {
      (Lookaside->L.Free)(Entry);
   }
}


//----------------------------------------------------------------------
//
//   MemInitializePagedLookasideList
//
// Description:
//   We have our own initialization function so that our lists
//   are not linked to the system-wide lookaside lists, which are
//   auto-tuned. We like to know what we're getting so we
//   keep our lookasides off the list.
//
//   NOTE: there is no need to implement private versions of
//   AllocateFromLookaside and FreeToLookaside since these
//   functions do no tuning and are actually in-line functions
//   defined in ntddk.h.
//
//----------------------------------------------------------------------
VOID
MemInitializePagedLookasideList( IN PNPAGED_LOOKASIDE_LIST Lookaside,
                                 IN PALLOCATE_FUNCTION Allocate,
                                 IN PFREE_FUNCTION Free,
                                 IN ULONG Flags,
                                 IN ULONG Size,
                                 IN ULONG Tag,
                                 IN USHORT Depth )
{
   RtlZeroMemory( Lookaside, sizeof(NPAGED_LOOKASIDE_LIST));
   Lookaside->L.Depth = Depth;
   Lookaside->L.Type  = PagedPool;
   Lookaside->L.Tag   = Tag;
   Lookaside->L.Size  = Size;

   if( Allocate ) 
      Lookaside->L.Allocate  = Allocate;
   else
#if DBG || MEMDBG
      Lookaside->L.Allocate =  MemAllocatePoolWithTag;
#else
      Lookaside->L.Allocate =  ExAllocatePoolWithTag;   
#endif

   if( Free ) 
      Lookaside->L.Free      = Free;
   else
#if DBG || MEMDBG
      Lookaside->L.Free      = MemFreePool;
#else
      Lookaside->L.Free      = ExFreePool;
#endif

   ExInitializeSListHead( &Lookaside->L.ListHead );
   KeInitializeSpinLock( &Lookaside->Lock );
}


//----------------------------------------------------------------------
//
// MemDeletePagedLookasideList
//
// See comments in MemInitializePagedLookasideList.
// MemDummyAllocate is part of this procedure.
//
//----------------------------------------------------------------------
VOID
MemDeletePagedLookasideList( PNPAGED_LOOKASIDE_LIST Lookaside )
{
   PVOID Entry;
   
   Lookaside->L.Allocate = MemDummyAllocate;
   while( (Entry = ExAllocateFromNPagedLookasideList( Lookaside )) != NULL ) {
      (Lookaside->L.Free)(Entry);
   }
}

//----------------------------------------------------------------------
//
// MemTrackPrintStats
//
//----------------------------------------------------------------------
VOID MemTrackPrintStats( PCHAR Text )
{
    KdPrint(("%s Nonpaged Memory Usage:\n", Text));
    KdPrint(("   Current: %8d bytes\n", MemTotalNonpagedAlloced ));
    KdPrint(("   Max:     %8d bytes\n", MemMaxNonpagedAlloced ));
    KdPrint(("%s Paged Memory Usage:\n", Text));
    KdPrint(("   Current: %8d bytes\n", MemTotalPagedAlloced ));
    KdPrint(("   Max:     %8d bytes\n", MemMaxPagedAlloced ));
}

//----------------------------------------------------------------------
//
// MemTrackInit
//
// Just initialize our spin lock
//
//----------------------------------------------------------------------
VOID MemTrackInit()
{
#if DBG || MEMDBG
    KeInitializeSpinLock( &MemLockAlloced );
#endif
}
