11use crate :: error:: * ;
22use crate :: { Dimension , Ix0 , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
33
4- /// Calculate the co_broadcast shape of two dimensions. Return error if shapes are
5- /// not compatible.
6- fn broadcast_shape < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
7- where
8- D1 : Dimension ,
9- D2 : Dimension ,
10- Output : Dimension ,
4+ /// Calculate the common shape for a pair of array shapes, which can be broadcasted
5+ /// to each other. Return an error if shapes are not compatible.
6+ ///
7+ /// Uses the [NumPy broadcasting rules]
8+ // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
9+ fn co_broadcasting < D1 , D2 , Output > ( shape1 : & D1 , shape2 : & D2 ) -> Result < Output , ShapeError >
10+ where
11+ D1 : Dimension ,
12+ D2 : Dimension ,
13+ Output : Dimension ,
1114{
1215 let ( k, overflow) = shape1. ndim ( ) . overflowing_sub ( shape2. ndim ( ) ) ;
1316 // Swap the order if d2 is longer.
1417 if overflow {
15- return broadcast_shape :: < D2 , D1 , Output > ( shape2, shape1) ;
18+ return co_broadcasting :: < D2 , D1 , Output > ( shape2, shape1) ;
1619 }
1720 // The output should be the same length as shape1.
1821 let mut out = Output :: zeros ( shape1. ndim ( ) ) ;
19- // Uses the [NumPy broadcasting rules]
20- // (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
21- //
22- // Zero dimension element is not in the original rules of broadcasting.
23- // We currently treat it like any other number greater than 1. As numpy does.
2422 for ( out, s) in izip ! ( out. slice_mut( ) , shape1. slice( ) ) {
2523 * out = * s;
2624 }
@@ -42,10 +40,7 @@ pub trait BroadcastShape<Other: Dimension> {
4240
4341 /// Determines the shape after broadcasting the dimensions together.
4442 ///
45- /// If the dimensions are not compatible, returns `Err`.
46- ///
47- /// Uses the [NumPy broadcasting rules]
48- /// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules).
43+ /// If the shapes are not compatible, returns `Err`.
4944 fn broadcast_shape ( & self , other : & Other ) -> Result < Self :: Output , ShapeError > ;
5045}
5146
@@ -56,7 +51,7 @@ impl<D: Dimension> BroadcastShape<D> for D {
5651 type Output = D ;
5752
5853 fn broadcast_shape ( & self , other : & D ) -> Result < Self :: Output , ShapeError > {
59- broadcast_shape :: < D , D , Self :: Output > ( self , other)
54+ co_broadcasting :: < D , D , Self :: Output > ( self , other)
6055 }
6156}
6257
@@ -66,15 +61,15 @@ macro_rules! impl_broadcast_distinct_fixed {
6661 type Output = $larger;
6762
6863 fn broadcast_shape( & self , other: & $larger) -> Result <Self :: Output , ShapeError > {
69- broadcast_shape :: <Self , $larger, Self :: Output >( self , other)
64+ co_broadcasting :: <Self , $larger, Self :: Output >( self , other)
7065 }
7166 }
7267
7368 impl BroadcastShape <$smaller> for $larger {
7469 type Output = $larger;
7570
7671 fn broadcast_shape( & self , other: & $smaller) -> Result <Self :: Output , ShapeError > {
77- broadcast_shape :: <Self , $smaller, Self :: Output >( self , other)
72+ co_broadcasting :: <Self , $smaller, Self :: Output >( self , other)
7873 }
7974 }
8075 } ;
0 commit comments