Skip to content

Commit 2d04630

Browse files
committed
vec - add tests for ceed compatibility for AXPY and PointwiseMult
1 parent 0f7fd0f commit 2d04630

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

interface/ceed-vector.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,15 @@ int CeedVectorAXPY(CeedVector y, CeedScalar alpha, CeedVector x) {
561561
"Cannot use same vector for x and y in CeedVectorAXPY");
562562
// LCOV_EXCL_STOP
563563

564+
Ceed ceed_parent_x, ceed_parent_y;
565+
ierr = CeedGetParent(x->ceed, &ceed_parent_x); CeedChk(ierr);
566+
ierr = CeedGetParent(y->ceed, &ceed_parent_y); CeedChk(ierr);
567+
if (ceed_parent_x != ceed_parent_y)
568+
// LCOV_EXCL_START
569+
return CeedError(y->ceed, CEED_ERROR_INCOMPATIBLE,
570+
"Vectors x and y must be created by the same Ceed context");
571+
// LCOV_EXCL_STOP
572+
564573
// Backend implementation
565574
if (y->AXPY)
566575
return y->AXPY(y, alpha, x);
@@ -594,7 +603,7 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
594603
int ierr;
595604
CeedScalar *w_array;
596605
CeedScalar const *x_array, *y_array;
597-
CeedInt n_x, n_y, n_w;
606+
CeedInt n_w, n_x, n_y;
598607

599608
ierr = CeedVectorGetLength(w, &n_w); CeedChk(ierr);
600609
ierr = CeedVectorGetLength(x, &n_x); CeedChk(ierr);
@@ -605,6 +614,17 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
605614
"Cannot multiply vectors of different lengths");
606615
// LCOV_EXCL_STOP
607616

617+
Ceed ceed_parent_w, ceed_parent_x, ceed_parent_y;
618+
ierr = CeedGetParent(w->ceed, &ceed_parent_w); CeedChk(ierr);
619+
ierr = CeedGetParent(x->ceed, &ceed_parent_x); CeedChk(ierr);
620+
ierr = CeedGetParent(y->ceed, &ceed_parent_y); CeedChk(ierr);
621+
if ((ceed_parent_w != ceed_parent_y) ||
622+
(ceed_parent_w != ceed_parent_y))
623+
// LCOV_EXCL_START
624+
return CeedError(w->ceed, CEED_ERROR_INCOMPATIBLE,
625+
"Vectors w, x, and y must be created by the same Ceed context");
626+
// LCOV_EXCL_STOP
627+
608628
// Backend implementation
609629
if (w->PointwiseMult)
610630
return w->PointwiseMult(w, x, y);

0 commit comments

Comments
 (0)