diff --git a/capture/src/capture.cpp b/capture/src/capture.cpp index 88503c9c..41338599 100644 --- a/capture/src/capture.cpp +++ b/capture/src/capture.cpp @@ -9,6 +9,7 @@ #include #include +#include "../../common/TracyProtocol.hpp" #include "../../server/TracyFileWrite.hpp" #include "../../server/TracyMemory.hpp" #include "../../server/TracyWorker.hpp" @@ -137,6 +138,16 @@ int main( int argc, char** argv ) printf( "Connecting to %s...", address ); fflush( stdout ); tracy::Worker worker( address ); + for(;;) + { + const auto handshake = worker.GetHandshakeStatus(); + if( handshake == tracy::HandshakeWelcome ) break; + if( handshake == tracy::HandshakeProtocolMismatch ) + { + printf( "\nThe client you are trying to connect to uses incompatible protocol version.\nMake sure you are using the same Tracy version on both client and server.\n" ); + return 1; + } + } while( !worker.HasData() ) std::this_thread::sleep_for( std::chrono::milliseconds( 100 ) ); printf( "\nQueue delay: %s\nTimer resolution: %s\n", TimeToString( worker.GetDelay() ), TimeToString( worker.GetResolution() ) ); diff --git a/client/TracyProfiler.cpp b/client/TracyProfiler.cpp index cef3cfe9..ecdc49a5 100644 --- a/client/TracyProfiler.cpp +++ b/client/TracyProfiler.cpp @@ -919,13 +919,31 @@ void Profiler::Worker() tv.tv_usec = 0; char shibboleth[HandshakeShibbolethSize]; - const auto res = m_sock->ReadRaw( shibboleth, HandshakeShibbolethSize, &tv ); + auto res = m_sock->ReadRaw( shibboleth, HandshakeShibbolethSize, &tv ); if( !res || memcmp( shibboleth, HandshakeShibboleth, HandshakeShibbolethSize ) != 0 ) { m_sock->~Socket(); tracy_free( m_sock ); continue; } + + uint32_t protocolVersion; + res = m_sock->ReadRaw( &protocolVersion, sizeof( protocolVersion ), &tv ); + if( !res ) + { + m_sock->~Socket(); + tracy_free( m_sock ); + continue; + } + + if( protocolVersion != ProtocolVersion ) + { + HandshakeStatus status = HandshakeProtocolMismatch; + m_sock->Send( &status, sizeof( status ) ); + m_sock->~Socket(); + tracy_free( m_sock ); + continue; + } } #ifdef TRACY_ON_DEMAND @@ -933,6 +951,9 @@ void Profiler::Worker() m_isConnected.store( true, std::memory_order_relaxed ); #endif + HandshakeStatus handshake = HandshakeWelcome; + m_sock->Send( &handshake, sizeof( handshake ) ); + LZ4_resetStream( m_stream ); m_sock->Send( &welcome, sizeof( welcome ) ); diff --git a/common/TracyProtocol.hpp b/common/TracyProtocol.hpp index f6701ddd..eae9fd8f 100644 --- a/common/TracyProtocol.hpp +++ b/common/TracyProtocol.hpp @@ -9,6 +9,8 @@ namespace tracy { +enum : uint32_t { ProtocolVersion = 0 }; + using lz4sz_t = uint32_t; enum { TargetFrameSize = 256 * 1024 }; @@ -30,6 +32,13 @@ enum ServerQuery : uint8_t enum { HandshakeShibbolethSize = 8 }; static const char HandshakeShibboleth[HandshakeShibbolethSize] = { 'T', 'r', 'a', 'c', 'y', 'P', 'r', 'f' }; +enum HandshakeStatus : uint8_t +{ + HandshakePending, + HandshakeWelcome, + HandshakeProtocolMismatch, +}; + enum { WelcomeMessageProgramNameSize = 64 }; enum { WelcomeMessageHostInfoSize = 1024 }; diff --git a/server/TracyView.cpp b/server/TracyView.cpp index b9ed99d2..c803f912 100644 --- a/server/TracyView.cpp +++ b/server/TracyView.cpp @@ -11,6 +11,7 @@ #include #include "../common/TracyMutex.hpp" +#include "../common/TracyProtocol.hpp" #include "../common/TracySystem.hpp" #include "tracy_pdqsort.h" #include "TracyBadVersion.hpp" @@ -454,6 +455,28 @@ void View::DrawTextContrast( ImDrawList* draw, const ImVec2& pos, uint32_t color bool View::Draw() { + HandshakeStatus status = (HandshakeStatus)s_instance->m_worker.GetHandshakeStatus(); + if( status == HandshakeProtocolMismatch ) + { + ImGui::OpenPopup( "Protocol mismatch" ); + } + + if( ImGui::BeginPopupModal( "Protocol mismatch", nullptr, ImGuiWindowFlags_AlwaysAutoResize ) ) + { +#ifdef TRACY_EXTENDED_FONT + TextCentered( ICON_FA_EXCLAMATION_TRIANGLE ); +#endif + ImGui::Text( "The client you are trying to connect to uses incompatible protocol version.\nMake sure you are using the same Tracy version on both client and server." ); + ImGui::Separator(); + if( ImGui::Button( "My bad" ) ) + { + ImGui::CloseCurrentPopup(); + ImGui::EndPopup(); + return false; + } + ImGui::EndPopup(); + } + return s_instance->DrawImpl(); } diff --git a/server/TracyWorker.cpp b/server/TracyWorker.cpp index 450a9585..f6581cec 100644 --- a/server/TracyWorker.cpp +++ b/server/TracyWorker.cpp @@ -201,6 +201,7 @@ Worker::Worker( const char* addr ) , m_pendingSourceLocation( 0 ) , m_pendingCallstackFrames( 0 ) , m_traceVersion( CurrentVersion ) + , m_handshake( 0 ) { m_data.sourceLocationExpand.push_back( 0 ); m_data.threadExpand.push_back( 0 ); @@ -224,6 +225,7 @@ Worker::Worker( FileRead& f, EventType::Type eventMask ) , m_crashed( false ) , m_stream( nullptr ) , m_buffer( nullptr ) + , m_handshake( 0 ) { m_data.threadExpand.push_back( 0 ); m_data.callstackPayload.push_back( nullptr ); @@ -1279,14 +1281,28 @@ void Worker::Exec() if( m_sock.Connect( m_addr.c_str(), "8086" ) ) break; } - m_sock.Send( HandshakeShibboleth, HandshakeShibbolethSize ); - auto lz4buf = std::make_unique( LZ4Size ); + std::chrono::time_point t0; uint64_t bytes = 0; uint64_t decBytes = 0; + m_sock.Send( HandshakeShibboleth, HandshakeShibbolethSize ); + uint32_t protocolVersion = ProtocolVersion; + m_sock.Send( &protocolVersion, sizeof( protocolVersion ) ); + HandshakeStatus handshake; + if( !m_sock.Read( &handshake, sizeof( handshake ), &tv, ShouldExit ) ) goto close; + m_handshake.store( handshake, std::memory_order_relaxed ); + switch( handshake ) + { + case HandshakeWelcome: + break; + case HandshakeProtocolMismatch: + default: + goto close; + } + m_data.framesBase = m_data.frames.Retrieve( 0, [this] ( uint64_t name ) { auto fd = m_slab.AllocInit(); fd->name = name; diff --git a/server/TracyWorker.hpp b/server/TracyWorker.hpp index 65923de3..71437844 100644 --- a/server/TracyWorker.hpp +++ b/server/TracyWorker.hpp @@ -264,6 +264,7 @@ public: void Write( FileWrite& f ); int GetTraceVersion() const { return m_traceVersion; } + uint8_t GetHandshakeStatus() const { return m_handshake.load( std::memory_order_relaxed ); } static const LoadProgress& GetLoadProgress() { return s_loadProgress; } @@ -415,6 +416,7 @@ private: MbpsBlock m_mbpsData; int m_traceVersion; + std::atomic m_handshake; static LoadProgress s_loadProgress; };