Skip to content

Commit 942790e

Browse files
committed
gpu/cuda/ref: mixed field assembly support
1 parent 83e27e9 commit 942790e

3 files changed

Lines changed: 344 additions & 13 deletions

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,118 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op,
15041504
return CEED_ERROR_SUCCESS;
15051505
}
15061506

1507+
//------------------------------------------------------------------------------
1508+
// Single Operator Assembly Setup
1509+
//------------------------------------------------------------------------------
1510+
static int CeedOperatorAssembleSingleBlockSetup_Cuda(CeedOperator op, CeedInt active_input, CeedInt active_output, CeedInt use_ceedsize_idx) {
1511+
Ceed ceed;
1512+
Ceed_Cuda *cuda_data;
1513+
CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1514+
CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out;
1515+
CeedSize num_output_components;
1516+
CeedSize eval_mode_offset_in = 0, eval_mode_offset_out = 0;
1517+
const CeedScalar *h_B_in, *h_B_out;
1518+
CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
1519+
CeedBasis basis_in = NULL, basis_out = NULL;
1520+
CeedOperatorField *input_fields, *output_fields;
1521+
CeedOperator_Cuda *impl;
1522+
1523+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1524+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1525+
1526+
// Get intput and output fields
1527+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
1528+
1529+
{
1530+
CeedInt num_active_bases_in, *t_num_eval_modes_in, num_active_bases_out, *t_num_eval_modes_out;
1531+
CeedSize **eval_modes_offsets_in, **eval_modes_offsets_out;
1532+
CeedBasis *active_bases_in, *active_bases_out;
1533+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1534+
const CeedScalar **B_mats_in, **B_mats_out;
1535+
CeedOperatorAssemblyData data;
1536+
1537+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1538+
CeedCall(CeedOperatorAssemblyDataGetEvalModes(data, &num_active_bases_in, &t_num_eval_modes_in, NULL, &eval_modes_offsets_in,
1539+
&num_active_bases_out, &t_num_eval_modes_out, NULL, &eval_modes_offsets_out,
1540+
&num_output_components));
1541+
// Number of elem restrictions is the same as the number of bases
1542+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1543+
CeedCall(CeedOperatorAssemblyDataGetBases(data, NULL, &active_bases_in, &B_mats_in, NULL, &active_bases_out, &B_mats_out));
1544+
1545+
num_eval_modes_in = t_num_eval_modes_in[active_input];
1546+
num_eval_modes_out = t_num_eval_modes_out[active_output];
1547+
CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs");
1548+
1549+
if (!impl->asmb_blocks) {
1550+
CeedCallBackend(CeedCalloc(num_active_bases_in * num_active_bases_out, &impl->asmb_blocks));
1551+
impl->num_blocks_out = num_active_bases_out;
1552+
}
1553+
1554+
rstr_in = active_rstrs_in[active_input];
1555+
basis_in = active_bases_in[active_input];
1556+
eval_mode_offset_in = eval_modes_offsets_in[active_input][0];
1557+
h_B_in = B_mats_in[active_input];
1558+
rstr_out = active_rstrs_out[active_output];
1559+
basis_out = active_bases_out[active_output];
1560+
eval_mode_offset_out = eval_modes_offsets_out[active_output][0];
1561+
h_B_out = B_mats_out[active_output];
1562+
}
1563+
1564+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1565+
if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in;
1566+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in));
1567+
1568+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1569+
if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out;
1570+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out));
1571+
CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED,
1572+
"Active input and output bases must have the same number of quadrature points");
1573+
1574+
CeedCallBackend(CeedCalloc(1, &impl->asmb_blocks[active_input * impl->num_blocks_out + active_output]));
1575+
CeedOperatorAssemble_Cuda *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1576+
asmb->elems_per_block = 1;
1577+
asmb->block_size_x = elem_size_in;
1578+
asmb->block_size_y = elem_size_out;
1579+
1580+
CeedCallBackend(CeedGetData(ceed, &cuda_data));
1581+
bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > cuda_data->device_prop.maxThreadsPerBlock;
1582+
1583+
if (fallback) {
1584+
// Use fallback kernel with 1D threadblock
1585+
asmb->block_size_y = 1;
1586+
}
1587+
1588+
// Compile kernels
1589+
const char assembly_kernel_source[] = "// Full assembly source\n#include <ceed/jit-source/cuda/cuda-ref-operator-assemble-block.h>\n";
1590+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in));
1591+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out));
1592+
CeedCallBackend(CeedCompile_Cuda(ceed, assembly_kernel_source, &asmb->module, 13, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT",
1593+
num_eval_modes_out, "EVAL_MODE_OFFSET_IN", eval_mode_offset_in, "EVAL_MODE_OFFSET_OUT", eval_mode_offset_out,
1594+
"NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "TOTAL_NUM_COMP_OUT", num_output_components,
1595+
"NUM_NODES_IN", elem_size_in, "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE",
1596+
asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y,
1597+
"USE_CEEDSIZE", use_ceedsize_idx));
1598+
CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, "LinearAssembleBlock", &asmb->LinearAssemble));
1599+
1600+
// Load into B_in, in order that they will be used in eval_modes_in
1601+
{
1602+
const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar);
1603+
1604+
CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, in_bytes));
1605+
CeedCallCuda(ceed, cudaMemcpy(asmb->d_B_in, h_B_in, in_bytes, cudaMemcpyHostToDevice));
1606+
}
1607+
1608+
// Load into B_out, in order that they will be used in eval_modes_out
1609+
{
1610+
const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar);
1611+
1612+
CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, out_bytes));
1613+
CeedCallCuda(ceed, cudaMemcpy(asmb->d_B_out, h_B_out, out_bytes, cudaMemcpyHostToDevice));
1614+
}
1615+
CeedCallBackend(CeedDestroy(&ceed));
1616+
return CEED_ERROR_SUCCESS;
1617+
}
1618+
15071619
//------------------------------------------------------------------------------
15081620
// Single Operator Assembly Setup
15091621
//------------------------------------------------------------------------------
@@ -1705,6 +1817,117 @@ static int CeedOperatorAssembleSingleSetup_Cuda(CeedOperator op, CeedInt use_cee
17051817
return CEED_ERROR_SUCCESS;
17061818
}
17071819

