Skip to content

Commit 03ae3c2

Browse files
committed
Address gemini-code-assist comments.
1 parent 25876de commit 03ae3c2

10 files changed

Lines changed: 44 additions & 30 deletions

apps/Makefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ CUDA_TARGETS := $(CUDA_TARGETS:%cuda.cu=%cuda.x)
77
CUSTATEVEC_TARGETS = $(shell find . -maxdepth 1 -name "*custatevec.cu")
88
CUSTATEVEC_TARGETS := $(CUSTATEVEC_TARGETS:%custatevec.cu=%custatevec.x)
99

10-
CUSTATEVEC_TARGETS = $(shell find . -maxdepth 1 -name "*custatevecex.cu")
11-
CUSTATEVEC_TARGETS := $(CUSTATEVEC_TARGETS:%custatevec.cu=%custatevecex.x)
10+
CUSTATEVECEX_TARGETS = $(shell find . -maxdepth 1 -name "*custatevecex.cu")
11+
CUSTATEVECEX_TARGETS := $(CUSTATEVEC_TARGETS:%custatevecex.cu=%custatevecex.x)
1212

1313
HIP_TARGETS = $(shell find . -maxdepth 1 -name '*cuda.cu')
1414
HIP_TARGETS := $(HIP_TARGETS:%cuda.cu=%hip.x)
@@ -22,6 +22,9 @@ qsim-cuda: $(CUDA_TARGETS)
2222
.PHONY: qsim-custatevec
2323
qsim-custatevec: $(CUSTATEVEC_TARGETS)
2424

25+
.PHONY: qsim-custatevecex
26+
qsim-custatevecex: $(CUSTATEVECEX_TARGETS)
27+
2528
.PHONY: qsim-hip
2629
qsim-hip: $(HIP_TARGETS)
2730

apps/qsim_base_custatevecex.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ Options GetOptions(int argc, char* argv[]) {
5454
case 's':
5555
opt.seed = std::atoi(optarg);
5656
break;
57-
break;
5857
case 'v':
5958
opt.verbosity = std::atoi(optarg);
6059
break;

lib/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ cuda_library(
561561
":util_cuda",
562562
":util_custatevec",
563563
":util_custatevecex",
564-
":vectorspace_custatevecex",
565564
],
566565
)
567566

lib/multiprocess_custatevecex.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ struct MultiProcessCuStateVecEx {
162162
ErrorCheck(custatevecExConfigureStateVectorMultiProcess(
163163
&sv_config, data_type, num_qubits, num_local_qubits, -1,
164164
memory_sharing_method_, global_index_bit_classes_.data(),
165-
(int32_t*) num_global_qubits_per_layer_.data(),
166-
(int32_t) global_index_bit_classes_.size(),
165+
reinterpret_cast<const int32_t*>(num_global_qubits_per_layer_.data()),
166+
static_cast<int32_t>(global_index_bit_classes_.size()),
167167
param_.transfer_buffer_size, nullptr, 0));
168168

169169
return sv_config;

lib/run_custatevecex.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ struct CuStateVecExRunner final {
260260
ErrorCheck(custatevecExSVUpdaterEnqueueMatrix(
261261
sv_updater, gate.matrix.data(), StateSpace::kMatrixDataType,
262262
StateSpace::kExMatrixType, StateSpace::kMatrixLayout, 0,
263-
(int32_t*) gate.qubits.data(), num_qubits, nullptr, nullptr, 0));
263+
reinterpret_cast<const int32_t*>(gate.qubits.data()),
264+
num_qubits, nullptr, nullptr, 0));
264265
}
265266
} else {
266267
std::vector<int32_t> control_bits;
@@ -273,9 +274,9 @@ struct CuStateVecExRunner final {
273274
ErrorCheck(custatevecExSVUpdaterEnqueueMatrix(
274275
sv_updater, gate.matrix.data(), StateSpace::kMatrixDataType,
275276
StateSpace::kExMatrixType, StateSpace::kMatrixLayout, 0,
276-
(int32_t*) gate.qubits.data(), num_qubits,
277-
(int32_t*) gate.controlled_by.data(), control_bits.data(),
278-
num_cqubits));
277+
reinterpret_cast<const int32_t*>(gate.qubits.data()), num_qubits,
278+
reinterpret_cast<const int32_t*>(gate.controlled_by.data()),
279+
control_bits.data(), num_cqubits));
279280
}
280281

