1- use rayon:: iter:: ParallelIterator ;
21use crate :: error:: SvdLibError ;
32use crate :: { Diagnostics , SMat , SvdFloat , SvdRec } ;
43use nalgebra_sparse:: na:: { ComplexField , DMatrix , DVector , RealField } ;
5- use ndarray:: { Array1 , Array2 } ;
4+ use ndarray:: Array1 ;
65use nshare:: IntoNdarray2 ;
76use rand:: prelude:: { Distribution , StdRng } ;
87use rand:: SeedableRng ;
98use rand_distr:: Normal ;
10- use std:: ops:: Mul ;
11- use rayon:: current_num_threads;
9+ use rayon:: iter:: ParallelIterator ;
1210use rayon:: prelude:: { IndexedParallelIterator , IntoParallelIterator } ;
11+ use std:: ops:: Mul ;
12+ use crate :: utils:: determine_chunk_size;
1313
1414pub enum PowerIterationNormalizer {
1515 QR ,
1616 LU ,
1717 None ,
1818}
1919
20-
2120const PARALLEL_THRESHOLD_ROWS : usize = 5000 ;
2221const PARALLEL_THRESHOLD_COLS : usize = 1000 ;
2322const PARALLEL_THRESHOLD_ELEMENTS : usize = 100_000 ;
@@ -35,22 +34,30 @@ where
3534 M : SMat < T > ,
3635 T : ComplexField ,
3736{
37+ let start = std:: time:: Instant :: now ( ) ; // only for debugging
3838 let m_rows = m. nrows ( ) ;
3939 let m_cols = m. ncols ( ) ;
4040
4141 let rank = target_rank. min ( m_rows. min ( m_cols) ) ;
4242 let l = rank + n_oversamples;
43+ println ! ( "Basic statistics: {:?}" , start. elapsed( ) ) ;
4344
4445 let omega = generate_random_matrix ( m_cols, l, seed) ;
46+ println ! ( "Generated Random Matrix here: {:?}" , start. elapsed( ) ) ;
4547
4648 let mut y = DMatrix :: < T > :: zeros ( m_rows, l) ;
4749 multiply_matrix ( m, & omega, & mut y, false ) ;
50+ println ! (
51+ "First multiplication took: {:?}, Continuing for power iterations:" ,
52+ start. elapsed( )
53+ ) ;
4854
4955 if n_power_iters > 0 {
5056 let mut z = DMatrix :: < T > :: zeros ( m_cols, l) ;
5157
52- for _ in 0 ..n_power_iters {
58+ for w in 0 ..n_power_iters {
5359 multiply_matrix ( m, & y, & mut z, true ) ;
60+ println ! ( "{}-nd power-iteration forward: {:?}" , w, start. elapsed( ) ) ;
5461 match power_iteration_normalizer {
5562 PowerIterationNormalizer :: QR => {
5663 let qr = z. qr ( ) ;
6168 }
6269 PowerIterationNormalizer :: None => { }
6370 }
71+ println ! (
72+ "{}-nd power-iteration forward, normalization: {:?}" ,
73+ w,
74+ start. elapsed( )
75+ ) ;
6476
6577 multiply_matrix ( m, & z, & mut y, false ) ;
78+ println ! ( "{}-nd power-iteration backward: {:?}" , w, start. elapsed( ) ) ;
6679 match power_iteration_normalizer {
6780 PowerIterationNormalizer :: QR => {
6881 let qr = y. qr ( ) ;
@@ -71,16 +84,30 @@ where
7184 PowerIterationNormalizer :: LU => normalize_columns ( & mut y) ,
7285 PowerIterationNormalizer :: None => { }
7386 }
87+ println ! (
88+ "{}-nd power-iteration backward, normalization: {:?}" ,
89+ w,
90+ start. elapsed( )
91+ ) ;
7492 }
7593 }
76-
94+ println ! (
95+ "Finished power-iteration, continuing QR: {:?}" ,
96+ start. elapsed( )
97+ ) ;
7798 let qr = y. qr ( ) ;
99+ println ! ( "QR finished: {:?}" , start. elapsed( ) ) ;
78100 let q = qr. q ( ) ;
79101
80102 let mut b = DMatrix :: < T > :: zeros ( q. ncols ( ) , m_cols) ;
81103 multiply_transposed_by_matrix ( & q, m, & mut b) ;
104+ println ! (
105+ "QMB matrix multiplication transposed: {:?}" ,
106+ start. elapsed( )
107+ ) ;
82108
83109 let svd = b. svd ( true , true ) ;
110+ println ! ( "SVD decomposition took: {:?}" , start. elapsed( ) ) ;
84111 let u_b = svd
85112 . u
86113 . ok_or_else ( || SvdLibError :: Las2Error ( "SVD U computation failed" . to_string ( ) ) ) ?;
@@ -98,10 +125,15 @@ where
98125
99126 // Convert to the format required by SvdRec
100127 let d = actual_rank;
128+ println ! ( "SVD Result Cropping: {:?}" , start. elapsed( ) ) ;
101129
102130 let ut = u. transpose ( ) . into_ndarray2 ( ) ;
103- let s = convert_singular_values ( <DVector < T > >:: from ( singular_values. rows ( 0 , actual_rank) ) , actual_rank) ;
131+ let s = convert_singular_values (
132+ <DVector < T > >:: from ( singular_values. rows ( 0 , actual_rank) ) ,
133+ actual_rank,
134+ ) ;
104135 let vt = vt_subset. into_ndarray2 ( ) ;
136+ println ! ( "Translation to ndarray: {:?}" , start. elapsed( ) ) ;
105137
106138 Ok ( SvdRec {
107139 d,
@@ -203,60 +235,32 @@ fn normalize_columns<T: SvdFloat + RealField + Send + Sync>(matrix: &mut DMatrix
203235 . collect ( ) ;
204236
205237 // Apply normalization
206- scales
207- . iter ( )
208- . for_each ( |( j, scale) | {
209- for i in 0 ..rows {
210- let value = matrix. get_mut ( ( i, * j) ) . unwrap ( ) ;
211- * value = value. clone ( ) * scale. clone ( ) ;
212- }
213- } ) ;
238+ scales. iter ( ) . for_each ( |( j, scale) | {
239+ for i in 0 ..rows {
240+ let value = matrix. get_mut ( ( i, * j) ) . unwrap ( ) ;
241+ * value = value. clone ( ) * scale. clone ( ) ;
242+ }
243+ } ) ;
214244}
215245
216246// ----------------------------------------
217247// Utils Functions
218248// ----------------------------------------
219249
220-
221250fn generate_random_matrix < T : SvdFloat + RealField > (
222251 rows : usize ,
223252 cols : usize ,
224253 seed : Option < u64 > ,
225254) -> DMatrix < T > {
226- //if rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS && rows * cols < PARALLEL_THRESHOLD_ELEMENTS {
227- let mut rng = match seed {
228- Some ( s) => StdRng :: seed_from_u64 ( s) ,
229- None => StdRng :: seed_from_u64 ( 0 ) ,
230- } ;
231-
232- let normal = Normal :: new ( 0.0 , 1.0 ) . unwrap ( ) ;
233- return DMatrix :: from_fn ( rows, cols, |_, _| {
234- T :: from_f64 ( normal. sample ( & mut rng) ) . unwrap ( )
235- } ) ;
236- //}
237-
238- /*let seed_value = seed.unwrap_or(0);
239- let mut matrix = DMatrix::<T>::zeros(rows, cols);
240- let num_threads = current_num_threads();
241- let chunk_size = (rows * cols + num_threads - 1) / num_threads;
242-
243- (0..(rows * cols)).into_par_iter()
244- .chunks(chunk_size)
245- .enumerate()
246- .for_each(|(chunk_idx, indices)| {
247- let thread_seed = seed_value.wrapping_add(chunk_idx as u64);
248- let mut rng = StdRng::seed_from_u64(thread_seed);
249- let normal = Normal::new(0.0, 1.0).unwrap();
250- for idx in indices {
251- let i = idx / cols;
252- let j = idx % cols;
253- unsafe {
254- *matrix.get_unchecked_mut((i, j)) = T::from_f64(normal.sample(&mut rng)).unwrap();
255- }
256- }
257- });
258- matrix*/
259-
255+ let mut rng = match seed {
256+ Some ( s) => StdRng :: seed_from_u64 ( s) ,
257+ None => StdRng :: seed_from_u64 ( 0 ) ,
258+ } ;
259+
260+ let normal = Normal :: new ( 0.0 , 1.0 ) . unwrap ( ) ;
261+ DMatrix :: from_fn ( rows, cols, |_, _| {
262+ T :: from_f64 ( normal. sample ( & mut rng) ) . unwrap ( )
263+ } )
260264}
261265
262266fn multiply_matrix < T : SvdFloat , M : SMat < T > > (
@@ -266,53 +270,94 @@ fn multiply_matrix<T: SvdFloat, M: SMat<T>>(
266270 transpose_sparse : bool ,
267271) {
268272 let cols = dense. ncols ( ) ;
269- //let matrix_rows = if transpose_sparse { sparse.ncols() } else { sparse.nrows() };
270273
271- //if matrix_rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS {
272- let mut col_vec = vec ! [ T :: zero( ) ; dense. nrows( ) ] ;
273- let mut result_vec = vec ! [ T :: zero( ) ; result. nrows( ) ] ;
274+ let results: Vec < ( usize , Vec < T > ) > = ( 0 ..cols)
275+ . into_par_iter ( )
276+ . map ( |j| {
277+ let mut col_vec = vec ! [ T :: zero( ) ; dense. nrows( ) ] ;
278+ let mut result_vec = vec ! [ T :: zero( ) ; result. nrows( ) ] ;
274279
275- for j in 0 ..cols {
276- // Extract column from dense matrix
277280 for i in 0 ..dense. nrows ( ) {
278281 col_vec[ i] = dense[ ( i, j) ] ;
279282 }
280283
281- // Perform sparse matrix operation
282284 sparse. svd_opa ( & col_vec, & mut result_vec, transpose_sparse) ;
283285
284- // Store results
285- for i in 0 ..result. nrows ( ) {
286- result[ ( i, j) ] = result_vec[ i] ;
287- }
286+ ( j, result_vec)
287+ } )
288+ . collect ( ) ;
288289
289- // Clear result vector for reuse
290- result_vec. iter_mut ( ) . for_each ( |v| * v = T :: zero ( ) ) ;
290+ for ( j, col_result) in results {
291+ for i in 0 ..result. nrows ( ) {
292+ result[ ( i, j) ] = col_result[ i] ;
291293 }
292- return ;
293- //}
294-
295-
294+ }
296295}
297296
298297fn multiply_transposed_by_matrix < T : SvdFloat , M : SMat < T > > (
299- q : & DMatrix < T > ,
298+ q : & DMatrix < T > ,
300299 sparse : & M ,
301300 result : & mut DMatrix < T > ,
302301) {
303- for j in 0 ..sparse. ncols ( ) {
304- let mut unit_vec = vec ! [ T :: zero( ) ; sparse. ncols( ) ] ;
305- unit_vec[ j] = T :: one ( ) ;
306-
307- let mut col_vec = vec ! [ T :: zero( ) ; sparse. nrows( ) ] ;
308- sparse. svd_opa ( & unit_vec, & mut col_vec, false ) ;
309-
310- for i in 0 ..q. ncols ( ) {
311- let mut sum = T :: zero ( ) ;
312- for k in 0 ..q. nrows ( ) {
313- sum += q[ ( k, i) ] * col_vec[ k] ;
302+ let q_rows = q. nrows ( ) ;
303+ let q_cols = q. ncols ( ) ;
304+ let sparse_rows = sparse. nrows ( ) ;
305+ let sparse_cols = sparse. ncols ( ) ;
306+
307+ eprintln ! ( "Q dimensions: {} x {}" , q_rows, q_cols) ;
308+ eprintln ! ( "Sparse dimensions: {} x {}" , sparse_rows, sparse_cols) ;
309+ eprintln ! ( "Result dimensions: {} x {}" , result. nrows( ) , result. ncols( ) ) ;
310+
311+ assert_eq ! (
312+ q_rows, sparse_rows,
313+ "Dimension mismatch: Q has {} rows but sparse has {} rows" ,
314+ q_rows, sparse_rows
315+ ) ;
316+
317+ assert_eq ! (
318+ result. nrows( ) ,
319+ q_cols,
320+ "Result matrix has incorrect row count: expected {}, got {}" ,
321+ q_cols,
322+ result. nrows( )
323+ ) ;
324+ assert_eq ! (
325+ result. ncols( ) ,
326+ sparse_cols,
327+ "Result matrix has incorrect column count: expected {}, got {}" ,
328+ sparse_cols,
329+ result. ncols( )
330+ ) ;
331+
332+ let chunk_size = determine_chunk_size ( q_cols) ;
333+
334+ let chunk_results: Vec < Vec < ( usize , Vec < T > ) > > = ( 0 ..q_cols)
335+ . into_par_iter ( )
336+ . chunks ( chunk_size)
337+ . map ( |chunk| {
338+ let mut chunk_results = Vec :: with_capacity ( chunk. len ( ) ) ;
339+
340+ for & col_idx in & chunk {
341+ let mut q_col = vec ! [ T :: zero( ) ; q_rows] ;
342+ for i in 0 ..q_rows {
343+ q_col[ i] = q[ ( i, col_idx) ] ;
344+ }
345+
346+ let mut result_row = vec ! [ T :: zero( ) ; sparse_cols] ;
347+
348+ sparse. svd_opa ( & q_col, & mut result_row, true ) ;
349+
350+ chunk_results. push ( ( col_idx, result_row) ) ;
351+ }
352+ chunk_results
353+ } )
354+ . collect ( ) ;
355+
356+ for chunk_result in chunk_results {
357+ for ( row_idx, row_values) in chunk_result {
358+ for j in 0 ..sparse_cols {
359+ result[ ( row_idx, j) ] = row_values[ j] ;
314360 }
315- result[ ( i, j) ] = sum;
316361 }
317362 }
318363}
0 commit comments