@@ -1318,9 +1318,12 @@ def get_cbuffer_sizes(src: object) -> tuple[(int, int, int)]:
13181318
13191319
13201320# Compute a decent value for chunksize based on L3 and/or heuristics
1321- def get_chunksize (blocksize , l3_minimum = 4 * 2 ** 20 , l3_maximum = 2 ** 26 ):
1322- # Find a decent default when L3 cannot be detected by cpuinfo
1323- # Based mainly in heuristics
1321+ def get_chunksize (blocksize , l3_minimum = 4 * 2 ** 20 , l3_maximum = 2 ** 26 , reduc_factor = 4 ):
1322+ # Find a decent default when L3 cannot be detected by cpuinfo.
1323+ # `reduc_factor` means that the chunk will be divided by this factor
1324+ # 4 stems for 3 operands + 1 result, but some functions (e.g., linalg ones) may
1325+ # decide to use another one (e.g., 1 for matmul has proved to be better).
1326+ # Most of this is based mainly on heuristics and experimentation.
13241327 chunksize = blocksize
13251328 if blocksize * 32 < l3_maximum :
13261329 chunksize = blocksize * 32
@@ -1339,15 +1342,14 @@ def get_chunksize(blocksize, l3_minimum=4 * 2**20, l3_maximum=2**26):
13391342 if isinstance (l2_cache_size , int ) and l3_cache_size > l2_cache_size :
13401343 chunksize = l3_cache_size
13411344 # When computing expressions, it is convenient to keep chunks for all operands
1342- # in L3 cache, so let's divide by 4 (3 operands + result is a typical situation
1343- # for moderately complex expressions)
1344- chunksize //= 4
1345+ # in L3 cache (reduc_factor will account for this).
1346+ chunksize //= reduc_factor
13451347
13461348 # Chunksize should be at least the size of L2
13471349 l2_cache_size = cpu_info .get ("l2_cache_size" , "Not found" )
13481350 if isinstance (l2_cache_size , int ) and l2_cache_size > chunksize :
13491351 # Apple Silicon has a large L2 cache, and memory bandwidth is high,
1350- # so we can use a larger chunksize based on L2 cache size
1352+ # so we can use a larger chunksize based on L2 cache size.
13511353 chunksize = l2_cache_size * 4
13521354
13531355 # Ensure a minimum size
@@ -1577,7 +1579,8 @@ def compute_chunks_blocks( # noqa: C901
15771579 # Finally, the chunks
15781580 if chunks is None :
15791581 blocksize = math .prod (blocks ) * itemsize
1580- chunksize = get_chunksize (blocksize )
1582+ reduc_factor = kwargs .get ("_chunksize_reduc_factor" , 4 )
1583+ chunksize = get_chunksize (blocksize , reduc_factor = reduc_factor )
15811584 # Make chunksize to be a multiple of the blocksize. This allows for:
15821585 # 1. Avoid unnecessary padding in chunks
15831586 # 2. Avoid exceeding the maximum buffer size (see #392)
0 commit comments