@@ -38,7 +38,7 @@ def from_ptype(cls, ptype, multi=False):
3838 return cls (ptype , ptype .extend )
3939 return cls (ptype , ptype .append )
4040 if ptype == str :
41- return cls (ptype , lambda coll , new : '' . join ([ coll , new ]) )
41+ return cls (ptype , lambda coll , new : f" { coll } { new } " )
4242
4343 def __init__ (self , acctype , accmethod ):
4444 """
@@ -1083,3 +1083,244 @@ def get_value(self, start, end=None):
10831083 """
10841084 lookup = start if end is None else (start , end )
10851085 return self .get (lookup , 'index' )
1086+
1087+
1088+ class _SendBatches (object ):
1089+
1090+ def __init__ (self , obj , data , force_unique , target_batch_size ):
1091+ self .obj = obj
1092+ self .data = data
1093+ self .num_items = len (data )
1094+ nbatches = self .num_items / target_batch_size
1095+ self .num_batches = int (nbatches ) + (0 if nbatches .is_integer () else 1 )
1096+ self .encoder = obj .setter .encoder
1097+ self .rtype = obj .setter .get_rtype_from_data (data , force_unique )
1098+ if self .rtype == 'encoded_obj' :
1099+ raise ValueError (
1100+ "cannot batch update an 'encoded_obj' type object -- please "
1101+ "convert to a string first if you really need to stream this "
1102+ "object to Redis"
1103+ )
1104+ self .batch_type = obj .getter .rtypes_to_ptypes [self .rtype ]
1105+ self .target_batch_size = target_batch_size
1106+ self .iterator = getattr (
1107+ self , f'_{ self .rtype } _iterator' , self ._default_iterator
1108+ )
1109+ self .make_batch = getattr (
1110+ self , f'_make_{ self .rtype } _batch' , self ._default_make_batch
1111+ )
1112+
1113+ def _zset_iterator (self , data ):
1114+ return (self .encoder .member (item ) for item in data )
1115+
1116+ def _hash_iterator (self , data ):
1117+ return self .encoder .hash (data )
1118+
1119+ def _default_iterator (self , data ):
1120+ return self .encoder .list (data )
1121+
1122+ def _make_string_batch (self ):
1123+ total_size = len (self .data )
1124+ index = 0
1125+ while index < total_size :
1126+ yield self .data [(index ):(index + self .target_batch_size )]
1127+ index += self .target_batch_size
1128+
1129+ def _default_make_batch (self ):
1130+ batch = []
1131+ for i , item in enumerate (self .iterator (self .data )):
1132+ batch .append (item )
1133+ if (i + 1 ) % self .target_batch_size == 0 :
1134+ yield self .batch_type (batch )
1135+ batch = []
1136+ if batch :
1137+ yield self .batch_type (batch )
1138+
1139+ def __call__ (self ):
1140+ return self .make_batch ()
1141+
1142+
1143+ class _AccumulatorFactory (object ):
1144+ accumulated_types = {
1145+ 'dict_all' : dict ,
1146+ 'str_index' : str ,
1147+ }
1148+
1149+ @staticmethod
1150+ def _flatten (accumulated , new_values ):
1151+ accumulated .extend ([
1152+ item for vals in new_values for item in (vals or [])
1153+ ])
1154+
1155+ @staticmethod
1156+ def dict_all (accumulated , new_values ):
1157+ for keys , vals in zip (* new_values ):
1158+ accumulated .update (dict (zip (keys , vals )))
1159+
1160+ @staticmethod
1161+ def dict_field (accumulated , new_values ):
1162+ accumulated .extend ([
1163+ item for vals in new_values [1 ] for item in (vals or [])
1164+ ])
1165+
1166+ @staticmethod
1167+ def str_index (collected , new_vals ):
1168+ return f"{ collected } { '' .join ([str (v ) if v else '' for v in new_vals ])} "
1169+
1170+ def __call__ (self , ptype , lookup_type ):
1171+ full_lookup_name = f'{ ptype .__name__ } _{ lookup_type } '
1172+ acc_type = self .accumulated_types .get (full_lookup_name , list )
1173+ method = getattr (self , full_lookup_name , self ._flatten )
1174+ return Accumulator (acc_type , method )
1175+
1176+
1177+ class RedisObjectStream (object ):
1178+
1179+ def __init__ (self , obj , target_batch_size , execute_every_nth_batch = None ):
1180+ self .obj = obj
1181+ self .target_batch_size = target_batch_size
1182+ self .execute_every = execute_every_nth_batch
1183+ self .pipe = obj .pipe
1184+ self .key = obj .key
1185+ self .batches = None
1186+ self .accumulator_factory = _AccumulatorFactory ()
1187+
1188+ def set (self , data , force_unique = None , update = False , index = None ):
1189+ prev_defer = self .obj .defer
1190+ prev_bypass = self .obj .setter .bypass_encoding
1191+ self .obj .defer = True
1192+ self .obj .setter .bypass_encoding = True
1193+ self .batches = _SendBatches (
1194+ self .obj , data , force_unique , self .target_batch_size
1195+ )
1196+ execute_rvals = []
1197+ try :
1198+ for i , batch in enumerate (self .batches ()):
1199+ update = update if i == 0 else True
1200+ if index is None :
1201+ offset = index
1202+ else :
1203+ offset = (self .target_batch_size * i ) + index
1204+ self .obj .set (batch , force_unique , update , offset )
1205+ nbatch = i + 1
1206+ is_final = nbatch == self .batches .num_batches
1207+ do_it = self .execute_every and nbatch % self .execute_every == 0
1208+ if is_final or do_it :
1209+ execute_rvals .extend (self .pipe .execute ())
1210+ except Exception :
1211+ self .pipe .reset ()
1212+ raise
1213+ self .obj .defer = prev_defer
1214+ self .obj .setter .bypass_encoding = prev_bypass
1215+ return execute_rvals
1216+
1217+ def _get_str (self , lookup , lookup_type , multi ):
1218+ accumulator = self .accumulator_factory (str , lookup_type )
1219+ batches = [
1220+ (num , min (num + self .target_batch_size - 1 , lookup [1 ]))
1221+ for num in range (
1222+ lookup [0 ], lookup [1 ] + 1 , self .target_batch_size
1223+ )
1224+ ]
1225+ lookup_type = 'index'
1226+ try :
1227+ for i , batch in enumerate (batches ):
1228+ self .obj .getter .add_to_pipe (batch , lookup_type , multi )
1229+ nbatch = i + 1
1230+ is_final = nbatch == len (batches )
1231+ do_it = self .execute_every and nbatch % self .execute_every == 0
1232+ if is_final or do_it :
1233+ accumulator .push (self .obj .pipe .execute ())
1234+ except Exception :
1235+ self .pipe .reset ()
1236+ raise
1237+ return accumulator .pop_all ()
1238+
1239+ def _get_set (self , lookup , lookup_type , multi ):
1240+ if lookup_type == 'all' :
1241+ raise TypeError (
1242+ "cannot use 'RedisObjectStream' to get entire Redis 'set' "
1243+ "objects in batches, as they have no methods for doing this "
1244+ "-- use 'RedisObject' instead"
1245+ )
1246+ accumulator = self .accumulator_factory (set , lookup_type )
1247+ batches = [
1248+ lookup [(num ):(num + self .target_batch_size )]
1249+ for num in range (0 , len (lookup ), self .target_batch_size )
1250+ ]
1251+ try :
1252+ for i , batch in enumerate (batches ):
1253+ self .obj .getter .add_to_pipe (batch , lookup_type , multi )
1254+ nbatch = i + 1
1255+ is_final = nbatch == len (batches )
1256+ do_it = self .execute_every and nbatch % self .execute_every == 0
1257+ if is_final or do_it :
1258+ accumulator .push (self .obj .pipe .execute ())
1259+ except Exception :
1260+ self .pipe .reset ()
1261+ raise
1262+ return accumulator .pop_all ()
1263+
1264+ def _get_list (self , lookup , lookup_type , multi ):
1265+ accumulator = self .accumulator_factory (list , lookup_type )
1266+ if lookup_type == 'value' :
1267+ batches = [
1268+ lookup [(num ):(num + self .target_batch_size )]
1269+ for num in range (0 , len (lookup ), self .target_batch_size )
1270+ ]
1271+ else :
1272+ batches = [
1273+ (num , min (num + self .target_batch_size - 1 , lookup [1 ]))
1274+ for num in range (
1275+ lookup [0 ], lookup [1 ] + 1 , self .target_batch_size
1276+ )
1277+ ]
1278+ lookup_type = 'index'
1279+ try :
1280+ for i , batch in enumerate (batches ):
1281+ self .obj .getter .add_to_pipe (batch , lookup_type , multi )
1282+ nbatch = i + 1
1283+ is_final = nbatch == len (batches )
1284+ do_it = self .execute_every and nbatch % self .execute_every == 0
1285+ if is_final or do_it :
1286+ accumulator .push (self .obj .pipe .execute ())
1287+ except Exception :
1288+ self .pipe .reset ()
1289+ raise
1290+ return accumulator .pop_all ()
1291+
1292+ def _get_dict (self , lookup , lookup_type , multi ):
1293+ accumulator = self .accumulator_factory (dict , lookup_type )
1294+ hash_keys = lookup or self .obj .conn .hkeys (self .obj .key )
1295+ keys_stack = []
1296+ key_batches = [
1297+ hash_keys [(num ):(num + self .target_batch_size )]
1298+ for num in range (0 , len (hash_keys ), self .target_batch_size )
1299+ ]
1300+ try :
1301+ for i , batch_keys in enumerate (key_batches ):
1302+ self .obj .getter .add_to_pipe (batch_keys , 'field' , True )
1303+ keys_stack .append (batch_keys )
1304+ nbatch = i + 1
1305+ is_final = nbatch == len (key_batches )
1306+ do_it = self .execute_every and nbatch % self .execute_every == 0
1307+ if is_final or do_it :
1308+ accumulator .push ((keys_stack , self .obj .pipe .execute ()))
1309+ keys_stack = []
1310+ except Exception :
1311+ self .pipe .reset ()
1312+ raise
1313+ return accumulator .pop_all ()
1314+
1315+ def get (self , lookup = None , lookup_type = None ):
1316+ lookup_args = self .obj .getter .configure_lookup (lookup , lookup_type )
1317+ lookup , lookup_type , multi = lookup_args
1318+ if lookup_type is None :
1319+ return None
1320+ if not multi :
1321+ self .obj .getter .add_to_pipe (lookup , lookup_type , multi )
1322+ return self .pipe .execute ()[- 1 ]
1323+ ptype = self .obj .getter .get_ptype_from_obj_rtype ()
1324+ # accumulator = Accumulator.from_ptype(ptype, multi=True)
1325+ data = getattr (self , f'_get_{ ptype .__name__ } ' )(* lookup_args )
1326+ return data
0 commit comments