Skip to content

Commit 8e08398

Browse files
authored
Refactor MB06 (#51)
* refactor to avoid loops * apply Diego's suggestions * remove useless return after error
1 parent 6750533 commit 8e08398

1 file changed

Lines changed: 53 additions & 74 deletions

File tree

pydomcfg/utils.py

Lines changed: 53 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Utilities
33
"""
44

5-
# from itertools import product
65
from typing import Hashable, Iterable, Iterator, Optional
76

87
import numpy as np
@@ -80,10 +79,15 @@ def _calc_rmax(depth: DataArray) -> DataArray:
8079
return rmax.fillna(0)
8180

8281

83-
def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray:
82+
def _smooth_MB06(
83+
depth: DataArray,
84+
rmax: float,
85+
tol: float = 1.0e-8,
86+
max_iter: int = 10_000,
87+
) -> DataArray:
8488
"""
85-
This is NEMO implementation of the direct iterative method
86-
of Martinho and Batteen (2006).
89+
Direct iterative method of Martinho and Batteen (2006) consistent
90+
with NEMO implementation.
8791
8892
The algorithm ensures that
8993
@@ -100,87 +104,62 @@ def _smooth_MB06(depth: DataArray, rmax: float) -> DataArray:
100104
Parameters
101105
----------
102106
depth: DataArray
103-
Bottom depth (units: m).
107+
Bottom depth.
104108
rmax: float
105109
Maximum slope parameter allowed
110+
tol: float, default = 1.0e-8
111+
Tolerance for the iterative method
112+
max_iter: int, default = 10000
113+
Maximum number of iterations
106114
107115
Returns
108116
-------
109117
DataArray
110118
Smooth version of the bottom topography with
111-
a maximum slope parameter < rmax (units: m).
112-
119+
a maximum slope parameter < rmax.
113120
"""
114121

115-
# set scaling factor used for smoothing
122+
# Set scaling factor used for smoothing
116123
zrfact = (1.0 - rmax) / (1.0 + rmax)
117124

118-
# getting the actual numpy array
119-
# TO BE OPTIMISED
120-
da_zenv = depth.copy()
121-
zenv = da_zenv.data
122-
nj = zenv.shape[0]
123-
ni = zenv.shape[1]
124-
125-
# initialise temporary evelope depth arrays
126-
ztmpi1 = zenv.copy()
127-
ztmpi2 = zenv.copy()
128-
ztmpj1 = zenv.copy()
129-
ztmpj2 = zenv.copy()
130-
131-
# Computing the initial maximum slope parameter
132-
zrmax = 1.0 # np.nanmax(_calc_rmax(depth))
133-
zri = np.ones(zenv.shape) # * zrmax
134-
zrj = np.ones(zenv.shape) # * zrmax
135-
136-
tol = 1.0e-8
137-
itr = 0
138-
max_itr = 10000
139-
140-
while itr <= max_itr and (zrmax - rmax) > tol:
141-
142-
itr += 1
143-
zrmax = 0.0
144-
# we set zrmax from previous r-values (zri and zrj) first
145-
# if set after current r-value calculation (as previously)
146-
# we could exit DO WHILE prematurely before checking r-value
147-
# of current zenv
148-
max_zri = np.nanmax(np.absolute(zri))
149-
max_zrj = np.nanmax(np.absolute(zrj))
150-
zrmax = np.nanmax([zrmax, max_zrj, max_zri])
151-
152-
print("Iter:", itr, "rmax: ", zrmax)
153-
154-
zri *= 0.0
155-
zrj *= 0.0
156-
157-
for j in range(nj - 1):
158-
for i in range(ni - 1):
159-
ip1 = np.minimum(i + 1, ni)
160-
jp1 = np.minimum(j + 1, nj)
161-
if zenv[j, i] > 0.0 and zenv[j, ip1] > 0.0:
162-
zri[j, i] = (zenv[j, ip1] - zenv[j, i]) / (
163-
zenv[j, ip1] + zenv[j, i]
164-
)
165-
if zenv[j, i] > 0.0 and zenv[jp1, i] > 0.0:
166-
zrj[j, i] = (zenv[jp1, i] - zenv[j, i]) / (
167-
zenv[jp1, i] + zenv[j, i]
168-
)
169-
if zri[j, i] > rmax:
170-
ztmpi1[j, i] = zenv[j, ip1] * zrfact
171-
if zri[j, i] < -rmax:
172-
ztmpi2[j, ip1] = zenv[j, i] * zrfact
173-
if zrj[j, i] > rmax:
174-
ztmpj1[j, i] = zenv[jp1, i] * zrfact
175-
if zrj[j, i] < -rmax:
176-
ztmpj2[jp1, i] = zenv[j, i] * zrfact
177-
178-
ztmpi = np.maximum(ztmpi1, ztmpi2)
179-
ztmpj = np.maximum(ztmpj1, ztmpj2)
180-
zenv = np.maximum(zenv, np.maximum(ztmpi, ztmpj))
181-
182-
da_zenv.data = zenv
183-
return da_zenv
125+
# Initialize envelope bathymetry
126+
zenv = depth
127+
128+
for _ in range(max_iter):
129+
130+
# Initialize lists of DataArrays to concatenate
131+
all_ztmp = []
132+
all_zr = []
133+
for dim in zenv.dims:
134+
135+
# Shifted arrays
136+
zenv_m1 = zenv.shift({dim: -1})
137+
zenv_p1 = zenv.shift({dim: +1})
138+
139+
# Compute zr
140+
zr = (zenv_m1 - zenv) / (zenv_m1 + zenv)
141+
zr = zr.where((zenv > 0) & (zenv_m1 > 0), 0)
142+
for dim_name in zenv.dims:
143+
zr[{dim_name: -1}] = 0
144+
all_zr += [zr]
145+
146+
# Compute ztmp
147+
zr_p1 = zr.shift({dim: +1})
148+
all_ztmp += [zenv.where(zr <= rmax, zenv_m1 * zrfact)]
149+
all_ztmp += [zenv.where(zr_p1 >= -rmax, zenv_p1 * zrfact)]
150+
151+
# Update envelope bathymetry
152+
zenv = xr.concat([zenv] + all_ztmp, "dummy_dim").max("dummy_dim")
153+
154+
# Check target rmax
155+
zr = xr.concat(all_zr, "dummy_dim")
156+
if ((np.abs(zr) - rmax) <= tol).all():
157+
return zenv
158+
159+
raise ValueError(
160+
"Iterative method did NOT converge."
161+
" You might want to increase the number of iterations and/or the tolerance."
162+
)
184163

185164

186165
def generate_cartesian_grid(

0 commit comments

Comments
 (0)