281282
if (times_to_measure_at.size() > 0) {

lib/simulator_custatevec.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ class SimulatorCuStateVec final {
8282
ErrorCheck(custatevecApplyMatrix(
8383
custatevec_handle_, state.get(), kStateType,
8484
state.num_qubits(), matrix, kMatrixType, kMatrixLayout, 0,
85-
(int32_t*) qs.data(), qs.size(), nullptr, nullptr, 0,
86-
kComputeType, workspace_, workspace_size));
85+
reinterpret_cast<const int32_t*>(qs.data()), qs.size(),
86+
nullptr, nullptr, 0, kComputeType, workspace_,
87+
workspace_size));
8788
}
8889
}
8990

@@ -118,9 +119,10 @@ class SimulatorCuStateVec final {
118119
ErrorCheck(custatevecApplyMatrix(
119120
custatevec_handle_, state.get(), kStateType,
120121
state.num_qubits(), matrix, kMatrixType, kMatrixLayout, 0,
121-
(int32_t*) qs.data(), qs.size(),
122-
(int32_t*) cqs.data(), control_bits.data(), cqs.size(),
123-
kComputeType, workspace_, workspace_size));
122+
reinterpret_cast<const int32_t*>(qs.data()), qs.size(),
123+
reinterpret_cast<const int32_t*>(cqs.data()),
124+
control_bits.data(), cqs.size(), kComputeType,
125+
workspace_, workspace_size));
124126
}
125127
}
126128

@@ -144,9 +146,12 @@ class SimulatorCuStateVec final {
144146
ErrorCheck(custatevecComputeExpectation(
145147
custatevec_handle_, state.get(), kStateType,
146148
state.num_qubits(), &eval, kExpectType, nullptr, matrix,
147-
kMatrixType, kMatrixLayout, (int32_t*) qs.data(), qs.size(),
149+
kMatrixType, kMatrixLayout,
150+
reinterpret_cast<const int32_t*>(qs.data()), qs.size(),
148151
kComputeType, workspace_, workspace_size));
149152

153+
ErrorCheck(cudaDeviceSynchronize());
154+
150155
return {cuCreal(eval), cuCimag(eval)};
151156
}
152157

lib/simulator_custatevecex.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class SimulatorCuStateVecEx final {
7272

7373
ErrorCheck(custatevecExApplyMatrix(
7474
state.get(), matrix, kMatrixDataType, kExMatrixType, kMatrixLayout,
75-
0, (int32_t*) qs.data(), qs.size(), nullptr, nullptr, 0));
75+
0, reinterpret_cast<const int32_t*>(qs.data()), qs.size(),
76+
nullptr, nullptr, 0));
7677
}
7778
}
7879

@@ -112,8 +113,9 @@ class SimulatorCuStateVecEx final {
112113

113114
ErrorCheck(custatevecExApplyMatrix(
114115
state.get(), matrix, kMatrixDataType, kExMatrixType, kMatrixLayout,
115-
0, (int32_t*) qs.data(), qs.size(), (int32_t*) cqs.data(),
116-
control_bits.data(), cqs.size()));
116+
0, reinterpret_cast<const int32_t*>(qs.data()), qs.size(),
117+
reinterpret_cast<const int32_t*>(cqs.data()), control_bits.data(),
118+
cqs.size()));
117119
}
118120
}
119121

@@ -189,15 +191,17 @@ class SimulatorCuStateVecEx final {
189191

190192
if (l > 0) {
191193
ErrorCheck(custatevecExStateVectorPermuteIndexBits(
192-
state.get(), (int32_t*) perm.data(), num_qubits,
193-
CUSTATEVEC_EX_PERMUTATION_SCATTER));
194+
state.get(), reinterpret_cast<const int32_t*>(perm.data()),
195+
num_qubits, CUSTATEVEC_EX_PERMUTATION_SCATTER));
194196
}
195197

