Skip to content

Commit 2a007dc

Browse files
committed
add type annotations and test with mypy
1 parent c34c6bf commit 2a007dc

4 files changed

Lines changed: 104 additions & 66 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,4 @@ repos:
6161
- id: mypy
6262
language_version: python3.12
6363
additional_dependencies: [types-all]
64-
files: src/.*\.py$
64+
files: heat/.*\.py$

heat/bmi_heat.py

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#! /usr/bin/env python
22
"""Basic Model Interface implementation for the 2D heat model."""
33

4+
from typing import Any
5+
46
import numpy as np
57
from bmipy import Bmi
8+
from numpy.typing import NDArray
69

710
from .heat import Heat
811

@@ -15,20 +18,21 @@ class BmiHeat(Bmi):
1518
_input_var_names = ("plate_surface__temperature",)
1619
_output_var_names = ("plate_surface__temperature",)
1720

18-
def __init__(self):
21+
def __init__(self) -> None:
1922
"""Create a BmiHeat model that is ready for initialization."""
20-
self._model = None
21-
self._values = {}
22-
self._var_units = {}
23-
self._var_loc = {}
24-
self._grids = {}
25-
self._grid_type = {}
23+
# self._model: Heat | None = None
24+
self._model: Heat
25+
self._values: dict[str, NDArray[Any]] = {}
26+
self._var_units: dict[str, str] = {}
27+
self._var_loc: dict[str, str] = {}
28+
self._grids: dict[int, list[str]] = {}
29+
self._grid_type: dict[int, str] = {}
2630

2731
self._start_time = 0.0
28-
self._end_time = np.finfo("d").max
32+
self._end_time = float(np.finfo("d").max)
2933
self._time_units = "s"
3034

31-
def initialize(self, filename=None):
35+
def initialize(self, filename: str | None = None) -> None:
3236
"""Initialize the Heat model.
3337
3438
Parameters
@@ -40,7 +44,7 @@ def initialize(self, filename=None):
4044
self._model = Heat()
4145
elif isinstance(filename, str):
4246
with open(filename) as file_obj:
43-
self._model = Heat.from_file_like(file_obj.read())
47+
self._model = Heat.from_file_like(file_obj)
4448
else:
4549
self._model = Heat.from_file_like(filename)
4650

@@ -50,11 +54,11 @@ def initialize(self, filename=None):
5054
self._grids = {0: ["plate_surface__temperature"]}
5155
self._grid_type = {0: "uniform_rectilinear"}
5256

53-
def update(self):
57+
def update(self) -> None:
5458
"""Advance model by one time step."""
5559
self._model.advance_in_time()
5660

57-
def update_frac(self, time_frac):
61+
def update_frac(self, time_frac: float) -> None:
5862
"""Update model by a fraction of a time step.
5963
6064
Parameters
@@ -67,7 +71,7 @@ def update_frac(self, time_frac):
6771
self.update()
6872
self._model.time_step = time_step
6973

70-
def update_until(self, then):
74+
def update_until(self, then: float) -> None:
7175
"""Update model until a particular time.
7276
7377
Parameters
@@ -81,11 +85,12 @@ def update_until(self, then):
8185
self.update()
8286
self.update_frac(n_steps - int(n_steps))
8387

84-
def finalize(self):
88+
def finalize(self) -> None:
8589
"""Finalize model."""
86-
self._model = None
90+
del self._model
91+
# self._model = None
8792

88-
def get_var_type(self, var_name):
93+
def get_var_type(self, var_name: str) -> str:
8994
"""Data type of variable.
9095
9196
Parameters
@@ -100,7 +105,7 @@ def get_var_type(self, var_name):
100105
"""
101106
return str(self.get_value_ptr(var_name).dtype)
102107

103-
def get_var_units(self, var_name):
108+
def get_var_units(self, var_name: str) -> str:
104109
"""Get units of variable.
105110
106111
Parameters
@@ -115,7 +120,7 @@ def get_var_units(self, var_name):
115120
"""
116121
return self._var_units[var_name]
117122

118-
def get_var_nbytes(self, var_name):
123+
def get_var_nbytes(self, var_name: str) -> int:
119124
"""Get units of variable.
120125
121126
Parameters
@@ -130,13 +135,13 @@ def get_var_nbytes(self, var_name):
130135
"""
131136
return self.get_value_ptr(var_name).nbytes
132137

133-
def get_var_itemsize(self, name):
138+
def get_var_itemsize(self, name: str) -> int:
134139
return np.dtype(self.get_var_type(name)).itemsize
135140

136-
def get_var_location(self, name):
141+
def get_var_location(self, name: str) -> str:
137142
return self._var_loc[name]
138143

139-
def get_var_grid(self, var_name):
144+
def get_var_grid(self, var_name: str) -> int | None:
140145
"""Grid id for a variable.
141146
142147
Parameters
@@ -152,8 +157,9 @@ def get_var_grid(self, var_name):
152157
for grid_id, var_name_list in self._grids.items():
153158
if var_name in var_name_list:
154159
return grid_id
160+
return None
155161

