Skip to content

Commit d51da6e

Browse files
committed
Add RedisObjectStream class for batch operations
When you have a very large amount of data stored in a single Redis key, it occurred to me that you may want an easy way to break the data into batches so that you don't saturate the network interface or tie up Redis for too long at any given time. This is an initial stab at adding that, but there's still a lot of work to do on it.
1 parent 96fed46 commit d51da6e

2 files changed

Lines changed: 888 additions & 1 deletion

File tree

django/sierra/utils/redisobjs.py

Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def from_ptype(cls, ptype, multi=False):
3838
return cls(ptype, ptype.extend)
3939
return cls(ptype, ptype.append)
4040
if ptype == str:
41-
return cls(ptype, lambda coll, new: ''.join([coll, new]))
41+
return cls(ptype, lambda coll, new: f"{coll}{new}")
4242

4343
def __init__(self, acctype, accmethod):
4444
"""
@@ -1083,3 +1083,244 @@ def get_value(self, start, end=None):
10831083
"""
10841084
lookup = start if end is None else (start, end)
10851085
return self.get(lookup, 'index')
1086+
1087+
1088+
class _SendBatches(object):
1089+
1090+
def __init__(self, obj, data, force_unique, target_batch_size):
1091+
self.obj = obj
1092+
self.data = data
1093+
self.num_items = len(data)
1094+
nbatches = self.num_items / target_batch_size
1095+
self.num_batches = int(nbatches) + (0 if nbatches.is_integer() else 1)
1096+
self.encoder = obj.setter.encoder
1097+
self.rtype = obj.setter.get_rtype_from_data(data, force_unique)
1098+
if self.rtype == 'encoded_obj':
1099+
raise ValueError(
1100+
"cannot batch update an 'encoded_obj' type object -- please "
1101+
"convert to a string first if you really need to stream this "
1102+
"object to Redis"
1103+
)
1104+
self.batch_type = obj.getter.rtypes_to_ptypes[self.rtype]
1105+
self.target_batch_size = target_batch_size
1106+
self.iterator = getattr(
1107+
self, f'_{self.rtype}_iterator', self._default_iterator
1108+
)
1109+
self.make_batch = getattr(
1110+
self, f'_make_{self.rtype}_batch', self._default_make_batch
1111+
)
1112+
1113+
def _zset_iterator(self, data):
1114+
return (self.encoder.member(item) for item in data)
1115+
1116+
def _hash_iterator(self, data):
1117+
return self.encoder.hash(data)
1118+
1119+
def _default_iterator(self, data):
1120+
return self.encoder.list(data)
1121+
1122+
def _make_string_batch(self):
1123+
total_size = len(self.data)
1124+
index = 0
1125+
while index < total_size:
1126+
yield self.data[(index):(index + self.target_batch_size)]
1127+
index += self.target_batch_size
1128+
1129+
def _default_make_batch(self):
1130+
batch = []
1131+
for i, item in enumerate(self.iterator(self.data)):
1132+
batch.append(item)
1133+
if (i + 1) % self.target_batch_size == 0:
1134+
yield self.batch_type(batch)
1135+
batch = []
1136+
if batch:
1137+
yield self.batch_type(batch)
1138+
1139+
def __call__(self):
1140+
return self.make_batch()
1141+
1142+
1143+
class _AccumulatorFactory(object):
1144+
accumulated_types = {
1145+
'dict_all': dict,
1146+
'str_index': str,
1147+
}
1148+
1149+
@staticmethod
1150+
def _flatten(accumulated, new_values):
1151+
accumulated.extend([
1152+
item for vals in new_values for item in (vals or [])
1153+
])
1154+
1155+
@staticmethod
1156+
def dict_all(accumulated, new_values):
1157+
for keys, vals in zip(*new_values):
1158+
accumulated.update(dict(zip(keys, vals)))
1159+
1160+
@staticmethod
1161+
def dict_field(accumulated, new_values):
1162+
accumulated.extend([
1163+
item for vals in new_values[1] for item in (vals or [])
1164+
])
1165+
1166+
@staticmethod
1167+
def str_index(collected, new_vals):
1168+
return f"{collected}{''.join([str(v) if v else '' for v in new_vals])}"
1169+
1170+
def __call__(self, ptype, lookup_type):
1171+
full_lookup_name = f'{ptype.__name__}_{lookup_type}'
1172+
acc_type = self.accumulated_types.get(full_lookup_name, list)
1173+
method = getattr(self, full_lookup_name, self._flatten)
1174+
return Accumulator(acc_type, method)
1175+
1176+
1177+
class RedisObjectStream(object):
1178+
1179+
def __init__(self, obj, target_batch_size, execute_every_nth_batch=None):
1180+
self.obj = obj
1181+
self.target_batch_size = target_batch_size
1182+
self.execute_every = execute_every_nth_batch
1183+
self.pipe = obj.pipe
1184+
self.key = obj.key
1185+
self.batches = None
1186+
self.accumulator_factory = _AccumulatorFactory()
1187+
1188+
def set(self, data, force_unique=None, update=False, index=None):
1189+
prev_defer = self.obj.defer
1190+
prev_bypass = self.obj.setter.bypass_encoding
1191+
self.obj.defer = True
1192+
self.obj.setter.bypass_encoding = True
1193+
self.batches = _SendBatches(
1194+
self.obj, data, force_unique, self.target_batch_size
1195+
)
1196+
execute_rvals = []
1197+
try:
1198+
for i, batch in enumerate(self.batches()):
1199+
update = update if i == 0 else True
1200+
if index is None:
1201+
offset = index
1202+
else:
1203+
offset = (self.target_batch_size * i) + index
1204+
self.obj.set(batch, force_unique, update, offset)
1205+
nbatch = i + 1
1206+
is_final = nbatch == self.batches.num_batches
1207+
do_it = self.execute_every and nbatch % self.execute_every == 0
1208+
if is_final or do_it:
1209+
execute_rvals.extend(self.pipe.execute())
1210+
except Exception:
1211+
self.pipe.reset()
1212+
raise
1213+
self.obj.defer = prev_defer
1214+
self.obj.setter.bypass_encoding = prev_bypass
1215+
return execute_rvals
1216+
1217+
def _get_str(self, lookup, lookup_type, multi):
1218+
accumulator = self.accumulator_factory(str, lookup_type)
1219+
batches = [
1220+
(num, min(num + self.target_batch_size - 1, lookup[1]))
1221+
for num in range(
1222+
lookup[0], lookup[1] + 1, self.target_batch_size
1223+
)
1224+
]
1225+
lookup_type = 'index'
1226+
try:
1227+
for i, batch in enumerate(batches):
1228+
self.obj.getter.add_to_pipe(batch, lookup_type, multi)
1229+
nbatch = i + 1
1230+
is_final = nbatch == len(batches)
1231+
do_it = self.execute_every and nbatch % self.execute_every == 0
1232+
if is_final or do_it:
1233+
accumulator.push(self.obj.pipe.execute())
1234+
except Exception:
1235+
self.pipe.reset()
1236+
raise
1237+
return accumulator.pop_all()
1238+
1239+
def _get_set(self, lookup, lookup_type, multi):
1240+
if lookup_type == 'all':
1241+
raise TypeError(
1242+
"cannot use 'RedisObjectStream' to get entire Redis 'set' "
1243+
"objects in batches, as they have no methods for doing this "
1244+
"-- use 'RedisObject' instead"
1245+
)
1246+
accumulator = self.accumulator_factory(set, lookup_type)
1247+
batches = [
1248+
lookup[(num):(num + self.target_batch_size)]
1249+
for num in range(0, len(lookup), self.target_batch_size)
1250+
]
1251+
try:
1252+
for i, batch in enumerate(batches):
1253+
self.obj.getter.add_to_pipe(batch, lookup_type, multi)
1254+
nbatch = i + 1
1255+
is_final = nbatch == len(batches)
1256+
do_it = self.execute_every and nbatch % self.execute_every == 0
1257+
if is_final or do_it:
1258+
accumulator.push(self.obj.pipe.execute())
1259+
except Exception:
1260+
self.pipe.reset()
1261+
raise
1262+
return accumulator.pop_all()
1263+
1264+
def _get_list(self, lookup, lookup_type, multi):
1265+
accumulator = self.accumulator_factory(list, lookup_type)
1266+
if lookup_type == 'value':
1267+
batches = [
1268+
lookup[(num):(num + self.target_batch_size)]
1269+
for num in range(0, len(lookup), self.target_batch_size)
1270+
]
1271+
else:
1272+
batches = [
1273+
(num, min(num + self.target_batch_size - 1, lookup[1]))
1274+
for num in range(
1275+
lookup[0], lookup[1] + 1, self.target_batch_size
1276+
)
1277+
]
1278+
lookup_type = 'index'
1279+
try:
1280+
for i, batch in enumerate(batches):
1281+
self.obj.getter.add_to_pipe(batch, lookup_type, multi)
1282+
nbatch = i + 1
1283+
is_final = nbatch == len(batches)
1284+
do_it = self.execute_every and nbatch % self.execute_every == 0
1285+
if is_final or do_it:
1286+
accumulator.push(self.obj.pipe.execute())
1287+
except Exception:
1288+
self.pipe.reset()
1289+
raise
1290+
return accumulator.pop_all()
1291+
1292+
def _get_dict(self, lookup, lookup_type, multi):
1293+
accumulator = self.accumulator_factory(dict, lookup_type)
1294+
hash_keys = lookup or self.obj.conn.hkeys(self.obj.key)
1295+
keys_stack = []
1296+
key_batches = [
1297+
hash_keys[(num):(num + self.target_batch_size)]
1298+
for num in range(0, len(hash_keys), self.target_batch_size)
1299+
]
1300+
try:
1301+
for i, batch_keys in enumerate(key_batches):
1302+
self.obj.getter.add_to_pipe(batch_keys, 'field', True)
1303+
keys_stack.append(batch_keys)
1304+
nbatch = i + 1
1305+
is_final = nbatch == len(key_batches)
1306+
do_it = self.execute_every and nbatch % self.execute_every == 0
1307+
if is_final or do_it:
1308+
accumulator.push((keys_stack, self.obj.pipe.execute()))
1309+
keys_stack = []
1310+
except Exception:
1311+
self.pipe.reset()
1312+
raise
1313+
return accumulator.pop_all()
1314+
1315+
def get(self, lookup=None, lookup_type=None):
1316+
lookup_args = self.obj.getter.configure_lookup(lookup, lookup_type)
1317+
lookup, lookup_type, multi = lookup_args
1318+
if lookup_type is None:
1319+
return None
1320+
if not multi:
1321+
self.obj.getter.add_to_pipe(lookup, lookup_type, multi)
1322+
return self.pipe.execute()[-1]
1323+
ptype = self.obj.getter.get_ptype_from_obj_rtype()
1324+
# accumulator = Accumulator.from_ptype(ptype, multi=True)
1325+
data = getattr(self, f'_get_{ptype.__name__}')(*lookup_args)
1326+
return data

0 commit comments

Comments
 (0)