@@ -66,7 +66,6 @@ impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
6666}
6767
6868impl <
69- ' a ,
7069 T : Float
7170 + AddAssign
7271 + Sync
7776 + std:: iter:: Sum
7877 + std:: ops:: SubAssign
7978 + num_traits:: FromPrimitive ,
80- > SMat < T > for MaskedCSRMatrix < ' a , T >
79+ > SMat < T > for MaskedCSRMatrix < ' _ , T >
8180{
8281 fn nrows ( & self ) -> usize {
8382 self . matrix . nrows ( )
@@ -296,81 +295,174 @@ impl<
296295 "Result matrix has incompatible column count"
297296 ) ;
298297
299- // Determine if we can use optimized path
300- //if self.ensure_identical_results_mode() {
301- // For small matrices, use the default implementation
302- // return <Self as SMat<T>>::multiply_matrix(self, dense, result, transpose_self);
303- //}
304-
305298 let ( major_offsets, minor_indices, values) = self . matrix . csr_data ( ) ;
306299
307300 if !transpose_self {
308301 let rows = self . matrix . nrows ( ) ;
309302 let dense_cols = dense. ncols ( ) ;
310303
311- let partial_results: Vec < ( usize , DMatrix < T > ) > = ( 0 ..rows)
304+ // Pre-filter valid column mappings to avoid repeated lookups
305+ let valid_cols: Vec < Option < usize > > = ( 0 ..self . matrix . ncols ( ) )
306+ . map ( |col| self . original_to_masked . get ( col) . copied ( ) . flatten ( ) )
307+ . collect ( ) ;
308+
309+ // Compute results in parallel, then apply to result matrix
310+ let row_results: Vec < ( usize , Vec < T > ) > = ( 0 ..rows)
312311 . into_par_iter ( )
313312 . map ( |row| {
314- let mut local_result = DMatrix :: < T > :: zeros ( 1 , dense_cols) ;
313+ let mut row_result = vec ! [ T :: zero ( ) ; dense_cols] ;
315314
316- for j in major_offsets[ row] ..major_offsets[ row + 1 ] {
315+ // Process sparse row with blocked inner loop for better vectorization
316+ let row_start = major_offsets[ row] ;
317+ let row_end = major_offsets[ row + 1 ] ;
318+
319+ // Unroll the sparse elements loop by 4 for better ILP
320+ let mut j = row_start;
321+ while j + 4 <= row_end {
322+ // Process 4 sparse elements at once
323+ for offset in 0 ..4 {
324+ let idx = j + offset;
325+ let col = minor_indices[ idx] ;
326+ if let Some ( masked_col) = valid_cols[ col] {
327+ let val = values[ idx] ;
328+
329+ // Vectorized dense column update
330+ for c in 0 ..dense_cols {
331+ row_result[ c] += val * dense[ ( masked_col, c) ] ;
332+ }
333+ }
334+ }
335+ j += 4 ;
336+ }
337+
338+ // Handle remaining elements
339+ while j < row_end {
317340 let col = minor_indices[ j] ;
318- if let Some ( masked_col) = self . original_to_masked [ col] {
341+ if let Some ( masked_col) = valid_cols [ col] {
319342 let val = values[ j] ;
320343
321344 for c in 0 ..dense_cols {
322- local_result [ ( 0 , c ) ] += val * dense[ ( masked_col, c) ] ;
345+ row_result [ c ] += val * dense[ ( masked_col, c) ] ;
323346 }
324347 }
348+ j += 1 ;
325349 }
326350
327- ( row, local_result )
351+ ( row, row_result )
328352 } )
329353 . collect ( ) ;
330354
331- for ( row, local_result) in partial_results {
355+ // Apply results to output matrix
356+ for ( row, row_values) in row_results {
332357 for c in 0 ..dense_cols {
333- result[ ( row, c) ] = local_result [ ( 0 , c ) ] ;
358+ result[ ( row, c) ] = row_values [ c ] ;
334359 }
335360 }
336361 } else {
337362 let nrows = self . matrix . nrows ( ) ;
338363 let ncols = self . ncols ( ) ;
339364 let dense_cols = dense. ncols ( ) ;
340365
366+ // Clear result matrix once at the beginning
367+ result. fill ( T :: zero ( ) ) ;
368+
369+ // Pre-filter valid column mappings
370+ let valid_cols: Vec < Option < usize > > = ( 0 ..self . matrix . ncols ( ) )
371+ . map ( |col| self . original_to_masked . get ( col) . copied ( ) . flatten ( ) )
372+ . collect ( ) ;
373+
341374 let chunk_size = determine_chunk_size ( nrows) ;
342375
343- let partial_results: Vec < DMatrix < T > > = ( 0 ..nrows. div_ceil ( chunk_size) )
376+ // Use atomic-free approach with proper synchronization
377+ let partial_results: Vec < Vec < T > > = ( 0 ..nrows. div_ceil ( chunk_size) )
344378 . into_par_iter ( )
345379 . map ( |chunk_idx| {
346380 let start = chunk_idx * chunk_size;
347381 let end = ( start + chunk_size) . min ( nrows) ;
348382
349- let mut local_result = DMatrix :: < T > :: zeros ( ncols, dense_cols) ;
383+ // Use flat vector for better cache performance
384+ let mut local_result = vec ! [ T :: zero( ) ; ncols * dense_cols] ;
350385
386+ // Process chunk with better memory access patterns
351387 for i in start..end {
352- for j in major_offsets[ i] ..major_offsets[ i + 1 ] {
388+ let dense_row = unsafe {
389+ std:: slice:: from_raw_parts (
390+ dense. as_ptr ( ) . add ( i * dense_cols) ,
391+ dense_cols,
392+ )
393+ } ;
394+
395+ // Block processing for better cache usage
396+ let row_start = major_offsets[ i] ;
397+ let row_end = major_offsets[ i + 1 ] ;
398+
399+ // Process sparse elements in blocks of 8 for better vectorization
400+ let mut j = row_start;
401+ while j + 8 <= row_end {
402+ for offset in 0 ..8 {
403+ let idx = j + offset;
404+ let col = minor_indices[ idx] ;
405+ if let Some ( masked_col) = valid_cols[ col] {
406+ let val = values[ idx] ;
407+ let base_offset = masked_col * dense_cols;
408+
409+ // Vectorized update with manual loop unrolling
410+ let mut c = 0 ;
411+ while c + 4 <= dense_cols {
412+ local_result[ base_offset + c] += val * dense_row[ c] ;
413+ local_result[ base_offset + c + 1 ] += val * dense_row[ c + 1 ] ;
414+ local_result[ base_offset + c + 2 ] += val * dense_row[ c + 2 ] ;
415+ local_result[ base_offset + c + 3 ] += val * dense_row[ c + 3 ] ;
416+ c += 4 ;
417+ }
418+
419+ // Handle remaining columns
420+ while c < dense_cols {
421+ local_result[ base_offset + c] += val * dense_row[ c] ;
422+ c += 1 ;
423+ }
424+ }
425+ }
426+ j += 8 ;
427+ }
428+
429+ // Handle remaining sparse elements
430+ while j < row_end {
353431 let col = minor_indices[ j] ;
354- if let Some ( masked_col) = self . original_to_masked [ col] {
432+ if let Some ( masked_col) = valid_cols [ col] {
355433 let val = values[ j] ;
434+ let base_offset = masked_col * dense_cols;
356435
357436 for c in 0 ..dense_cols {
358- local_result[ ( masked_col , c ) ] += val * dense [ ( i , c ) ] ;
437+ local_result[ base_offset + c ] += val * dense_row [ c ] ;
359438 }
360439 }
440+ j += 1 ;
361441 }
362442 }
363443
364444 local_result
365445 } )
366446 . collect ( ) ;
367447
448+ // Efficient reduction with blocked memory access
449+ const BLOCK_SIZE : usize = 32 ;
368450 for local_result in partial_results {
369- for r in 0 ..ncols {
370- for c in 0 ..dense_cols {
371- let val = local_result[ ( r, c) ] ;
372- if !val. is_zero ( ) {
373- result[ ( r, c) ] += val;
451+ // Process in blocks for better cache performance
452+ for r_block in ( 0 ..ncols) . step_by ( BLOCK_SIZE ) {
453+ let r_end = ( r_block + BLOCK_SIZE ) . min ( ncols) ;
454+
455+ for c_block in ( 0 ..dense_cols) . step_by ( BLOCK_SIZE ) {
456+ let c_end = ( c_block + BLOCK_SIZE ) . min ( dense_cols) ;
457+
458+ // Update result block
459+ for r in r_block..r_end {
460+ for c in c_block..c_end {
461+ let val = local_result[ r * dense_cols + c] ;
462+ if !val. is_zero ( ) {
463+ result[ ( r, c) ] += val;
464+ }
465+ }
374466 }
375467 }
376468 }
@@ -416,8 +508,6 @@ impl<
416508 } )
417509 . collect ( ) ;
418510
419- let chunk_size = std:: cmp:: max ( 16 , rows / ( rayon:: current_num_threads ( ) * 4 ) ) ;
420-
421511 let row_updates: Vec < ( usize , Vec < T > ) > = ( 0 ..rows)
422512 . into_par_iter ( )
423513 . map ( |row| {
0 commit comments