Skip to content

Commit 0f7fd0f

Browse files
committed
vec - add PointwiseMult and AXPY convenience functions
1 parent 14e5f01 commit 0f7fd0f

6 files changed

Lines changed: 222 additions & 1 deletion

File tree

include/ceed-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ struct CeedVector_private {
138138
int (*RestoreArray)(CeedVector);
139139
int (*RestoreArrayRead)(CeedVector);
140140
int (*Norm)(CeedVector, CeedNormType, CeedScalar *);
141+
int (*AXPY)(CeedVector, CeedScalar, CeedVector);
142+
int (*PointwiseMult)(CeedVector, CeedVector, CeedVector);
141143
int (*Reciprocal)(CeedVector);
142144
int (*Destroy)(CeedVector);
143145
int ref_count;

include/ceed/ceed.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ CEED_EXTERN int CeedGetVersion(int *major, int *minor, int *patch,
236236
/// data is likely modified or corrupted.
237237
/// @ingroup Ceed
238238
typedef enum {
239-
/// Sucess error code
239+
/// Success error code
240240
CEED_ERROR_SUCCESS = 0,
241241
/// Minor error, generic
242242
CEED_ERROR_MINOR = 1,
@@ -321,6 +321,8 @@ CEED_EXTERN int CeedVectorRestoreArrayRead(CeedVector vec,
321321
const CeedScalar **array);
322322
CEED_EXTERN int CeedVectorNorm(CeedVector vec, CeedNormType type,
323323
CeedScalar *norm);
324+
CEED_EXTERN int CeedVectorAXPY(CeedVector y, CeedScalar alpha, CeedVector x);
325+
CEED_EXTERN int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y);
324326
CEED_EXTERN int CeedVectorReciprocal(CeedVector vec);
325327
CEED_EXTERN int CeedVectorView(CeedVector vec, const char *fp_fmt, FILE *stream);
326328
CEED_EXTERN int CeedVectorGetLength(CeedVector vec, CeedInt *length);

interface/ceed-vector.c

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,112 @@ int CeedVectorNorm(CeedVector vec, CeedNormType norm_type, CeedScalar *norm) {
531531
return CEED_ERROR_SUCCESS;
532532
}
533533

534+
/**
535+
@brief Compute y = alpha x + y
536+
537+
@param y[in,out] target vector for sum
538+
@param alpha[in] scaling factor
539+
@param x[in] second vector, must be different than y
540+
541+
@return An error code: 0 - success, otherwise - failure
542+
543+
@ref User
544+
**/
545+
int CeedVectorAXPY(CeedVector y, CeedScalar alpha, CeedVector x) {
546+
int ierr;
547+
CeedScalar *y_array;
548+
CeedScalar const *x_array;
549+
CeedInt n_x, n_y;
550+
551+
ierr = CeedVectorGetLength(y, &n_y); CeedChk(ierr);
552+
ierr = CeedVectorGetLength(x, &n_x); CeedChk(ierr);
553+
if (n_x != n_y)
554+
// LCOV_EXCL_START
555+
return CeedError(y->ceed, CEED_ERROR_UNSUPPORTED,
556+
"Cannot add vector of different lengths");
557+
// LCOV_EXCL_STOP
558+
if (x == y)
559+
// LCOV_EXCL_START
560+
return CeedError(y->ceed, CEED_ERROR_UNSUPPORTED,
561+
"Cannot use same vector for x and y in CeedVectorAXPY");
562+
// LCOV_EXCL_STOP
563+
564+
// Backend implementation
565+
if (y->AXPY)
566+
return y->AXPY(y, alpha, x);
567+
568+
// Default implementation
569+
ierr = CeedVectorGetArray(y, CEED_MEM_HOST, &y_array); CeedChk(ierr);
570+
ierr = CeedVectorGetArrayRead(x, CEED_MEM_HOST, &x_array); CeedChk(ierr);
571+
572+
for (CeedInt i=0; i<n_y; i++)
573+
y_array[i] += alpha * x_array[i];
574+
575+
ierr = CeedVectorRestoreArray(y, &y_array); CeedChk(ierr);
576+
ierr = CeedVectorRestoreArrayRead(x, &x_array); CeedChk(ierr);
577+
578+
return CEED_ERROR_SUCCESS;
579+
}
580+
581+
/**
582+
@brief Compute the pointwise multiplication w = x * y. Any
583+
subset of x, y, and w may be the same vector.
584+
585+
@param w[out] target vector for the product
586+
@param x[in] first vector for product
587+
@param y[in] second vector for the product
588+
589+
@return An error code: 0 - success, otherwise - failure
590+
591+
@ ref User
592+
**/
593+
int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
594+
int ierr;
595+
CeedScalar *w_array;
596+
CeedScalar const *x_array, *y_array;
597+
CeedInt n_x, n_y, n_w;
598+
599+
ierr = CeedVectorGetLength(w, &n_w); CeedChk(ierr);
600+
ierr = CeedVectorGetLength(x, &n_x); CeedChk(ierr);
601+
ierr = CeedVectorGetLength(y, &n_y); CeedChk(ierr);
602+
if (n_w != n_x || n_w != n_y)
603+
// LCOV_EXCL_START
604+
return CeedError(w->ceed, CEED_ERROR_UNSUPPORTED,
605+
"Cannot multiply vectors of different lengths");
606+
// LCOV_EXCL_STOP
607+
608+
// Backend implementation
609+
if (w->PointwiseMult)
610+
return w->PointwiseMult(w, x, y);
611+
612+
// Default implementation
613+
ierr = CeedVectorGetArray(w, CEED_MEM_HOST, &w_array); CeedChk(ierr);
614+
if (x != w) {
615+
ierr = CeedVectorGetArrayRead(x, CEED_MEM_HOST, &x_array); CeedChk(ierr);
616+
} else {
617+
x_array = w_array;
618+
}
619+
if (y != w && y != x) {
620+
ierr = CeedVectorGetArrayRead(y, CEED_MEM_HOST, &y_array); CeedChk(ierr);
621+
} else if (y != x) {
622+
y_array = w_array;
623+
} else {
624+
y_array = x_array;
625+
}
626+
627+
for (CeedInt i=0; i<n_w; i++)
628+
w_array[i] = x_array[i] * y_array[i];
629+
630+
if (y != w && y != x) {
631+
ierr = CeedVectorRestoreArrayRead(y, &y_array); CeedChk(ierr);
632+
}
633+
if (x != w) {
634+
ierr = CeedVectorRestoreArrayRead(x, &x_array); CeedChk(ierr);
635+
}
636+
ierr = CeedVectorRestoreArray(w, &w_array); CeedChk(ierr);
637+
return CEED_ERROR_SUCCESS;
638+
}
639+
534640
/**
535641
@brief Take the reciprocal of a CeedVector.
536642

interface/ceed.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,8 @@ int CeedInit(const char *resource, Ceed *ceed) {
770770
CEED_FTABLE_ENTRY(CeedVector, RestoreArray),
771771
CEED_FTABLE_ENTRY(CeedVector, RestoreArrayRead),
772772
CEED_FTABLE_ENTRY(CeedVector, Norm),
773+
CEED_FTABLE_ENTRY(CeedVector, AXPY),
774+
CEED_FTABLE_ENTRY(CeedVector, PointwiseMult),
773775
CEED_FTABLE_ENTRY(CeedVector, Reciprocal),
774776
CEED_FTABLE_ENTRY(CeedVector, Destroy),
775777
CEED_FTABLE_ENTRY(CeedElemRestriction, Apply),

tests/t121-vector.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/// @file
2+
/// Test summing of a pair of vectors
3+
/// \test Test summing of a pair of vectors
4+
#include <ceed.h>
5+
#include <math.h>
6+
7+
int main(int argc, char **argv) {
8+
Ceed ceed;
9+
CeedVector x, y;
10+
CeedInt n;
11+
CeedScalar a[10];
12+
const CeedScalar *b;
13+
14+
CeedInit(argv[1], &ceed);
15+
16+
n = 10;
17+
CeedVectorCreate(ceed, n, &x);
18+
CeedVectorCreate(ceed, n, &y);
19+
for (CeedInt i=0; i<n; i++)
20+
a[i] = 10 + i;
21+
CeedVectorSetArray(x, CEED_MEM_HOST, CEED_COPY_VALUES, a);
22+
CeedVectorSetArray(y, CEED_MEM_HOST, CEED_COPY_VALUES, a);
23+
24+
CeedVectorAXPY(y, -0.5, x);
25+
26+
CeedVectorGetArrayRead(y, CEED_MEM_HOST, &b);
27+
for (CeedInt i=0; i<n; i++)
28+
if (fabs(b[i] - (10.0 + i)/2 ) > 1e-14)
29+
// LCOV_EXCL_START
30+
printf("Error in alpha x + y, computed: %f actual: %f\n", b[i],
31+
(10.0 + i)/2);
32+
// LCOV_EXCL_STOP
33+
CeedVectorRestoreArrayRead(y, &b);
34+
35+
CeedVectorDestroy(&x);
36+
CeedVectorDestroy(&y);
37+
CeedDestroy(&ceed);
38+
return 0;
39+
}

tests/t122-vector.c

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/// @file
2+
/// Test poinwise muliplication of a pair of vectors
3+
/// \test Test poinwise muliplication of a pair of vectors
4+
#include <ceed.h>
5+
#include <math.h>
6+
7+
int main(int argc, char **argv) {
8+
Ceed ceed;
9+
CeedVector x, y, w;
10+
CeedInt n;
11+
CeedScalar a[10];
12+
const CeedScalar *b;
13+
14+
CeedInit(argv[1], &ceed);
15+
16+
n = 10;
17+
CeedVectorCreate(ceed, n, &x);
18+
CeedVectorCreate(ceed, n, &y);
19+
CeedVectorCreate(ceed, n, &w);
20+
for (CeedInt i=0; i<n; i++)
21+
a[i] = i;
22+
CeedVectorSetArray(x, CEED_MEM_HOST, CEED_COPY_VALUES, a);
23+
CeedVectorSetArray(y, CEED_MEM_HOST, CEED_COPY_VALUES, a);
24+
25+
// Test multiplying two vectors into third
26+
CeedVectorPointwiseMult(w, x, y);
27+
CeedVectorGetArrayRead(w, CEED_MEM_HOST, &b);
28+
for (CeedInt i=0; i<n; i++)
29+
if (fabs(b[i] - i*i ) > 1e-14)
30+
// LCOV_EXCL_START
31+
printf("Error in alpha x + y, computed: %f actual: %f\n", b[i], 1.0*i*i);
32+
// LCOV_EXCL_STOP
33+
CeedVectorRestoreArrayRead(w, &b);
34+
35+
// Test multiplying two vectors into one of the two
36+
CeedVectorPointwiseMult(w, w, y);
37+
CeedVectorGetArrayRead(w, CEED_MEM_HOST, &b);
38+
for (CeedInt i=0; i<n; i++)
39+
if (fabs(b[i] - i*i*i ) > 1e-14)
40+
// LCOV_EXCL_START
41+
printf("Error in alpha x + y, computed: %f actual: %f\n", b[i], 1.0*i*i*i);
42+
// LCOV_EXCL_STOP
43+
CeedVectorRestoreArrayRead(w, &b);
44+
45+
// Test multiplying two vectors into one of the two
46+
CeedVectorPointwiseMult(w, x, w);
47+
CeedVectorGetArrayRead(w, CEED_MEM_HOST, &b);
48+
for (CeedInt i=0; i<n; i++)
49+
if (fabs(b[i] - i*i*i*i ) > 1e-14)
50+
// LCOV_EXCL_START
51+
printf("Error in alpha x + y, computed: %f actual: %f\n", b[i], 1.0*i*i*i*i);
52+
// LCOV_EXCL_STOP
53+
CeedVectorRestoreArrayRead(w, &b);
54+
55+
// Test multiplying vector by itself and putting product into self
56+
CeedVectorPointwiseMult(y, y, y);
57+
CeedVectorGetArrayRead(y, CEED_MEM_HOST, &b);
58+
for (CeedInt i=0; i<n; i++)
59+
if (fabs(b[i] - i*i ) > 1e-14)
60+
// LCOV_EXCL_START
61+
printf("Error in alpha x + y, computed: %f actual: %f\n", b[i], 1.0*i*i);
62+
// LCOV_EXCL_STOP
63+
CeedVectorRestoreArrayRead(y, &b);
64+
65+
CeedVectorDestroy(&x);
66+
CeedVectorDestroy(&y);
67+
CeedVectorDestroy(&w);
68+
CeedDestroy(&ceed);
69+
return 0;
70+
}

0 commit comments

Comments
 (0)