Skip to content

Commit 4d6bfdd

Browse files
pkdashthodson-usgs
andauthored
[#116] initial implementation of type hints (#117)
* [#116] initial implementation of type hints * Apply suggestions from code review - using Tuple from typing module Co-authored-by: Timothy Hodson <34148978+thodson-usgs@users.noreply.github.com> --------- Co-authored-by: Timothy Hodson <34148978+thodson-usgs@users.noreply.github.com>
1 parent f6d18aa commit 4d6bfdd

2 files changed

Lines changed: 39 additions & 13 deletions

File tree

dataretrieval/nwis.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
1313
"""
1414

15+
import re
1516
import warnings
16-
import pandas as pd
1717
from io import StringIO
18-
import re
18+
from typing import List, Optional, Union, Tuple
19+
20+
import pandas as pd
1921

20-
from dataretrieval.utils import to_str, format_datetime, update_merge
2122
from dataretrieval.utils import BaseMetadata
23+
from dataretrieval.utils import format_datetime, to_str, update_merge
2224
from .utils import query
2325

2426
WATERDATA_BASE_URL = 'https://nwis.waterdata.usgs.gov/'
@@ -270,28 +272,31 @@ def _discharge_measurements(ssl_check=True, **kwargs):
270272
return _read_rdb(response.text), NWIS_Metadata(response, **kwargs)
271273

272274

273-
def get_discharge_peaks(sites=None, start=None, end=None,
274-
multi_index=True, ssl_check=True, **kwargs):
275+
def get_discharge_peaks(sites: Optional[Union[List[str], str]] = None,
276+
start: Optional[str] = None, end: Optional[str] = None,
277+
multi_index: bool = True,
278+
ssl_check: bool = True, **kwargs) -> Tuple[pd.DataFrame, BaseMetadata]:
275279
"""
276280
Get discharge peaks from the waterdata service.
277281
278282
Parameters
279283
----------
280-
sites: array of strings
284+
sites: list of strings, string, Optional
281285
If the waterdata parameter site_no is supplied, it will overwrite the
282286
sites parameter
283-
start: string
287+
start: string, Optional
284288
If the waterdata parameter begin_date is supplied, it will overwrite
285289
the start parameter (YYYY-MM-DD)
286-
end: string
290+
end: string, Optional
287291
If the waterdata parameter end_date is supplied, it will overwrite
288292
the end parameter (YYYY-MM-DD)
289-
multi_index: boolean
290-
If False, a dataframe with a single-level index (datetime) is returned
291-
ssl_check: bool
293+
multi_index: boolean, Optional
294+
If False, a dataframe with a single-level index (datetime) is returned,
295+
default is True
296+
ssl_check: boolean, Optional
292297
If True, check SSL certificates, if False, do not check SSL,
293298
default is True
294-
**kwargs: optional
299+
**kwargs: Optional
295300
If supplied, will be used as query parameters
296301
297302
Returns
@@ -314,6 +319,9 @@ def get_discharge_peaks(sites=None, start=None, end=None,
314319
... start='1980-01-01', end='1980-01-02', stateCd='HI')
315320
316321
"""
322+
if sites and not isinstance(sites, str):
323+
assert isinstance(sites, list), "sites must be a string or a list of strings"
324+
317325
start = kwargs.pop('begin_date', start)
318326
end = kwargs.pop('end_date', end)
319327
sites = kwargs.pop('site_no', sites)

tests/waterservices_test.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def test_get_gwlevels(requests_mock):
135135
assert df.size == 16
136136
assert_metadata(requests_mock, request_url, md, site, None, format)
137137

138-
139138
def test_get_discharge_peaks(requests_mock):
140139
"""Tests get_discharge_peaks method correctly generates the request url and returns the result in a DataFrame"""
141140
format = "rdb"
@@ -149,6 +148,25 @@ def test_get_discharge_peaks(requests_mock):
149148
assert df.size == 240
150149
assert_metadata(requests_mock, request_url, md, site, None, format)
151150

151+
@pytest.mark.parametrize("site_input_type_list", [True, False])
152+
def test_get_discharge_peaks_sites_value_types(requests_mock, site_input_type_list):
153+
"""Tests get_discharge_peaks for valid input types of the 'sites' parameter"""
154+
155+
format = "rdb"
156+
site = '01594440'
157+
request_url = 'https://nwis.waterdata.usgs.gov/nwis/peaks?format={}&site_no={}' \
158+
'&begin_date=2000-02-14&end_date=2020-02-15'.format(format, site)
159+
response_file_path = 'data/waterservices_peaks.txt'
160+
mock_request(requests_mock, request_url, response_file_path)
161+
if site_input_type_list:
162+
sites = [site]
163+
else:
164+
sites = site
165+
166+
df, md = get_discharge_peaks(sites=sites, start='2000-02-14', end='2020-02-15')
167+
assert type(df) is DataFrame
168+
assert df.size == 240
169+
152170

153171
def test_get_discharge_measurements(requests_mock):
154172
"""Tests get_discharge_measurements method correctly generates the request url and returns the result in a

0 commit comments

Comments
 (0)