Skip to content

Commit f0087d6

Browse files
authored
Merge pull request #135 from ArcticSnow/optimisations
Performance optimisations and pandas 2.x compatibility fixes
2 parents 5c6a9b7 + c966d2b commit f0087d6

6 files changed

Lines changed: 149 additions & 114 deletions

File tree

TopoPyScale/fetch_era5.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,7 @@ def remap_CDSbeta(file_pattern, file_type='SURF'):
955955
try:
956956
ds = xr.open_dataset(nc_file)
957957
ds = ds.rename({ 'pressure_level': 'level', 'valid_time' : 'time'})
958-
ds = ds.isel(level=slice(None, None, -1)) # reverse order of levels
958+
ds = ds.sortby('level', ascending=True) # sort levels ascending (300 to 1000)
959959

960960
try:
961961
ds = ds.drop_vars('number')
@@ -1028,7 +1028,7 @@ def era5_request_surf_snowmapper( today, latN, latS, lonE, lonW, eraDir, output_
10281028

10291029
bbox = [str(latN), str(lonW), str(latS), str(lonE)]
10301030

1031-
target = eraDir + "/forecast/SURF_%04d%02d%02d.nc" % (today.year, today.month, today.day)
1031+
target = str(eraDir) + "/forecast/SURF_%04d%02d%02d.nc" % (today.year, today.month, today.day)
10321032

10331033
c = cdsapi.Client()
10341034
c.retrieve(
@@ -1080,7 +1080,7 @@ def era5_request_plev_snowmapper(today, latN, latS, lonE, lonW, eraDir, plevels,
10801080

10811081
bbox = [str(latN), str(lonW), str(latS), str(lonE)]
10821082

1083-
target = eraDir + "/forecast/PLEV_%04d%02d%02d.nc" % (today.year, today.month, today.day)
1083+
target = str(eraDir) + "/forecast/PLEV_%04d%02d%02d.nc" % (today.year, today.month, today.day)
10841084

10851085

10861086
c = cdsapi.Client()

TopoPyScale/sim_fsm2oshd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def read_fsm_ds_with_mask(path):
2727
'''
2828
function to read fsm as ds, adding a dummy cluster for masked area
2929
'''
30-
ds = xr.open_mfdataset(path, concat_dim='point_ind', combine='nested')
30+
ds = xr.open_mfdataset(path, concat_dim='point_ind', combine='nested', parallel=True)
3131
tp = ds.isel(point_ind=0).copy()
3232
tp['point_ind'] = -9999
3333
for var in list(tp.keys()):

TopoPyScale/topo_param.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ def extract_pts_param(df_pts, ds_param, method='nearest'):
8787
df_pts[['elevation', 'slope', 'aspect', 'aspect_cos', 'aspect_sin', 'svf']] = 0
8888

8989
if method == 'nearest':
90-
for i, row in df_pts.iterrows():
90+
for row in df_pts.itertuples():
9191
d_mini = ds_param.sel(x=row.x, y=row.y, method='nearest')
92-
df_pts.loc[i, ['elevation', 'slope', 'aspect', 'aspect_cos', 'aspect_sin', 'svf']] = np.array((d_mini.elevation.values,
92+
df_pts.loc[row.Index, ['elevation', 'slope', 'aspect', 'aspect_cos', 'aspect_sin', 'svf']] = np.array((d_mini.elevation.values,
9393
d_mini.slope.values,
9494
d_mini.aspect.values,
9595
d_mini.aspect_cos,
9696
d_mini.aspect_sin,
9797
d_mini.svf.values))
9898
elif method == 'idw' or method == 'linear':
99-
for i, row in df_pts.iterrows():
99+
for row in df_pts.itertuples():
100100
ind_lat = np.abs(ds_param.y-row.y).argmin()
101101
ind_lon = np.abs(ds_param.x-row.x).argmin()
102102
ds_param_pt = ds_param.isel(y=[ind_lat-1, ind_lat, ind_lat+1], x=[ind_lon-1, ind_lon, ind_lon+1])
@@ -119,7 +119,7 @@ def extract_pts_param(df_pts, ds_param, method='nearest'):
119119
)
120120
dw = xr.Dataset.weighted(ds_param_pt, da_idw)
121121
d_mini = dw.sum(['x', 'y'], keep_attrs=True)
122-
df_pts.loc[i, ['elevation', 'slope', 'aspect', 'aspect_cos', 'aspect_sin', 'svf']] = np.array((d_mini.elevation.values,
122+
df_pts.loc[row.Index, ['elevation', 'slope', 'aspect', 'aspect_cos', 'aspect_sin', 'svf']] = np.array((d_mini.elevation.values,
123123
d_mini.slope.values,
124124
d_mini.aspect.values,
125125
d_mini.aspect_cos,

TopoPyScale/topo_scale.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pyproj import Transformer
4343
import numpy as np
4444
import sys, time
45+
from types import SimpleNamespace
4546
import datetime as dt
4647
from TopoPyScale import meteo_util as mu
4748
from TopoPyScale import topo_utils as tu
@@ -79,7 +80,10 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
7980

8081
# convert gridcells coordinates from WGS84 to DEM projection
8182
lons, lats = np.meshgrid(ds_plev_pt.longitude.values, ds_plev_pt.latitude.values)
82-
trans = Transformer.from_crs("epsg:4326", "epsg:" + str(meta.get('target_epsg')), always_xy=True)
83+
trans = meta.get('transformer') # Use cached transformer
84+
if trans is None:
85+
# Fallback for backwards compatibility
86+
trans = Transformer.from_crs("epsg:4326", f"epsg:{meta.get('target_epsg')}", always_xy=True)
8387
Xs, Ys = trans.transform(lons.flatten(), lats.flatten())
8488
Xs = Xs.reshape(lons.shape)
8589
Ys = Ys.reshape(lons.shape)
@@ -148,7 +152,6 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
148152
except:
149153
raise ValueError(
150154
f'ERROR: Upper pressure level {plev_interp.level.min().values} hPa geopotential is lower than cluster mean elevation {row.elevation} {plev_interp.z}')
151-
152155

153156
top = plev_interp.isel(level=ind_z_top)
154157
bot = plev_interp.isel(level=ind_z_bot)
@@ -177,10 +180,10 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
177180
},
178181
coords={'month': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]}
179182
)
180-
down_pt['precip_lapse_rate'] = (1 + monthly_coeffs.coef.sel(month=down_pt.month.values).data * (
181-
row.elevation - surf_interp.z) * 1e-3) / \
182-
(1 - monthly_coeffs.coef.sel(month=down_pt.month.values).data * (
183-
row.elevation - surf_interp.z) * 1e-3)
183+
# Cache coefficient selection (avoid duplicate .sel() call)
184+
coef = monthly_coeffs.coef.sel(month=down_pt.month.values).data
185+
elev_diff = (row.elevation - surf_interp.z) * 1e-3
186+
down_pt['precip_lapse_rate'] = (1 + coef * elev_diff) / (1 - coef * elev_diff)
184187
else:
185188
down_pt['precip_lapse_rate'] = down_pt.t * 0 + 1
186189

@@ -396,8 +399,8 @@ def downscale_climate(project_directory,
396399
# =========== Open dataset with Dask =================
397400
tvec = pd.date_range(start_date, pd.to_datetime(end_date) + pd.to_timedelta('1D'), freq=tstep, inclusive='left')
398401

399-
flist_PLEV = glob.glob(f'{climate_directory}/PLEV*.nc')
400-
flist_SURF = glob.glob(f'{climate_directory}/SURF*.nc')
402+
flist_PLEV = sorted(glob.glob(f'{climate_directory}/PLEV*.nc'))
403+
flist_SURF = sorted(glob.glob(f'{climate_directory}/SURF*.nc'))
401404

402405
if 'object' in list(df_centroids.dtypes):
403406
pass
@@ -407,7 +410,7 @@ def downscale_climate(project_directory,
407410

408411
def _open_dataset_climate(flist):
409412

410-
ds_ = xr.open_mfdataset(flist, parallel=False, concat_dim="time", combine='nested', coords='minimal')
413+
ds_ = xr.open_mfdataset(flist, parallel=True, concat_dim="time", combine='nested', coords='minimal')
411414

412415
# this block handles the expver dimension that is in downloaded ERA5 data if data is ERA5/ERA5T mix. If only ERA5 or
413416
# only ERA5T it is not present. ERA5T data can be present in the timewindow T-5days to T -3months, where T is today.
@@ -473,11 +476,9 @@ def _subset_climate_dataset(ds_, row, type='plev'):
473476

474477
ds_plev = ds_plev.sel(time=tvec.values)
475478

476-
row_list = []
477-
ds_list = []
478-
for _, row in df_centroids.iterrows():
479-
row_list.append(row)
480-
ds_list.append(ds_plev)
479+
# Convert to list of row tuples (faster than iterrows)
480+
row_list = list(df_centroids.itertuples(index=False))
481+
ds_list = [ds_plev] * len(row_list)
481482

482483
fun_param = zip(ds_list, row_list, ['plev'] * len(row_list)) # construct here the tuple that goes into the pooling for arguments
483484
tu.multithread_pooling(_subset_climate_dataset, fun_param, n_threads=n_core)
@@ -509,39 +510,43 @@ def _subset_climate_dataset(ds_, row, type='plev'):
509510
print(f"Requested time range: {requested_times.min()} to {requested_times.max()}")
510511
print(f"Missing timesteps: {missing_str}")
511512

512-
ds_list = []
513-
for _, _ in df_centroids.iterrows():
514-
ds_list.append(ds_surf)
513+
# Repeat ds_surf for each centroid (no loop needed)
514+
ds_list = [ds_surf] * len(row_list)
515515

516516
fun_param = zip(ds_list, row_list, ['surf'] * len(row_list)) # construct here the tuple that goes into the pooling for arguments
517517
tu.multithread_pooling(_subset_climate_dataset, fun_param, n_threads=n_core)
518518
fun_param = None
519519
ds_surf = None
520520

521521
# Preparing list to feed into Pooling
522-
surf_pt_list = []
523-
plev_pt_list = []
524-
ds_solar_list = []
525-
horizon_da_list = []
526-
row_list = []
527-
meta_list = []
528-
i = 0
529-
for _, row in df_centroids.iterrows():
530-
surf_pt_list.append(xr.open_dataset(output_directory / f'tmp/ds_surf_pt_{row.point_name}.nc', engine='h5netcdf'))
531-
plev_pt_list.append(xr.open_dataset(output_directory / f'tmp/ds_plev_pt_{row.point_name}.nc', engine='h5netcdf'))
532-
ds_solar_list.append(ds_solar.sel(point_name=row.point_name))
533-
horizon_da_list.append(horizon_da)
534-
row_list.append(row)
535-
meta_list.append({'interp_method': interp_method,
536-
'lw_terrain_flag': lw_terrain_flag,
537-
'tstep': tstep_dict.get(tstep),
538-
'n_digits': n_digits,
539-
'file_pattern': file_pattern,
540-
'target_epsg':target_EPSG,
541-
'precip_lapse_rate_flag':precip_lapse_rate_flag,
542-
'output_directory':output_directory,
543-
'lw_terrain_flag':lw_terrain_flag})
544-
i+=1
522+
# Pre-build meta dict once (same for all points)
523+
# Pre-create transformer once (used by all points for CRS conversion)
524+
transformer = Transformer.from_crs("epsg:4326", f"epsg:{target_EPSG}", always_xy=True)
525+
526+
meta_template = {
527+
'interp_method': interp_method,
528+
'lw_terrain_flag': lw_terrain_flag,
529+
'tstep': tstep_dict.get(tstep),
530+
'n_digits': n_digits,
531+
'file_pattern': file_pattern,
532+
'target_epsg': target_EPSG,
533+
'precip_lapse_rate_flag': precip_lapse_rate_flag,
534+
'output_directory': output_directory,
535+
'transformer': transformer, # Cached CRS transformer
536+
}
537+
538+
# Build lists using itertuples (faster than iterrows)
539+
# Convert to SimpleNamespace for pickling compatibility with pandas 2.x + multiprocessing
540+
cols = df_centroids.columns.tolist()
541+
row_list = [SimpleNamespace(**dict(zip(cols, row)))
542+
for row in df_centroids.itertuples(index=False, name=None)]
543+
point_names = df_centroids.point_name.values
544+
545+
surf_pt_list = [xr.open_dataset(output_directory / f'tmp/ds_surf_pt_{pn}.nc', engine='h5netcdf') for pn in point_names]
546+
plev_pt_list = [xr.open_dataset(output_directory / f'tmp/ds_plev_pt_{pn}.nc', engine='h5netcdf') for pn in point_names]
547+
ds_solar_list = [ds_solar.sel(point_name=pn) for pn in point_names]
548+
horizon_da_list = [horizon_da] * len(row_list)
549+
meta_list = [meta_template.copy() for _ in row_list]
545550

546551
fun_param = zip(row_list, plev_pt_list, surf_pt_list, meta_list) # construct here the tuple that goes into the pooling for arguments
547552
tu.multicore_pooling(pt_downscale_interp, fun_param, n_core)
@@ -570,5 +575,5 @@ def read_downscaled(path='outputs/down_pt*.nc'):
570575
dataset: merged dataset readily to use and loaded in chuncks via Dask
571576
"""
572577

573-
down_pts = xr.open_mfdataset(path, concat_dim='point_name', combine='nested', parallel=False)
578+
down_pts = xr.open_mfdataset(path, concat_dim='point_name', combine='nested', parallel=True)
574579
return down_pts

0 commit comments

Comments
 (0)