1010
1111#include < assert.h>
1212#include < system_error>
13+ #include < variant>
1314#include < vector>
1415#include < sstream>
1516
@@ -146,6 +147,51 @@ std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const ClientOptions&
146147
147148}
148149
150+
151+ template <class ... Ts>
152+ struct overloaded : Ts... { using Ts::operator ()...; };
153+ template <class ... Ts>
154+ overloaded (Ts...) -> overloaded<Ts...>;
155+
156+ struct DataTag {
157+ Block block;
158+ };
159+ struct ExceptionTag {
160+ };
161+ struct ProfileTag {
162+ Profile profile;
163+ };
164+ struct ProgressTag {
165+ Progress progress;
166+ };
167+ struct PongTag {
168+ };
169+ struct HelloTag {
170+ };
171+ struct LogTag {
172+ Block block;
173+ };
174+ struct TableColumnsTag {
175+ };
176+ struct ProfileEventsTag {
177+ Block block;
178+ };
179+ struct EndOfStreamTag {
180+ };
181+
182+ using EncodedPacket = std::variant<
183+ std::monostate,
184+ DataTag,
185+ ExceptionTag,
186+ ProfileTag,
187+ ProgressTag,
188+ PongTag,
189+ HelloTag,
190+ LogTag,
191+ TableColumnsTag,
192+ ProfileEventsTag,
193+ EndOfStreamTag>;
194+
149195class Client ::Impl {
150196public:
151197 Impl (const ClientOptions& opts);
@@ -180,7 +226,8 @@ class Client::Impl {
180226private:
181227 bool Handshake ();
182228
183- bool ReceivePacket (uint64_t * server_packet = nullptr );
229+ EncodedPacket ReceivePacket (uint64_t * server_packet = nullptr );
230+ bool ProcessPacket (uint64_t * server_packet = nullptr );
184231
185232 void SendQuery (const Query& query, bool finalize = true );
186233 void FinalizeQuery ();
@@ -197,7 +244,7 @@ class Client::Impl {
197244 bool ReceiveHello ();
198245
199246 // / Reads data packet form input stream.
200- bool ReceiveData ();
247+ bool ReceiveData (Block * block );
201248
202249 // / Reads exception packet form input stream.
203250 bool ReceiveException (bool rethrow = false );
@@ -308,7 +355,7 @@ void Client::Impl::ExecuteQuery(Query query) {
308355
309356 SendQuery (query);
310357
311- while (ReceivePacket ()) {
358+ while (ProcessPacket ()) {
312359 ;
313360 }
314361}
@@ -333,7 +380,7 @@ void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& ext
333380 SendExternalData (external_tables);
334381 FinalizeQuery ();
335382
336- while (ReceivePacket ()) {
383+ while (ProcessPacket ()) {
337384 ;
338385 }
339386}
@@ -408,7 +455,7 @@ void Client::Impl::Insert(const std::string& table_name, const std::string& quer
408455
409456 // Wait for a data packet and return
410457 uint64_t server_packet = 0 ;
411- while (ReceivePacket (&server_packet)) {
458+ while (ProcessPacket (&server_packet)) {
412459 if (server_packet == ServerCodes::Data) {
413460 SendData (block);
414461 EndInsert ();
@@ -443,7 +490,7 @@ Block Client::Impl::BeginInsert(Query query) {
443490
444491 // Wait for a data packet and return
445492 uint64_t server_packet = 0 ;
446- while (ReceivePacket (&server_packet)) {
493+ while (ProcessPacket (&server_packet)) {
447494 if (server_packet == ServerCodes::Data) {
448495 return block;
449496 }
@@ -470,7 +517,7 @@ void Client::Impl::EndInsert() {
470517
471518 // Wait for EOS.
472519 uint64_t eos_packet{0 };
473- while (ReceivePacket (&eos_packet)) {
520+ while (ProcessPacket (&eos_packet)) {
474521 ;
475522 }
476523
@@ -491,7 +538,7 @@ void Client::Impl::Ping() {
491538 output_->Flush ();
492539
493540 uint64_t server_packet;
494- const bool ret = ReceivePacket (&server_packet);
541+ const bool ret = ProcessPacket (&server_packet);
495542
496543 if (!ret || server_packet != ServerCodes::Pong) {
497544 throw ProtocolError (" fail to ping server" );
@@ -569,157 +616,176 @@ bool Client::Impl::Handshake() {
569616 return true ;
570617}
571618
572- bool Client::Impl::ReceivePacket (uint64_t * server_packet) {
619+ EncodedPacket Client::Impl::ReceivePacket (uint64_t * server_packet) {
573620 uint64_t packet_type = 0 ;
574621
575622 if (!WireFormat::ReadVarint64 (*input_, &packet_type)) {
576- return false ;
623+ return {} ;
577624 }
578625 if (server_packet) {
579626 *server_packet = packet_type;
580627 }
581628
582629 switch (packet_type) {
630+
583631 case ServerCodes::Data: {
584- if (!ReceiveData ()) {
632+ DataTag ret{};
633+ if (!ReceiveData (&ret.block )) {
585634 throw ProtocolError (" can't read data packet from input stream" );
586635 }
587- return true ;
636+ return ret ;
588637 }
589638
590639 case ServerCodes::Exception: {
640+ ExceptionTag ret{};
591641 ReceiveException ();
592- return false ;
642+ return ret ;
593643 }
594644
595645 case ServerCodes::ProfileInfo: {
596- Profile profile ;
646+ ProfileTag ret{} ;
597647
598- if (!WireFormat::ReadUInt64 (*input_, &profile.rows )) {
599- return false ;
648+ if (!WireFormat::ReadUInt64 (*input_, &ret. profile .rows )) {
649+ return {} ;
600650 }
601- if (!WireFormat::ReadUInt64 (*input_, &profile.blocks )) {
602- return false ;
651+ if (!WireFormat::ReadUInt64 (*input_, &ret. profile .blocks )) {
652+ return {} ;
603653 }
604- if (!WireFormat::ReadUInt64 (*input_, &profile.bytes )) {
605- return false ;
654+ if (!WireFormat::ReadUInt64 (*input_, &ret. profile .bytes )) {
655+ return {} ;
606656 }
607- if (!WireFormat::ReadFixed (*input_, &profile.applied_limit )) {
608- return false ;
657+ if (!WireFormat::ReadFixed (*input_, &ret. profile .applied_limit )) {
658+ return {} ;
609659 }
610- if (!WireFormat::ReadUInt64 (*input_, &profile.rows_before_limit )) {
611- return false ;
660+ if (!WireFormat::ReadUInt64 (*input_, &ret. profile .rows_before_limit )) {
661+ return {} ;
612662 }
613- if (!WireFormat::ReadFixed (*input_, &profile.calculated_rows_before_limit )) {
614- return false ;
663+ if (!WireFormat::ReadFixed (*input_, &ret. profile .calculated_rows_before_limit )) {
664+ return {} ;
615665 }
616666
617667 if (events_) {
618- events_->OnProfile (profile);
668+ events_->OnProfile (ret. profile );
619669 }
620670
621- return true ;
671+ return ret ;
622672 }
623673
624674 case ServerCodes::Progress: {
625- Progress info ;
675+ ProgressTag ret ;
626676
627- if (!WireFormat::ReadUInt64 (*input_, &info .rows )) {
628- return false ;
677+ if (!WireFormat::ReadUInt64 (*input_, &ret. progress .rows )) {
678+ return {} ;
629679 }
630- if (!WireFormat::ReadUInt64 (*input_, &info .bytes )) {
631- return false ;
680+ if (!WireFormat::ReadUInt64 (*input_, &ret. progress .bytes )) {
681+ return {} ;
632682 }
633683 if constexpr (DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) {
634- if (!WireFormat::ReadUInt64 (*input_, &info .total_rows )) {
635- return false ;
684+ if (!WireFormat::ReadUInt64 (*input_, &ret. progress .total_rows )) {
685+ return {} ;
636686 }
637687 }
638688 if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
639689 {
640- if (!WireFormat::ReadUInt64 (*input_, &info .written_rows )) {
641- return false ;
690+ if (!WireFormat::ReadUInt64 (*input_, &ret. progress .written_rows )) {
691+ return {} ;
642692 }
643- if (!WireFormat::ReadUInt64 (*input_, &info .written_bytes )) {
644- return false ;
693+ if (!WireFormat::ReadUInt64 (*input_, &ret. progress .written_bytes )) {
694+ return {} ;
645695 }
646696 }
647697
648698 if (events_) {
649- events_->OnProgress (info );
699+ events_->OnProgress (ret. progress );
650700 }
651701
652- return true ;
702+ return ret ;
653703 }
654704
655705 case ServerCodes::Pong: {
656- return true ;
706+ return PongTag{} ;
657707 }
658708
659709 case ServerCodes::Hello: {
660- return true ;
710+ return HelloTag{} ;
661711 }
662712
663713 case ServerCodes::EndOfStream: {
664714 if (events_) {
665715 events_->OnFinish ();
666716 }
667- return false ;
717+ return EndOfStreamTag{} ;
668718 }
669719
670720 case ServerCodes::Log: {
671721 // log tag
672722 if (!WireFormat::SkipString (*input_)) {
673- return false ;
723+ return {} ;
674724 }
675- Block block ;
725+ LogTag ret ;
676726
677727 // Use uncompressed stream since log blocks usually contain only one row
678- if (!ReadBlock (*input_, &block)) {
679- return false ;
728+ if (!ReadBlock (*input_, &ret. block )) {
729+ return {} ;
680730 }
681731
682732 if (events_) {
683- events_->OnServerLog (block);
733+ events_->OnServerLog (ret. block );
684734 }
685- return true ;
735+ return ret ;
686736 }
687737
688738 case ServerCodes::TableColumns: {
689739 // external table name
690740 if (!WireFormat::SkipString (*input_)) {
691- return false ;
741+ return {} ;
692742 }
693743
694744 // columns metadata
695745 if (!WireFormat::SkipString (*input_)) {
696- return false ;
746+ return {} ;
697747 }
698- return true ;
748+ return TableColumnsTag{} ;
699749 }
700750
701751 case ServerCodes::ProfileEvents: {
702752 if (!WireFormat::SkipString (*input_)) {
703- return false ;
753+ return {} ;
704754 }
705755
706- Block block ;
707- if (!ReadBlock (*input_, &block)) {
708- return false ;
756+ ProfileEventsTag ret ;
757+ if (!ReadBlock (*input_, &ret. block )) {
758+ return {} ;
709759 }
710760
711761 if (events_) {
712- events_->OnProfileEvents (block);
762+ events_->OnProfileEvents (ret. block );
713763 }
714- return true ;
764+ return ret ;
715765 }
716766
717767 default :
718768 throw UnimplementedError (" unimplemented " + std::to_string ((int )packet_type));
719- break ;
720769 }
721770}
722771
772+ bool Client::Impl::ProcessPacket (uint64_t * server_packet) {
773+ auto packet = ReceivePacket (server_packet);
774+ return std::visit (overloaded{
775+ [](std::monostate) { return false ; },
776+ [](DataTag) { return true ; },
777+ [](ExceptionTag) {return false ; },
778+ [](ProfileTag){ return true ; },
779+ [](ProgressTag){ return true ; },
780+ [](PongTag){ return true ; },
781+ [](HelloTag){ return true ; },
782+ [](LogTag){ return true ; },
783+ [](TableColumnsTag){ return true ; },
784+ [](ProfileEventsTag){ return true ; },
785+ [](EndOfStreamTag){ return false ; },
786+ }, packet);
787+ }
788+
723789bool Client::Impl::ReadBlock (InputStream& input, Block* block) {
724790 // Additional information about block.
725791 if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
@@ -793,8 +859,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
793859 return true ;
794860}
795861
796- bool Client::Impl::ReceiveData () {
797- Block block;
862+ bool Client::Impl::ReceiveData (Block * block) {
798863
799864 if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
800865 if (!WireFormat::SkipString (*input_)) {
@@ -804,18 +869,18 @@ bool Client::Impl::ReceiveData() {
804869
805870 if (compression_ == CompressionState::Enable) {
806871 CompressedInput compressed (input_.get ());
807- if (!ReadBlock (compressed, & block)) {
872+ if (!ReadBlock (compressed, block)) {
808873 return false ;
809874 }
810875 } else {
811- if (!ReadBlock (*input_, & block)) {
876+ if (!ReadBlock (*input_, block)) {
812877 return false ;
813878 }
814879 }
815880
816881 if (events_) {
817- events_->OnData (block);
818- if (!events_->OnDataCancelable (block)) {
882+ events_->OnData (* block);
883+ if (!events_->OnDataCancelable (* block)) {
819884 SendCancel ();
820885 }
821886 }
0 commit comments