11from __future__ import absolute_import
2- import os
2+ import ast
3+ import collections
34import io
5+ import json
6+ import logging
7+ import os
8+ import six
9+ import weakref
410from irods .models import DataObject , Collection
511from irods .manager import Manager
612from irods .manager ._internal import _api_impl , _logical_path
1723import irods .keywords as kw
1824import irods .parallel as parallel
1925from irods .parallel import deferred_call
20- import six
21- import ast
22- import json
23- import logging
2426
27+ logger = logging .getLogger (__name__ )
28+
29+ _update_types = []
30+ _update_functions = weakref .WeakKeyDictionary ()
31+
32+ def register_update_instance (object_ , updater ): # updater
33+ _update_functions [object_ ] = updater
34+
35+ def register_update_type (type_ , factory_ ):
36+ """
37+ Create an entry corresponding to a type_ of instance to be allowed among updatables, with processing
38+ based on the factory_ callable.
39+
40+ Parameters:
41+ type_ : a type of instance to be allowed in the updatables parameter.
42+ factory_ : a function accepting the instance passed in, and yielding an update callable.
43+ If None, then remove the type from the list.
44+ """
45+
46+ # Delete if already present in list
47+ z = tuple (zip (* _update_types ))
48+ if z and type_ in z [0 ]:
49+ _update_types .pop (z [0 ].index (type_ ))
50+ # Rewrite the list
51+ # - with the new item introduced at the start of the list but otherwise in the same order, and
52+ # - preserving only pairs that do not contain 'None' as the second member.
53+ _update_types [:] = list ((k ,v ) for k ,v in collections .OrderedDict ([(type_ ,factory_ )] + _update_types ).items () if v is not None )
54+
55+
56+ def unregister_update_type (type_ ):
57+ """
58+ Remove type_ from the listof recognized updatable types maintained by the PRC.
59+ """
60+ register_update_type (type_ , None )
61+
62+
63+ def do_progress_updates (updatables , n , logging_function = logger .warning ):
64+ """
65+ Used internally by Python iRODS Client's data transfer routines (put, get) to iterate through updatables to be processed.
66+ This, in turn, should cause the underlying corresponding progress bars or indicators to be updated.
67+ """
68+ if not isinstance (updatables , (list ,tuple )):
69+ updatables = [updatables ]
70+
71+ for object_ in updatables :
72+ # If an updatable is directly callable, we set that up to be called without further ado.
73+ if callable (object_ ):
74+ update_func = object_
75+ else :
76+ # If not, we search for a registered type that matches object_ and register (or look up if previously registered) a factory-produced updater for that instance.
77+ # Examine the unit tests for issue #574 in data_obj_test.py for factory examples.
78+ update_func = _update_functions .get (object_ )
79+ if not update_func :
80+ # search based on type
81+ for class_ ,factory_ in _update_types :
82+ if isinstance (object_ ,class_ ):
83+ update_func = factory_ (object_ )
84+ register_update_instance (object_ , update_func )
85+ break
86+ else :
87+ logging_function ("Could not derive an update function for: %r" ,object_ )
88+ continue
89+
90+ # Do the update.
91+ if update_func : update_func (n )
2592
2693
2794def call___del__if_exists (super_ ):
@@ -124,7 +191,7 @@ def should_parallelize_transfer( self,
124191 open_options [kw .DATA_SIZE_KW ] = size
125192
126193
127- def _download (self , obj , local_path , num_threads , progress_bar , ** options ):
194+ def _download (self , obj , local_path , num_threads , updatables = () , ** options ):
128195 """Transfer the contents of a data object to a local file.
129196
130197 Called from get() when a local path is named.
@@ -146,16 +213,15 @@ def _download(self, obj, local_path, num_threads, progress_bar, **options):
146213 if not self .parallel_get ( (obj ,o ), local_file , num_threads = num_threads ,
147214 target_resource_name = options .get (kw .RESC_NAME_KW ,'' ),
148215 data_open_returned_values = data_open_returned_values_ ,
149- progress_bar = progress_bar ):
216+ updatables = updatables ):
150217 raise RuntimeError ("parallel get failed" )
151218 else :
152219 for chunk in chunks (o , self .READ_BUFFER_SIZE ):
153220 f .write (chunk )
154- if progress_bar is not None :
155- progress_bar .update (len (chunk ))
221+ do_progress_updates (updatables , len (chunk ))
156222
157223
158- def get (self , path , local_path = None , num_threads = DEFAULT_NUMBER_OF_THREADS , progress_bar = None , ** options ):
224+ def get (self , path , local_path = None , num_threads = DEFAULT_NUMBER_OF_THREADS , updatables = () , ** options ):
159225 """
160226 Get a reference to the data object at the specified `path'.
161227
@@ -166,7 +232,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
166232
167233 # TODO: optimize
168234 if local_path :
169- self ._download (path , local_path , num_threads = num_threads , progress_bar = progress_bar , ** options )
235+ self ._download (path , local_path , num_threads = num_threads , updatables = updatables , ** options )
170236
171237 query = self .sess .query (DataObject )\
172238 .filter (DataObject .name == irods_basename (path ))\
@@ -183,7 +249,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
183249 return iRODSDataObject (self , parent , results )
184250
185251
186- def put (self , local_path , irods_path , return_data_object = False , num_threads = DEFAULT_NUMBER_OF_THREADS , progress_bar = None , ** options ):
252+ def put (self , local_path , irods_path , return_data_object = False , num_threads = DEFAULT_NUMBER_OF_THREADS , updatables = () , ** options ):
187253
188254 if self .sess .collections .exists (irods_path ):
189255 obj = iRODSCollection .normalize_path (irods_path , os .path .basename (local_path ))
@@ -198,7 +264,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
198264 if not self .parallel_put ( local_path , (obj ,o ), total_bytes = sizelist [0 ], num_threads = num_threads ,
199265 target_resource_name = options .get (kw .RESC_NAME_KW ,'' ) or
200266 options .get (kw .DEST_RESC_NAME_KW ,'' ),
201- open_options = options , progress_bar = progress_bar ):
267+ open_options = options , updatables = updatables ):
202268 raise RuntimeError ("parallel put failed" )
203269 else :
204270 with self .open (obj , 'w' , ** options ) as o :
@@ -207,8 +273,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
207273 options [kw .OPR_TYPE_KW ] = 1 # PUT_OPR
208274 for chunk in chunks (f , self .WRITE_BUFFER_SIZE ):
209275 o .write (chunk )
210- if progress_bar is not None :
211- progress_bar .update (len (chunk ))
276+ do_progress_updates (updatables , len (chunk ))
212277 if kw .ALL_KW in options :
213278 repl_options = options .copy ()
214279 repl_options [kw .UPDATE_REPL_KW ] = ''
@@ -265,7 +330,7 @@ def parallel_get(self,
265330 target_resource_name = '' ,
266331 data_open_returned_values = None ,
267332 progressQueue = False ,
268- progress_bar = None ):
333+ updatables = () ):
269334 """Call into the irods.parallel library for multi-1247 GET.
270335
271336 Called from a session.data_objects.get(...) (via the _download method) on
@@ -277,7 +342,7 @@ def parallel_get(self,
277342 num_threads = num_threads , target_resource_name = target_resource_name ,
278343 data_open_returned_values = data_open_returned_values ,
279344 queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0 ),
280- progress_bar = progress_bar )
345+ updatables = updatables )
281346
282347 def parallel_put (self ,
283348 file_ ,
@@ -287,7 +352,7 @@ def parallel_put(self,
287352 num_threads = 0 ,
288353 target_resource_name = '' ,
289354 open_options = {},
290- progress_bar = None ,
355+ updatables = () ,
291356 progressQueue = False ):
292357 """Call into the irods.parallel library for multi-1247 PUT.
293358
@@ -298,8 +363,8 @@ def parallel_put(self,
298363 return parallel .io_main ( self .sess , data_or_path_ , parallel .Oper .PUT | (parallel .Oper .NONBLOCKING if async_ else 0 ), file_ ,
299364 num_threads = num_threads , total_bytes = total_bytes , target_resource_name = target_resource_name ,
300365 open_options = open_options ,
301- progress_bar = progress_bar ,
302- queueLength = ( DEFAULT_QUEUE_DEPTH if progressQueue else 0 )
366+ queueLength = ( DEFAULT_QUEUE_DEPTH if progressQueue else 0 ) ,
367+ updatables = updatables ,
303368 )
304369
305370
0 commit comments