156-
def get_grid_rank(self, grid_id):
162+
def get_grid_rank(self, grid_id: int) -> int:
157163
"""Rank of grid.
158164
159165
Parameters
@@ -168,7 +174,7 @@ def get_grid_rank(self, grid_id):
168174
"""
169175
return len(self._model.shape)
170176

171-
def get_grid_size(self, grid_id):
177+
def get_grid_size(self, grid_id: int) -> int:
172178
"""Size of grid.
173179
174180
Parameters
@@ -183,7 +189,7 @@ def get_grid_size(self, grid_id):
183189
"""
184190
return int(np.prod(self._model.shape))
185191

186-
def get_value_ptr(self, var_name):
192+
def get_value_ptr(self, var_name: str) -> NDArray[Any]:
187193
"""Reference to values.
188194
189195
Parameters
@@ -198,7 +204,7 @@ def get_value_ptr(self, var_name):
198204
"""
199205
return self._values[var_name]
200206

201-
def get_value(self, var_name, dest):
207+
def get_value(self, var_name: str, dest: NDArray[Any]) -> NDArray[Any]:
202208
"""Copy of values.
203209
204210
Parameters
@@ -216,7 +222,9 @@ def get_value(self, var_name, dest):
216222
dest[:] = self.get_value_ptr(var_name).flatten()
217223
return dest
218224

219-
def get_value_at_indices(self, var_name, dest, indices):
225+
def get_value_at_indices(
226+
self, var_name: str, dest: NDArray[Any], indices: NDArray[np.int_]
227+
) -> NDArray[Any]:
220228
"""Get values at particular indices.
221229
222230
Parameters
@@ -236,7 +244,7 @@ def get_value_at_indices(self, var_name, dest, indices):
236244
dest[:] = self.get_value_ptr(var_name).take(indices)
237245
return dest
238246

239-
def set_value(self, var_name, src):
247+
def set_value(self, var_name: str, src: NDArray[Any]) -> None:
240248
"""Set model values.
241249
242250
Parameters
@@ -249,7 +257,9 @@ def set_value(self, var_name, src):
249257
val = self.get_value_ptr(var_name)
250258
val[:] = src.reshape(val.shape)
251259

252-
def set_value_at_indices(self, name, inds, src):
260+
def set_value_at_indices(
261+
self, name: str, inds: NDArray[np.int_], src: NDArray[Any]
262+
) -> None:
253263
"""Set model values at particular indices.
254264
255265
Parameters
@@ -264,76 +274,80 @@ def set_value_at_indices(self, name, inds, src):
264274
val = self.get_value_ptr(name)
265275
val.flat[inds] = src
266276

267-
def get_component_name(self):
277+
def get_component_name(self) -> str:
268278
"""Name of the component."""
269279
return self._name
270280

271-
def get_input_item_count(self):
281+
def get_input_item_count(self) -> int:
272282
"""Get names of input variables."""
273283
return len(self._input_var_names)
274284

275-
def get_output_item_count(self):
285+
def get_output_item_count(self) -> int:
276286
"""Get names of output variables."""
277287
return len(self._output_var_names)
278288

279-
def get_input_var_names(self):
289+
def get_input_var_names(self) -> tuple[str, ...]:
280290
"""Get names of input variables."""
281291
return self._input_var_names
282292

283-
def get_output_var_names(self):
293+
def get_output_var_names(self) -> tuple[str, ...]:
284294
"""Get names of output variables."""
285295
return self._output_var_names
286296

