@@ -1504,6 +1504,118 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op,
15041504 return CEED_ERROR_SUCCESS ;
15051505}
15061506
1507+ //------------------------------------------------------------------------------
1508+ // Single Operator Assembly Setup
1509+ //------------------------------------------------------------------------------
1510+ static int CeedOperatorAssembleSingleBlockSetup_Cuda (CeedOperator op , CeedInt active_input , CeedInt active_output , CeedInt use_ceedsize_idx ) {
1511+ Ceed ceed ;
1512+ Ceed_Cuda * cuda_data ;
1513+ CeedInt num_input_fields , num_output_fields , num_eval_modes_in = 0 , num_eval_modes_out = 0 ;
1514+ CeedInt elem_size_in , num_qpts_in = 0 , num_comp_in , elem_size_out , num_qpts_out , num_comp_out ;
1515+ CeedSize num_output_components ;
1516+ CeedSize eval_mode_offset_in = 0 , eval_mode_offset_out = 0 ;
1517+ const CeedScalar * h_B_in , * h_B_out ;
1518+ CeedElemRestriction rstr_in = NULL , rstr_out = NULL ;
1519+ CeedBasis basis_in = NULL , basis_out = NULL ;
1520+ CeedOperatorField * input_fields , * output_fields ;
1521+ CeedOperator_Cuda * impl ;
1522+
1523+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1524+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1525+
1526+ // Get intput and output fields
1527+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & input_fields , & num_output_fields , & output_fields ));
1528+
1529+ {
1530+ CeedInt num_active_bases_in , * t_num_eval_modes_in , num_active_bases_out , * t_num_eval_modes_out ;
1531+ CeedSize * * eval_modes_offsets_in , * * eval_modes_offsets_out ;
1532+ CeedBasis * active_bases_in , * active_bases_out ;
1533+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1534+ const CeedScalar * * B_mats_in , * * B_mats_out ;
1535+ CeedOperatorAssemblyData data ;
1536+
1537+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1538+ CeedCall (CeedOperatorAssemblyDataGetEvalModes (data , & num_active_bases_in , & t_num_eval_modes_in , NULL , & eval_modes_offsets_in ,
1539+ & num_active_bases_out , & t_num_eval_modes_out , NULL , & eval_modes_offsets_out ,
1540+ & num_output_components ));
1541+ // Number of elem restrictions is the same as the number of bases
1542+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1543+ CeedCall (CeedOperatorAssemblyDataGetBases (data , NULL , & active_bases_in , & B_mats_in , NULL , & active_bases_out , & B_mats_out ));
1544+
1545+ num_eval_modes_in = t_num_eval_modes_in [active_input ];
1546+ num_eval_modes_out = t_num_eval_modes_out [active_output ];
1547+ CeedCheck (num_eval_modes_in > 0 && num_eval_modes_out > 0 , ceed , CEED_ERROR_UNSUPPORTED , "Cannot assemble operator without inputs/outputs" );
1548+
1549+ if (!impl -> asmb_blocks ) {
1550+ CeedCallBackend (CeedCalloc (num_active_bases_in * num_active_bases_out , & impl -> asmb_blocks ));
1551+ impl -> num_blocks_out = num_active_bases_out ;
1552+ }
1553+
1554+ rstr_in = active_rstrs_in [active_input ];
1555+ basis_in = active_bases_in [active_input ];
1556+ eval_mode_offset_in = eval_modes_offsets_in [active_input ][0 ];
1557+ h_B_in = B_mats_in [active_input ];
1558+ rstr_out = active_rstrs_out [active_output ];
1559+ basis_out = active_bases_out [active_output ];
1560+ eval_mode_offset_out = eval_modes_offsets_out [active_output ][0 ];
1561+ h_B_out = B_mats_out [active_output ];
1562+ }
1563+
1564+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1565+ if (basis_in == CEED_BASIS_NONE ) num_qpts_in = elem_size_in ;
1566+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_in , & num_qpts_in ));
1567+
1568+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1569+ if (basis_out == CEED_BASIS_NONE ) num_qpts_out = elem_size_out ;
1570+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_out , & num_qpts_out ));
1571+ CeedCheck (num_qpts_in == num_qpts_out , ceed , CEED_ERROR_UNSUPPORTED ,
1572+ "Active input and output bases must have the same number of quadrature points" );
1573+
1574+ CeedCallBackend (CeedCalloc (1 , & impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ]));
1575+ CeedOperatorAssemble_Cuda * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1576+ asmb -> elems_per_block = 1 ;
1577+ asmb -> block_size_x = elem_size_in ;
1578+ asmb -> block_size_y = elem_size_out ;
1579+
1580+ CeedCallBackend (CeedGetData (ceed , & cuda_data ));
1581+ bool fallback = asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block > cuda_data -> device_prop .maxThreadsPerBlock ;
1582+
1583+ if (fallback ) {
1584+ // Use fallback kernel with 1D threadblock
1585+ asmb -> block_size_y = 1 ;
1586+ }
1587+
1588+ // Compile kernels
1589+ const char assembly_kernel_source [] = "// Full assembly source\n#include <ceed/jit-source/cuda/cuda-ref-operator-assemble-block.h>\n" ;
1590+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_in , & num_comp_in ));
1591+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_out , & num_comp_out ));
1592+ CeedCallBackend (CeedCompile_Cuda (ceed , assembly_kernel_source , & asmb -> module , 13 , "NUM_EVAL_MODES_IN" , num_eval_modes_in , "NUM_EVAL_MODES_OUT" ,
1593+ num_eval_modes_out , "EVAL_MODE_OFFSET_IN" , eval_mode_offset_in , "EVAL_MODE_OFFSET_OUT" , eval_mode_offset_out ,
1594+ "NUM_COMP_IN" , num_comp_in , "NUM_COMP_OUT" , num_comp_out , "TOTAL_NUM_COMP_OUT" , num_output_components ,
1595+ "NUM_NODES_IN" , elem_size_in , "NUM_NODES_OUT" , elem_size_out , "NUM_QPTS" , num_qpts_in , "BLOCK_SIZE" ,
1596+ asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block , "BLOCK_SIZE_Y" , asmb -> block_size_y ,
1597+ "USE_CEEDSIZE" , use_ceedsize_idx ));
1598+ CeedCallBackend (CeedGetKernel_Cuda (ceed , asmb -> module , "LinearAssembleBlock" , & asmb -> LinearAssemble ));
1599+
1600+ // Load into B_in, in order that they will be used in eval_modes_in
1601+ {
1602+ const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof (CeedScalar );
1603+
1604+ CeedCallCuda (ceed , cudaMalloc ((void * * )& asmb -> d_B_in , in_bytes ));
1605+ CeedCallCuda (ceed , cudaMemcpy (asmb -> d_B_in , h_B_in , in_bytes , cudaMemcpyHostToDevice ));
1606+ }
1607+
1608+ // Load into B_out, in order that they will be used in eval_modes_out
1609+ {
1610+ const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof (CeedScalar );
1611+
1612+ CeedCallCuda (ceed , cudaMalloc ((void * * )& asmb -> d_B_out , out_bytes ));
1613+ CeedCallCuda (ceed , cudaMemcpy (asmb -> d_B_out , h_B_out , out_bytes , cudaMemcpyHostToDevice ));
1614+ }
1615+ CeedCallBackend (CeedDestroy (& ceed ));
1616+ return CEED_ERROR_SUCCESS ;
1617+ }
1618+
15071619//------------------------------------------------------------------------------
15081620// Single Operator Assembly Setup
15091621//------------------------------------------------------------------------------
@@ -1705,6 +1817,117 @@ static int CeedOperatorAssembleSingleSetup_Cuda(CeedOperator op, CeedInt use_cee
17051817 return CEED_ERROR_SUCCESS ;
17061818}
17071819
1820+ //------------------------------------------------------------------------------
1821+ // Assemble matrix data for one block of a COO matrix of assembled operator.
1822+ // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1823+ //------------------------------------------------------------------------------
1824+ static int CeedOperatorAssembleSingleBlock_Cuda (CeedOperator op , CeedInt offset , CeedInt active_input , CeedInt active_output , CeedVector values ) {
1825+ Ceed ceed ;
1826+ CeedSize values_length = 0 , assembled_qf_length = 0 ;
1827+ CeedInt use_ceedsize_idx = 0 , num_elem_in , num_elem_out , elem_size_in , elem_size_out ;
1828+ CeedScalar * values_array ;
1829+ const CeedScalar * assembled_qf_array ;
1830+ CeedVector assembled_qf = NULL ;
1831+ CeedElemRestriction assembled_rstr = NULL , rstr_in , rstr_out ;
1832+ CeedRestrictionType rstr_type_in , rstr_type_out ;
1833+ const bool * orients_in = NULL , * orients_out = NULL ;
1834+ const CeedInt8 * curl_orients_in = NULL , * curl_orients_out = NULL ;
1835+ CeedOperator_Cuda * impl ;
1836+
1837+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1838+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1839+
1840+ // Assemble QFunction
1841+ CeedCallBackend (CeedOperatorLinearAssembleQFunctionBuildOrUpdate (op , & assembled_qf , & assembled_rstr , CEED_REQUEST_IMMEDIATE ));
1842+ CeedCallBackend (CeedElemRestrictionDestroy (& assembled_rstr ));
1843+ CeedCallBackend (CeedVectorGetArrayRead (assembled_qf , CEED_MEM_DEVICE , & assembled_qf_array ));
1844+
1845+ CeedCallBackend (CeedVectorGetLength (values , & values_length ));
1846+ CeedCallBackend (CeedVectorGetLength (assembled_qf , & assembled_qf_length ));
1847+ if ((values_length > INT_MAX ) || (assembled_qf_length > INT_MAX )) use_ceedsize_idx = 1 ;
1848+
1849+ // Setup
1850+ if (!impl -> asmb_blocks || (impl -> asmb_blocks && !impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ])) {
1851+ CeedCallBackend (CeedOperatorAssembleSingleBlockSetup_Cuda (op , active_input , active_output , use_ceedsize_idx ));
1852+ }
1853+ CeedOperatorAssemble_Cuda * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1854+
1855+ assert (asmb != NULL );
1856+
1857+ // Assemble element operator
1858+ CeedCallBackend (CeedVectorGetArray (values , CEED_MEM_DEVICE , & values_array ));
1859+ values_array += offset ;
1860+
1861+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1862+ CeedOperatorAssemblyData data ;
1863+
1864+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1865+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1866+
1867+ rstr_in = active_rstrs_in [active_input ];
1868+ rstr_out = active_rstrs_out [active_output ];
1869+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_in , & num_elem_in ));
1870+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1871+
1872+ CeedCallBackend (CeedElemRestrictionGetType (rstr_in , & rstr_type_in ));
1873+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1874+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_in , CEED_MEM_DEVICE , & orients_in ));
1875+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1876+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_in , CEED_MEM_DEVICE , & curl_orients_in ));
1877+ }
1878+
1879+ if (rstr_in != rstr_out ) {
1880+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_out , & num_elem_out ));
1881+ CeedCheck (num_elem_in == num_elem_out , ceed , CEED_ERROR_UNSUPPORTED ,
1882+ "Active input and output operator restrictions must have the same number of elements" );
1883+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1884+
1885+ CeedCallBackend (CeedElemRestrictionGetType (rstr_out , & rstr_type_out ));
1886+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1887+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_out , CEED_MEM_DEVICE , & orients_out ));
1888+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1889+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_out , CEED_MEM_DEVICE , & curl_orients_out ));
1890+ }
1891+ } else {
1892+ elem_size_out = elem_size_in ;
1893+ orients_out = orients_in ;
1894+ curl_orients_out = curl_orients_in ;
1895+ }
1896+
1897+ // Compute B^T D B
1898+ CeedInt shared_mem =
1899+ ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0 ) + (curl_orients_in ? elem_size_in * asmb -> block_size_y : 0 )) *
1900+ sizeof (CeedScalar );
1901+ CeedInt grid = CeedDivUpInt (num_elem_in , asmb -> elems_per_block );
1902+ void * args [] = {(void * )& num_elem_in , & asmb -> d_B_in , & asmb -> d_B_out , & orients_in , & curl_orients_in ,
1903+ & orients_out , & curl_orients_out , & assembled_qf_array , & values_array };
1904+
1905+ CeedCallBackend (CeedRunKernelDimShared_Cuda (ceed , asmb -> LinearAssemble , NULL , grid , asmb -> block_size_x , asmb -> block_size_y , asmb -> elems_per_block ,
1906+ shared_mem , args ));
1907+ CeedCallCuda (ceed , cudaDeviceSynchronize ());
1908+
1909+ // Restore arrays
1910+ CeedCallBackend (CeedVectorRestoreArray (values , & values_array ));
1911+ CeedCallBackend (CeedVectorRestoreArrayRead (assembled_qf , & assembled_qf_array ));
1912+
1913+ // Cleanup
1914+ CeedCallBackend (CeedVectorDestroy (& assembled_qf ));
1915+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1916+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_in , & orients_in ));
1917+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1918+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_in , & curl_orients_in ));
1919+ }
1920+ if (rstr_in != rstr_out ) {
1921+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1922+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_out , & orients_out ));
1923+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1924+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_out , & curl_orients_out ));
1925+ }
1926+ }
1927+ CeedCallBackend (CeedDestroy (& ceed ));
1928+ return CEED_ERROR_SUCCESS ;
1929+ }
1930+
17081931//------------------------------------------------------------------------------
17091932// Assemble matrix data for COO matrix of assembled operator.
17101933// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2090,6 +2313,7 @@ int CeedOperatorCreate_Cuda(CeedOperator op) {
20902313 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddPointBlockDiagonal" ,
20912314 CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda ));
20922315 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingle" , CeedOperatorAssembleSingle_Cuda ));
2316+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingleBlock" , CeedOperatorAssembleSingleBlock_Cuda ));
20932317 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda ));
20942318 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda ));
20952319 CeedCallBackend (CeedDestroy (& ceed ));
0 commit comments