/*++
    Copyright  (c) 2002 Sten
    Contact information:
        mail: stenri@mail.ru

    This program is free software; you can redistribute it and/or
    modify it under the terms of the GNU General Public License
    as published by the Free Software Foundation; either version 2
    of the License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

 
Module Name:
    bpr.cpp

Abstract:  Breapoint on Range helper library. Implements some useful functions. 

Revision History:

 Sten        05/06/2002
      Initial release


--*/

extern "C"{
#pragma warning ( push, 3 )
#include <ntddk.h>
#pragma warning ( pop )
}

#pragma warning ( disable: 4514 ) // unreferenced inline function has been removed

#define  __BPR_C__
#include "bpr.h"
#undef   __BPR_C__

#include "defs.h"
#include "softice.h"
#include "stdlib.h"

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Check 
//  Check if any break point on range matches conditions
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_Check(ULONG _CR3, ULONG _CR2, ULONG ErrCode)
{
    for (ULONG i=0; i<bpr_Count; i++)
    {
         if (BPR[i].CR3 == (_CR3 & 0xFFFFF000)) // address space matches breakpoint
         {
             if (                               // check if page fault is inside breakpoint range?
                  (_CR2 >= BPR[i].VA      ) && 
                  (_CR2 <  BPR[i].VA + BPR[i].Len)
                )
             {
                // TO DO: test for breakpoint type!
                if ((ErrCode & 0x02) == 0x02) 
                { // write attempt
                    if (BPR[i].Type == 'W')
                       return i;
                    else
                       return i | 0x80000000; // emulate memory access 
                }
                else
                { // read attempt
                    if (BPR[i].Type == 'R')
                       return i;
                    else
                       return i | 0x80000000; // emulate memory access 
                }
             }
             else
             {
                ULONG nPTEs = (BPR[i].Len + PAGE_SIZE - 1)/PAGE_SIZE;

                if (                               // check if page fault is inside PAGEs range?
                     (_CR2 >= (BPR[i].VA & 0xFFFFF000)) && 
                     (_CR2 <  ((BPR[i].VA & 0xFFFFF000) + nPTEs*PAGE_SIZE))
                   )
                {
                    // break address is not inside our BPR range but fits into PAGEs range so
                    // we must emulate memory access
                    return i | 0x80000000; // emulate memory access 
                }
             }
         }
    }
    return 0xFFFFFFFF; // no break point matches conditions
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Set 
//  Set breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_Set(ULONG bpr_Type,
             ULONG bpr_Addr,
             ULONG bpr_Len,
             ULONG bpr_CR3)

