Skip to content

Commit d41ff2d

Browse files
author
Ian
committed
added new optimizations for matrix multiplication
1 parent 662cea9 commit d41ff2d

6 files changed

Lines changed: 295 additions & 137 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "A Rust port of LAS2 from SVDLIBC"
44
keywords = ["svd"]
55
categories = ["algorithms", "data-structures", "mathematics", "science"]
66
name = "single-svdlib"
7-
version = "1.0.4"
7+
version = "1.0.5"
88
edition = "2021"
99
license-file = "SVDLIBC-LICENSE.txt"
1010

src/lanczos/masked.rs

Lines changed: 261 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ impl<
446446
.collect();
447447

448448
// Efficient reduction with blocked memory access
449-
const BLOCK_SIZE: usize = 32;
449+
const BLOCK_SIZE: usize = 64;
450450
for local_result in partial_results {
451451
// Process in blocks for better cache performance
452452
for r_block in (0..ncols).step_by(BLOCK_SIZE) {
@@ -612,6 +612,266 @@ impl<
612612
}
613613
}
614614
}
615+
616+
fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
617+
let q_rows = q.nrows();
618+
let q_cols = q.ncols();
619+
let masked_cols = self.ncols();
620+
621+
assert_eq!(
622+
q_rows,
623+
self.nrows(),
624+
"Q matrix has incompatible row count: expected {}, got {}",
625+
self.nrows(),
626+
q_rows
627+
);
628+
assert_eq!(
629+
result.nrows(),
630+
q_cols,
631+
"Result matrix has incompatible row count: expected {}, got {}",
632+
q_cols,
633+
result.nrows()
634+
);
635+
assert_eq!(
636+
result.ncols(),
637+
masked_cols,
638+
"Result matrix has incompatible column count: expected {}, got {}",
639+
masked_cols,
640+
result.ncols()
641+
);
642+
643+
// Clear result matrix
644+
for i in 0..result.nrows() {
645+
for j in 0..result.ncols() {
646+
result[(i, j)] = T::zero();
647+
}
648+
}
649+
650+
let (major_offsets, minor_indices, values) = self.matrix.csr_data();
651+
let nrows = self.matrix.nrows();
652+
let chunk_size = determine_chunk_size(nrows);
653+
654+
if self.uses_all_columns() && (nrows < 1000 && self.matrix.ncols() < 1000) {
655+
// Fast path for small unmasked matrices
656+
let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
657+
.into_par_iter()
658+
.map(|chunk_idx| {
659+
let start = chunk_idx * chunk_size;
660+
let end = (start + chunk_size).min(nrows);
661+
let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
662+
663+
for row in start..end {
664+
// Process all non-zeros in this row
665+
for idx in major_offsets[row]..major_offsets[row + 1] {
666+
let col = minor_indices[idx];
667+
let sparse_val = values[idx];
668+
669+
// Accumulate: local_result[q_col, col] += q[row, q_col] * sparse_val
670+
for q_col in 0..q_cols {
671+
local_result[(q_col, col)] += q[(row, q_col)] * sparse_val;
672+
}
673+
}
674+
}
675+
676+
local_result
677+
})
678+
.collect();
679+
680+
// Combine partial results efficiently
681+
for local_result in partial_results {
682+
for r in 0..q_cols {
683+
for c in 0..masked_cols {
684+
let val = local_result[(r, c)];
685+
if !val.is_zero() {
686+
result[(r, c)] += val;
687+
}
688+
}
689+
}
690+
}
691+
} else {
692+
// Optimized path for masked matrices
693+
let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
694+
.into_par_iter()
695+
.map(|chunk_idx| {
696+
let start = chunk_idx * chunk_size;
697+
let end = (start + chunk_size).min(nrows);
698+
let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
699+
700+
for row in start..end {
701+
// Process all non-zeros in this row
702+
for idx in major_offsets[row]..major_offsets[row + 1] {
703+
let original_col = minor_indices[idx];
704+
705+
// Check if this column is in our mask
706+
if let Some(masked_col) = self.original_to_masked[original_col] {
707+
let sparse_val = values[idx];
708+
709+
// Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
710+
for q_col in 0..q_cols {
711+
local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
712+
}
713+
}
714+
}
715+
}
716+
717+
local_result
718+
})
719+
.collect();
720+
721+
// Combine partial results efficiently
722+
for local_result in partial_results {
723+
for r in 0..q_cols {
724+
for c in 0..masked_cols {
725+
let val = local_result[(r, c)];
726+
if !val.is_zero() {
727+
result[(r, c)] += val;
728+
}
729+
}
730+
}
731+
}
732+
}
733+
}
734+
735+
fn multiply_transposed_by_dense_centered(
736+
&self,
737+
q: &DMatrix<T>,
738+
result: &mut DMatrix<T>,
739+
means: &DVector<T>,
740+
) {
741+
let q_rows = q.nrows();
742+
let q_cols = q.ncols();
743+
let masked_cols = self.ncols();
744+
745+
assert_eq!(
746+
q_rows,
747+
self.nrows(),
748+
"Q matrix has incompatible row count: expected {}, got {}",
749+
self.nrows(),
750+
q_rows
751+
);
752+
assert_eq!(
753+
result.nrows(),
754+
q_cols,
755+
"Result matrix has incompatible row count: expected {}, got {}",
756+
q_cols,
757+
result.nrows()
758+
);
759+
assert_eq!(
760+
result.ncols(),
761+
masked_cols,
762+
"Result matrix has incompatible column count: expected {}, got {}",
763+
masked_cols,
764+
result.ncols()
765+
);
766+
assert_eq!(
767+
means.len(),
768+
masked_cols,
769+
"Means vector has incompatible length: expected {}, got {}",
770+
masked_cols,
771+
means.len()
772+
);
773+
774+
// Clear result matrix
775+
for i in 0..result.nrows() {
776+
for j in 0..result.ncols() {
777+
result[(i, j)] = T::zero();
778+
}
779+
}
780+
781+
let (major_offsets, minor_indices, values) = self.matrix.csr_data();
782+
783+
// Pre-compute column sums of Q - following the pattern from multiply_with_dense_centered
784+
let q_col_sums: Vec<T> = (0..q_cols)
785+
.into_par_iter()
786+
.map(|col| {
787+
(0..q_rows).map(|row| q[(row, col)]).sum()
788+
})
789+
.collect();
790+
791+
// Pre-compute mean adjustments for each masked column
792+
// For Q^T * (A - means): result[q_col, masked_col] = Q^T * A - sum(Q[q_col]) * means[masked_col]
793+
let mean_adjustments: Vec<T> = q_col_sums
794+
.iter()
795+
.enumerate()
796+
.map(|(q_col, &q_sum)| {
797+
means
798+
.iter()
799+
.enumerate()
800+
.map(|(masked_col_idx, &mean_val)| {
801+
if masked_col_idx < masked_cols {
802+
q_sum * mean_val
803+
} else {
804+
T::zero()
805+
}
806+
})
807+
.sum()
808+
})
809+
.collect();
810+
811+
let nrows = self.matrix.nrows();
812+
let chunk_size = determine_chunk_size(nrows);
813+
814+
// Process sparse matrix rows in chunks, similar to the transpose_self=true case
815+
let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
816+
.into_par_iter()
817+
.map(|chunk_idx| {
818+
let start = chunk_idx * chunk_size;
819+
let end = std::cmp::min(start + chunk_size, nrows);
820+
821+
let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
822+
823+
for row in start..end {
824+
// Process all non-zeros in this row
825+
for idx in major_offsets[row]..major_offsets[row + 1] {
826+
let original_col = minor_indices[idx];
827+
828+
// Check if this column is in our mask
829+
if let Some(masked_col) = self.original_to_masked[original_col] {
830+
let sparse_val = values[idx];
831+
832+
// Accumulate: local_result[q_col, masked_col] += q[row, q_col] * sparse_val
833+
for q_col in 0..q_cols {
834+
local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
835+
}
836+
}
837+
}
838+
}
839+
840+
// Apply mean adjustment for this chunk, following the pattern from your function
841+
let chunk_fraction = T::from_f64((end - start) as f64 / q_rows as f64).unwrap();
842+
843+
for q_col in 0..q_cols {
844+
let q_sum = q_col_sums[q_col];
845+
for masked_col in 0..masked_cols {
846+
local_result[(q_col, masked_col)] -= q_sum * means[masked_col] * chunk_fraction;
847+
}
848+
}
849+
850+
local_result
851+
})
852+
.collect();
853+
854+
// Combine partial results with block-wise writing for better cache locality
855+
for local_result in partial_results {
856+
const BLOCK_SIZE: usize = 64;
857+
858+
for r_block in 0..q_cols.div_ceil(BLOCK_SIZE) {
859+
let r_start = r_block * BLOCK_SIZE;
860+
let r_end = std::cmp::min(r_start + BLOCK_SIZE, q_cols);
861+
862+
for c_block in 0..masked_cols.div_ceil(BLOCK_SIZE) {
863+
let c_start = c_block * BLOCK_SIZE;
864+
let c_end = std::cmp::min(c_start + BLOCK_SIZE, masked_cols);
865+
866+
for r in r_start..r_end {
867+
for c in c_start..c_end {
868+
result[(r, c)] += local_result[(r, c)];
869+
}
870+
}
871+
}
872+
}
873+
}
874+
}
615875
}
616876

