Skip to content

Commit f0189dc

Browse files
Hot start functionality added for GSSHA functions
1 parent 0afacb7 commit f0189dc

2 files changed

Lines changed: 44 additions & 11 deletions

File tree

src/xarray_data_accessor/data_converters/to_gssha.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
import logging
3-
import pyproj
43
import pandas as pd
54
import numpy as np
65
import xarray as xr
@@ -28,21 +27,27 @@
2827
Dict,
2928
Union,
3029
Optional,
30+
Literal,
3131
TypedDict,
3232
)
3333

34+
OPEN_MODES: Dict[bool, Literal['a', 'w']] = {
35+
True: 'a',
36+
False: 'w',
37+
}
38+
3439

3540
class EventIntervals(TypedDict):
3641
name: str
3742
start: datetime
3843
end: datetime
3944

4045

41-
@ DataConversionFactory.register
46+
@DataConversionFactory.register
4247
class ConvertToGSSHA(DataConverterBase):
4348
"""Converts xarray datasets to GSSHA input files."""
4449

45-
@ staticmethod
50+
@staticmethod
4651
def _get_file_path(
4752
file_dir: Optional[Union[str, Path]] = None,
4853
file_name: Optional[str] = None,
@@ -63,7 +68,7 @@ def _get_file_path(
6368
# make sure the file name is valid
6469
if not file_name:
6570
file_name = 'gssha_input'
66-
logging.warn(
71+
logging.warning(
6772
f'No file name was provided! Using default file name {file_name}.',
6873
)
6974
if not isinstance(file_name, str):
@@ -86,16 +91,17 @@ def _get_file_path(
8691
# return the file path
8792
return Path(file_dir / f'{file_name}{file_suffix}')
8893

89-
@ staticmethod
94+
@staticmethod
9095
def _write_ascii_file(
9196
text_content: str,
9297
file_path: Path,
98+
hot_start: Optional[bool] = False,
9399
) -> None:
94100
"""Writes the text content to the file path."""
95-
# write the text content to the file path
101+
96102
with open(
97103
file_path,
98-
'w',
104+
OPEN_MODES[hot_start],
99105
encoding='ascii',
100106
) as file:
101107
file.write(text_content)
@@ -117,7 +123,7 @@ def _write_ascii_file(
117123
f'Something went wrong - File {file_path} is not a valid ASCII file.',
118124
)
119125

120-
@ staticmethod
126+
@staticmethod
121127
def _write_precip_coords(
122128
easting: np.ndarray,
123129
northing: np.ndarray,
@@ -254,6 +260,7 @@ def make_gssha_precipitation_input(
254260
file_dir: Optional[Union[str, Path]] = None,
255261
file_name: Optional[str] = None,
256262
file_suffix: Optional[str] = None,
263+
hot_start: Optional[bool] = False,
257264
) -> Path:
258265
"""Creates a GSSHA precipitation input file from an xarray dataset.
259266
@@ -269,6 +276,8 @@ def make_gssha_precipitation_input(
269276
file_dir: The directory to save the file to.
270277
file_name: The name of the file to save.
271278
file_suffix: The file suffix to use.
279+
hot_start: If true data is appended to the end of the file.
280+
Otherwise, the file is overwritten.
272281
273282
Returns:
274283
The path of the output precipitation ASCII input file.
@@ -356,11 +365,12 @@ def make_gssha_precipitation_input(
356365
cls._write_ascii_file(
357366
text_content=ascii_text,
358367
file_path=file_path,
368+
hot_start=hot_start,
359369
)
360370
logging.info(f'Precipitation ASCII file saved @ {file_path}.')
361371
return file_path
362372

363-
@ classmethod
373+
@classmethod
364374
def make_gssha_grass_ascii(
365375
cls,
366376
xarray_dataset: xr.Dataset,
@@ -477,7 +487,7 @@ def make_gssha_grass_ascii(
477487
)
478488
return file_paths
479489

480-
@ classmethod
490+
@classmethod
481491
def make_gssha_hmet_wes(
482492
cls,
483493
xarray_dataset: xr.Dataset,
@@ -487,6 +497,7 @@ def make_gssha_hmet_wes(
487497
file_dir: Optional[Union[str, Path]] = None,
488498
file_name: Optional[str] = None,
489499
file_suffix: Optional[str] = None,
500+
hot_start: Optional[bool] = False,
490501
how: Optional[HMETAggregationFunctions] = None,
491502
xy_coords: Optional[Tuple[str, str]] = None,
492503
) -> Path:
@@ -505,6 +516,8 @@ def make_gssha_hmet_wes(
505516
NOTE: The file name is automatically generated.
506517
file_name: The name of the file to save.
507518
file_suffix: The file suffix to use.
519+
hot_start: If true data is appended to the end of the file.
520+
Otherwise, the file is overwritten.
508521
how: The method to use to aggregate the data at each time step.
509522
Options include: 'mean', 'median', 'min', 'max', 'sum'.
510523
xy_coords: The x and y coordinate names to use for aggregation.
@@ -586,12 +599,13 @@ def make_gssha_hmet_wes(
586599
cls._write_ascii_file(
587600
text_content=data_str,
588601
file_path=file_path,
602+
hot_start=hot_start,
589603
)
590604

591605
logging.info(f'HMET WES ASCII file saved @ {file_path}.')
592606
return file_path
593607

594-
@ classmethod
608+
@classmethod
595609
def get_conversion_functions(
596610
cls,
597611
) -> Dict[str, DataConverterBase.ConversionFunctionType]:

testing/test_5_gssha.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def test_dataset(test_dir) -> xr.Dataset:
2424
return ds
2525

2626

27+
def count_lines(filename: Path):
28+
with open(filename, 'r') as file:
29+
lines = file.readlines()
30+
return len(lines)
31+
32+
2733
def test_factory() -> None:
2834
"""Make sure the function was correctly registered."""
2935
assert 'ConvertToGSSHA' in xda.DataConversionFactory.get_converter_classes().keys()
@@ -43,11 +49,24 @@ def test_precipitation_input(test_dataset) -> None:
4349
out_path = xda.DataConversionFunctions.make_gssha_precipitation_input(
4450
test_dataset,
4551
precipitation_variable='2m_temperature',
52+
precipitation_type='GAGE',
4653
output_epsg=26915,
4754
)
4855

4956
assert out_path.exists()
5057
assert out_path.suffix == '.gag'
58+
l1 = count_lines(out_path)
59+
60+
# test the hot start
61+
out_path = xda.DataConversionFunctions.make_gssha_precipitation_input(
62+
test_dataset,
63+
precipitation_variable='2m_temperature',
64+
precipitation_type='GAGE',
65+
output_epsg=26915,
66+
hot_start=True,
67+
)
68+
l2 = count_lines(out_path)
69+
assert l1 < l2
5170
out_path.unlink()
5271

5372

0 commit comments

Comments
 (0)