1820+
//------------------------------------------------------------------------------
1821+
// Assemble matrix data for one block of a COO matrix of assembled operator.
1822+
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1823+
//------------------------------------------------------------------------------
1824+
static int CeedOperatorAssembleSingleBlock_Cuda(CeedOperator op, CeedInt offset, CeedInt active_input, CeedInt active_output, CeedVector values) {
1825+
Ceed ceed;
1826+
CeedSize values_length = 0, assembled_qf_length = 0;
1827+
CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out;
1828+
CeedScalar *values_array;
1829+
const CeedScalar *assembled_qf_array;
1830+
CeedVector assembled_qf = NULL;
1831+
CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out;
1832+
CeedRestrictionType rstr_type_in, rstr_type_out;
1833+
const bool *orients_in = NULL, *orients_out = NULL;
1834+
const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL;
1835+
CeedOperator_Cuda *impl;
1836+
1837+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1838+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1839+
1840+
// Assemble QFunction
1841+
CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE));
1842+
CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
1843+
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
1844+
1845+
CeedCallBackend(CeedVectorGetLength(values, &values_length));
1846+
CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1847+
if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1848+
1849+
// Setup
1850+
if (!impl->asmb_blocks || (impl->asmb_blocks && !impl->asmb_blocks[active_input * impl->num_blocks_out + active_output])) {
1851+
CeedCallBackend(CeedOperatorAssembleSingleBlockSetup_Cuda(op, active_input, active_output, use_ceedsize_idx));
1852+
}
1853+
CeedOperatorAssemble_Cuda *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1854+
1855+
assert(asmb != NULL);
1856+
1857+
// Assemble element operator
1858+
CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array));
1859+
values_array += offset;
1860+
1861+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1862+
CeedOperatorAssemblyData data;
1863+
1864+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1865+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1866+
1867+
rstr_in = active_rstrs_in[active_input];
1868+
rstr_out = active_rstrs_out[active_output];
1869+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in));
1870+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1871+
1872+
CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in));
1873+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1874+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in));
1875+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1876+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in));
1877+
}
1878+
1879+
if (rstr_in != rstr_out) {
1880+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out));
1881+
CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED,
1882+
"Active input and output operator restrictions must have the same number of elements");
1883+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1884+
1885+
CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out));
1886+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1887+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out));
1888+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1889+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out));
1890+
}
1891+
} else {
1892+
elem_size_out = elem_size_in;
1893+
orients_out = orients_in;
1894+
curl_orients_out = curl_orients_in;
1895+
}
1896+
1897+
// Compute B^T D B
1898+
CeedInt shared_mem =
1899+
((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) *
1900+
sizeof(CeedScalar);
1901+
CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block);
1902+
void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in,
1903+
&orients_out, &curl_orients_out, &assembled_qf_array, &values_array};
1904+
1905+
CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, asmb->LinearAssemble, NULL, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block,
1906+
shared_mem, args));
1907+
CeedCallCuda(ceed, cudaDeviceSynchronize());
1908+
1909+
// Restore arrays
1910+
CeedCallBackend(CeedVectorRestoreArray(values, &values_array));
1911+
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
1912+
1913+
// Cleanup
1914+
CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1915+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1916+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in));
1917+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1918+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in));
1919+
}
1920+
if (rstr_in != rstr_out) {
1921+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1922+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out));
1923+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1924+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out));
1925+
}
1926+
}
1927+
CeedCallBackend(CeedDestroy(&ceed));
1928+
return CEED_ERROR_SUCCESS;
1929+
}
1930+
17081931
//------------------------------------------------------------------------------
17091932
// Assemble matrix data for COO matrix of assembled operator.
17101933
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2090,6 +2313,7 @@ int CeedOperatorCreate_Cuda(CeedOperator op) {
20902313
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal",
20912314
CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda));
20922315
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingle_Cuda));
2316+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingleBlock", CeedOperatorAssembleSingleBlock_Cuda));
20932317
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda));
20942318
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda));
20952319
CeedCallBackend(CeedDestroy(&ceed));

