@@ -446,7 +446,7 @@ impl<
446446 . collect ( ) ;
447447
448448 // Efficient reduction with blocked memory access
449- const BLOCK_SIZE : usize = 32 ;
449+ const BLOCK_SIZE : usize = 64 ;
450450 for local_result in partial_results {
451451 // Process in blocks for better cache performance
452452 for r_block in ( 0 ..ncols) . step_by ( BLOCK_SIZE ) {
@@ -612,6 +612,266 @@ impl<
612612 }
613613 }
614614 }
615+
616+ fn multiply_transposed_by_dense ( & self , q : & DMatrix < T > , result : & mut DMatrix < T > ) {
617+ let q_rows = q. nrows ( ) ;
618+ let q_cols = q. ncols ( ) ;
619+ let masked_cols = self . ncols ( ) ;
620+
621+ assert_eq ! (
622+ q_rows,
623+ self . nrows( ) ,
624+ "Q matrix has incompatible row count: expected {}, got {}" ,
625+ self . nrows( ) ,
626+ q_rows
627+ ) ;
628+ assert_eq ! (
629+ result. nrows( ) ,
630+ q_cols,
631+ "Result matrix has incompatible row count: expected {}, got {}" ,
632+ q_cols,
633+ result. nrows( )
634+ ) ;
635+ assert_eq ! (
636+ result. ncols( ) ,
637+ masked_cols,
638+ "Result matrix has incompatible column count: expected {}, got {}" ,
639+ masked_cols,
640+ result. ncols( )
641+ ) ;
642+
643+ // Clear result matrix
644+ for i in 0 ..result. nrows ( ) {
645+ for j in 0 ..result. ncols ( ) {
646+ result[ ( i, j) ] = T :: zero ( ) ;
647+ }
648+ }
649+
650+ let ( major_offsets, minor_indices, values) = self . matrix . csr_data ( ) ;
651+ let nrows = self . matrix . nrows ( ) ;
652+ let chunk_size = determine_chunk_size ( nrows) ;
653+
654+ if self . uses_all_columns ( ) && ( nrows < 1000 && self . matrix . ncols ( ) < 1000 ) {
655+ // Fast path for small unmasked matrices
656+ let partial_results: Vec < DMatrix < T > > = ( 0 ..nrows. div_ceil ( chunk_size) )
657+ . into_par_iter ( )
658+ . map ( |chunk_idx| {
659+ let start = chunk_idx * chunk_size;
660+ let end = ( start + chunk_size) . min ( nrows) ;
661+ let mut local_result = DMatrix :: < T > :: zeros ( q_cols, masked_cols) ;
662+
663+ for row in start..end {
664+ // Process all non-zeros in this row
665+ for idx in major_offsets[ row] ..major_offsets[ row + 1 ] {
666+ let col = minor_indices[ idx] ;
667+ let sparse_val = values[ idx] ;
668+
669+ // Accumulate: local_result[q_col, col] += q[row, q_col] * sparse_val
670+ for q_col in 0 ..q_cols {
671+ local_result[ ( q_col, col) ] += q[ ( row, q_col) ] * sparse_val;
672+ }
673+ }
674+ }
675+
676+ local_result
677+ } )
678+ . collect ( ) ;
679+
680+ // Combine partial results efficiently
681+ for local_result in partial_results {
682+ for r in 0 ..q_cols {
683+ for c in 0 ..masked_cols {
684+ let val = local_result[ ( r, c) ] ;
685+ if !val. is_zero ( ) {
686+ result[ ( r, c) ] += val;
687+ }
688+ }
689+ }
690+ }
691+ } else {
692+ // Optimized path for masked matrices
693+ let partial_results: Vec < DMatrix < T > > = ( 0 ..nrows. div_ceil ( chunk_size) )
694+ . into_par_iter ( )
695+ . map ( |chunk_idx| {
696+ let start = chunk_idx * chunk_size;
697+ let end = ( start + chunk_size) . min ( nrows) ;
698+ let mut local_result = DMatrix :: < T > :: zeros ( q_cols, masked_cols) ;
699+
700+ for row in start..end {
701+ // Process all non-zeros in this row
702+ for idx in major_offsets[ row] ..major_offsets[ row + 1 ] {
703+ let original_col = minor_indices[ idx] ;
704+
705+ // Check if this column is in our mask
706+ if let Some ( masked_col) = self . original_to_masked [ original_col] {
707+ let sparse_val = values[ idx] ;
708+
709+ // Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
710+ for q_col in 0 ..q_cols {
711+ local_result[ ( q_col, masked_col) ] += q[ ( row, q_col) ] * sparse_val;
712+ }
713+ }
714+ }
715+ }
716+
717+ local_result
718+ } )
719+ . collect ( ) ;
720+
721+ // Combine partial results efficiently
722+ for local_result in partial_results {
723+ for r in 0 ..q_cols {
724+ for c in 0 ..masked_cols {
725+ let val = local_result[ ( r, c) ] ;
726+ if !val. is_zero ( ) {
727+ result[ ( r, c) ] += val;
728+ }
729+ }
730+ }
731+ }
732+ }
733+ }
734+
735+ fn multiply_transposed_by_dense_centered (
736+ & self ,
737+ q : & DMatrix < T > ,
738+ result : & mut DMatrix < T > ,
739+ means : & DVector < T > ,
740+ ) {
741+ let q_rows = q. nrows ( ) ;
742+ let q_cols = q. ncols ( ) ;
743+ let masked_cols = self . ncols ( ) ;
744+
745+ assert_eq ! (
746+ q_rows,
747+ self . nrows( ) ,
748+ "Q matrix has incompatible row count: expected {}, got {}" ,
749+ self . nrows( ) ,
750+ q_rows
751+ ) ;
752+ assert_eq ! (
753+ result. nrows( ) ,
754+ q_cols,
755+ "Result matrix has incompatible row count: expected {}, got {}" ,
756+ q_cols,
757+ result. nrows( )
758+ ) ;
759+ assert_eq ! (
760+ result. ncols( ) ,
761+ masked_cols,
762+ "Result matrix has incompatible column count: expected {}, got {}" ,
763+ masked_cols,
764+ result. ncols( )
765+ ) ;
766+ assert_eq ! (
767+ means. len( ) ,
768+ masked_cols,
769+ "Means vector has incompatible length: expected {}, got {}" ,
770+ masked_cols,
771+ means. len( )
772+ ) ;
773+
774+ // Clear result matrix
775+ for i in 0 ..result. nrows ( ) {
776+ for j in 0 ..result. ncols ( ) {
777+ result[ ( i, j) ] = T :: zero ( ) ;
778+ }
779+ }
780+
781+ let ( major_offsets, minor_indices, values) = self . matrix . csr_data ( ) ;
782+
783+ // Pre-compute column sums of Q - following the pattern from multiply_with_dense_centered
784+ let q_col_sums: Vec < T > = ( 0 ..q_cols)
785+ . into_par_iter ( )
786+ . map ( |col| {
787+ ( 0 ..q_rows) . map ( |row| q[ ( row, col) ] ) . sum ( )
788+ } )
789+ . collect ( ) ;
790+
791+ // Pre-compute mean adjustments for each masked column
792+ // For Q^T * (A - means): result[q_col, masked_col] = Q^T * A - sum(Q[q_col]) * means[masked_col]
793+ let mean_adjustments: Vec < T > = q_col_sums
794+ . iter ( )
795+ . enumerate ( )
796+ . map ( |( q_col, & q_sum) | {
797+ means
798+ . iter ( )
799+ . enumerate ( )
800+ . map ( |( masked_col_idx, & mean_val) | {
801+ if masked_col_idx < masked_cols {
802+ q_sum * mean_val
803+ } else {
804+ T :: zero ( )
805+ }
806+ } )
807+ . sum ( )
808+ } )
809+ . collect ( ) ;
810+
811+ let nrows = self . matrix . nrows ( ) ;
812+ let chunk_size = determine_chunk_size ( nrows) ;
813+
814+ // Process sparse matrix rows in chunks, similar to the transpose_self=true case
815+ let partial_results: Vec < DMatrix < T > > = ( 0 ..nrows. div_ceil ( chunk_size) )
816+ . into_par_iter ( )
817+ . map ( |chunk_idx| {
818+ let start = chunk_idx * chunk_size;
819+ let end = std:: cmp:: min ( start + chunk_size, nrows) ;
820+
821+ let mut local_result = DMatrix :: < T > :: zeros ( q_cols, masked_cols) ;
822+
823+ for row in start..end {
824+ // Process all non-zeros in this row
825+ for idx in major_offsets[ row] ..major_offsets[ row + 1 ] {
826+ let original_col = minor_indices[ idx] ;
827+
828+ // Check if this column is in our mask
829+ if let Some ( masked_col) = self . original_to_masked [ original_col] {
830+ let sparse_val = values[ idx] ;
831+
832+ // Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
833+ for q_col in 0 ..q_cols {
834+ local_result[ ( q_col, masked_col) ] += q[ ( row, q_col) ] * sparse_val;
835+ }
836+ }
837+ }
838+ }
839+
840+ // Apply mean adjustment for this chunk, following the pattern from your function
841+ let chunk_fraction = T :: from_f64 ( ( end - start) as f64 / q_rows as f64 ) . unwrap ( ) ;
842+
843+ for q_col in 0 ..q_cols {
844+ let q_sum = q_col_sums[ q_col] ;
845+ for masked_col in 0 ..masked_cols {
846+ local_result[ ( q_col, masked_col) ] -= q_sum * means[ masked_col] * chunk_fraction;
847+ }
848+ }
849+
850+ local_result
851+ } )
852+ . collect ( ) ;
853+
854+ // Combine partial results with block-wise writing for better cache locality
855+ for local_result in partial_results {
856+ const BLOCK_SIZE : usize = 64 ;
857+
858+ for r_block in 0 ..q_cols. div_ceil ( BLOCK_SIZE ) {
859+ let r_start = r_block * BLOCK_SIZE ;
860+ let r_end = std:: cmp:: min ( r_start + BLOCK_SIZE , q_cols) ;
861+
862+ for c_block in 0 ..masked_cols. div_ceil ( BLOCK_SIZE ) {
863+ let c_start = c_block * BLOCK_SIZE ;
864+ let c_end = std:: cmp:: min ( c_start + BLOCK_SIZE , masked_cols) ;
865+
866+ for r in r_start..r_end {
867+ for c in c_start..c_end {
868+ result[ ( r, c) ] += local_result[ ( r, c) ] ;
869+ }
870+ }
871+ }
872+ }
873+ }
874+ }
615875}
616876
617877#[ cfg( test) ]
0 commit comments