Skip to content

Commit 51147fc

Browse files
Robbie1977claude
andcommitted
Use OWLERY subclass expansion for class connectivity queries
Connectivity data is now cached per direct neuron class only (no SUBCLASSOF in the indexer). The query functions now follow the classic VFB pattern: OWLERY resolves subclasses, Solr is batch-queried for all subclass docs, and results are merged by partner class with summed statistics. Also removes SUBCLASSOF expansion from the Neo4j fallback in vfb_connectivity.py for consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3c886eb commit 51147fc

2 files changed

Lines changed: 148 additions & 90 deletions

File tree

src/vfbquery/vfb_connectivity.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,18 @@ def _build_connectivity_cypher(upstream_id, downstream_id, weight,
151151
"""
152152
clauses = []
153153

154-
# Match upstream class and subclasses
154+
# Match upstream class directly (subclass expansion is handled at
155+
# query time via OWLERY in the Solr-cached path; this Neo4j fallback
156+
# keeps only direct INSTANCEOF matching for consistency).
155157
if upstream_id is not None:
156158
clauses.append(
157-
f"MATCH (:Class:Neuron {{short_form:'{upstream_id}'}})"
158-
f"<-[:SUBCLASSOF*0..]-(c1:Class:Neuron)"
159+
f"MATCH (c1:Class:Neuron {{short_form:'{upstream_id}'}})"
159160
)
160161

161-
# Match downstream class and subclasses
162+
# Match downstream class directly
162163
if downstream_id is not None:
163164
clauses.append(
164-
f"MATCH (:Class:Neuron {{short_form:'{downstream_id}'}})"
165-
f"<-[:SUBCLASSOF*0..]-(c2:Class:Neuron)"
165+
f"MATCH (c2:Class:Neuron {{short_form:'{downstream_id}'}})"
166166
)
167167

168168
# Core synapse matching

src/vfbquery/vfb_queries.py

Lines changed: 142 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,14 +3009,130 @@ def get_neuron_region_connectivity(short_form: str, return_dataframe=True, limit
30093009
}
30103010

30113011

3012+
def _fetch_connectivity_entries(short_form: str, solr_field: str):
3013+
"""Fetch connectivity entries from Solr for a neuron class and all its
3014+
OWLERY subclasses.
3015+
3016+
Returns a flat list of parsed JSON entries (dicts) from the Solr
3017+
connectivity field, collected across every subclass doc.
3018+
"""
3019+
# Step 1: OWLERY subclass expansion (includes the class itself)
3020+
owl_query = f"<{short_form}>"
3021+
try:
3022+
subclass_ids = vc.vfb.oc.get_subclasses(
3023+
query=owl_query, query_by_label=False, verbose=False
3024+
)
3025+
except Exception as e:
3026+
print(f"Owlery subclass query failed for {short_form}: {e}")
3027+
subclass_ids = []
3028+
3029+
# Always include the queried class itself
3030+
if short_form not in subclass_ids:
3031+
subclass_ids.insert(0, short_form)
3032+
3033+
if not subclass_ids:
3034+
return []
3035+
3036+
# Step 2: Batch-fetch from Solr using {!terms f=id}
3037+
id_list = ','.join(subclass_ids)
3038+
try:
3039+
results = vfb_solr.search(
3040+
q='id:*',
3041+
fq=f'{{!terms f=id}}{id_list}',
3042+
fl=solr_field,
3043+
rows=len(subclass_ids),
3044+
)
3045+
except Exception as e:
3046+
print(f"Error querying Solr for {solr_field}: {e}")
3047+
return []
3048+
3049+
# Step 3: Parse all connectivity JSON from all returned docs
3050+
all_entries = []
3051+
for doc in results.docs:
3052+
if solr_field not in doc:
3053+
continue
3054+
raw = doc[solr_field]
3055+
field_json = raw[0] if isinstance(raw, list) else raw
3056+
try:
3057+
entries = json.loads(field_json)
3058+
except (json.JSONDecodeError, TypeError):
3059+
continue
3060+
if isinstance(entries, list):
3061+
all_entries.extend(entries)
3062+
else:
3063+
all_entries.append(entries)
3064+
3065+
return all_entries
3066+
3067+
3068+
def _merge_connectivity_rows(entries, partner_key, partner_id_key, partner_label_key):
3069+
"""Merge connectivity entries by partner class, summing statistics.
3070+
3071+
Returns a list of merged row dicts ready for DataFrame / dict output.
3072+
``partner_key`` is the output column name (e.g. 'downstream_class'),
3073+
``partner_id_key`` / ``partner_label_key`` are the keys inside
3074+
``class_connectivity`` to read partner id and label from.
3075+
"""
3076+
# Accumulate by partner class id
3077+
merged = {} # partner_id -> {label, total_n, connected_n, pw, tw}
3078+
for entry in entries:
3079+
cc = entry.get('class_connectivity', {})
3080+
obj = entry.get('object', {})
3081+
pid = obj.get('short_form', cc.get(partner_id_key, ''))
3082+
plabel = obj.get('label', cc.get(partner_label_key, ''))
3083+
if not pid:
3084+
continue
3085+
if pid not in merged:
3086+
merged[pid] = {
3087+
'label': plabel,
3088+
'total_n': 0,
3089+
'connected_n': 0,
3090+
'pairwise_connections': 0,
3091+
'total_weight': 0,
3092+
}
3093+
m = merged[pid]
3094+
m['total_n'] += _num(cc.get('total_upstream_count', 0))
3095+
m['connected_n'] += _num(cc.get('connected_upstream_count', 0))
3096+
m['pairwise_connections'] += _num(cc.get('pairwise_connections', 0))
3097+
m['total_weight'] += _num(cc.get('total_weight', 0))
3098+
3099+
rows = []
3100+
for pid, m in merged.items():
3101+
total_n = m['total_n']
3102+
connected_n = m['connected_n']
3103+
pw = m['pairwise_connections']
3104+
tw = m['total_weight']
3105+
pct = round((connected_n / total_n) * 100) if total_n else 0
3106+
avg = tw / pw if pw else 0
3107+
rows.append({
3108+
'id': pid,
3109+
partner_key: f"[{m['label']}]({pid})" if pid else m['label'],
3110+
'total_n': total_n,
3111+
'connected_n': connected_n,
3112+
'percent_connected': pct,
3113+
'pairwise_connections': pw,
3114+
'total_weight': tw,
3115+
'avg_weight': avg,
3116+
})
3117+
return rows
3118+
3119+
3120+
def _num(v):
3121+
"""Coerce a value to a number, defaulting to 0."""
3122+
try:
3123+
return float(v)
3124+
except (TypeError, ValueError):
3125+
return 0
3126+
3127+
30123128
@with_solr_cache('downstream_class_connectivity_query')
30133129
def get_downstream_class_connectivity(short_form: str, return_dataframe=True, limit: int = -1):
30143130
"""
30153131
Retrieves downstream connectivity classes for the specified neuron class.
30163132
3017-
Reads the downstream_connectivity_query Solr field, which contains a JSON array
3018-
of vfb_query-format objects populated by the neuron_downstream_connectivity_indexer.
3019-
Each element represents one (primary_class → downstream_class) connection summary.
3133+
Uses OWLERY to expand subclasses of the queried class, fetches the
3134+
downstream_connectivity_query Solr field for each, and merges results
3135+
by downstream partner class.
30203136
30213137
Matching criteria: Class + Neuron
30223138
@@ -3025,50 +3141,21 @@ def get_downstream_class_connectivity(short_form: str, return_dataframe=True, li
30253141
:param limit: maximum number of results to return (default -1, returns all results)
30263142
:return: Downstream partner neuron classes with connectivity statistics
30273143
"""
3028-
solr_field = 'downstream_connectivity_query'
3029-
try:
3030-
results = vfb_solr.search(f'id:{short_form}', fl=solr_field, rows=1)
3031-
except Exception as e:
3032-
print(f"Error querying Solr for downstream class connectivity: {e}")
3033-
if return_dataframe:
3034-
return pd.DataFrame()
3035-
return {'headers': {}, 'rows': [], 'count': 0}
3036-
3037-
if not results.hits or not results.docs or solr_field not in results.docs[0]:
3144+
entries = _fetch_connectivity_entries(short_form, 'downstream_connectivity_query')
3145+
if not entries:
30383146
if return_dataframe:
30393147
return pd.DataFrame()
30403148
return {'headers': {}, 'rows': [], 'count': 0}
30413149

3042-
raw = results.docs[0][solr_field]
3043-
field_json = raw[0] if isinstance(raw, list) else raw
3044-
try:
3045-
entries = json.loads(field_json)
3046-
except (json.JSONDecodeError, TypeError) as e:
3047-
print(f"Error parsing downstream_connectivity_query JSON for {short_form}: {e}")
3048-
if return_dataframe:
3049-
return pd.DataFrame()
3050-
return {'headers': {}, 'rows': [], 'count': 0}
3051-
3052-
if not isinstance(entries, list):
3053-
entries = [entries]
3150+
rows = _merge_connectivity_rows(
3151+
entries,
3152+
partner_key='downstream_class',
3153+
partner_id_key='downstream_class_id',
3154+
partner_label_key='downstream_class',
3155+
)
30543156

3055-
rows = []
3056-
for entry in entries:
3057-
cc = entry.get('class_connectivity', {})
3058-
obj = entry.get('object', {})
3059-
ds_id = obj.get('short_form', cc.get('downstream_class_id', ''))
3060-
ds_label = obj.get('label', cc.get('downstream_class', ''))
3061-
row = {
3062-
'id': ds_id,
3063-
'downstream_class': f"[{ds_label}]({ds_id})" if ds_id else ds_label,
3064-
'total_n': cc.get('total_upstream_count', ''),
3065-
'connected_n': cc.get('connected_upstream_count', ''),
3066-
'percent_connected': cc.get('percent_connected', ''),
3067-
'pairwise_connections': cc.get('pairwise_connections', ''),
3068-
'total_weight': cc.get('total_weight', ''),
3069-
'avg_weight': cc.get('average_weight', ''),
3070-
}
3071-
rows.append(row)
3157+
# Sort by pairwise_connections descending
3158+
rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True)
30723159

30733160
total_count = len(rows)
30743161
if limit != -1:
@@ -3097,9 +3184,9 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi
30973184
"""
30983185
Retrieves upstream connectivity classes for the specified neuron class.
30993186
3100-
Reads the upstream_connectivity_query Solr field, which contains a JSON array
3101-
of vfb_query-format objects populated by the neuron_upstream_connectivity_indexer.
3102-
Each element represents one (upstream_class → primary_class) connection summary.
3187+
Uses OWLERY to expand subclasses of the queried class, fetches the
3188+
upstream_connectivity_query Solr field for each, and merges results
3189+
by upstream partner class.
31033190
31043191
Matching criteria: Class + Neuron
31053192
@@ -3108,50 +3195,21 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi
31083195
:param limit: maximum number of results to return (default -1, returns all results)
31093196
:return: Upstream partner neuron classes with connectivity statistics
31103197
"""
3111-
solr_field = 'upstream_connectivity_query'
3112-
try:
3113-
results = vfb_solr.search(f'id:{short_form}', fl=solr_field, rows=1)
3114-
except Exception as e:
3115-
print(f"Error querying Solr for upstream class connectivity: {e}")
3116-
if return_dataframe:
3117-
return pd.DataFrame()
3118-
return {'headers': {}, 'rows': [], 'count': 0}
3119-
3120-
if not results.hits or not results.docs or solr_field not in results.docs[0]:
3121-
if return_dataframe:
3122-
return pd.DataFrame()
3123-
return {'headers': {}, 'rows': [], 'count': 0}
3124-
3125-
raw = results.docs[0][solr_field]
3126-
field_json = raw[0] if isinstance(raw, list) else raw
3127-
try:
3128-
entries = json.loads(field_json)
3129-
except (json.JSONDecodeError, TypeError) as e:
3130-
print(f"Error parsing upstream_connectivity_query JSON for {short_form}: {e}")
3198+
entries = _fetch_connectivity_entries(short_form, 'upstream_connectivity_query')
3199+
if not entries:
31313200
if return_dataframe:
31323201
return pd.DataFrame()
31333202
return {'headers': {}, 'rows': [], 'count': 0}
31343203

3135-
if not isinstance(entries, list):
3136-
entries = [entries]
3204+
rows = _merge_connectivity_rows(
3205+
entries,
3206+
partner_key='upstream_class',
3207+
partner_id_key='upstream_class_id',
3208+
partner_label_key='upstream_class',
3209+
)
31373210

3138-
rows = []
3139-
for entry in entries:
3140-
cc = entry.get('class_connectivity', {})
3141-
obj = entry.get('object', {})
3142-
us_id = obj.get('short_form', cc.get('upstream_class_id', ''))
3143-
us_label = obj.get('label', cc.get('upstream_class', ''))
3144-
row = {
3145-
'id': us_id,
3146-
'upstream_class': f"[{us_label}]({us_id})" if us_id else us_label,
3147-
'total_n': cc.get('total_upstream_count', ''),
3148-
'connected_n': cc.get('connected_upstream_count', ''),
3149-
'percent_connected': cc.get('percent_connected', ''),
3150-
'pairwise_connections': cc.get('pairwise_connections', ''),
3151-
'total_weight': cc.get('total_weight', ''),
3152-
'avg_weight': cc.get('average_weight', ''),
3153-
}
3154-
rows.append(row)
3211+
# Sort by pairwise_connections descending
3212+
rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True)
31553213

31563214
total_count = len(rows)
31573215
if limit != -1:

0 commit comments

Comments
 (0)