00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036 #pragma warning(disable:4786)
00037
00038 #include <winsock2.h>
00039 #include <windows.h>
00040
00041 #ifdef _DEBUG
00042 #include "DebugHeap.h"
00043 #endif
00044
00045 #include "TCPSocketAsyncSSL.h"
00046
00047 #include "GenericCriticalSection.h"
00048 #include "OSManager.h"
00049
00050 #include <openssl\ssl.h>
00051 #include <openssl\err.h>
00052
00053 KOMODIA_NAMESPACE_START
00054
00055
00056 typedef struct SSLData
00057 {
00058 SSL* m_pConnection;
00059 BIO* m_pIn;
00060 BIO* m_pOut;
00061 SSL_CTX* m_pContext;
00062 } _SSLData;
00063
00064 CTCPSocketAsyncSSL::CTCPSocketAsyncSSL(bool bDisableSSL) : m_bDisableSSL(bDisableSSL)
00065 {
00066
00067 m_bVerify=false;
00068 m_bSend=false;
00069 m_pFather=NULL;
00070 m_bSession=false;
00071 m_pData=NULL;
00072 m_pCS=NULL;
00073 m_pCSWrite=NULL;
00074 m_bEvent=false;
00075 m_bEventSpawned=false;
00076 m_bDontQueue=false;
00077 }
00078
00079 CTCPSocketAsyncSSL::~CTCPSocketAsyncSSL()
00080 {
00081
00082 if (m_pData)
00083 {
00084
00085 SSL_free(m_pData->m_pConnection);
00086
00087
00088 if (!m_pFather)
00089
00090 SSL_CTX_free(m_pData->m_pContext);
00091 }
00092
00093
00094 delete m_pData;
00095
00096
00097 delete m_pCS;
00098 delete m_pCSWrite;
00099 }
00100
00101 void CTCPSocketAsyncSSL::CreateSSLData()
00102 {
00103
00104 m_pCS=COSManager::CreateCriticalSection();
00105 m_pCSWrite=COSManager::CreateCriticalSection();
00106
00107
00108 m_pData=new SSLData;
00109
00110
00111
00112 if (!m_pFather)
00113
00114 m_pData->m_pContext=SSL_CTX_new(SSLv3_method());
00115 else
00116
00117 m_pData->m_pContext=m_pFather->m_pData->m_pContext;
00118
00119
00120 m_pData->m_pConnection=SSL_new(m_pData->m_pContext);
00121
00122
00123 m_pData->m_pIn=BIO_new(BIO_s_mem());
00124 m_pData->m_pOut=BIO_new(BIO_s_mem());
00125
00126
00127 SSL_set_bio(m_pData->m_pConnection,
00128 m_pData->m_pIn,
00129 m_pData->m_pOut);
00130 }
00131
00132 void CTCPSocketAsyncSSL::InitializeSSL()
00133 {
00134
00135 SSL_load_error_strings();
00136 SSL_library_init();
00137 OpenSSL_add_all_algorithms();
00138 }
00139
00140 void CTCPSocketAsyncSSL::UninitializeSSL()
00141 {
00142
00143 ERR_free_strings();
00144 }
00145
00146 int CTCPSocketAsyncSSL::LocalSend(const char* pBuffer,
00147 unsigned long ulBufferLength,
00148 bool bDontDeque,
00149 bool bDontFlush)
00150 {
00151
00152 m_bSend=true;
00153
00154
00155 m_bCanSend=false;
00156
00157
00158 int iResult;
00159
00160 {
00161 {
00162
00163 CCriticalAutoRelease aRelease(m_pCS);
00164
00165
00166 iResult=SSL_write(m_pData->m_pConnection,
00167 pBuffer,
00168 ulBufferLength);
00169 }
00170
00171
00172 bool bDisc;
00173 bool bRead;
00174
00175
00176 if (iResult<=0 &&
00177 !HandleSSLError("Send",
00178 "Failure calling SSL_write",
00179 iResult,
00180 bDisc,
00181 bRead))
00182 {
00183
00184 ::SetLastError(2);
00185
00186
00187 if (bDisc)
00188 {
00189
00190 Close();
00191
00192
00193 if (m_bEventSpawned)
00194
00195 SocketConnected(-1,
00196 false);
00197 else
00198
00199 SocketClosed(-1,
00200 false);
00201 }
00202
00203
00204 return -1;
00205 }
00206 else if (iResult<=0 &&
00207 !bDontDeque)
00208 {
00209
00210 {
00211
00212 CCriticalAutoRelease aRelease(m_pCSWrite);
00213
00214
00215 AddDataToDeque(m_aPendingDataWrite,
00216 pBuffer,
00217 ulBufferLength);
00218 }
00219
00220 if (!bDontFlush)
00221
00222 FlushSSLWrite();
00223
00224
00225 return 0;
00226 }
00227 else if (iResult<=0 &&
00228 bDontDeque)
00229 {
00230
00231 if (!bDontFlush)
00232
00233 FlushSSLWrite();
00234
00235
00236 return -1;
00237 }
00238 }
00239
00240
00241 if (!bDontDeque ||
00242 iResult>0)
00243
00244 return FlushSSLWrite();
00245 else
00246 return 1;
00247 }
00248
00249 int CTCPSocketAsyncSSL::Send(const char* pBuffer,
00250 unsigned long ulBufferLength)
00251 {
00252
00253 if (m_bDisableSSL)
00254 return SSLBaseClass::Send(pBuffer,
00255 ulBufferLength);
00256 else
00257 return LocalSend(pBuffer,
00258 ulBufferLength,
00259 m_bDontQueue,
00260 false);
00261 }
00262
00263 bool CTCPSocketAsyncSSL::HandleSSLError(const std::string& rMethod,
00264 const std::string& rMessage,
00265 int iResult,
00266 bool& rKeepAlive,
00267 bool& rSSLWantRead)const
00268 {
00269
00270 rKeepAlive=true;
00271
00272
00273 rSSLWantRead=false;
00274
00275
00276 std::string sSSLError;
00277
00278
00279 if (iResult<=0)
00280 {
00281
00282 int iError;
00283 iError=SSL_get_error(m_pData->m_pConnection,
00284 iResult);
00285
00286
00287
00288 if (iError==SSL_ERROR_WANT_READ)
00289
00290 rSSLWantRead=true;
00291
00292
00293 if (iError==SSL_ERROR_ZERO_RETURN ||
00294 iError==SSL_ERROR_NONE ||
00295 iError==SSL_ERROR_WANT_READ)
00296 return true;
00297
00298
00299 int iCounter;
00300 iCounter=1;
00301
00302
00303 while (iError)
00304 {
00305
00306 char aTmp[256];
00307 ERR_error_string_n(iError,
00308 aTmp,
00309 sizeof(aTmp));
00310
00311
00312 char aTmp2[512];
00313 sprintf(aTmp2,"SSL error number: %i, error code: %i error string: %s",iCounter,iError,aTmp);
00314
00315
00316 if (!sSSLError.empty())
00317 sSSLError+=',';
00318
00319
00320 sSSLError+=aTmp2;
00321
00322
00323 iError=ERR_get_error();
00324 }
00325
00326
00327 ReportError(rMethod,
00328 rMessage,
00329 sSSLError);
00330
00331
00332 if (m_bEventSpawned)
00333
00334 rKeepAlive=OnSSLError(rMethod,
00335 rMessage,
00336 sSSLError);
00337 else
00338
00339 rKeepAlive=false;
00340 }
00341 else
00342 return true;
00343
00344
00345 return false;
00346 }
00347
00348 int CTCPSocketAsyncSSL::GetDataFromDeque(DataDeque& rDeque,
00349 char* pData,
00350 int iDataSize)const
00351 {
00352
00353 if (rDeque.empty() ||
00354 !iDataSize)
00355 return 0;
00356
00357
00358 int iSize;
00359 if (iDataSize>rDeque.size())
00360 iSize=rDeque.size();
00361 else
00362 iSize=iDataSize;
00363
00364
00365 for (int iCounter=0;
00366 iCounter<iSize;
00367 ++iCounter)
00368 {
00369
00370 pData[iCounter]=rDeque.back();
00371 rDeque.pop_back();
00372 }
00373
00374
00375 return iCounter;
00376 }
00377
00378 void CTCPSocketAsyncSSL::AddDataToDeque(DataDeque& rDeque,
00379 const char* pData,
00380 int iDataSize,
00381 bool bBack)const
00382 {
00383
00384 if (bBack)
00385 for (int iCounter=iDataSize;
00386 iCounter;
00387 --iCounter)
00388
00389 rDeque.push_back(pData[iCounter-1]);
00390 else
00391 for (int iCounter=0;
00392 iCounter<iDataSize;
00393 ++iCounter)
00394
00395 rDeque.push_front(pData[iCounter]);
00396 }
00397
00398 int CTCPSocketAsyncSSL::FlushData()
00399 {
00400
00401 FlushWrite();
00402 return FlushSSLWrite();
00403 }
00404
00405 int CTCPSocketAsyncSSL::FlushWrite()
00406 {
00407
00408 if (!m_aPendingDataWrite.empty())
00409 {
00410
00411 char aChunk[2048];
00412 int iReceive;
00413
00414 {
00415
00416 CCriticalAutoRelease aRelease(m_pCSWrite);
00417
00418
00419 iReceive=GetDataFromDeque(m_aPendingDataWrite,
00420 aChunk,
00421 sizeof(aChunk));
00422 }
00423
00424
00425 if (iReceive)
00426 {
00427
00428 int iResult;
00429 iResult=LocalSend(aChunk,
00430 iReceive,
00431 true,
00432 true);
00433
00434
00435 if (iResult>0)
00436
00437 return iResult;
00438 else
00439 {
00440
00441 CCriticalAutoRelease aRelease(m_pCSWrite);
00442
00443
00444 AddDataToDeque(m_aPendingDataWrite,
00445 aChunk,
00446 iReceive,
00447 true);
00448
00449
00450 return -1;
00451 }
00452 }
00453 }
00454
00455
00456 return 0;
00457 }
00458
00459 int CTCPSocketAsyncSSL::FlushSSLWrite()
00460 {
00461
00462 int iDataSent;
00463 iDataSent=0;
00464
00465
00466 while (!m_aDataToWrite.empty())
00467 {
00468
00469
00470 char aChunk[2048];
00471 int iBytesToSend;
00472 iBytesToSend=GetDataFromDeque(m_aDataToWrite,
00473 aChunk,
00474 sizeof(aChunk));
00475
00476
00477 int iResult;
00478 iResult=SSLBaseClass::Send(aChunk,
00479 iBytesToSend);
00480
00481
00482 if (iResult<=0)
00483 {
00484
00485 AddDataToDeque(m_aDataToWrite,
00486 aChunk,
00487 iBytesToSend,
00488 true);
00489
00490
00491 return iResult;
00492 }
00493 else
00494
00495 iDataSent+=iResult;
00496 }
00497
00498
00499 if (BIO_ctrl_pending(m_pData->m_pOut))
00500 {
00501
00502 CCriticalAutoRelease aRelease(m_pCS);
00503
00504
00505 int iPending;
00506 while ((iPending=BIO_ctrl_pending(m_pData->m_pOut))>0)
00507 {
00508
00509 char aChunk[1024];
00510 int iBytesToSend;
00511 iBytesToSend=BIO_read(m_pData->m_pOut,
00512 (void*)aChunk,
00513 sizeof(aChunk));
00514
00515
00516 bool bDisc;
00517 bool bRead;
00518
00519
00520 if (iBytesToSend>0)
00521
00522 AddDataToDeque(m_aDataToWrite,
00523 aChunk,
00524 iBytesToSend);
00525 else if (!BIO_should_retry(m_pData->m_pOut) &&
00526 !HandleSSLError("FlushData",
00527 "Failed to call BIO_read",
00528 iBytesToSend,
00529 bDisc,
00530 bRead))
00531 {
00532
00533 ::SetLastError(2);
00534
00535
00536 if (bDisc)
00537 {
00538
00539 aRelease.Exit();
00540
00541
00542 Close();
00543
00544
00545 if (m_bEventSpawned)
00546
00547 SocketConnected(-1,
00548 false);
00549 else
00550
00551 SocketClosed(-1,
00552 false);
00553 }
00554
00555
00556 return -1;
00557 }
00558 }
00559
00560
00561 aRelease.Exit();
00562
00563
00564 return FlushData()+iDataSent;
00565 }
00566
00567
00568 return iDataSent;
00569 }
00570
00571 int CTCPSocketAsyncSSL::LocalReceive(char* pBuffer,
00572 unsigned long ulBufferLength,
00573 bool bFlush)
00574 {
00575
00576 if (m_bDisableSSL)
00577 return SSLBaseClass::Receive(pBuffer,
00578 ulBufferLength);
00579
00580
00581 char aChunk[2048];
00582 int iReceive;
00583 iReceive=SSLBaseClass::Receive(aChunk,
00584 sizeof(aChunk));
00585
00586
00587 bool bReadAgain;
00588
00589
00590 if (iReceive<sizeof(aChunk))
00591 bReadAgain=false;
00592 else
00593 bReadAgain=true;
00594
00595
00596 if (iReceive>0)
00597 {
00598 {
00599
00600 CCriticalAutoRelease aRelease(m_pCS);
00601
00602
00603 BIO_write(m_pData->m_pIn,
00604 aChunk,
00605 iReceive);
00606 }
00607
00608
00609 m_bCanSend=true;
00610
00611
00612 if (!m_bSession &&
00613 SSL_is_init_finished(m_pData->m_pConnection))
00614 {
00615
00616 m_bSession=true;
00617
00618
00619 if (m_bVerify)
00620 {
00621
00622 m_bVerify=false;
00623
00624
00625 if (SSL_get_verify_result(m_pData->m_pConnection)!=X509_V_OK)
00626
00627 OnSSLEvent(seBadCertificate);
00628 else
00629
00630 OnSSLEvent(seHandshakeFinished);
00631 }
00632 else
00633
00634 OnSSLEvent(seHandshakeFinished);
00635 }
00636
00637 do
00638 {
00639 {
00640
00641 CCriticalAutoRelease aRelease(m_pCS);
00642
00643
00644 iReceive=SSL_read(m_pData->m_pConnection,
00645 (void*)aChunk,
00646 sizeof(aChunk));
00647 }
00648
00649
00650 bool bDisc;
00651 bool bRead;
00652
00653
00654 if (!HandleSSLError("Receive",
00655 "Failed calling SSL_read",
00656 iReceive,
00657 bDisc,
00658 bRead))
00659 {
00660
00661 ::SetLastError(2);
00662
00663
00664 if (bDisc)
00665 {
00666
00667 Close();
00668
00669
00670 if (m_bEventSpawned)
00671
00672 SocketConnected(-1,
00673 false);
00674 else
00675
00676 SocketClosed(-1,
00677 false);
00678 }
00679
00680
00681 return -1;
00682 }
00683
00684
00685 if (iReceive>0)
00686
00687 AddDataToDeque(m_aDataToRead,
00688 aChunk,
00689 iReceive);
00690 else
00691 iReceive=-2;
00692 } while (iReceive>0);
00693 }
00694
00695
00696 if (m_aDataToRead.empty())
00697
00698 if (iReceive<=0)
00699 {
00700
00701 if ((bFlush || !m_aPendingDataWrite.empty()) &&
00702 iReceive==-2)
00703
00704 FlushData();
00705 else if (!bFlush &&
00706 iReceive==-2 &&
00707 m_bEventSpawned &&
00708 !m_bSession)
00709
00710 SocketWrite(0);
00711
00712
00713 if (bReadAgain)
00714
00715 return -3;
00716 else
00717
00718 return iReceive;
00719 }
00720 else
00721 return 0;
00722 else
00723
00724 return GetDataFromDeque(m_aDataToRead,
00725 pBuffer,
00726 ulBufferLength);
00727 }
00728
00729 int CTCPSocketAsyncSSL::Receive(char* pBuffer,
00730 unsigned long ulBufferLength)
00731 {
00732 return LocalReceive(pBuffer,
00733 ulBufferLength,
00734 !m_bEventSpawned);
00735 }
00736
00737 int CTCPSocketAsyncSSL::Peek(char* pBuffer,
00738 unsigned long ulBufferLength)
00739 {
00740
00741 if (m_bDisableSSL)
00742 return SSLBaseClass::Peek(pBuffer,
00743 ulBufferLength);
00744
00745
00746 int iBytes=0;
00747
00748
00749 for (int iCount=0;
00750 iCount<ulBufferLength &&
00751 m_aDataToRead.size()>iCount;
00752 ++iCount)
00753 {
00754
00755 pBuffer[iCount]=m_aDataToRead[m_aDataToRead.size()-iCount-1];
00756
00757
00758 iBytes++;
00759 }
00760
00761
00762 return iBytes;
00763 }
00764
00765 BOOL CTCPSocketAsyncSSL::Connect(unsigned short usSourcePort,
00766 IP aDestinationAddress,
00767 unsigned short usDestinationPort,
00768 BOOL bDisableAsync,
00769 BOOL bForceErrorEvent)
00770 {
00771
00772 if (m_bDisableSSL)
00773 return SSLBaseClass::Connect(usSourcePort,
00774 aDestinationAddress,
00775 usDestinationPort,
00776 bDisableAsync,
00777 bForceErrorEvent);
00778
00779
00780 SSL_set_connect_state(m_pData->m_pConnection);
00781
00782
00783 if (SSLBaseClass::Connect(usSourcePort,
00784 aDestinationAddress,
00785 usDestinationPort,
00786 bDisableAsync,
00787 bForceErrorEvent))
00788 {
00789
00790 if (bDisableAsync ||
00791 IsBlocking())
00792
00793 StartSSLHandshake();
00794
00795
00796 return true;
00797 }
00798 else
00799 return false;
00800 }
00801
00802 BOOL CTCPSocketAsyncSSL::Connect(IP aDestinationAddress,
00803 unsigned short usDestinationPort,
00804 BOOL bDisableAsync,
00805 BOOL bForceErrorEvent)
00806 {
00807
00808 return Connect(0,
00809 aDestinationAddress,
00810 usDestinationPort,
00811 bDisableAsync,
00812 bForceErrorEvent);
00813 }
00814
00815 BOOL CTCPSocketAsyncSSL::Connect(const std::string& rDestinationAddress,
00816 unsigned short usDestinationPort,
00817 BOOL bDisableAsync,
00818 BOOL bForceErrorEvent)
00819 {
00820
00821 return Connect(0,
00822 rDestinationAddress,
00823 usDestinationPort,
00824 bDisableAsync,
00825 bForceErrorEvent);
00826 }
00827
00828 BOOL CTCPSocketAsyncSSL::Connect(unsigned short usSourcePort,
00829 const std::string& rDestinationAddress,
00830 unsigned short usDestinationPort,
00831 BOOL bDisableAsync,
00832 BOOL bForceErrorEvent)
00833 {
00834
00835 return Connect(usSourcePort,
00836 StringToLong(rDestinationAddress),
00837 usDestinationPort,
00838 bDisableAsync,
00839 bForceErrorEvent);
00840 }
00841
00842 bool CTCPSocketAsyncSSL::OnSSLError(const std::string& rMethod,
00843 const std::string& rMessage,
00844 const std::string& rSSLError)const
00845 {
00846
00847 return true;
00848 }
00849
00850 bool CTCPSocketAsyncSSL::LoadCertificateStore(const std::string& rPEMPath)
00851 {
00852 return (m_bVerify=SSL_CTX_load_verify_locations(m_pData->m_pContext,
00853 rPEMPath.c_str(),
00854 NULL));
00855 }
00856
00857 void CTCPSocketAsyncSSL::OnSSLEvent(SSLEvents)const
00858 {
00859
00860 }
00861
00862 bool CTCPSocketAsyncSSL::LoadCertificatesForServer(const std::string& rPublicKey,
00863 const std::string& rPrivateKey)
00864 {
00865
00866 SSL_CTX_use_certificate_file(m_pData->m_pContext,
00867 rPublicKey.c_str(),
00868 SSL_FILETYPE_PEM);
00869
00870
00871 SSL_CTX_use_PrivateKey_file(m_pData->m_pContext,
00872 rPrivateKey.c_str(),
00873 SSL_FILETYPE_PEM);
00874
00875
00876 return (m_bServerCert=SSL_CTX_check_private_key(m_pData->m_pContext));
00877 }
00878
00879 void CTCPSocketAsyncSSL::StartSSLHandshake()
00880 {
00881
00882 if (!m_bSend)
00883 {
00884
00885 LocalSend("Bogus",
00886 5,
00887 true,
00888 true);
00889
00890
00891 m_bSend=true;
00892
00893
00894 FlushSSLWrite();
00895 }
00896 }
00897
00898 BOOL CTCPSocketAsyncSSL::Accept(CTCPSocket* pNewSocket)
00899 {
00900
00901 if (!m_bDisableSSL)
00902
00903 ((CTCPSocketAsyncSSL*)pNewSocket)->SetFather(this);
00904
00905
00906 return SSLBaseClass::Accept(pNewSocket);
00907 }
00908
00909 void CTCPSocketAsyncSSL::SetFather(CTCPSocketAsyncSSL* pFather)
00910 {
00911
00912 m_pFather=pFather;
00913
00914
00915 CreateSSLData();
00916
00917
00918 SSL_set_accept_state(m_pData->m_pConnection);
00919 }
00920
00921 BOOL CTCPSocketAsyncSSL::Create()
00922 {
00923
00924 if (!SSLBaseClass::Create())
00925 return false;
00926 else if (!m_bDisableSSL)
00927 {
00928
00929 CreateSSLData();
00930
00931
00932 return true;
00933 }
00934 else
00935
00936 return true;
00937 }
00938
00939 BOOL CTCPSocketAsyncSSL::SocketConnected(int iErrorCode,
00940 BOOL bNoEvent)
00941 {
00942
00943 m_bEvent=bNoEvent;
00944
00945
00946 if (!m_bDisableSSL)
00947 {
00948
00949 SSLBaseClass::SocketConnected(iErrorCode,
00950 !iErrorCode);
00951
00952
00953 if (!iErrorCode)
00954
00955 StartSSLHandshake();
00956
00957
00958 return TRUE;
00959 }
00960 else
00961
00962
00963 return SSLBaseClass::SocketConnected(iErrorCode,
00964 bNoEvent);
00965 }
00966
00967 BOOL CTCPSocketAsyncSSL::SocketReceive(int iErrorCode,
00968 BOOL bNoEvent)
00969 {
00970
00971 if (m_bSession ||
00972 iErrorCode ||
00973 m_bDisableSSL ||
00974 m_bEventSpawned)
00975
00976 return SSLBaseClass::SocketReceive(iErrorCode,
00977 bNoEvent);
00978 else if (!m_bSession &&
00979 !iErrorCode)
00980 {
00981
00982
00983 int iReceive;
00984 iReceive=LocalReceive(NULL,
00985 0,
00986 false);
00987
00988
00989 if (!FlushData() &&
00990 !m_bSession &&
00991 !m_bEventSpawned &&
00992 iReceive!=-3)
00993 {
00994
00995 m_bEventSpawned=true;
00996
00997
00998 SSLBaseClass::SocketConnected(0,
00999 m_bEvent);
01000 }
01001
01002
01003 if (iReceive>=0 &&
01004 m_bSession)
01005
01006 SocketReceive(0,
01007 bNoEvent);
01008
01009
01010 return TRUE;
01011 }
01012
01013
01014 return TRUE;
01015 }
01016
01017 void CTCPSocketAsyncSSL::DisableSSL()
01018 {
01019 m_bDisableSSL=true;
01020 }
01021
01022 bool CTCPSocketAsyncSSL::IsHandshakeComplete()const
01023 {
01024 return m_bSession;
01025 }
01026
01027 void CTCPSocketAsyncSSL::DontQueue()
01028 {
01029 m_bDontQueue=true;
01030 }
01031
01032 bool CTCPSocketAsyncSSL::IsReadyForSend()const
01033 {
01034 return m_bSession ||
01035 m_bCanSend ||
01036 m_bDisableSSL;
01037 }
01038
01039 BOOL CTCPSocketAsyncSSL::SocketWrite(int iErrorCode)
01040 {
01041
01042 if (m_bEventSpawned)
01043 return SSLBaseClass::SocketWrite(iErrorCode);
01044 else
01045 return TRUE;
01046 }
01047
01048 KOMODIA_NAMESPACE_END