Skip to content

Commit 732bc3c

Browse files
d-w-moorealanking
authored andcommitted
[#574] rename progress_bar to updatables, allow for genericity
updatable objects are either bound update functions taking a number of bytes in a transfer, or progress-bar objects. If the latter, the object's type must already be registered. (See tests: Use of the progressbar and tqdm modules is demonstrated.)
1 parent ca9ba6d commit 732bc3c

4 files changed

Lines changed: 319 additions & 36 deletions

File tree

README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,48 @@ will spawn a number of threads in order to optimize performance for
232232
iRODS server versions 4.2.9+ and file sizes larger than a default
233233
threshold value of 32 Megabytes.
234234

235+
Progress bars
236+
-------------
237+
238+
The PRC now has support for progress bars which function on the basis of
239+
an "update" callback function. In the case of a tqdm progress bar (see https://github.com/tqdm/tqdm), you can always just
240+
pass the update method of the progress bar instance directly to the data_object
241+
`put` or `get` method:
242+
243+
```python
244+
pbar = tqdm.tqdm(total = data_obj.size)
245+
session.data_objects.get(file_name, data_obj.path, updatables = pbar.update)
246+
```
247+
248+
The updatables parameter can be a list or tuple of update-enabling objects and/or bound methods.
249+
250+
Alternatively, the tqdm progress bar object itself can be passed in, if an adapting
251+
function such as the following is first registered:
252+
253+
```python
254+
def adapt_tqdm(pbar, l = threading.Lock()):
255+
def _update(n):
256+
with l:
257+
pbar.update(n)
258+
return _update
259+
irods.manager.data_objects_manager.register_update_type( adapt_tqdm )
260+
session.data_objects.put( file, logical_path, updatables = [tqdm_1,tqdm_2] ) # update two tqdm's simultaneously
261+
```
262+
263+
Other progress bars may be included in an updatables parameter, but may require more extensive adaptation.
264+
For example, the ProgressBar object (from the progressbar module) also has an update method, but this one
265+
takes an up-to-date cumulative byte-count, instead of the size of an individual transfer in bytes,
266+
as its sole parameter. There can be other complications: e.g. a ProgressBar instance does not allow a weak
267+
reference to itself to be formed, which interferes with the Python iRODS Client's internal scheme of accounting
268+
for progress bar instances "still in progress" while also preventing resource leaks.
269+
270+
In such cases, it is probably best to implement a wrapper for the progress
271+
bar in question, and submit the wrapper instance as the updatable parameter. Whether
272+
a wrapper or the progress-bar object itself is thus employed, it is recommended that the user take steps to
273+
ensure the lifetime of the updatable instance extends beyond the time needed for the transfer to complete.
274+
275+
See `irods/test/data_obj_test.py` for examples of these and other subtleties of progress-bar usage.
276+
235277
Working with collections
236278
------------------------
237279

irods/manager/data_object_manager.py

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from __future__ import absolute_import
2-
import os
2+
import ast
3+
import collections
34
import io
5+
import json
6+
import logging
7+
import os
8+
import six
9+
import weakref
410
from irods.models import DataObject, Collection
511
from irods.manager import Manager
612
from irods.manager._internal import _api_impl, _logical_path
@@ -17,11 +23,72 @@
1723
import irods.keywords as kw
1824
import irods.parallel as parallel
1925
from irods.parallel import deferred_call
20-
import six
21-
import ast
22-
import json
23-
import logging
2426

27+
logger = logging.getLogger(__name__)
28+
29+
_update_types = []
30+
_update_functions = weakref.WeakKeyDictionary()
31+
32+
def register_update_instance(object_, updater): # updater
33+
_update_functions[object_] = updater
34+
35+
def register_update_type(type_, factory_):
36+
"""
37+
Create an entry corresponding to a type_ of instance to be allowed among updatables, with processing
38+
based on the factory_ callable.
39+
40+
Parameters:
41+
type_ : a type of instance to be allowed in the updatables parameter.
42+
factory_ : a function accepting the instance passed in, and yielding an update callable.
43+
If None, then remove the type from the list.
44+
"""
45+
46+
# Delete if already present in list
47+
z = tuple(zip(*_update_types))
48+
if z and type_ in z[0]:
49+
_update_types.pop(z[0].index(type_))
50+
# Rewrite the list
51+
# - with the new item introduced at the start of the list but otherwise in the same order, and
52+
# - preserving only pairs that do not contain 'None' as the second member.
53+
_update_types[:] = list((k,v) for k,v in collections.OrderedDict([(type_,factory_)] + _update_types).items() if v is not None)
54+
55+
56+
def unregister_update_type(type_):
57+
"""
58+
Remove type_ from the listof recognized updatable types maintained by the PRC.
59+
"""
60+
register_update_type(type_, None)
61+
62+
63+
def do_progress_updates(updatables, n, logging_function = logger.warning):
64+
"""
65+
Used internally by Python iRODS Client's data transfer routines (put, get) to iterate through updatables to be processed.
66+
This, in turn, should cause the underlying corresponding progress bars or indicators to be updated.
67+
"""
68+
if not isinstance(updatables, (list,tuple)):
69+
updatables = [updatables]
70+
71+
for object_ in updatables:
72+
# If an updatable is directly callable, we set that up to be called without further ado.
73+
if callable(object_):
74+
update_func = object_
75+
else:
76+
# If not, we search for a registered type that matches object_ and register (or look up if previously registered) a factory-produced updater for that instance.
77+
# Examine the unit tests for issue #574 in data_obj_test.py for factory examples.
78+
update_func = _update_functions.get(object_)
79+
if not update_func:
80+
# search based on type
81+
for class_,factory_ in _update_types:
82+
if isinstance(object_,class_):
83+
update_func = factory_(object_)
84+
register_update_instance(object_, update_func)
85+
break
86+
else:
87+
logging_function("Could not derive an update function for: %r",object_)
88+
continue
89+
90+
# Do the update.
91+
if update_func: update_func(n)
2592

2693

2794
def call___del__if_exists(super_):
@@ -124,7 +191,7 @@ def should_parallelize_transfer( self,
124191
open_options[kw.DATA_SIZE_KW] = size
125192

126193

127-
def _download(self, obj, local_path, num_threads, progress_bar, **options):
194+
def _download(self, obj, local_path, num_threads, updatables = (), **options):
128195
"""Transfer the contents of a data object to a local file.
129196
130197
Called from get() when a local path is named.
@@ -146,16 +213,15 @@ def _download(self, obj, local_path, num_threads, progress_bar, **options):
146213
if not self.parallel_get( (obj,o), local_file, num_threads = num_threads,
147214
target_resource_name = options.get(kw.RESC_NAME_KW,''),
148215
data_open_returned_values = data_open_returned_values_,
149-
progress_bar=progress_bar):
216+
updatables = updatables):
150217
raise RuntimeError("parallel get failed")
151218
else:
152219
for chunk in chunks(o, self.READ_BUFFER_SIZE):
153220
f.write(chunk)
154-
if progress_bar is not None:
155-
progress_bar.update(len(chunk))
221+
do_progress_updates(updatables, len(chunk))
156222

157223

158-
def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, progress_bar = None, **options):
224+
def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS, updatables = (), **options):
159225
"""
160226
Get a reference to the data object at the specified `path'.
161227
@@ -166,7 +232,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
166232

167233
# TODO: optimize
168234
if local_path:
169-
self._download(path, local_path, num_threads = num_threads, progress_bar=progress_bar, **options)
235+
self._download(path, local_path, num_threads = num_threads, updatables = updatables, **options)
170236

171237
query = self.sess.query(DataObject)\
172238
.filter(DataObject.name == irods_basename(path))\
@@ -183,7 +249,7 @@ def get(self, path, local_path = None, num_threads = DEFAULT_NUMBER_OF_THREADS,
183249
return iRODSDataObject(self, parent, results)
184250

185251

186-
def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, progress_bar = None, **options):
252+
def put(self, local_path, irods_path, return_data_object = False, num_threads = DEFAULT_NUMBER_OF_THREADS, updatables = (), **options):
187253

188254
if self.sess.collections.exists(irods_path):
189255
obj = iRODSCollection.normalize_path(irods_path, os.path.basename(local_path))
@@ -198,7 +264,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
198264
if not self.parallel_put( local_path, (obj,o), total_bytes = sizelist[0], num_threads = num_threads,
199265
target_resource_name = options.get(kw.RESC_NAME_KW,'') or
200266
options.get(kw.DEST_RESC_NAME_KW,''),
201-
open_options = options, progress_bar = progress_bar):
267+
open_options = options, updatables = updatables):
202268
raise RuntimeError("parallel put failed")
203269
else:
204270
with self.open(obj, 'w', **options) as o:
@@ -207,8 +273,7 @@ def put(self, local_path, irods_path, return_data_object = False, num_threads =
207273
options[kw.OPR_TYPE_KW] = 1 # PUT_OPR
208274
for chunk in chunks(f, self.WRITE_BUFFER_SIZE):
209275
o.write(chunk)
210-
if progress_bar is not None:
211-
progress_bar.update(len(chunk))
276+
do_progress_updates(updatables, len(chunk))
212277
if kw.ALL_KW in options:
213278
repl_options = options.copy()
214279
repl_options[kw.UPDATE_REPL_KW] = ''
@@ -265,7 +330,7 @@ def parallel_get(self,
265330
target_resource_name = '',
266331
data_open_returned_values = None,
267332
progressQueue = False,
268-
progress_bar = None):
333+
updatables = ()):
269334
"""Call into the irods.parallel library for multi-1247 GET.
270335
271336
Called from a session.data_objects.get(...) (via the _download method) on
@@ -277,7 +342,7 @@ def parallel_get(self,
277342
num_threads = num_threads, target_resource_name = target_resource_name,
278343
data_open_returned_values = data_open_returned_values,
279344
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0),
280-
progress_bar = progress_bar)
345+
updatables = updatables)
281346

282347
def parallel_put(self,
283348
file_ ,
@@ -287,7 +352,7 @@ def parallel_put(self,
287352
num_threads = 0,
288353
target_resource_name = '',
289354
open_options = {},
290-
progress_bar = None,
355+
updatables = (),
291356
progressQueue = False):
292357
"""Call into the irods.parallel library for multi-1247 PUT.
293358
@@ -298,8 +363,8 @@ def parallel_put(self,
298363
return parallel.io_main( self.sess, data_or_path_, parallel.Oper.PUT | (parallel.Oper.NONBLOCKING if async_ else 0), file_,
299364
num_threads = num_threads, total_bytes = total_bytes, target_resource_name = target_resource_name,
300365
open_options = open_options,
301-
progress_bar = progress_bar,
302-
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0)
366+
queueLength = (DEFAULT_QUEUE_DEPTH if progressQueue else 0),
367+
updatables = updatables,
303368
)
304369

305370

irods/parallel.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,14 @@ def _io_send_bytes_progress (queueObject, item):
223223

224224
COPY_BUF_SIZE = (1024 ** 2) * 4
225225

226-
def _copy_part( src, dst, length, queueObject, debug_info, mgr, progress_bar):
226+
def _copy_part( src, dst, length, queueObject, debug_info, mgr, updatables = ()):
227227
"""
228228
The work-horse for performing the copy between file and data object.
229229
230230
It also helps determine whether there has been a large enough increment of
231231
bytes to inform the progress bar of a need to update.
232232
"""
233+
from irods.manager.data_object_manager import do_progress_updates
233234
bytecount = 0
234235
accum = 0
235236
while True and bytecount < length:
@@ -240,8 +241,7 @@ def _copy_part( src, dst, length, queueObject, debug_info, mgr, progress_bar):
240241
bytecount += buf_len
241242
accum += buf_len
242243
if queueObject and accum and _io_send_bytes_progress(queueObject,accum): accum = 0
243-
if progress_bar is not None:
244-
progress_bar.update(buf_len)
244+
do_progress_updates(updatables, buf_len)
245245
if verboseConnection:
246246
print ("("+debug_info+")",end='',file=sys.stderr)
247247
sys.stderr.flush()
@@ -303,7 +303,7 @@ def finalize(self):
303303
self.initial_io.close()
304304

305305

306-
def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None, progress_bar = None):
306+
def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueObject = None, updatables = None):
307307
"""
308308
Runs in a separate thread to manage the transfer of a range of bytes within the data object.
309309
@@ -317,12 +317,12 @@ def _io_part (objHandle, range_, file_, opr_, mgr_, thread_debug_id = '', queueO
317317
file_.seek(offset)
318318
if thread_debug_id == '': # for more succinct thread identifiers while debugging.
319319
thread_debug_id = str(threading.currentThread().ident)
320-
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_, progress_bar) if Operation.isPut()
321-
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_, progress_bar) )
320+
return ( _copy_part (file_, objHandle, length, queueObject, thread_debug_id, mgr_, updatables) if Operation.isPut()
321+
else _copy_part (objHandle, file_, length, queueObject, thread_debug_id, mgr_, updatables) )
322322

323323

324324
def _io_multipart_threaded(operation_ , dataObj_and_IO, replica_token, hier_str, session, fname,
325-
total_size, num_threads, progress_bar, **extra_options):
325+
total_size, num_threads, **extra_options):
326326
"""Called by _io_main.
327327
328328
Carve up (0,total_size) range into `num_threads` parts and initiate a transfer thread for each one.
@@ -344,9 +344,9 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
344344

345345
logger.info(u"num_threads = %s ; bytes_per_thread = %s", num_threads, bytes_per_thread)
346346

347-
_queueLength = extra_options.get('_queueLength',0)
348-
if _queueLength > 0:
349-
queueObject = Queue(_queueLength)
347+
queueLength = extra_options.get('queueLength',0)
348+
if queueLength > 0:
349+
queueObject = Queue(queueLength)
350350
else:
351351
queueObject = None
352352

@@ -357,6 +357,11 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
357357
counter = 1
358358
gen_file_handle = lambda: open(fname, Operation.disk_file_mode(initial_open = (counter == 1)))
359359
File = gen_file_handle()
360+
361+
thread_opts = { 'updatables' : extra_options.get('updatables',()),
362+
'queueObject' : queueObject
363+
}
364+
360365
for byte_range in ranges:
361366
if Io is None:
362367
Io = session.data_objects.open( Data_object.path, Operation.data_object_mode(initial_open = False),
@@ -368,12 +373,14 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
368373
mgr.add_io( Io )
369374
logger.debug(u'target_host = %s', Io.raw.session.pool.account.host)
370375
if File is None: File = gen_file_handle()
371-
futures.append(executor.submit( _io_part, Io, byte_range, File, Operation, mgr, str(counter), queueObject, progress_bar))
376+
futures.append(executor.submit(_io_part, Io, byte_range, File, Operation, mgr,
377+
thread_debug_id = str(counter),
378+
**thread_opts))
372379
counter += 1
373380
Io = File = None
374381

375382
if Operation.isNonBlocking():
376-
if _queueLength:
383+
if queueLength:
377384
return futures, queueObject, mgr
378385
else:
379386
return futures
@@ -383,7 +390,7 @@ def bytes_range_for_thread( i, num_threads, total_bytes, chunk ):
383390

384391

385392

386-
def io_main( session, Data, opr_, fname, R='', progress_bar = None, **kwopt):
393+
def io_main( session, Data, opr_, fname, R='', **kwopt):
387394
"""
388395
The entry point for parallel transfers (multithreaded PUT and GET operations).
389396
@@ -467,10 +474,11 @@ def io_main( session, Data, opr_, fname, R='', progress_bar = None, **kwopt):
467474
(replica_token , resc_hier) = rawfile.replica_access_info()
468475

469476
queueLength = kwopt.get('queueLength',0)
477+
478+
pass_thru_options = ('updatables','queueLength')
470479
retval = _io_multipart_threaded (Operation, (Data, Io), replica_token, resc_hier, session, fname, total_bytes,
471480
num_threads = num_threads,
472-
_queueLength = queueLength,
473-
progress_bar = progress_bar)
481+
**{k:v for k,v in kwopt.items() if k in pass_thru_options})
474482

475483
# SessionObject.data_objects.parallel_{put,get} will return:
476484
# - immediately with an AsyncNotify instance, if Oper.NONBLOCKING flag is used.

0 commit comments

Comments
 (0)