00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 #include "stdafx.h"
00037 #include "DNSManager.h"
00038 
00039 #include "ErrorHandlerMacros.h"
00040 #include "Array_ptr.h"
00041 #include "OSManager.h"
00042 #include "GenericCriticalSection.h"
00043 #include "GenericThread.h"
00044 #include "ThreadPool.h"
00045 
00046 #include "DNSSocket.h"
00047 #include "DNSQuery.h"
00048 
00049 #include <memory>
00050 
00051 #ifdef _MEMORY_DEBUG 
00052     #define new    DEBUG_NEW  
00053     #define malloc DEBUG_MALLOC  
00054     static char THIS_FILE[] = __FILE__;  
00055 #endif
00056 
00057 namespace KomodiaDNS
00058 {
00059 
00060 #define CDNSManager_Class "CDNSManager"
00061 
00062 #define DELETE_TIME 120000
00063 
00064 CDNSManager::CDNSManager(const std::string strDNSServer,
00065                          BOOL bTCP,
00066                          BOOL bAsync,
00067                          BOOL bAsyncConnect) : CErrorHandler(),
00068                                                m_pSocket(NULL),
00069                                                m_strDNSServer(strDNSServer),
00070                                                m_iParseThreads(0),
00071                                                m_pCSection(NULL),
00072                                                m_bInitialized(FALSE),
00073                                                m_pThread(NULL),
00074                                                m_dwTimeout(0)
00075 {
00076     try
00077     {
00078         
00079         SetName(CDNSManager_Class);
00080 
00081         
00082         m_bAsync=bAsync;
00083         m_bTCP=bTCP;
00084 
00085         
00086         m_pCSection=COSManager::CreateCriticalSection();
00087 
00088         
00089         m_pSocket=new CDNSSocket(this,
00090                                 m_strDNSServer,
00091                                 m_bTCP,
00092                                 m_bAsync,
00093                                 bAsyncConnect);
00094 
00095         
00096         if (m_bAsync)
00097         {
00098             
00099             m_pThread=new CPeriodicThread(ThreadProc);
00100             
00101             
00102             m_pThread->SetPriority(CGenericThread::tpAboveNormal);
00103 
00104             
00105             if (!m_pThread->Start(500,
00106                                   (LPVOID)this))
00107                 
00108                 ReportError("CDNSManager","Failed to start periodic thread!");
00109         }
00110     }
00111     ERROR_HANDLER("CDNSManager")
00112 }
00113 
00114 CDNSManager::~CDNSManager()
00115 {
00116     try
00117     {
00118         
00119         delete m_pSocket;
00120 
00121         
00122         delete m_pThread;
00123 
00124         
00125         delete m_pCSection;
00126     }
00127     ERROR_HANDLER("~CDNSManager")
00128 }
00129 
00130 CThreadPool* CDNSManager::GetThreadManager()const
00131 {
00132     return m_pSocket->GetThreadManager();
00133 }
00134 
00135 std::string CDNSManager::ParseAddress(const std::string& rAddress)
00136 {
00137     try
00138     {
00139         
00140         std::string sAddress;
00141 
00142         
00143         int iIndex=0;
00144 
00145         
00146         std::string::size_type iLastPosition;
00147         iLastPosition=0;
00148 
00149         
00150         std::string::size_type iFoundPos;
00151 
00152         
00153         while ((iFoundPos=rAddress.find_first_of('.',iLastPosition))!=std::string::npos)
00154         {
00155             
00156             sAddress+=(char)(iFoundPos-iLastPosition);
00157 
00158             
00159             sAddress+=rAddress.substr(iLastPosition,
00160                                       iFoundPos-iLastPosition);
00161 
00162             
00163             iLastPosition=iFoundPos+1;
00164         }
00165 
00166         
00167         if (rAddress.length()!=iLastPosition)
00168         {
00169             
00170             sAddress+=(char)(rAddress.length()-iLastPosition);
00171             sAddress+=rAddress.substr(iLastPosition,
00172                                       rAddress.length()-iLastPosition);
00173         }
00174 
00175         
00176         return sAddress;
00177     }
00178     ERROR_HANDLER_STATIC_RETURN(CDNSManager_Class,"ParseAddress",NULL)
00179 }
00180 
00181 LPDnsHeaderHeader CDNSManager::GetDNSHeaderHeader(unsigned short usID,
00182                                                   unsigned short usFlags,
00183                                                   unsigned short usQuestions,
00184                                                   unsigned short usAnswers)const
00185 {
00186     try
00187     {
00188         
00189         LPDnsHeaderHeader lpHead;
00190         lpHead=new DnsHeaderHeader;
00191 
00192         lpHead->usID=htons(usID);
00193         lpHead->usOptions=usFlags;
00194         lpHead->usQDCount=htons(usQuestions);
00195         lpHead->usANCount=htons(usAnswers);
00196 
00197         
00198         return lpHead;
00199     }
00200     ERROR_HANDLER_RETURN("GetDNSHeaderHeader",NULL)
00201 }
00202 
00203 unsigned short CDNSManager::QueryDNS(const CDNSQuery& rQuery,
00204                                      unsigned short usID,
00205                                      LPVOID lpLocalIDData)
00206 {
00207     try
00208     {
00209         
00210         if (!m_bInitialized)
00211         {
00212             
00213             ReportError("QueryDNS","Not initialized!");
00214             
00215             
00216             return 0;
00217         }
00218 
00219         
00220         unsigned short usTotalSize;
00221         usTotalSize=rQuery.GetSerializationSize();
00222 
00223         
00224         if (!usTotalSize)
00225         {
00226             
00227             ReportError("QueryDNS","Bad query size!");
00228             
00229             
00230             return 0;
00231         }
00232 
00233         
00234         usTotalSize+=DnsHeaderHeaderLength;
00235 
00236         
00237         if (m_bTCP)
00238             usTotalSize+=2;
00239 
00240         
00241         char* pBuffer;
00242         pBuffer=new char[usTotalSize];
00243 
00244         
00245         CArray_ptr<char> pProtection(pBuffer);
00246 
00247         
00248         char* pBackupBuffer=pBuffer;
00249 
00250         if (m_bTCP)
00251             
00252             pBuffer+=2;
00253 
00254         
00255         if (!usID)
00256         {
00257             
00258             static long lID=0;
00259             InterlockedIncrement(&lID);
00260             usID=lID;
00261         }
00262 
00263         
00264         LPDnsHeaderHeader lpHead;
00265         lpHead=GetDNSHeaderHeader(usID,
00266                                   DNS_RECURSION,
00267                                   rQuery.GetQuestionCount());
00268 
00269         
00270         memcpy(pBuffer,
00271                lpHead,
00272                DnsHeaderHeaderLength);
00273         pBuffer+=DnsHeaderHeaderLength;
00274 
00275         
00276         delete lpHead;
00277 
00278         
00279         if (!rQuery.SerializeQuery(pBuffer))
00280         {
00281             
00282             ReportError("QueryDNS","Failed to serialize query!");
00283             
00284             
00285             return 0;
00286         }
00287 
00288         
00289         if (m_bTCP)
00290             *((unsigned short*)pBackupBuffer)=htons(usTotalSize-2);
00291 
00292         
00293         AddRecord(usID,
00294                   lpLocalIDData,
00295                   m_dwTimeout,
00296                   rQuery);
00297 
00298         
00299         BOOL bResult;
00300         bResult=m_pSocket->Send(pBackupBuffer,
00301                                 usTotalSize,
00302                                 lpLocalIDData);
00303 
00304         
00305         if (!bResult)
00306             RemoveRecord(usID);
00307 
00308         
00309         if (bResult)
00310             return usID;
00311         else
00312             return 0;
00313     }
00314     ERROR_HANDLER_RETURN("QueryDNS",0)
00315 }
00316 
00317 unsigned short  CDNSManager::GetDNSEntry(const std::string& rAddress,
00318                                          unsigned short usID,
00319                                          LPVOID lpLocalIDData)
00320 {
00321     try
00322     {
00323         
00324         unsigned long ulAddress;
00325         ulAddress=CSpoofBase::StringToLong(rAddress);
00326 
00327         
00328         return GetDNSEntry(ulAddress,
00329                            usID,
00330                            lpLocalIDData);
00331     }
00332     ERROR_HANDLER_RETURN("GetDNSEntry",0)
00333 }
00334 
00335 unsigned short CDNSManager::GetDNSEntry(unsigned long ulAddress,
00336                                         unsigned short usID,
00337                                         LPVOID lpLocalIDData)
00338 {
00339     try
00340     {
00341         
00342         long ulNewAddress;
00343         ulNewAddress=htonl(ulAddress);
00344 
00345         
00346         std::string sAddr;
00347         sAddr=CSpoofBase::LongToStdString(ulNewAddress);
00348 
00349         
00350         sAddr+=".in-addr.arpa.";
00351 
00352         
00353         CDNSQuery aQuery;
00354         aQuery.AddQuery(sAddr.c_str(),
00355                         CDNSQuery::PTR);
00356 
00357         
00358         return QueryDNS(aQuery,
00359                         usID,
00360                         lpLocalIDData);
00361     }
00362     ERROR_HANDLER_RETURN("GetDNSEntry",0)
00363 }
00364 
00365 void CDNSManager::OnDNSReceive(CDNSAnswers* pAnswers,
00366                                LPVOID lpLocalIDData)
00367 {
00368     try
00369     {
00370         
00371         delete pAnswers;
00372     }
00373     ERROR_HANDLER("OnDNSReceive")
00374 }
00375 
00376 void CDNSManager::OnDNSTimeout(const CDNSQuery& rQuery,
00377                                LPVOID lpLocalIDData)
00378 {
00379 }
00380 
00381 void CDNSManager::OnDNSError(int iErrorCode,
00382                              LPVOID lpLocalIDData)
00383 {
00384 }
00385 
00386 void CDNSManager::ParseMultithreaded(int iThreadNumber)
00387 {
00388     try
00389     {
00390         
00391         m_iParseThreads=iThreadNumber;
00392 
00393         
00394         m_pSocket->SetMultithreaded(m_iParseThreads);
00395     }
00396     ERROR_HANDLER("ParseMultithreaded")
00397 }
00398 
00399 BOOL CDNSManager::Initialize()
00400 {
00401     try
00402     {
00403         if (m_bInitialized)
00404             return TRUE;
00405         else
00406         {
00407             
00408             m_bInitialized=m_pSocket->Initialize();
00409 
00410             
00411             return m_bInitialized;
00412         }
00413     }
00414     ERROR_HANDLER_RETURN("Initialize",FALSE)
00415 }
00416 
00417 BOOL CDNSManager::IsInitialized()const
00418 {
00419     return m_bInitialized;
00420 }
00421 
00422 void CDNSManager::Uninitialize()
00423 {
00424     try
00425     {
00426         if (!m_bInitialized)
00427             return;
00428         else
00429         {
00430             
00431             m_bInitialized=FALSE;
00432             
00433             
00434             delete m_pSocket;
00435             m_pSocket=NULL;
00436         }
00437     }
00438     ERROR_HANDLER("Uninitialize")
00439 }
00440 
00441 void CDNSManager::OnDNSReceive(CDNSAnswers* pAnswers)
00442 {
00443     try
00444     {
00445         
00446         std::auto_ptr<CDNSAnswers> pProtection(pAnswers);
00447 
00448         
00449         unsigned short usID;
00450         usID=pAnswers->GetDNSID();
00451 
00452         
00453         if (WasTimedout(usID))
00454             return;
00455 
00456         
00457         LPVOID lpData;
00458         lpData=GetRecord(usID);
00459 
00460         
00461         pProtection.release();
00462 
00463         
00464         OnDNSReceive(pAnswers,lpData);
00465 
00466         
00467         RemoveRecord(usID);
00468     }
00469     ERROR_HANDLER("OnDNSReceive")
00470 }
00471 
00472 CDNSAnswers* CDNSManager::Receive()
00473 {
00474     try
00475     {
00476         
00477         if (!m_bInitialized)
00478         {
00479             
00480             ReportError("Receive","Not initialized!");
00481             
00482             
00483             return NULL;
00484         }
00485 
00486         
00487         return m_pSocket->Receive();
00488     }
00489     ERROR_HANDLER_RETURN("Receive",NULL)
00490 }
00491 
00492 void CDNSManager::AddRecord(unsigned short usID,
00493                             LPVOID lpLocalIDData,
00494                             DWORD dwTimeout,
00495                             const CDNSQuery& rQuery)
00496 {
00497     try
00498     {
00499         
00500         DNSData aData;
00501         aData.aQuery=rQuery;
00502         aData.dwTime=GetTickCount();
00503         aData.lpData=lpLocalIDData;
00504         aData.dwTimeout=dwTimeout;
00505         aData.bTimedout=FALSE;
00506         aData.bNotified=FALSE;
00507         aData.dwTimedoutTime=0;
00508 
00509         
00510         CCriticalAutoRelease aRelease(m_pCSection);
00511 
00512         
00513         m_aData.erase(usID);
00514 
00515         
00516         m_aData.insert(IDMap::value_type(usID,aData));
00517     }
00518     ERROR_HANDLER("AddRecord")
00519 }
00520 
00521 BOOL CDNSManager::WasTimedout(unsigned short usID)
00522 {
00523     try
00524     {
00525         
00526         CCriticalAutoRelease aRelease(m_pCSection);
00527 
00528         
00529         IDMap::iterator aIterator;
00530         aIterator=m_aData.find(usID);
00531 
00532         
00533         if (aIterator!=m_aData.end())
00534             if (aIterator->second.bTimedout)
00535                 return TRUE;
00536             else
00537             {
00538                 
00539                 aIterator->second.bNotified=TRUE;
00540 
00541                 
00542                 return FALSE;
00543             }
00544         else
00545             return FALSE;
00546     }
00547     ERROR_HANDLER_RETURN("WasTimedout",TRUE)
00548 }
00549 
00550 LPVOID CDNSManager::GetRecord(unsigned short usID)const
00551 {
00552     try
00553     {
00554         
00555         CCriticalAutoRelease aRelease(m_pCSection);
00556 
00557         
00558         IDMap::const_iterator aIterator;
00559         aIterator=m_aData.find(usID);
00560 
00561         
00562         if (aIterator!=m_aData.end())
00563             return aIterator->second.lpData;
00564         else
00565             return NULL;
00566     }
00567     ERROR_HANDLER_RETURN("GetRecord",NULL)
00568 }
00569 
00570 void CDNSManager::RemoveRecord(unsigned short usID)
00571 {
00572     try
00573     {
00574         
00575         CCriticalAutoRelease aRelease(m_pCSection);
00576 
00577         
00578         m_aData.erase(usID);
00579     }
00580     ERROR_HANDLER("RemoveRecord")
00581 }
00582 
00583 void CDNSManager::SetDNSTimeout(DWORD dwMS)
00584 {
00585     try
00586     {
00587         
00588         if (!m_bInitialized)
00589         {
00590             
00591             ReportError("SetTimeout","Not initialized!");
00592             
00593             
00594             return;
00595         }
00596 
00597         
00598         m_dwTimeout=dwMS;
00599     }
00600     ERROR_HANDLER("SetDNSTimeout")
00601 }
00602 
00603 BOOL CDNSManager::SetConnectionTimeout(DWORD dwMS)
00604 {
00605     try
00606     {
00607         
00608         if (!m_bInitialized)
00609         {
00610             
00611             ReportError("SetTimeout","Not initialized!");
00612             
00613             
00614             return NULL;
00615         }
00616 
00617         
00618         return m_pSocket->SetConnectionTimeout(dwMS);
00619     }
00620     ERROR_HANDLER_RETURN("SetTimeout",FALSE)
00621 }
00622 
00623 BOOL CDNSManager::ThreadProc(CPeriodicThread::ThreadStage aStage,
00624                              LPVOID lpData)
00625 {
00626     try
00627     {
00628         
00629         if (aStage!=CPeriodicThread::tsBody)
00630             return TRUE;
00631 
00632         
00633         CDNSManager* pClass;
00634         pClass=(CDNSManager*)lpData;
00635 
00636         
00637         CCriticalAutoRelease aRelease(pClass->m_pCSection);
00638 
00639         
00640         IDMap::iterator aIterator;
00641         aIterator=pClass->m_aData.begin();
00642 
00643         
00644         typedef std::vector<DNSData> TimeoutVector;
00645         TimeoutVector aTimeoutVector;
00646 
00647         
00648         while (aIterator!=pClass->m_aData.end())
00649         {
00650             
00651             if (aIterator->second.dwTimeout && 
00652                 GetTickCount()-aIterator->second.dwTime>aIterator->second.dwTimeout &&
00653                 !aIterator->second.bNotified &&
00654                 !aIterator->second.bTimedout)
00655             {
00656                 
00657                 
00658                 aIterator->second.bTimedout=TRUE;
00659                 aIterator->second.dwTimedoutTime=GetTickCount();
00660 
00661                 
00662                 aTimeoutVector.push_back(aIterator->second);
00663 
00664                 
00665                 ++aIterator;    
00666             }
00667             else if (GetTickCount()-aIterator->second.dwTimedoutTime>DELETE_TIME &&
00668                      aIterator->second.bTimedout)
00669             {
00670                 
00671                 IDMap::iterator aBackupIterator=aIterator;
00672                 ++aIterator;
00673 
00674                 
00675                 pClass->m_aData.erase(aBackupIterator);
00676             }
00677             else
00678                 
00679                 ++aIterator;
00680         }
00681 
00682         
00683         aRelease.Exit();
00684 
00685         
00686         TimeoutVector::const_iterator aTimeoutIterator;
00687         aTimeoutIterator=aTimeoutVector.begin();
00688 
00689         
00690         while (aTimeoutIterator!=aTimeoutVector.end())
00691         {
00692             
00693             pClass->OnDNSTimeout(aTimeoutIterator->aQuery,
00694                                  aTimeoutIterator->lpData);
00695 
00696             
00697             ++aTimeoutIterator;
00698         }
00699 
00700         
00701         return TRUE;
00702     }
00703     ERROR_HANDLER_STATIC_RETURN(CDNSManager_Class,"ThreadProc",THREAD_DO_NOTHING_EXIT_VALUE)
00704 }
00705 
00706 void CDNSManager::ParseMultithreaded(CThreadPool* pThreadManager)
00707 {
00708     try
00709     {
00710         
00711         m_pSocket->SetMultithreaded(pThreadManager);
00712     }
00713     ERROR_HANDLER("ParseMultithreaded")
00714 }
00715 
00716 
00717 }