287-
def get_grid_shape(self, grid_id, shape):
297+
def get_grid_shape(self, grid_id: int, shape: NDArray[np.int_]) -> NDArray[np.int_]:
288298
"""Number of rows and columns of uniform rectilinear grid."""
289299
var_name = self._grids[grid_id][0]
290300
shape[:] = self.get_value_ptr(var_name).shape
291301
return shape
292302

293-
def get_grid_spacing(self, grid_id, spacing):
303+
def get_grid_spacing(
304+
self, grid_id: int, spacing: NDArray[np.float_]
305+
) -> NDArray[np.float_]:
294306
"""Spacing of rows and columns of uniform rectilinear grid."""
295307
spacing[:] = self._model.spacing
296308
return spacing
297309

298-
def get_grid_origin(self, grid_id, origin):
310+
def get_grid_origin(
311+
self, grid_id: int, origin: NDArray[np.float_]
312+
) -> NDArray[np.float_]:
299313
"""Origin of uniform rectilinear grid."""
300314
origin[:] = self._model.origin
301315
return origin
302316

303-
def get_grid_type(self, grid_id):
317+
def get_grid_type(self, grid_id: int) -> str:
304318
"""Type of grid."""
305319
return self._grid_type[grid_id]
306320

307-
def get_start_time(self):
321+
def get_start_time(self) -> float:
308322
"""Start time of model."""
309323
return self._start_time
310324

311-
def get_end_time(self):
325+
def get_end_time(self) -> float:
312326
"""End time of model."""
313327
return self._end_time
314328

315-
def get_current_time(self):
329+
def get_current_time(self) -> float:
316330
return self._model.time
317331

318-
def get_time_step(self):
332+
def get_time_step(self) -> float:
319333
return self._model.time_step
320334

321-
def get_time_units(self):
335+
def get_time_units(self) -> str:
322336
return self._time_units
323337

324-
def get_grid_edge_count(self, grid):
338+
def get_grid_edge_count(self, grid: int) -> int:
325339
raise NotImplementedError("get_grid_edge_count")
326340

327-
def get_grid_edge_nodes(self, grid, edge_nodes):
341+
def get_grid_edge_nodes(self, grid: int, edge_nodes: NDArray[np.int_]) -> None:
328342
raise NotImplementedError("get_grid_edge_nodes")
329343

330-
def get_grid_face_count(self, grid):
344+
def get_grid_face_count(self, grid: int) -> None:
331345
raise NotImplementedError("get_grid_face_count")
332346

333-
def get_grid_face_nodes(self, grid, face_nodes):
347+
def get_grid_face_nodes(self, grid: int, face_nodes: NDArray[np.int_]) -> None:
334348
raise NotImplementedError("get_grid_face_nodes")
335349

336-
def get_grid_node_count(self, grid):
350+
def get_grid_node_count(self, grid: int) -> int:
337351
"""Number of grid nodes.
338352
339353
Parameters
@@ -348,17 +362,19 @@ def get_grid_node_count(self, grid):
348362
"""
349363
return self.get_grid_size(grid)
350364

351-
def get_grid_nodes_per_face(self, grid, nodes_per_face):
365+
def get_grid_nodes_per_face(
366+
self, grid: int, nodes_per_face: NDArray[np.int_]
367+
) -> None:
352368
raise NotImplementedError("get_grid_nodes_per_face")
353369

354-
def get_grid_face_edges(self, grid, face_edges):
370+
def get_grid_face_edges(self, grid: int, face_edges: NDArray[np.int_]) -> None:
355371
raise NotImplementedError("get_grid_face_edges")
356372

357-
def get_grid_x(self, grid, x):
373+
def get_grid_x(self, grid: int, x: NDArray[np.float_]) -> None:
358374
raise NotImplementedError("get_grid_x")
359375

360-
def get_grid_y(self, grid, y):
376+
def get_grid_y(self, grid: int, y: NDArray[np.float_]) -> None:
361377
raise NotImplementedError("get_grid_y")
362378

363-
def get_grid_z(self, grid, z):
379+
def get_grid_z(self, grid: int, z: NDArray[np.float_]) -> None:
364380
raise NotImplementedError("get_grid_z")

0 commit comments

Comments
 (0)