/* * Copyright (c) 2006-2010 Andy Polyakov * * Build with: * * cl -Ox -GD -GF -Zl -MD -LD keepalive.c ws2_32.lib kernel32.lib ntdll.lib * * Pre-load as: * * [HKLM\SOFTWARE\Microsoft\Windows NT\CurrentVersion\DLL_PRELOAD] * "mstsc.exe"="keepalive.dll" * * See http://fy.chalmers.se/~appro/nt/DLL_PRELOAD/ for further details. * */ #ifndef _DLL #error "_DLL is not defined." #endif #ifdef _WIN64 #pragma comment(linker,"/entry:DllMain") #pragma comment(linker,"/merge:.rdata=.text") #else #pragma comment(linker,"/entry:DllMain@12") #pragma comment(linker,"/section:.text,erw") #pragma comment(linker,"/merge:.rdata=.text") #pragma comment(linker,"/merge:.data=.text") #endif #define UNICODE #define _UNICODE #if defined(WIN32) && !defined(_WIN32) #define _WIN32 #endif #define _WIN32_WINNT 0x0500 #define NTDDI_VERSION 0x05000100 //NTDDI_WIN2KSP1 #include #include #include #include #include #include #include #include #include #include static VOID DebugOutputA (const char *fmt,...) { va_list argv; char buf[256]; va_start(argv,fmt); _vsnprintf (buf,sizeof(buf)/sizeof(buf[0])-1,fmt,argv); buf[sizeof(buf)/sizeof(buf[0])-1]='\0'; OutputDebugStringA (buf); va_end(argv); } /* WSP wrappers */ static LPWSPSTARTUP _WSPStartup; static WSPPROC_TABLE WSPDispatch_; static int WSPAPI WSPSetSockOpt_( SOCKET s, int level, int optname, const char *optval, int optlen, LPINT lpErrno ) { /* don't let application reset SO_KEEPALIVE */ int one=1; if (level==SOL_SOCKET && optname==SO_KEEPALIVE && !*optval) DebugOutputA("filtering SO_KEEPALIVE"), optval = (const char *)&one, optlen = sizeof(one); return WSPDispatch_.lpWSPSetSockOpt(s,level,optname,optval,optlen,lpErrno); } static int WSPAPI WSPConnect_( SOCKET s, const struct sockaddr FAR * name, int namelen, LPWSABUF lpCallerData, LPWSABUF lpCalleeData, LPQOS lpSQOS, LPQOS lpGQOS, LPINT lpErrno ) { if (name->sa_family == AF_INET) { struct tcp_keepalive ka = { 1, 90*1000, 1000 }; DWORD len,err; int one=1; if (WSPDispatch_.lpWSPSetSockOpt(s,SOL_SOCKET,SO_KEEPALIVE, (const char *)&one,sizeof(one),&err)) DebugOutputA("failed to SO_KEEPALIVE %d",err); else DebugOutputA("SO_KEEPALIVE successful"); if (WSPDispatch_.lpWSPIoctl(s,SIO_KEEPALIVE_VALS, &ka,sizeof(ka),&ka,sizeof(ka),&len,NULL,NULL,NULL,&err)) DebugOutputA("failed to SIO_KEEPALIVE_VALS with %d",err); else DebugOutputA("SIO_KEEPALIVE_VALS successful"); } return WSPDispatch_.lpWSPConnect (s,(void *)name,namelen, lpCallerData,lpCalleeData,lpSQOS,lpGQOS,lpErrno); } static int WSPAPI WSPIoctl_( IN SOCKET s, IN DWORD dwIoControlCode, IN LPVOID lpvInBuffer, IN DWORD cbInBuffer, OUT LPVOID lpvOutBuffer, IN DWORD cbOutBuffer, OUT LPDWORD lpcbBytesReturned, IN LPWSAOVERLAPPED lpOverlapped, IN LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine, IN LPWSATHREADID lpThreadId, OUT LPINT lpErrno ) { struct tcp_keepalive ka = { 1, 90*1000, 1000 }; if (dwIoControlCode == SIO_KEEPALIVE_VALS) { DebugOutputA("filtering SIO_KEEPALIVE_VALS"); lpvInBuffer = &ka; cbInBuffer = sizeof(ka); } return WSPDispatch_.lpWSPIoctl(s,dwIoControlCode,lpvInBuffer,cbInBuffer, lpvOutBuffer,cbOutBuffer,lpcbBytesReturned,lpOverlapped, lpCompletionRoutine,lpThreadId,lpErrno); } static int WSPAPI WSPStartup_( IN WORD wVersionRequested, OUT LPWSPDATA lpWSPData, IN LPWSAPROTOCOL_INFOW lpProtocolInfo, IN WSPUPCALLTABLE UpcallTable, OUT LPWSPPROC_TABLE lpProcTable ) { int ret = (*_WSPStartup)( wVersionRequested, lpWSPData, lpProtocolInfo, UpcallTable, &WSPDispatch_); /* Copy whole table... */ *lpProcTable = WSPDispatch_; /* ... and subclass selected functions */ lpProcTable->lpWSPConnect = WSPConnect_; lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt_; lpProcTable->lpWSPIoctl = WSPIoctl_; return ret; } static PVOID WINAPI GetProcAddress_ ( HMODULE hModule, LPCSTR lpProcName ) { PVOID ret = GetProcAddress(hModule,lpProcName); if (ret && !strcmp(lpProcName,"WSPStartup")) { /* Place myself between WS2_32 and WSP */ _WSPStartup = ret; ret = WSPStartup_; } return ret; } PVOID *__GetProcAddress =NULL; BOOL WINAPI DllMain (HINSTANCE h, DWORD reason, LPVOID junk) { DWORD acc0,acc1,sess; HMODULE hmod; IMAGE_DOS_HEADER *dos_header; IMAGE_NT_HEADERS *nt_headers; IMAGE_DATA_DIRECTORY *dir; IMAGE_IMPORT_DESCRIPTOR *idesc; IMAGE_THUNK_DATA *thunk; static void *page0,*page1; static size_t plen0,plen1; switch (reason) { case DLL_PROCESS_ATTACH: DisableThreadLibraryCalls(h); /* make sure WS32_32.DLL is linked */ plen0 = htonl(0); if (!(hmod=GetModuleHandle(_T("WS2_32.DLL")))) { OutputDebugStringA("WS2_32.DLL not found?"); return FALSE; } dos_header = (IMAGE_DOS_HEADER *)hmod; nt_headers = (IMAGE_NT_HEADERS *)((char *)hmod + dos_header->e_lfanew); dir=&nt_headers->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT]; idesc=(IMAGE_IMPORT_DESCRIPTOR *)((char *)hmod + dir->VirtualAddress); while (idesc->Name) { if (!_stricmp((char *)hmod+idesc->Name,"KERNEL32.DLL")) break; idesc++; } if (!idesc->Name) { OutputDebugStringA("Can't locate KERNEL32.DLL import descriptor"); return FALSE; } page1 = (char *)hmod+idesc->FirstThunk; for (thunk=(IMAGE_THUNK_DATA *)page1; thunk->u1.Function && thunk->u1.Function!=(ULONG_PTR)GetProcAddress; thunk++) ; if (thunk->u1.Function) __GetProcAddress=(PVOID *)thunk; for (thunk=(IMAGE_THUNK_DATA *)page1;thunk->u1.Function;thunk++) ; plen1 = (size_t)thunk-(size_t)page1; if (!VirtualProtect (page1,plen1,PAGE_EXECUTE_READWRITE,&acc1)) { OutputDebugStringA("Unable to unlock WS2_32.DLL Thunk Table"); return FALSE; } if (__GetProcAddress) *__GetProcAddress =GetProcAddress_; VirtualProtect (page1,plen1,acc1,&acc1); break; case DLL_PROCESS_DETACH: VirtualProtect (page1,plen1,PAGE_EXECUTE_READWRITE,&acc1); if (__GetProcAddress) *__GetProcAddress =GetProcAddress; VirtualProtect (page1,plen1,acc1,&acc1); break; } return TRUE; }