@@ -72,7 +72,8 @@ class SimulatorCuStateVecEx final {
7272
7373 ErrorCheck (custatevecExApplyMatrix (
7474 state.get (), matrix, kMatrixDataType , kExMatrixType , kMatrixLayout ,
75- 0 , (int32_t *) qs.data (), qs.size (), nullptr , nullptr , 0 ));
75+ 0 , reinterpret_cast <const int32_t *>(qs.data ()), qs.size (),
76+ nullptr , nullptr , 0 ));
7677 }
7778 }
7879
@@ -112,8 +113,9 @@ class SimulatorCuStateVecEx final {
112113
113114 ErrorCheck (custatevecExApplyMatrix (
114115 state.get (), matrix, kMatrixDataType , kExMatrixType , kMatrixLayout ,
115- 0 , (int32_t *) qs.data (), qs.size (), (int32_t *) cqs.data (),
116- control_bits.data (), cqs.size ()));
116+ 0 , reinterpret_cast <const int32_t *>(qs.data ()), qs.size (),
117+ reinterpret_cast <const int32_t *>(cqs.data ()), control_bits.data (),
118+ cqs.size ()));
117119 }
118120 }
119121
@@ -189,15 +191,17 @@ class SimulatorCuStateVecEx final {
189191
190192 if (l > 0 ) {
191193 ErrorCheck (custatevecExStateVectorPermuteIndexBits (
192- state.get (), ( int32_t *) perm.data (), num_qubits ,
193- CUSTATEVEC_EX_PERMUTATION_SCATTER));
194+ state.get (), reinterpret_cast < const int32_t *>( perm.data ()) ,
195+ num_qubits, CUSTATEVEC_EX_PERMUTATION_SCATTER));
194196 }
195197
196198 auto f = [&matrix, &state, &num_local_qubits, &qs2](
197199 unsigned i, const auto & r) {
198200 void * workspace;
199201 size_t workspace_size;
200202
203+ ErrorCheck (cudaSetDevice (r.device_id ));
204+
201205 ErrorCheck (custatevecComputeExpectationGetWorkspaceSize (
202206 r.custatevec_handle , kStateDataType , num_local_qubits, matrix,
203207 kMatrixDataType , kMatrixLayout , qs2.size (), kComputeType ,
@@ -211,9 +215,11 @@ class SimulatorCuStateVecEx final {
211215 ErrorCheck (custatevecComputeExpectation (
212216 r.custatevec_handle , r.device_ptr , kStateDataType , num_local_qubits,
213217 &eval, kExpectDataType , nullptr , matrix, kMatrixDataType ,
214- kMatrixLayout , ( int32_t *) qs2.data (), qs2. size (), kComputeType ,
215- workspace, workspace_size));
218+ kMatrixLayout , reinterpret_cast < const int32_t *>( qs2.data ()) ,
219+ qs2. size (), kComputeType , workspace, workspace_size));
216220
221+ // TODO: make it faster.
222+ ErrorCheck (custatevecExStateVectorSynchronize (state.get ()));
217223 ErrorCheck (cudaFree (workspace));
218224
219225 return std::complex <double >{cuCreal (eval), cuCimag (eval)};
0 commit comments