{
    if (bpr_Count >= MAX_BPR) return;

    BPR[bpr_Count].Type  = bpr_Type;
    BPR[bpr_Count].VA    = bpr_Addr;
    BPR[bpr_Count].Len   = bpr_Len; 
    if (BPR[bpr_Count].Len == 0) BPR[bpr_Count].Len = PAGE_SIZE;
    BPR[bpr_Count].CR3   = bpr_CR3  & 0xFFFFF000; // align to the page boundary

    if ((BPR[bpr_Count].VA+BPR[bpr_Count].Len) > NT_HIGHEST_USER_ADDRESS)
    {
          DbgPrint("BPR ERROR: break points on range aren't allowed in system code.\n");
          return; 
    }

    if (bpr_CheckPTEs(bpr_Count, TRUE) == FALSE) return;

    DbgPrint("BPR: %c %x %x (CR3=%08X)\n", BPR[bpr_Count].Type, 
                                           BPR[bpr_Count].VA, 
                                           BPR[bpr_Count].Len,
                                           BPR[bpr_Count].CR3);


    BPR[bpr_Count].Flags = BPRFLG_SET | BPRFLG_DEACTIVATED;

    BPR[bpr_Count].PTEs = (PULONG)ExAllocatePool(NonPagedPool,
                                                4*(BPR[bpr_Count].Len/PAGE_SIZE+1));
	
    bpr_Count++;
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_CheckPTEs 
//  Check PTEs of given break point for range for validity.
//  Print error messages if asked.
//
/////////////////////////////////////////////////////////////////////////////

BOOLEAN bpr_CheckPTEs(ULONG bpNo, BOOLEAN Info)
{
    ULONG *Pte;

    for(ULONG VA = BPR[bpNo].VA; VA < (BPR[bpNo].VA+BPR[bpNo].Len); VA += PAGE_SIZE)
    {

        Pte = (PULONG)GetPte((void*)VA);
  
        if (Pte == 0)
        {
              if (Info)
                  DbgPrint("Error: invalid PTE at virtual address VA=%08X.\n", VA);
              return FALSE; // invalid PTE?
        }

        if ((*Pte & 0x1) == 0)
        {
              if (Info)
                  DbgPrint("Error: some page in range is not physically present in memory (VA=%08X).\n", VA);
              return FALSE;
        }
  
        if ((*Pte & 0x4) == 0)
        {
              if (Info)
                  DbgPrint("Error: some page in range is supervisor page (VA=%08X).\n", VA);
              return FALSE;
        }
  
        if ((*Pte & 0x100) == 0x100)
        {
              if (Info)
                  DbgPrint("Error: some page in range is global page (VA=%08X).\n", VA);
              return FALSE;
        }
    }
   
    return TRUE;
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_SavePTEsAndMakeInvalid
//  Saves PTEs of the given breakpoint and makes them invalid
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_SavePTEsAndMakeInvalid(ULONG bpNo)
{
    PULONG Pte;
    ULONG _CR3;

    if (bpr_Count >= MAX_BPR) return 0;

//    if ((BPR[bpNo].Flags & BPRFLG_DEACTIVATED) == 0) return;

    __asm
    {
         mov         eax, cr3
         mov        _CR3, eax
    }

    if (BPR[bpNo].CR3 != _CR3) return 0;

    if (bpr_CheckPTEs(bpNo, FALSE) == FALSE) return 0;

    ULONG nPTEs = (BPR[bpNo].Len + PAGE_SIZE - 1)/PAGE_SIZE;

    for(ULONG i=0; i < nPTEs; i++)     // handle PTEs
    {                                             
        Pte = (PULONG)GetPte((void*)(BPR[bpNo].VA+i*PAGE_SIZE));
        *(BPR[bpNo].PTEs+4*i) = *Pte;             // save the PTE
        *Pte = 0;                                 // make PTE invalid
    } 

    __asm                   // flush TLBs
    {
         mov        eax,cr3
         mov        cr3,eax

    }

    return 1; 
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_RestorePTEs
//  Restores PTEs of the given breakpoint
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_RestorePTEs(ULONG bpNo)
{
    PULONG Pte;
    ULONG _CR3;

    if (bpr_Count >= MAX_BPR) return 0;

    // PTEs can only be restored when SET flag is cleared
    if ((BPR[bpNo].Flags & BPRFLG_SET) == BPRFLG_SET) return 1;

    __asm
    {
         mov         eax, cr3
         mov        _CR3, eax
    }

    if (BPR[bpNo].CR3 != _CR3) return 0;
  
    ULONG nPTEs = (BPR[bpNo].Len + PAGE_SIZE - 1)/PAGE_SIZE;
   
    for(ULONG i=0; i < nPTEs; i++) // handle PTEs
    {                                             
        Pte = (PULONG)GetPte((void*)(BPR[bpNo].VA+i*PAGE_SIZE));
        *Pte = *(BPR[bpNo].PTEs+4*i);             // restore the PTE
    } 

    __asm                   // flush TLBs
    {
         mov        eax,cr3
         mov        cr3,eax
    }

    return 1;
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Disable 
//  Disables given breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_Disable(ULONG bpNo)
{
   if (bpNo >= bpr_Count) return;

   if ((BPR[bpNo].Flags & BPRFLG_DISABLED) == 0)
   {
       if (bpr_RestorePTEs(bpNo) != 0)
           BPR[bpNo].Flags |= BPRFLG_DISABLED; // Set DISABLED flag
   }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Enable 
//  Enables given breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_Enable(ULONG bpNo)
{
   if (bpNo >= bpr_Count) return;

   if ((BPR[bpNo].Flags & BPRFLG_DISABLED) == BPRFLG_DISABLED)
   {
       if (bpr_SavePTEsAndMakeInvalid(bpNo) != 0)
             BPR[bpNo].Flags &= !BPRFLG_DISABLED; // Clear DISABLED flag
   }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Activate 
//  Activates given breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_Activate(ULONG bpNo)
{
   if (bpNo >= bpr_Count) return 0;

   if (
        ((BPR[bpNo].Flags & BPRFLG_DEACTIVATED) == BPRFLG_DEACTIVATED) ||
        ((BPR[bpNo].Flags & BPRFLG_SET) == BPRFLG_SET)
      )
   {
       if (bpr_SavePTEsAndMakeInvalid(bpNo) != 0)
       {
           // Clear DEACTIVATED and SET flag
           BPR[bpNo].Flags &= !(BPRFLG_DEACTIVATED | BPRFLG_SET); 
           return 1;
       }
   }

   return 0;
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Deactivate 
//  Deactivates given breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_Deactivate(ULONG bpNo)
{
   if (bpNo >= bpr_Count) return;

   if ((BPR[bpNo].Flags & BPRFLG_DEACTIVATED) == 0)
   {
       if (bpr_RestorePTEs(bpNo) != 0)
           BPR[bpNo].Flags |= BPRFLG_DEACTIVATED; // Set DEACTIVATED flag
   }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_ActivateAll 
//
//   Tries to activate all breakpoints on range.
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_ActivateAll()
{
     ULONG dwActivatedNum = 0;

     for (ULONG i=0; i<bpr_Count; i++)
     {
         if (bpr_Activate(i)) dwActivatedNum++;
     }

     return dwActivatedNum;
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_DeactivateAll 
//
//   Tries to deactivate all breakpoints on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_DeactivateAll()
{
     for (ULONG i=0; i<bpr_Count; i++)
     {
           bpr_Deactivate(i);
     }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_IsDisabled 
//  Check if given breakpoint is disabled.
//
/////////////////////////////////////////////////////////////////////////////

ULONG bpr_IsDisabled(ULONG bpNo)
{
   return (BPR[bpNo].Flags & BPRFLG_DISABLED);
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_Clear 
//  Clear given breakpoint on range.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_Clear(ULONG bpNo)
{
     if (bpr_Count == 0)
     {
        DbgPrint("Warning: nothing to clear.\n");
        return;
     }

     if (bpNo >= bpr_Count)
     {
        DbgPrint("ERROR: no such breakpoint (%d).\n",bpNo);
        return;
     }

     bpr_Deactivate(bpNo);

     bpr_Count--;     // Decrease BPRs count

     // Delete a structure member
     RtlMoveMemory((void*)&BPR[bpNo],(void*)&BPR[bpNo+1],(bpr_Count - bpNo)*sizeof(BPR[0]));

     BPR[bpr_Count].Flags  = 0;
     BPR[bpr_Count].CR3    = 0;
     BPR[bpr_Count].VA     = 0;
     BPR[bpr_Count].Len    = 0;
     BPR[bpr_Count].Type   = 0;

     if (BPR[bpr_Count].PTEs)
     {
         ExFreePool(BPR[bpr_Count].PTEs);
         BPR[bpr_Count].PTEs   = 0;
     }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_ClearAll
//  Clear all breakpoints for the current process.
//
/////////////////////////////////////////////////////////////////////////////

void bpr_ClearAll()
{
    ULONG _CR3;

    __asm
    {
          mov        eax, CR3
          mov       _CR3, eax
    }

    for(ULONG i=0;i<bpr_Count;i++)
    {
        if (BPR[i].CR3 == (_CR3 & 0xFFFFF000))
        {
            bpr_Clear(i);
        }
    }
}

/////////////////////////////////////////////////////////////////////////////
//
// bpr_ListAll
//  Lists all breakpoints on range
//
/////////////////////////////////////////////////////////////////////////////

void bpr_ListAll ()
{
    for(ULONG i=0;i<bpr_Count;i++)
    {
        DbgPrint("%02d)   BPR %c %x %x (CR3=%08X, Flags=%08X)\n", i,
                                                   BPR[i].Type, 
                                                   BPR[i].VA, 
                                                   BPR[i].Len,
                                                   BPR[i].CR3,
                                                   BPR[i].Flags);
    }
}