4242from pyproj import Transformer
4343import numpy as np
4444import sys , time
45+ from types import SimpleNamespace
4546import datetime as dt
4647from TopoPyScale import meteo_util as mu
4748from TopoPyScale import topo_utils as tu
@@ -79,7 +80,10 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
7980
8081 # convert gridcells coordinates from WGS84 to DEM projection
8182 lons , lats = np .meshgrid (ds_plev_pt .longitude .values , ds_plev_pt .latitude .values )
82- trans = Transformer .from_crs ("epsg:4326" , "epsg:" + str (meta .get ('target_epsg' )), always_xy = True )
83+ trans = meta .get ('transformer' ) # Use cached transformer
84+ if trans is None :
85+ # Fallback for backwards compatibility
86+ trans = Transformer .from_crs ("epsg:4326" , f"epsg:{ meta .get ('target_epsg' )} " , always_xy = True )
8387 Xs , Ys = trans .transform (lons .flatten (), lats .flatten ())
8488 Xs = Xs .reshape (lons .shape )
8589 Ys = Ys .reshape (lons .shape )
@@ -148,7 +152,6 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
148152 except :
149153 raise ValueError (
150154 f'ERROR: Upper pressure level { plev_interp .level .min ().values } hPa geopotential is lower than cluster mean elevation { row .elevation } { plev_interp .z } ' )
151-
152155
153156 top = plev_interp .isel (level = ind_z_top )
154157 bot = plev_interp .isel (level = ind_z_bot )
@@ -177,10 +180,10 @@ def pt_downscale_interp(row, ds_plev_pt, ds_surf_pt, meta):
177180 },
178181 coords = {'month' : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ]}
179182 )
180- down_pt [ 'precip_lapse_rate' ] = ( 1 + monthly_coeffs . coef . sel (month = down_pt . month . values ). data * (
181- row . elevation - surf_interp . z ) * 1e-3 ) / \
182- ( 1 - monthly_coeffs . coef . sel ( month = down_pt . month . values ). data * (
183- row . elevation - surf_interp . z ) * 1e-3 )
183+ # Cache coefficient selection (avoid duplicate . sel() call)
184+ coef = monthly_coeffs . coef . sel ( month = down_pt . month . values ). data
185+ elev_diff = ( row . elevation - surf_interp . z ) * 1e-3
186+ down_pt [ 'precip_lapse_rate' ] = ( 1 + coef * elev_diff ) / ( 1 - coef * elev_diff )
184187 else :
185188 down_pt ['precip_lapse_rate' ] = down_pt .t * 0 + 1
186189
@@ -396,8 +399,8 @@ def downscale_climate(project_directory,
396399 # =========== Open dataset with Dask =================
397400 tvec = pd .date_range (start_date , pd .to_datetime (end_date ) + pd .to_timedelta ('1D' ), freq = tstep , inclusive = 'left' )
398401
399- flist_PLEV = glob .glob (f'{ climate_directory } /PLEV*.nc' )
400- flist_SURF = glob .glob (f'{ climate_directory } /SURF*.nc' )
402+ flist_PLEV = sorted ( glob .glob (f'{ climate_directory } /PLEV*.nc' ) )
403+ flist_SURF = sorted ( glob .glob (f'{ climate_directory } /SURF*.nc' ) )
401404
402405 if 'object' in list (df_centroids .dtypes ):
403406 pass
@@ -407,7 +410,7 @@ def downscale_climate(project_directory,
407410
408411 def _open_dataset_climate (flist ):
409412
410- ds_ = xr .open_mfdataset (flist , parallel = False , concat_dim = "time" , combine = 'nested' , coords = 'minimal' )
413+ ds_ = xr .open_mfdataset (flist , parallel = True , concat_dim = "time" , combine = 'nested' , coords = 'minimal' )
411414
412415 # this block handles the expver dimension that is in downloaded ERA5 data if data is ERA5/ERA5T mix. If only ERA5 or
413416 # only ERA5T it is not present. ERA5T data can be present in the timewindow T-5days to T -3months, where T is today.
@@ -473,11 +476,9 @@ def _subset_climate_dataset(ds_, row, type='plev'):
473476
474477 ds_plev = ds_plev .sel (time = tvec .values )
475478
476- row_list = []
477- ds_list = []
478- for _ , row in df_centroids .iterrows ():
479- row_list .append (row )
480- ds_list .append (ds_plev )
479+ # Convert to list of row tuples (faster than iterrows)
480+ row_list = list (df_centroids .itertuples (index = False ))
481+ ds_list = [ds_plev ] * len (row_list )
481482
482483 fun_param = zip (ds_list , row_list , ['plev' ] * len (row_list )) # construct here the tuple that goes into the pooling for arguments
483484 tu .multithread_pooling (_subset_climate_dataset , fun_param , n_threads = n_core )
@@ -509,39 +510,43 @@ def _subset_climate_dataset(ds_, row, type='plev'):
509510 print (f"Requested time range: { requested_times .min ()} to { requested_times .max ()} " )
510511 print (f"Missing timesteps: { missing_str } " )
511512
512- ds_list = []
513- for _ , _ in df_centroids .iterrows ():
514- ds_list .append (ds_surf )
513+ # Repeat ds_surf for each centroid (no loop needed)
514+ ds_list = [ds_surf ] * len (row_list )
515515
516516 fun_param = zip (ds_list , row_list , ['surf' ] * len (row_list )) # construct here the tuple that goes into the pooling for arguments
517517 tu .multithread_pooling (_subset_climate_dataset , fun_param , n_threads = n_core )
518518 fun_param = None
519519 ds_surf = None
520520
521521 # Preparing list to feed into Pooling
522- surf_pt_list = []
523- plev_pt_list = []
524- ds_solar_list = []
525- horizon_da_list = []
526- row_list = []
527- meta_list = []
528- i = 0
529- for _ , row in df_centroids .iterrows ():
530- surf_pt_list .append (xr .open_dataset (output_directory / f'tmp/ds_surf_pt_{ row .point_name } .nc' , engine = 'h5netcdf' ))
531- plev_pt_list .append (xr .open_dataset (output_directory / f'tmp/ds_plev_pt_{ row .point_name } .nc' , engine = 'h5netcdf' ))
532- ds_solar_list .append (ds_solar .sel (point_name = row .point_name ))
533- horizon_da_list .append (horizon_da )
534- row_list .append (row )
535- meta_list .append ({'interp_method' : interp_method ,
536- 'lw_terrain_flag' : lw_terrain_flag ,
537- 'tstep' : tstep_dict .get (tstep ),
538- 'n_digits' : n_digits ,
539- 'file_pattern' : file_pattern ,
540- 'target_epsg' :target_EPSG ,
541- 'precip_lapse_rate_flag' :precip_lapse_rate_flag ,
542- 'output_directory' :output_directory ,
543- 'lw_terrain_flag' :lw_terrain_flag })
544- i += 1
522+ # Pre-build meta dict once (same for all points)
523+ # Pre-create transformer once (used by all points for CRS conversion)
524+ transformer = Transformer .from_crs ("epsg:4326" , f"epsg:{ target_EPSG } " , always_xy = True )
525+
526+ meta_template = {
527+ 'interp_method' : interp_method ,
528+ 'lw_terrain_flag' : lw_terrain_flag ,
529+ 'tstep' : tstep_dict .get (tstep ),
530+ 'n_digits' : n_digits ,
531+ 'file_pattern' : file_pattern ,
532+ 'target_epsg' : target_EPSG ,
533+ 'precip_lapse_rate_flag' : precip_lapse_rate_flag ,
534+ 'output_directory' : output_directory ,
535+ 'transformer' : transformer , # Cached CRS transformer
536+ }
537+
538+ # Build lists using itertuples (faster than iterrows)
539+ # Convert to SimpleNamespace for pickling compatibility with pandas 2.x + multiprocessing
540+ cols = df_centroids .columns .tolist ()
541+ row_list = [SimpleNamespace (** dict (zip (cols , row )))
542+ for row in df_centroids .itertuples (index = False , name = None )]
543+ point_names = df_centroids .point_name .values
544+
545+ surf_pt_list = [xr .open_dataset (output_directory / f'tmp/ds_surf_pt_{ pn } .nc' , engine = 'h5netcdf' ) for pn in point_names ]
546+ plev_pt_list = [xr .open_dataset (output_directory / f'tmp/ds_plev_pt_{ pn } .nc' , engine = 'h5netcdf' ) for pn in point_names ]
547+ ds_solar_list = [ds_solar .sel (point_name = pn ) for pn in point_names ]
548+ horizon_da_list = [horizon_da ] * len (row_list )
549+ meta_list = [meta_template .copy () for _ in row_list ]
545550
546551 fun_param = zip (row_list , plev_pt_list , surf_pt_list , meta_list ) # construct here the tuple that goes into the pooling for arguments
547552 tu .multicore_pooling (pt_downscale_interp , fun_param , n_core )
@@ -570,5 +575,5 @@ def read_downscaled(path='outputs/down_pt*.nc'):
570575 dataset: merged dataset readily to use and loaded in chuncks via Dask
571576 """
572577
573- down_pts = xr .open_mfdataset (path , concat_dim = 'point_name' , combine = 'nested' , parallel = False )
578+ down_pts = xr .open_mfdataset (path , concat_dim = 'point_name' , combine = 'nested' , parallel = True )
574579 return down_pts
0 commit comments