11#! /usr/bin/env python
22"""Basic Model Interface implementation for the 2D heat model."""
33
4+ from typing import Any
5+
46import numpy as np
57from bmipy import Bmi
8+ from numpy .typing import NDArray
69
710from .heat import Heat
811
@@ -15,20 +18,21 @@ class BmiHeat(Bmi):
1518 _input_var_names = ("plate_surface__temperature" ,)
1619 _output_var_names = ("plate_surface__temperature" ,)
1720
18- def __init__ (self ):
21+ def __init__ (self ) -> None :
1922 """Create a BmiHeat model that is ready for initialization."""
20- self ._model = None
21- self ._values = {}
22- self ._var_units = {}
23- self ._var_loc = {}
24- self ._grids = {}
25- self ._grid_type = {}
23+ # self._model: Heat | None = None
24+ self ._model : Heat
25+ self ._values : dict [str , NDArray [Any ]] = {}
26+ self ._var_units : dict [str , str ] = {}
27+ self ._var_loc : dict [str , str ] = {}
28+ self ._grids : dict [int , list [str ]] = {}
29+ self ._grid_type : dict [int , str ] = {}
2630
2731 self ._start_time = 0.0
28- self ._end_time = np .finfo ("d" ).max
32+ self ._end_time = float ( np .finfo ("d" ).max )
2933 self ._time_units = "s"
3034
31- def initialize (self , filename = None ):
35+ def initialize (self , filename : str | None = None ) -> None :
3236 """Initialize the Heat model.
3337
3438 Parameters
@@ -40,7 +44,7 @@ def initialize(self, filename=None):
4044 self ._model = Heat ()
4145 elif isinstance (filename , str ):
4246 with open (filename ) as file_obj :
43- self ._model = Heat .from_file_like (file_obj . read () )
47+ self ._model = Heat .from_file_like (file_obj )
4448 else :
4549 self ._model = Heat .from_file_like (filename )
4650
@@ -50,11 +54,11 @@ def initialize(self, filename=None):
5054 self ._grids = {0 : ["plate_surface__temperature" ]}
5155 self ._grid_type = {0 : "uniform_rectilinear" }
5256
53- def update (self ):
57+ def update (self ) -> None :
5458 """Advance model by one time step."""
5559 self ._model .advance_in_time ()
5660
57- def update_frac (self , time_frac ) :
61+ def update_frac (self , time_frac : float ) -> None :
5862 """Update model by a fraction of a time step.
5963
6064 Parameters
@@ -67,7 +71,7 @@ def update_frac(self, time_frac):
6771 self .update ()
6872 self ._model .time_step = time_step
6973
70- def update_until (self , then ) :
74+ def update_until (self , then : float ) -> None :
7175 """Update model until a particular time.
7276
7377 Parameters
@@ -81,11 +85,12 @@ def update_until(self, then):
8185 self .update ()
8286 self .update_frac (n_steps - int (n_steps ))
8387
84- def finalize (self ):
88+ def finalize (self ) -> None :
8589 """Finalize model."""
86- self ._model = None
90+ del self ._model
91+ # self._model = None
8792
88- def get_var_type (self , var_name ) :
93+ def get_var_type (self , var_name : str ) -> str :
8994 """Data type of variable.
9095
9196 Parameters
@@ -100,7 +105,7 @@ def get_var_type(self, var_name):
100105 """
101106 return str (self .get_value_ptr (var_name ).dtype )
102107
103- def get_var_units (self , var_name ) :
108+ def get_var_units (self , var_name : str ) -> str :
104109 """Get units of variable.
105110
106111 Parameters
@@ -115,7 +120,7 @@ def get_var_units(self, var_name):
115120 """
116121 return self ._var_units [var_name ]
117122
118- def get_var_nbytes (self , var_name ) :
123+ def get_var_nbytes (self , var_name : str ) -> int :
119124 """Get units of variable.
120125
121126 Parameters
@@ -130,13 +135,13 @@ def get_var_nbytes(self, var_name):
130135 """
131136 return self .get_value_ptr (var_name ).nbytes
132137
133- def get_var_itemsize (self , name ) :
138+ def get_var_itemsize (self , name : str ) -> int :
134139 return np .dtype (self .get_var_type (name )).itemsize
135140
136- def get_var_location (self , name ) :
141+ def get_var_location (self , name : str ) -> str :
137142 return self ._var_loc [name ]
138143
139- def get_var_grid (self , var_name ) :
144+ def get_var_grid (self , var_name : str ) -> int | None :
140145 """Grid id for a variable.
141146
142147 Parameters
@@ -152,8 +157,9 @@ def get_var_grid(self, var_name):
152157 for grid_id , var_name_list in self ._grids .items ():
153158 if var_name in var_name_list :
154159 return grid_id
160+ return None
155161
156- def get_grid_rank (self , grid_id ) :
162+ def get_grid_rank (self , grid_id : int ) -> int :
157163 """Rank of grid.
158164
159165 Parameters
@@ -168,7 +174,7 @@ def get_grid_rank(self, grid_id):
168174 """
169175 return len (self ._model .shape )
170176
171- def get_grid_size (self , grid_id ) :
177+ def get_grid_size (self , grid_id : int ) -> int :
172178 """Size of grid.
173179
174180 Parameters
@@ -183,7 +189,7 @@ def get_grid_size(self, grid_id):
183189 """
184190 return int (np .prod (self ._model .shape ))
185191
186- def get_value_ptr (self , var_name ) :
192+ def get_value_ptr (self , var_name : str ) -> NDArray [ Any ] :
187193 """Reference to values.
188194
189195 Parameters
@@ -198,7 +204,7 @@ def get_value_ptr(self, var_name):
198204 """
199205 return self ._values [var_name ]
200206
201- def get_value (self , var_name , dest ) :
207+ def get_value (self , var_name : str , dest : NDArray [ Any ]) -> NDArray [ Any ] :
202208 """Copy of values.
203209
204210 Parameters
@@ -216,7 +222,9 @@ def get_value(self, var_name, dest):
216222 dest [:] = self .get_value_ptr (var_name ).flatten ()
217223 return dest
218224
219- def get_value_at_indices (self , var_name , dest , indices ):
225+ def get_value_at_indices (
226+ self , var_name : str , dest : NDArray [Any ], indices : NDArray [np .int_ ]
227+ ) -> NDArray [Any ]:
220228 """Get values at particular indices.
221229
222230 Parameters
@@ -236,7 +244,7 @@ def get_value_at_indices(self, var_name, dest, indices):
236244 dest [:] = self .get_value_ptr (var_name ).take (indices )
237245 return dest
238246
239- def set_value (self , var_name , src ) :
247+ def set_value (self , var_name : str , src : NDArray [ Any ]) -> None :
240248 """Set model values.
241249
242250 Parameters
@@ -249,7 +257,9 @@ def set_value(self, var_name, src):
249257 val = self .get_value_ptr (var_name )
250258 val [:] = src .reshape (val .shape )
251259
252- def set_value_at_indices (self , name , inds , src ):
260+ def set_value_at_indices (
261+ self , name : str , inds : NDArray [np .int_ ], src : NDArray [Any ]
262+ ) -> None :
253263 """Set model values at particular indices.
254264
255265 Parameters
@@ -264,76 +274,80 @@ def set_value_at_indices(self, name, inds, src):
264274 val = self .get_value_ptr (name )
265275 val .flat [inds ] = src
266276
267- def get_component_name (self ):
277+ def get_component_name (self ) -> str :
268278 """Name of the component."""
269279 return self ._name
270280
271- def get_input_item_count (self ):
281+ def get_input_item_count (self ) -> int :
272282 """Get names of input variables."""
273283 return len (self ._input_var_names )
274284
275- def get_output_item_count (self ):
285+ def get_output_item_count (self ) -> int :
276286 """Get names of output variables."""
277287 return len (self ._output_var_names )
278288
279- def get_input_var_names (self ):
289+ def get_input_var_names (self ) -> tuple [ str , ...] :
280290 """Get names of input variables."""
281291 return self ._input_var_names
282292
283- def get_output_var_names (self ):
293+ def get_output_var_names (self ) -> tuple [ str , ...] :
284294 """Get names of output variables."""
285295 return self ._output_var_names
286296
287- def get_grid_shape (self , grid_id , shape ) :
297+ def get_grid_shape (self , grid_id : int , shape : NDArray [ np . int_ ]) -> NDArray [ np . int_ ] :
288298 """Number of rows and columns of uniform rectilinear grid."""
289299 var_name = self ._grids [grid_id ][0 ]
290300 shape [:] = self .get_value_ptr (var_name ).shape
291301 return shape
292302
293- def get_grid_spacing (self , grid_id , spacing ):
303+ def get_grid_spacing (
304+ self , grid_id : int , spacing : NDArray [np .float_ ]
305+ ) -> NDArray [np .float_ ]:
294306 """Spacing of rows and columns of uniform rectilinear grid."""
295307 spacing [:] = self ._model .spacing
296308 return spacing
297309
298- def get_grid_origin (self , grid_id , origin ):
310+ def get_grid_origin (
311+ self , grid_id : int , origin : NDArray [np .float_ ]
312+ ) -> NDArray [np .float_ ]:
299313 """Origin of uniform rectilinear grid."""
300314 origin [:] = self ._model .origin
301315 return origin
302316
303- def get_grid_type (self , grid_id ) :
317+ def get_grid_type (self , grid_id : int ) -> str :
304318 """Type of grid."""
305319 return self ._grid_type [grid_id ]
306320
307- def get_start_time (self ):
321+ def get_start_time (self ) -> float :
308322 """Start time of model."""
309323 return self ._start_time
310324
311- def get_end_time (self ):
325+ def get_end_time (self ) -> float :
312326 """End time of model."""
313327 return self ._end_time
314328
315- def get_current_time (self ):
329+ def get_current_time (self ) -> float :
316330 return self ._model .time
317331
318- def get_time_step (self ):
332+ def get_time_step (self ) -> float :
319333 return self ._model .time_step
320334
321- def get_time_units (self ):
335+ def get_time_units (self ) -> str :
322336 return self ._time_units
323337
324- def get_grid_edge_count (self , grid ) :
338+ def get_grid_edge_count (self , grid : int ) -> int :
325339 raise NotImplementedError ("get_grid_edge_count" )
326340
327- def get_grid_edge_nodes (self , grid , edge_nodes ) :
341+ def get_grid_edge_nodes (self , grid : int , edge_nodes : NDArray [ np . int_ ]) -> None :
328342 raise NotImplementedError ("get_grid_edge_nodes" )
329343
330- def get_grid_face_count (self , grid ) :
344+ def get_grid_face_count (self , grid : int ) -> None :
331345 raise NotImplementedError ("get_grid_face_count" )
332346
333- def get_grid_face_nodes (self , grid , face_nodes ) :
347+ def get_grid_face_nodes (self , grid : int , face_nodes : NDArray [ np . int_ ]) -> None :
334348 raise NotImplementedError ("get_grid_face_nodes" )
335349
336- def get_grid_node_count (self , grid ) :
350+ def get_grid_node_count (self , grid : int ) -> int :
337351 """Number of grid nodes.
338352
339353 Parameters
@@ -348,17 +362,19 @@ def get_grid_node_count(self, grid):
348362 """
349363 return self .get_grid_size (grid )
350364
351- def get_grid_nodes_per_face (self , grid , nodes_per_face ):
365+ def get_grid_nodes_per_face (
366+ self , grid : int , nodes_per_face : NDArray [np .int_ ]
367+ ) -> None :
352368 raise NotImplementedError ("get_grid_nodes_per_face" )
353369
354- def get_grid_face_edges (self , grid , face_edges ) :
370+ def get_grid_face_edges (self , grid : int , face_edges : NDArray [ np . int_ ]) -> None :
355371 raise NotImplementedError ("get_grid_face_edges" )
356372
357- def get_grid_x (self , grid , x ) :
373+ def get_grid_x (self , grid : int , x : NDArray [ np . float_ ]) -> None :
358374 raise NotImplementedError ("get_grid_x" )
359375
360- def get_grid_y (self , grid , y ) :
376+ def get_grid_y (self , grid : int , y : NDArray [ np . float_ ]) -> None :
361377 raise NotImplementedError ("get_grid_y" )
362378
363- def get_grid_z (self , grid , z ) :
379+ def get_grid_z (self , grid : int , z : NDArray [ np . float_ ]) -> None :
364380 raise NotImplementedError ("get_grid_z" )
0 commit comments