00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #include "instlsp.h"
00018
00019
00020
00021
00022
00023
00024
00025
00026 BOOL
00027 RemoveIdFromChain(
00028 WSAPROTOCOL_INFOW *pInfo,
00029 DWORD dwCatalogId
00030 )
00031 {
00032 int i,
00033 j;
00034
00035 for(i=0; i < pInfo->ProtocolChain.ChainLen ;i++)
00036 {
00037 if ( pInfo->ProtocolChain.ChainEntries[ i ] == dwCatalogId )
00038 {
00039 for(j=i; j < pInfo->ProtocolChain.ChainLen-1 ; j++)
00040 {
00041 pInfo->ProtocolChain.ChainEntries[ j ] =
00042 pInfo->ProtocolChain.ChainEntries[ j+1 ];
00043 }
00044 pInfo->ProtocolChain.ChainLen--;
00045 return TRUE;
00046 }
00047 }
00048 return FALSE;
00049 }
00050
00051
00052
00053
00054
00055
00056
00057
00058 BOOL
00059 IsIdInChain(
00060 WSAPROTOCOL_INFOW *pInfo,
00061 DWORD dwId)
00062 {
00063 int i;
00064
00065 for(i=0; i < pInfo->ProtocolChain.ChainLen ;i++)
00066 {
00067 if ( pInfo->ProtocolChain.ChainEntries[ i ] == dwId )
00068 return TRUE;
00069 }
00070 return FALSE;
00071 }
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082 int
00083 GetProviderCount(
00084 WSAPROTOCOL_INFOW *pProviders,
00085 int iProviderCount,
00086 int iProviderType
00087 )
00088 {
00089 int Count, i;
00090
00091 Count = 0;
00092 for(i=0; i < iProviderCount ;i++)
00093 {
00094 if ( ( LAYERED_CHAIN == iProviderType ) && ( pProviders[ i ].ProtocolChain.ChainLen > 1 ) )
00095 Count++;
00096 else if ( ( LAYERED_CHAIN != iProviderType) && ( pProviders[ i ].ProtocolChain.ChainLen == iProviderType ) )
00097 Count++;
00098 }
00099 return Count;
00100 }
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117 int
00118 GetLayeredEntriesByGuid(
00119 WSAPROTOCOL_INFOW *pMatchLayers,
00120 int *iLayeredCount,
00121 WSAPROTOCOL_INFOW *pEntries,
00122 int iEntryCount,
00123 GUID *MatchGuid
00124 )
00125 {
00126 int count,
00127 err = SOCKET_ERROR,
00128 i;
00129
00130
00131 count = 0;
00132 for(i=0; i < iEntryCount ;i++)
00133 {
00134 if ( 0 == memcmp( MatchGuid, &pEntries[i].ProviderId, sizeof( GUID ) ) )
00135 count++;
00136 }
00137
00138
00139 if ( count > *iLayeredCount )
00140 {
00141 *iLayeredCount = count;
00142 goto cleanup;
00143 }
00144
00145
00146 count = 0;
00147 for(i=0; i < iEntryCount ;i++)
00148 {
00149 if ( 0 == memcmp( MatchGuid, &pEntries[ i ].ProviderId, sizeof( GUID ) ) )
00150 {
00151 memcpy( &pMatchLayers[ count++ ], &pEntries[ i ], sizeof( WSAPROTOCOL_INFOW ) );
00152 }
00153 }
00154
00155 *iLayeredCount = count;
00156
00157 err = NO_ERROR;
00158
00159 cleanup:
00160
00161 return err;
00162 }
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176 BOOL
00177 IsEqualProtocolEntries(
00178 WSAPROTOCOL_INFOW *pInfo1,
00179 WSAPROTOCOL_INFOW *pInfo2
00180 )
00181 {
00182 if ( (memcmp(&pInfo1->ProviderId, &pInfo2->ProviderId, sizeof(GUID)) == 0) &&
00183 (pInfo1->dwServiceFlags1 == pInfo2->dwServiceFlags1) &&
00184 (pInfo1->dwServiceFlags2 == pInfo2->dwServiceFlags2) &&
00185 (pInfo1->dwServiceFlags3 == pInfo2->dwServiceFlags3) &&
00186 (pInfo1->dwServiceFlags4 == pInfo2->dwServiceFlags4) &&
00187 (pInfo1->ProtocolChain.ChainLen == pInfo2->ProtocolChain.ChainLen) &&
00188 (pInfo1->iVersion == pInfo2->iVersion) &&
00189 (pInfo1->iAddressFamily == pInfo2->iAddressFamily) &&
00190 (pInfo1->iMaxSockAddr == pInfo2->iMaxSockAddr) &&
00191 (pInfo1->iMinSockAddr == pInfo2->iMinSockAddr) &&
00192 (pInfo1->iSocketType == pInfo2->iSocketType) &&
00193 (pInfo1->iProtocol == pInfo2->iProtocol) &&
00194 (pInfo1->iProtocolMaxOffset == pInfo2->iProtocolMaxOffset) &&
00195 (pInfo1->iNetworkByteOrder == pInfo2->iNetworkByteOrder) &&
00196 (pInfo1->iSecurityScheme == pInfo2->iSecurityScheme) &&
00197 (pInfo1->dwMessageSize == pInfo2->dwMessageSize)
00198 )
00199 {
00200 return TRUE;
00201 }
00202 else
00203 {
00204 return FALSE;
00205 }
00206 }
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217 int
00218 RetrieveLspGuid(
00219 __in_z char *LspPath,
00220 GUID *Guid
00221 )
00222 {
00223 HMODULE hMod = NULL;
00224 LPFN_GETLSPGUID fnGetLspGuid = NULL;
00225 int retval = SOCKET_ERROR;
00226
00227
00228 hMod = LoadLibraryA( LspPath );
00229 if ( NULL == hMod )
00230 {
00231 fprintf( stderr, "RetrieveLspGuid: LoadLibraryA failed: %d\n", GetLastError() );
00232 goto cleanup;
00233 }
00234
00235
00236 fnGetLspGuid = (LPFN_GETLSPGUID) GetProcAddress( hMod, "GetLspGuid" );
00237 if ( NULL == fnGetLspGuid )
00238 {
00239 fprintf( stderr, "RetrieveLspGuid: GetProcAddress failed: %d\n", GetLastError() );
00240 goto cleanup;
00241 }
00242
00243
00244 fnGetLspGuid( Guid );
00245
00246 retval = NO_ERROR;
00247
00248 cleanup:
00249
00250 if ( NULL != hMod )
00251 FreeLibrary( hMod );
00252
00253 return retval;
00254 }
00255
00256
00257 #pragma warning(push)
00258 #pragma warning(disable: 4127)
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 BOOL
00270 IsNonIfsProvider(
00271 WSAPROTOCOL_INFOW *pProvider,
00272 int iProviderCount,
00273 DWORD dwProviderId
00274 )
00275 {
00276 int i;
00277
00278 for(i=0; i < iProviderCount ;i++)
00279 {
00280 if ( pProvider[ i ].dwCatalogEntryId == dwProviderId )
00281 {
00282 return !( pProvider[ i ].dwServiceFlags1 & XP1_IFS_HANDLES );
00283 }
00284 }
00285
00286 return FALSE;
00287 }
00288
00289 #pragma warning(pop)
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301 HMODULE
00302 LoadUpdateProviderFunction()
00303 {
00304 HMODULE hModule = NULL;
00305 HRESULT hr;
00306 char WinsockLibraryPath[ MAX_PATH+1 ],
00307 szExpandPath[ MAX_PATH+1 ];
00308
00309
00310
00311
00312
00313
00314 if ( GetSystemDirectoryA( WinsockLibraryPath, MAX_PATH+1 ) == 0 )
00315 {
00316 hr = StringCchCopyA( szExpandPath, MAX_PATH+1, "%SYSTEMROOT%\\system32" );
00317 if ( FAILED( hr ) )
00318 {
00319 fprintf( stderr, "LoadUpdateProviderFunctions: StringCchCopyA failed: 0x%x\n", hr );
00320 goto cleanup;
00321 }
00322
00323 if ( ExpandEnvironmentStringsA( WinsockLibraryPath, szExpandPath, MAX_PATH+1 ) == 0 )
00324 {
00325 fprintf(stderr, "LoadUpdateProviderFunctions: Unable to expand environment string: %d\n",
00326 GetLastError()
00327 );
00328 goto cleanup;
00329 }
00330 }
00331
00332 hr = StringCchCatA( WinsockLibraryPath, MAX_PATH+1, WINSOCK_DLL );
00333 if ( FAILED( hr ) )
00334 {
00335 fprintf( stderr, "LoadUpdateProviderFunctions: StringCchCatA failed: 0x%x\n", hr );
00336 goto cleanup;
00337 }
00338
00339 hModule = LoadLibraryA( WinsockLibraryPath );
00340 if (hModule == NULL)
00341 {
00342 fprintf(stderr, "LoadUpdateProviderFunctions: Unable to load %s: %d\n",
00343 WinsockLibraryPath, GetLastError()
00344 );
00345 goto cleanup;
00346 }
00347 #ifdef _WIN64
00348 fnWscUpdateProvider = (LPWSCUPDATEPROVIDER)GetProcAddress(hModule, "WSCUpdateProvider");
00349
00350 fnWscUpdateProvider32 = (LPWSCUPDATEPROVIDER)GetProcAddress(hModule, "WSCUpdateProvider32");
00351 #else
00352 fnWscUpdateProvider = (LPWSCUPDATEPROVIDER)GetProcAddress(hModule, "WSCUpdateProvider");
00353 #endif
00354
00355 return hModule;
00356
00357 cleanup:
00358
00359 if ( NULL != hModule )
00360 {
00361 FreeLibrary( hModule );
00362 hModule = NULL;
00363 }
00364
00365 return NULL;
00366 }
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379 int
00380 CountOrphanedChainEntries(
00381 WSAPROTOCOL_INFOW *pCatalog,
00382 int iCatalogCount
00383 )
00384 {
00385 int orphanCount = 0,
00386 i, j;
00387
00388 for(i=0; i < iCatalogCount ;i++)
00389 {
00390 if ( pCatalog[ i ].ProtocolChain.ChainLen > 1 )
00391 {
00392 for(j=0; j < iCatalogCount ;j++)
00393 {
00394 if ( i == j )
00395 continue;
00396 if ( pCatalog[ j ].dwCatalogEntryId == pCatalog[ i ].ProtocolChain.ChainEntries[ 0 ] )
00397 {
00398 break;
00399 }
00400 }
00401 if ( j >= iCatalogCount )
00402 orphanCount++;
00403 }
00404 }
00405
00406 return orphanCount;
00407 }
00408
00409
00410
00411
00412
00413
00414
00415
00416 WSAPROTOCOL_INFOW *
00417 FindProviderById(
00418 DWORD CatalogId,
00419 WSAPROTOCOL_INFOW *Catalog,
00420 int CatalogCount
00421 )
00422 {
00423 int i;
00424
00425 for(i=0; i < CatalogCount ;i++)
00426 {
00427 if ( Catalog[ i ].dwCatalogEntryId == CatalogId )
00428 return &Catalog[ i ];
00429 }
00430 return NULL;
00431 }
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441 WSAPROTOCOL_INFOW *
00442 FindProviderByGuid(
00443 GUID *Guid,
00444 WSAPROTOCOL_INFOW *Catalog,
00445 int CatalogCount
00446 )
00447 {
00448 int i;
00449
00450 for(i=0; i < CatalogCount ;i++)
00451 {
00452 if ( 0 == memcmp( &Catalog[ i ].ProviderId, Guid, sizeof( GUID ) ) )
00453 {
00454 return &Catalog[ i ];
00455 }
00456 }
00457
00458 return NULL;
00459 }
00460
00461
00462
00463
00464
00465
00466
00467
00468 DWORD
00469 GetCatalogIdForProviderGuid(
00470 GUID *Guid,
00471 WSAPROTOCOL_INFOW *Catalog,
00472 int CatalogCount
00473 )
00474 {
00475 WSAPROTOCOL_INFOW *match = NULL;
00476
00477 match = FindProviderByGuid( Guid, Catalog, CatalogCount );
00478 if ( NULL != match )
00479 {
00480 return match->dwCatalogEntryId;
00481 }
00482
00483 return 0;
00484 }
00485
00486 #pragma warning(push)
00487 #pragma warning(disable: 4127 )
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500 DWORD
00501 FindDummyIdFromProtocolChainId(
00502 DWORD CatalogId,
00503 WSAPROTOCOL_INFOW *Catalog,
00504 int CatalogCount
00505 )
00506 {
00507 int i;
00508
00509 for(i=0; i < CatalogCount ;i++)
00510 {
00511 if ( CatalogId == Catalog[ i ].dwCatalogEntryId )
00512 {
00513 if ( Catalog[ i ].ProtocolChain.ChainLen == LAYERED_PROTOCOL )
00514 return Catalog[ i ].dwCatalogEntryId;
00515 else
00516 return Catalog[ i ].ProtocolChain.ChainEntries[ 0 ];
00517 }
00518 }
00519
00520 ASSERT( 0 );
00521
00522 return 0;
00523 }
00524
00525 #pragma warning(pop)
00526
00527
00528
00529
00530
00531
00532
00533
00534 void
00535 InsertIdIntoProtocolChain(
00536 WSAPROTOCOL_INFOW *Entry,
00537 int Index,
00538 DWORD InsertId
00539 )
00540 {
00541 int i;
00542
00543 for(i=Entry->ProtocolChain.ChainLen; i > Index ;i--)
00544 {
00545 Entry->ProtocolChain.ChainEntries[ i ] = Entry->ProtocolChain.ChainEntries[ i - 1 ];
00546 }
00547
00548 Entry->ProtocolChain.ChainEntries[ Index ] = InsertId;
00549 Entry->ProtocolChain.ChainLen++;
00550 }
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569 void
00570 BuildSubsetLspChain(
00571 WSAPROTOCOL_INFOW *Entry,
00572 int Index,
00573 DWORD DummyId
00574 )
00575 {
00576 int Idx, i;
00577
00578 for(i=Index,Idx=1; i < Entry->ProtocolChain.ChainLen ;i++,Idx++)
00579 {
00580 Entry->ProtocolChain.ChainEntries[ Idx ] = Entry->ProtocolChain.ChainEntries[ i ];
00581 }
00582
00583 Entry->ProtocolChain.ChainEntries[ 0 ] = DummyId;
00584 Entry->ProtocolChain.ChainLen = Entry->ProtocolChain.ChainLen - Index + 1;
00585 }