backends/cuda-ref/ceed-cuda-ref.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,21 @@ typedef struct {
132132
} CeedOperatorAssemble_Cuda;
133133

134134
typedef struct {
135-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136-
uint64_t *input_states, points_state; // State tracking for passive inputs
137-
CeedVector *e_vecs_in, *e_vecs_out;
138-
CeedVector *q_vecs_in, *q_vecs_out;
139-
CeedInt num_inputs, num_outputs;
140-
CeedInt num_active_in, num_active_out;
141-
CeedInt *input_field_order, *output_field_order;
142-
CeedSize max_active_e_vec_len;
143-
CeedInt max_num_points;
144-
CeedInt *num_points;
145-
CeedVector *qf_active_in, point_coords_elem;
146-
CeedOperatorDiag_Cuda *diag;
147-
CeedOperatorAssemble_Cuda *asmb;
135+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136+
uint64_t *input_states, points_state; // State tracking for passive inputs
137+
CeedVector *e_vecs_in, *e_vecs_out;
138+
CeedVector *q_vecs_in, *q_vecs_out;
139+
CeedInt num_inputs, num_outputs;
140+
CeedInt num_active_in, num_active_out;
141+
CeedInt *input_field_order, *output_field_order;
142+
CeedSize max_active_e_vec_len;
143+
CeedInt max_num_points;
144+
CeedInt *num_points;
145+
CeedVector *qf_active_in, point_coords_elem;
146+
CeedOperatorDiag_Cuda *diag;
147+
CeedOperatorAssemble_Cuda *asmb;
148+
CeedOperatorAssemble_Cuda **asmb_blocks;
149+
CeedInt num_blocks_out;
148150
} CeedOperator_Cuda;
149151

150152
CEED_INTERN int CeedGetCublasHandle_Cuda(Ceed ceed, cublasHandle_t *handle);

0 commit comments

Comments
 (0)