617877
#[cfg(test)]

src/lanczos/mod.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,14 @@ impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::cs
14931493
) {
14941494
todo!()
14951495
}
1496+
1497+
fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1498+
todo!()
1499+
}
1500+
1501+
fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1502+
todo!()
1503+
}
14961504
}
14971505

14981506
impl<T: Float + Zero + AddAssign + Clone + Sync + Send + std::ops::MulAssign> SMat<T>
@@ -1636,6 +1644,14 @@ impl<T: Float + Zero + AddAssign + Clone + Sync + Send + std::ops::MulAssign> SM
16361644
) {
16371645
todo!()
16381646
}
1647+
1648+
fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1649+
todo!()
1650+
}
1651+
1652+
fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1653+
todo!()
1654+
}
16391655
}
16401656

16411657
impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::coo::CooMatrix<T> {
@@ -1713,4 +1729,12 @@ impl<T: Float + Zero + AddAssign + Clone + Sync> SMat<T> for nalgebra_sparse::co
17131729
) {
17141730
todo!()
17151731
}
1732+
1733+
fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
1734+
todo!()
1735+
}
1736+
1737+
fn multiply_transposed_by_dense_centered(&self, q: &DMatrix<T>, result: &mut DMatrix<T>, means: &DVector<T>) {
1738+
todo!()
1739+
}
17161740
}

0 commit comments

Comments
 (0)