Skip to content

Commit 8e2ed11

Browse files
committed
Make ReceivePacket return data that it has processed
1 parent 497e6fa commit 8e2ed11

1 file changed

Lines changed: 131 additions & 66 deletions

File tree

clickhouse/client.cpp

Lines changed: 131 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <assert.h>
1212
#include <system_error>
13+
#include <variant>
1314
#include <vector>
1415
#include <sstream>
1516

@@ -146,6 +147,51 @@ std::unique_ptr<EndpointsIteratorBase> GetEndpointsIterator(const ClientOptions&
146147

147148
}
148149

150+
151+
template<class... Ts>
152+
struct overloaded : Ts... { using Ts::operator()...; };
153+
template<class... Ts>
154+
overloaded(Ts...) -> overloaded<Ts...>;
155+
156+
struct DataTag {
157+
Block block;
158+
};
159+
struct ExceptionTag {
160+
};
161+
struct ProfileTag {
162+
Profile profile;
163+
};
164+
struct ProgressTag {
165+
Progress progress;
166+
};
167+
struct PongTag {
168+
};
169+
struct HelloTag {
170+
};
171+
struct LogTag {
172+
Block block;
173+
};
174+
struct TableColumnsTag {
175+
};
176+
struct ProfileEventsTag {
177+
Block block;
178+
};
179+
struct EndOfStreamTag {
180+
};
181+
182+
using EncodedPacket = std::variant<
183+
std::monostate,
184+
DataTag,
185+
ExceptionTag,
186+
ProfileTag,
187+
ProgressTag,
188+
PongTag,
189+
HelloTag,
190+
LogTag,
191+
TableColumnsTag,
192+
ProfileEventsTag,
193+
EndOfStreamTag>;
194+
149195
class Client::Impl {
150196
public:
151197
Impl(const ClientOptions& opts);
@@ -180,7 +226,8 @@ class Client::Impl {
180226
private:
181227
bool Handshake();
182228

183-
bool ReceivePacket(uint64_t* server_packet = nullptr);
229+
EncodedPacket ReceivePacket(uint64_t* server_packet = nullptr);
230+
bool ProcessPacket(uint64_t* server_packet = nullptr);
184231

185232
void SendQuery(const Query& query, bool finalize = true);
186233
void FinalizeQuery();
@@ -197,7 +244,7 @@ class Client::Impl {
197244
bool ReceiveHello();
198245

199246
/// Reads data packet form input stream.
200-
bool ReceiveData();
247+
bool ReceiveData(Block * block);
201248

202249
/// Reads exception packet form input stream.
203250
bool ReceiveException(bool rethrow = false);
@@ -308,7 +355,7 @@ void Client::Impl::ExecuteQuery(Query query) {
308355

309356
SendQuery(query);
310357

311-
while (ReceivePacket()) {
358+
while (ProcessPacket()) {
312359
;
313360
}
314361
}
@@ -333,7 +380,7 @@ void Client::Impl::SelectWithExternalData(Query query, const ExternalTables& ext
333380
SendExternalData(external_tables);
334381
FinalizeQuery();
335382

336-
while (ReceivePacket()) {
383+
while (ProcessPacket()) {
337384
;
338385
}
339386
}
@@ -408,7 +455,7 @@ void Client::Impl::Insert(const std::string& table_name, const std::string& quer
408455

409456
// Wait for a data packet and return
410457
uint64_t server_packet = 0;
411-
while (ReceivePacket(&server_packet)) {
458+
while (ProcessPacket(&server_packet)) {
412459
if (server_packet == ServerCodes::Data) {
413460
SendData(block);
414461
EndInsert();
@@ -443,7 +490,7 @@ Block Client::Impl::BeginInsert(Query query) {
443490

444491
// Wait for a data packet and return
445492
uint64_t server_packet = 0;
446-
while (ReceivePacket(&server_packet)) {
493+
while (ProcessPacket(&server_packet)) {
447494
if (server_packet == ServerCodes::Data) {
448495
return block;
449496
}
@@ -470,7 +517,7 @@ void Client::Impl::EndInsert() {
470517

471518
// Wait for EOS.
472519
uint64_t eos_packet{0};
473-
while (ReceivePacket(&eos_packet)) {
520+
while (ProcessPacket(&eos_packet)) {
474521
;
475522
}
476523

@@ -491,7 +538,7 @@ void Client::Impl::Ping() {
491538
output_->Flush();
492539

493540
uint64_t server_packet;
494-
const bool ret = ReceivePacket(&server_packet);
541+
const bool ret = ProcessPacket(&server_packet);
495542

496543
if (!ret || server_packet != ServerCodes::Pong) {
497544
throw ProtocolError("fail to ping server");
@@ -569,157 +616,176 @@ bool Client::Impl::Handshake() {
569616
return true;
570617
}
571618

572-
bool Client::Impl::ReceivePacket(uint64_t* server_packet) {
619+
EncodedPacket Client::Impl::ReceivePacket(uint64_t* server_packet) {
573620
uint64_t packet_type = 0;
574621

575622
if (!WireFormat::ReadVarint64(*input_, &packet_type)) {
576-
return false;
623+
return {};
577624
}
578625
if (server_packet) {
579626
*server_packet = packet_type;
580627
}
581628

582629
switch (packet_type) {
630+
583631
case ServerCodes::Data: {
584-
if (!ReceiveData()) {
632+
DataTag ret{};
633+
if (!ReceiveData(&ret.block)) {
585634
throw ProtocolError("can't read data packet from input stream");
586635
}
587-
return true;
636+
return ret;
588637
}
589638

590639
case ServerCodes::Exception: {
640+
ExceptionTag ret{};
591641
ReceiveException();
592-
return false;
642+
return ret;
593643
}
594644

595645
case ServerCodes::ProfileInfo: {
596-
Profile profile;
646+
ProfileTag ret{};
597647

598-
if (!WireFormat::ReadUInt64(*input_, &profile.rows)) {
599-
return false;
648+
if (!WireFormat::ReadUInt64(*input_, &ret.profile.rows)) {
649+
return {};
600650
}
601-
if (!WireFormat::ReadUInt64(*input_, &profile.blocks)) {
602-
return false;
651+
if (!WireFormat::ReadUInt64(*input_, &ret.profile.blocks)) {
652+
return {};
603653
}
604-
if (!WireFormat::ReadUInt64(*input_, &profile.bytes)) {
605-
return false;
654+
if (!WireFormat::ReadUInt64(*input_, &ret.profile.bytes)) {
655+
return {};
606656
}
607-
if (!WireFormat::ReadFixed(*input_, &profile.applied_limit)) {
608-
return false;
657+
if (!WireFormat::ReadFixed(*input_, &ret.profile.applied_limit)) {
658+
return {};
609659
}
610-
if (!WireFormat::ReadUInt64(*input_, &profile.rows_before_limit)) {
611-
return false;
660+
if (!WireFormat::ReadUInt64(*input_, &ret.profile.rows_before_limit)) {
661+
return {};
612662
}
613-
if (!WireFormat::ReadFixed(*input_, &profile.calculated_rows_before_limit)) {
614-
return false;
663+
if (!WireFormat::ReadFixed(*input_, &ret.profile.calculated_rows_before_limit)) {
664+
return {};
615665
}
616666

617667
if (events_) {
618-
events_->OnProfile(profile);
668+
events_->OnProfile(ret.profile);
619669
}
620670

621-
return true;
671+
return ret;
622672
}
623673

624674
case ServerCodes::Progress: {
625-
Progress info;
675+
ProgressTag ret;
626676

627-
if (!WireFormat::ReadUInt64(*input_, &info.rows)) {
628-
return false;
677+
if (!WireFormat::ReadUInt64(*input_, &ret.progress.rows)) {
678+
return {};
629679
}
630-
if (!WireFormat::ReadUInt64(*input_, &info.bytes)) {
631-
return false;
680+
if (!WireFormat::ReadUInt64(*input_, &ret.progress.bytes)) {
681+
return {};
632682
}
633683
if constexpr(DMBS_PROTOCOL_REVISION >= DBMS_MIN_REVISION_WITH_TOTAL_ROWS_IN_PROGRESS) {
634-
if (!WireFormat::ReadUInt64(*input_, &info.total_rows)) {
635-
return false;
684+
if (!WireFormat::ReadUInt64(*input_, &ret.progress.total_rows)) {
685+
return {};
636686
}
637687
}
638688
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO)
639689
{
640-
if (!WireFormat::ReadUInt64(*input_, &info.written_rows)) {
641-
return false;
690+
if (!WireFormat::ReadUInt64(*input_, &ret.progress.written_rows)) {
691+
return {};
642692
}
643-
if (!WireFormat::ReadUInt64(*input_, &info.written_bytes)) {
644-
return false;
693+
if (!WireFormat::ReadUInt64(*input_, &ret.progress.written_bytes)) {
694+
return {};
645695
}
646696
}
647697

648698
if (events_) {
649-
events_->OnProgress(info);
699+
events_->OnProgress(ret.progress);
650700
}
651701

652-
return true;
702+
return ret;
653703
}
654704

655705
case ServerCodes::Pong: {
656-
return true;
706+
return PongTag{};
657707
}
658708

659709
case ServerCodes::Hello: {
660-
return true;
710+
return HelloTag{};
661711
}
662712

663713
case ServerCodes::EndOfStream: {
664714
if (events_) {
665715
events_->OnFinish();
666716
}
667-
return false;
717+
return EndOfStreamTag{};
668718
}
669719

670720
case ServerCodes::Log: {
671721
// log tag
672722
if (!WireFormat::SkipString(*input_)) {
673-
return false;
723+
return {};
674724
}
675-
Block block;
725+
LogTag ret;
676726

677727
// Use uncompressed stream since log blocks usually contain only one row
678-
if (!ReadBlock(*input_, &block)) {
679-
return false;
728+
if (!ReadBlock(*input_, &ret.block)) {
729+
return {};
680730
}
681731

682732
if (events_) {
683-
events_->OnServerLog(block);
733+
events_->OnServerLog(ret.block);
684734
}
685-
return true;
735+
return ret;
686736
}
687737

688738
case ServerCodes::TableColumns: {
689739
// external table name
690740
if (!WireFormat::SkipString(*input_)) {
691-
return false;
741+
return {};
692742
}
693743

694744
// columns metadata
695745
if (!WireFormat::SkipString(*input_)) {
696-
return false;
746+
return {};
697747
}
698-
return true;
748+
return TableColumnsTag{};
699749
}
700750

701751
case ServerCodes::ProfileEvents: {
702752
if (!WireFormat::SkipString(*input_)) {
703-
return false;
753+
return {};
704754
}
705755

706-
Block block;
707-
if (!ReadBlock(*input_, &block)) {
708-
return false;
756+
ProfileEventsTag ret;
757+
if (!ReadBlock(*input_, &ret.block)) {
758+
return {};
709759
}
710760

711761
if (events_) {
712-
events_->OnProfileEvents(block);
762+
events_->OnProfileEvents(ret.block);
713763
}
714-
return true;
764+
return ret;
715765
}
716766

717767
default:
718768
throw UnimplementedError("unimplemented " + std::to_string((int)packet_type));
719-
break;
720769
}
721770
}
722771

772+
bool Client::Impl::ProcessPacket(uint64_t* server_packet) {
773+
auto packet = ReceivePacket(server_packet);
774+
return std::visit(overloaded{
775+
[](std::monostate) { return false; },
776+
[](DataTag) { return true; },
777+
[](ExceptionTag) {return false; },
778+
[](ProfileTag){ return true; },
779+
[](ProgressTag){ return true; },
780+
[](PongTag){ return true; },
781+
[](HelloTag){ return true; },
782+
[](LogTag){ return true; },
783+
[](TableColumnsTag){ return true; },
784+
[](ProfileEventsTag){ return true; },
785+
[](EndOfStreamTag){ return false; },
786+
}, packet);
787+
}
788+
723789
bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
724790
// Additional information about block.
725791
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_BLOCK_INFO) {
@@ -793,8 +859,7 @@ bool Client::Impl::ReadBlock(InputStream& input, Block* block) {
793859
return true;
794860
}
795861

796-
bool Client::Impl::ReceiveData() {
797-
Block block;
862+
bool Client::Impl::ReceiveData(Block * block) {
798863

799864
if (server_info_.revision >= DBMS_MIN_REVISION_WITH_TEMPORARY_TABLES) {
800865
if (!WireFormat::SkipString(*input_)) {
@@ -804,18 +869,18 @@ bool Client::Impl::ReceiveData() {
804869

805870
if (compression_ == CompressionState::Enable) {
806871
CompressedInput compressed(input_.get());
807-
if (!ReadBlock(compressed, &block)) {
872+
if (!ReadBlock(compressed, block)) {
808873
return false;
809874
}
810875
} else {
811-
if (!ReadBlock(*input_, &block)) {
876+
if (!ReadBlock(*input_, block)) {
812877
return false;
813878
}
814879
}
815880

816881
if (events_) {
817-
events_->OnData(block);
818-
if (!events_->OnDataCancelable(block)) {
882+
events_->OnData(*block);
883+
if (!events_->OnDataCancelable(*block)) {
819884
SendCancel();
820885
}
821886
}

0 commit comments

Comments
 (0)