22Utilities
33"""
44
5- # from itertools import product
65from typing import Hashable , Iterable , Iterator , Optional
76
87import numpy as np
@@ -80,10 +79,15 @@ def _calc_rmax(depth: DataArray) -> DataArray:
8079 return rmax .fillna (0 )
8180
8281
83- def _smooth_MB06 (depth : DataArray , rmax : float ) -> DataArray :
82+ def _smooth_MB06 (
83+ depth : DataArray ,
84+ rmax : float ,
85+ tol : float = 1.0e-8 ,
86+ max_iter : int = 10_000 ,
87+ ) -> DataArray :
8488 """
85- This is NEMO implementation of the direct iterative method
86- of Martinho and Batteen (2006) .
89+ Direct iterative method of Martinho and Batteen (2006) consistent
90+ with NEMO implementation .
8791
8892 The algorithm ensures that
8993
@@ -100,87 +104,62 @@ def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray:
100104 Parameters
101105 ----------
102106 depth: DataArray
103- Bottom depth (units: m) .
107+ Bottom depth.
104108 rmax: float
105109 Maximum slope parameter allowed
110+ tol: float, default = 1.0e-8
111+ Tolerance for the iterative method
112+ max_iter: int, default = 10000
113+ Maximum number of iterations
106114
107115 Returns
108116 -------
109117 DataArray
110118 Smooth version of the bottom topography with
111- a maximum slope parameter < rmax (units: m).
112-
119+ a maximum slope parameter < rmax.
113120 """
114121
115- # set scaling factor used for smoothing
122+ # Set scaling factor used for smoothing
116123 zrfact = (1.0 - rmax ) / (1.0 + rmax )
117124
118- # getting the actual numpy array
119- # TO BE OPTIMISED
120- da_zenv = depth .copy ()
121- zenv = da_zenv .data
122- nj = zenv .shape [0 ]
123- ni = zenv .shape [1 ]
124-
125- # initialise temporary evelope depth arrays
126- ztmpi1 = zenv .copy ()
127- ztmpi2 = zenv .copy ()
128- ztmpj1 = zenv .copy ()
129- ztmpj2 = zenv .copy ()
130-
131- # Computing the initial maximum slope parameter
132- zrmax = 1.0 # np.nanmax(_calc_rmax(depth))
133- zri = np .ones (zenv .shape ) # * zrmax
134- zrj = np .ones (zenv .shape ) # * zrmax
135-
136- tol = 1.0e-8
137- itr = 0
138- max_itr = 10000
139-
140- while itr <= max_itr and (zrmax - rmax ) > tol :
141-
142- itr += 1
143- zrmax = 0.0
144- # we set zrmax from previous r-values (zri and zrj) first
145- # if set after current r-value calculation (as previously)
146- # we could exit DO WHILE prematurely before checking r-value
147- # of current zenv
148- max_zri = np .nanmax (np .absolute (zri ))
149- max_zrj = np .nanmax (np .absolute (zrj ))
150- zrmax = np .nanmax ([zrmax , max_zrj , max_zri ])
151-
152- print ("Iter:" , itr , "rmax: " , zrmax )
153-
154- zri *= 0.0
155- zrj *= 0.0
156-
157- for j in range (nj - 1 ):
158- for i in range (ni - 1 ):
159- ip1 = np .minimum (i + 1 , ni )
160- jp1 = np .minimum (j + 1 , nj )
161- if zenv [j , i ] > 0.0 and zenv [j , ip1 ] > 0.0 :
162- zri [j , i ] = (zenv [j , ip1 ] - zenv [j , i ]) / (
163- zenv [j , ip1 ] + zenv [j , i ]
164- )
165- if zenv [j , i ] > 0.0 and zenv [jp1 , i ] > 0.0 :
166- zrj [j , i ] = (zenv [jp1 , i ] - zenv [j , i ]) / (
167- zenv [jp1 , i ] + zenv [j , i ]
168- )
169- if zri [j , i ] > rmax :
170- ztmpi1 [j , i ] = zenv [j , ip1 ] * zrfact
171- if zri [j , i ] < - rmax :
172- ztmpi2 [j , ip1 ] = zenv [j , i ] * zrfact
173- if zrj [j , i ] > rmax :
174- ztmpj1 [j , i ] = zenv [jp1 , i ] * zrfact
175- if zrj [j , i ] < - rmax :
176- ztmpj2 [jp1 , i ] = zenv [j , i ] * zrfact
177-
178- ztmpi = np .maximum (ztmpi1 , ztmpi2 )
179- ztmpj = np .maximum (ztmpj1 , ztmpj2 )
180- zenv = np .maximum (zenv , np .maximum (ztmpi , ztmpj ))
181-
182- da_zenv .data = zenv
183- return da_zenv
125+ # Initialize envelope bathymetry
126+ zenv = depth
127+
128+ for _ in range (max_iter ):
129+
130+ # Initialize lists of DataArrays to concatenate
131+ all_ztmp = []
132+ all_zr = []
133+ for dim in zenv .dims :
134+
135+ # Shifted arrays
136+ zenv_m1 = zenv .shift ({dim : - 1 })
137+ zenv_p1 = zenv .shift ({dim : + 1 })
138+
139+ # Compute zr
140+ zr = (zenv_m1 - zenv ) / (zenv_m1 + zenv )
141+ zr = zr .where ((zenv > 0 ) & (zenv_m1 > 0 ), 0 )
142+ for dim_name in zenv .dims :
143+ zr [{dim_name : - 1 }] = 0
144+ all_zr += [zr ]
145+
146+ # Compute ztmp
147+ zr_p1 = zr .shift ({dim : + 1 })
148+ all_ztmp += [zenv .where (zr <= rmax , zenv_m1 * zrfact )]
149+ all_ztmp += [zenv .where (zr_p1 >= - rmax , zenv_p1 * zrfact )]
150+
151+ # Update envelope bathymetry
152+ zenv = xr .concat ([zenv ] + all_ztmp , "dummy_dim" ).max ("dummy_dim" )
153+
154+ # Check target rmax
155+ zr = xr .concat (all_zr , "dummy_dim" )
156+ if ((np .abs (zr ) - rmax ) <= tol ).all ():
157+ return zenv
158+
159+ raise ValueError (
160+ "Iterative method did NOT converge."
161+ " You might want to increase the number of iterations and/or the tolerance."
162+ )
184163
185164
186165def generate_cartesian_grid (
0 commit comments