196198
auto f = [&matrix, &state, &num_local_qubits, &qs2](
197199
unsigned i, const auto& r) {
198200
void* workspace;
199201
size_t workspace_size;
200202

203+
ErrorCheck(cudaSetDevice(r.device_id));
204+
201205
ErrorCheck(custatevecComputeExpectationGetWorkspaceSize(
202206
r.custatevec_handle, kStateDataType, num_local_qubits, matrix,
203207
kMatrixDataType, kMatrixLayout, qs2.size(), kComputeType,
@@ -211,9 +215,11 @@ class SimulatorCuStateVecEx final {
211215
ErrorCheck(custatevecComputeExpectation(
212216
r.custatevec_handle, r.device_ptr, kStateDataType, num_local_qubits,
213217
&eval, kExpectDataType, nullptr, matrix, kMatrixDataType,
214-
kMatrixLayout, (int32_t*) qs2.data(), qs2.size(), kComputeType,
215-
workspace, workspace_size));
218+
kMatrixLayout, reinterpret_cast<const int32_t*>(qs2.data()),
219+
qs2.size(), kComputeType, workspace, workspace_size));
216220

221+
// TODO: make it faster.
222+
ErrorCheck(custatevecExStateVectorSynchronize(state.get()));
217223
ErrorCheck(cudaFree(workspace));
218224

219225
return std::complex<double>{cuCreal(eval), cuCimag(eval)};

lib/statespace_custatevec.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,10 @@ class StateSpaceCuStateVec :
306306

307307
ErrorCheck(custatevecBatchMeasure(
308308
custatevec_handle_, state.get(), kStateType,
309-
state.num_qubits(), (int*) result.bitstring.data(),
310-
(int*) qubits.data(), qubits.size(), r, collapse));
309+
state.num_qubits(),
310+
reinterpret_cast<int*>(result.bitstring.data()),
311+
reinterpret_cast<const int*>(qubits.data()), qubits.size(),
312+
r, collapse));
311313

312314
for (std::size_t i = 0; i < result.bitstring.size(); ++i) {
313315
result.bits |= result.bitstring[i] << qubits[i];

lib/statespace_custatevecex.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ class StateSpaceCuStateVecEx :
402402
custatevecIndex_t bits;
403403

404404
ErrorCheck(custatevecExMeasure(
405-
state.get(), &bits, (int32_t*) qubits.data(), qubits.size(),
406-
r, collapse, nullptr));
405+
state.get(), &bits, reinterpret_cast<const int32_t*>(qubits.data()),
406+
qubits.size(), r, collapse, nullptr));
407407
ErrorCheck(custatevecExStateVectorSynchronize(state.get()));
408408

409409
for (std::size_t i = 0; i < qubits.size(); ++i) {

lib/vectorspace_custatevecex.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ class VectorSpaceCuStateVecEx {
163163
const auto& get_wire_ordering() const {
164164
ErrorCheck(custatevecExStateVectorGetProperty(
165165
ptr_, CUSTATEVEC_EX_SV_PROP_WIRE_ORDERING,
166-
const_cast<int32_t*>(wire_ordering_.data()),
167-
sizeof(int32_t) * num_qubits_));
166+
wire_ordering_.data(), sizeof(int32_t) * num_qubits_));
168167

169168
return wire_ordering_;
170169
}
@@ -373,7 +372,7 @@ class VectorSpaceCuStateVecEx {
373372
private:
374373
const MultiProcessCuStateVecEx* mp_;
375374
custatevecExStateVectorDescriptor_t ptr_;
376-
std::vector<int32_t> wire_ordering_;
375+
mutable std::vector<int32_t> wire_ordering_;
377376
unsigned num_qubits_;
378377
unsigned num_substates_;
379378
DistributionType distr_type_;

0 commit comments

Comments
 (0)