Skip to content

Commit dd118ef

Browse files
committed
Move PPR C++ to gigl/csrc following PyTorch csrc conventions
1 parent 906df01 commit dd118ef

8 files changed

Lines changed: 441 additions & 468 deletions

File tree

gigl/csrc/distributed/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
try:
2+
from gigl.csrc.distributed.ppr_forward_push import PPRForwardPushState
3+
except ImportError as e:
4+
raise ImportError(
5+
"PPR C++ extension not compiled. "
6+
"Run `make build_cpp_extensions` from the GiGL root to build it."
7+
) from e
8+
9+
__all__ = ["PPRForwardPushState"]
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
#include "ppr_forward_push.h"
2+
3+
PPRForwardPushState::PPRForwardPushState(
4+
torch::Tensor seed_nodes, int32_t seed_node_type_id, double alpha,
5+
double requeue_threshold_factor,
6+
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids,
7+
std::vector<int32_t> edge_type_to_dst_ntype_id, std::vector<torch::Tensor> degree_tensors)
8+
: alpha_(alpha),
9+
one_minus_alpha_(1.0 - alpha),
10+
requeue_threshold_factor_(requeue_threshold_factor),
11+
// std::move transfers ownership of each vector into the member variable
12+
// without copying its contents — equivalent to Python's list hand-off
13+
// when you no longer need the original.
14+
node_type_to_edge_type_ids_(std::move(node_type_to_edge_type_ids)),
15+
edge_type_to_dst_ntype_id_(std::move(edge_type_to_dst_ntype_id)),
16+
degree_tensors_(std::move(degree_tensors)) {
17+
TORCH_CHECK(seed_nodes.dim() == 1, "seed_nodes must be 1D");
18+
batch_size_ = static_cast<int32_t>(seed_nodes.size(0));
19+
num_node_types_ = static_cast<int32_t>(node_type_to_edge_type_ids_.size());
20+
21+
// Allocate per-seed, per-node-type tables.
22+
// .assign(n, val) fills a vector with n copies of val — like [val] * n in Python.
23+
ppr_scores_.assign(batch_size_,
24+
std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
25+
residuals_.assign(batch_size_,
26+
std::vector<std::unordered_map<int32_t, double>>(num_node_types_));
27+
queue_.assign(batch_size_, std::vector<std::unordered_set<int32_t>>(num_node_types_));
28+
queued_nodes_.assign(batch_size_,
29+
std::vector<std::unordered_set<int32_t>>(num_node_types_));
30+
31+
// accessor<dtype, ndim>() returns a typed view into the tensor's data that
32+
// supports [i] indexing with bounds checking in debug builds.
33+
auto acc = seed_nodes.accessor<int64_t, 1>();
34+
num_nodes_in_queue_ = batch_size_;
35+
for (int32_t i = 0; i < batch_size_; ++i) {
36+
int32_t seed = static_cast<int32_t>(acc[i]);
37+
// PPR initialisation: each seed starts with residual = alpha (the
38+
// restart probability). The first push will move alpha into ppr_score
39+
// and distribute (1-alpha)*alpha to the seed's neighbors.
40+
residuals_[i][seed_node_type_id][seed] = alpha_;
41+
queue_[i][seed_node_type_id].insert(seed);
42+
}
43+
}
44+
45+
std::optional<std::unordered_map<int32_t, torch::Tensor>> PPRForwardPushState::drain_queue() {
46+
if (num_nodes_in_queue_ == 0) {
47+
return std::nullopt;
48+
}
49+
50+
// Reset the snapshot from the previous iteration.
51+
for (int32_t s = 0; s < batch_size_; ++s)
52+
for (auto& qs : queued_nodes_[s])
53+
qs.clear();
54+
55+
// nodes_to_lookup[eid] = set of node IDs that need a neighbor fetch for
56+
// edge type eid this round. Using a set deduplicates nodes that appear
57+
// in multiple seeds' queues: we only fetch each (node, etype) pair once.
58+
std::unordered_map<int32_t, std::unordered_set<int32_t>> nodes_to_lookup;
59+
60+
for (int32_t s = 0; s < batch_size_; ++s) {
61+
for (int32_t nt = 0; nt < num_node_types_; ++nt) {
62+
if (queue_[s][nt].empty())
63+
continue;
64+
65+
// Move the live queue into the snapshot (no data copy — O(1)).
66+
queued_nodes_[s][nt] = std::move(queue_[s][nt]);
67+
queue_[s][nt].clear();
68+
num_nodes_in_queue_ -= static_cast<int32_t>(queued_nodes_[s][nt].size());
69+
70+
for (int32_t node_id : queued_nodes_[s][nt]) {
71+
for (int32_t eid : node_type_to_edge_type_ids_[nt]) {
72+
if (neighbor_cache_.find(pack_key(node_id, eid)) == neighbor_cache_.end()) {
73+
nodes_to_lookup[eid].insert(node_id);
74+
}
75+
}
76+
}
77+
}
78+
}
79+
80+
std::unordered_map<int32_t, torch::Tensor> result;
81+
for (auto& [eid, node_set] : nodes_to_lookup) {
82+
std::vector<int64_t> ids(node_set.begin(), node_set.end());
83+
result[eid] = torch::tensor(ids, torch::kLong);
84+
}
85+
return result;
86+
}
87+
88+
void PPRForwardPushState::push_residuals(
89+
const std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>&
90+
fetched_by_etype_id) {
91+
// Step 1: Unpack the input map into a C++ map keyed by pack_key(node_id, etype_id)
92+
// for fast lookup during the residual-push loop below.
93+
std::unordered_map<uint64_t, std::vector<int32_t>> fetched;
94+
for (const auto& [eid, tup] : fetched_by_etype_id) {
95+
const auto& node_ids_t = std::get<0>(tup);
96+
const auto& flat_nbrs_t = std::get<1>(tup);
97+
const auto& counts_t = std::get<2>(tup);
98+
99+
// accessor<int64_t, 1>() gives a bounds-checked, typed 1-D view into
100+
// each tensor's data — equivalent to iterating over a NumPy array.
101+
auto node_acc = node_ids_t.accessor<int64_t, 1>();
102+
auto nbr_acc = flat_nbrs_t.accessor<int64_t, 1>();
103+
auto cnt_acc = counts_t.accessor<int64_t, 1>();
104+
105+
// Walk the flat neighbor list, slicing out each node's neighbors using
106+
// the running offset into the concatenated flat buffer.
107+
int64_t offset = 0;
108+
for (int64_t i = 0; i < node_ids_t.size(0); ++i) {
109+
int32_t nid = static_cast<int32_t>(node_acc[i]);
110+
int64_t count = cnt_acc[i];
111+
std::vector<int32_t> nbrs(count);
112+
for (int64_t j = 0; j < count; ++j)
113+
nbrs[j] = static_cast<int32_t>(nbr_acc[offset + j]);
114+
fetched[pack_key(nid, eid)] = std::move(nbrs);
115+
offset += count;
116+
}
117+
}
118+
119+
// Step 2: For every node that was in the queue (captured in queued_nodes_
120+
// by drain_queue()), apply one PPR push step:
121+
// a. Absorb residual into the PPR score.
122+
// b. Distribute (1-alpha) * residual equally to each neighbor.
123+
// c. Enqueue any neighbor whose residual now exceeds the requeue threshold.
124+
for (int32_t s = 0; s < batch_size_; ++s) {
125+
for (int32_t nt = 0; nt < num_node_types_; ++nt) {
126+
if (queued_nodes_[s][nt].empty())
127+
continue;
128+
129+
for (int32_t src : queued_nodes_[s][nt]) {
130+
auto& src_res = residuals_[s][nt];
131+
auto it = src_res.find(src);
132+
double res = (it != src_res.end()) ? it->second : 0.0;
133+
134+
// a. Absorb: move residual into the PPR score.
135+
ppr_scores_[s][nt][src] += res;
136+
src_res[src] = 0.0;
137+
138+
int32_t total_deg = get_total_degree(src, nt);
139+
// Destination-only nodes absorb residual but do not push further.
140+
if (total_deg == 0)
141+
continue;
142+
143+
// b. Distribute: each neighbor receives an equal share.
144+
double res_per_nbr = one_minus_alpha_ * res / static_cast<double>(total_deg);
145+
146+
for (int32_t eid : node_type_to_edge_type_ids_[nt]) {
147+
// Invariant: fetched and neighbor_cache_ are mutually exclusive for
148+
// any given (node, etype) key within one iteration. drain_queue()
149+
// only requests a fetch for nodes absent from neighbor_cache_, so a
150+
// key is in at most one of the two.
151+
const std::vector<int32_t>* nbr_list = nullptr;
152+
auto fi = fetched.find(pack_key(src, eid));
153+
if (fi != fetched.end()) {
154+
nbr_list = &fi->second;
155+
} else {
156+
auto ci = neighbor_cache_.find(pack_key(src, eid));
157+
if (ci != neighbor_cache_.end())
158+
nbr_list = &ci->second;
159+
}
160+
if (!nbr_list || nbr_list->empty())
161+
continue;
162+
163+
int32_t dst_nt = edge_type_to_dst_ntype_id_[eid];
164+
165+
// c. Accumulate residual for each neighbor and re-enqueue if threshold
166+
// exceeded.
167+
for (int32_t nbr : *nbr_list) {
168+
residuals_[s][dst_nt][nbr] += res_per_nbr;
169+
170+
double threshold = requeue_threshold_factor_ *
171+
static_cast<double>(get_total_degree(nbr, dst_nt));
172+
173+
if (queue_[s][dst_nt].find(nbr) == queue_[s][dst_nt].end() &&
174+
residuals_[s][dst_nt][nbr] >= threshold) {
175+
queue_[s][dst_nt].insert(nbr);
176+
++num_nodes_in_queue_;
177+
178+
// Promote neighbor lists to the persistent cache: this node will
179+
// be processed next iteration, so caching avoids a re-fetch.
180+
for (int32_t peid : node_type_to_edge_type_ids_[dst_nt]) {
181+
uint64_t pk = pack_key(nbr, peid);
182+
if (neighbor_cache_.find(pk) == neighbor_cache_.end()) {
183+
auto pfi = fetched.find(pk);
184+
if (pfi != fetched.end())
185+
neighbor_cache_[pk] = pfi->second;
186+
}
187+
}
188+
}
189+
}
190+
}
191+
}
192+
}
193+
}
194+
}
195+
196+
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
197+
PPRForwardPushState::extract_top_k(int32_t max_ppr_nodes) {
198+
std::unordered_set<int32_t> active;
199+
for (int32_t s = 0; s < batch_size_; ++s)
200+
for (int32_t nt = 0; nt < num_node_types_; ++nt)
201+
if (!ppr_scores_[s][nt].empty())
202+
active.insert(nt);
203+
204+
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> result;
205+
for (int32_t nt : active) {
206+
std::vector<int64_t> flat_ids;
207+
std::vector<float> flat_weights;
208+
std::vector<int64_t> valid_counts;
209+
210+
for (int32_t s = 0; s < batch_size_; ++s) {
211+
const auto& scores = ppr_scores_[s][nt];
212+
int32_t k = std::min(max_ppr_nodes, static_cast<int32_t>(scores.size()));
213+
if (k > 0) {
214+
std::vector<std::pair<int32_t, double>> items(scores.begin(), scores.end());
215+
std::partial_sort(
216+
items.begin(), items.begin() + k, items.end(),
217+
[](const auto& a, const auto& b) { return a.second > b.second; });
218+
219+
for (int32_t i = 0; i < k; ++i) {
220+
flat_ids.push_back(static_cast<int64_t>(items[i].first));
221+
// Cast to float32 for output; internal scores stay double to
222+
// avoid accumulated rounding errors in the push loop.
223+
flat_weights.push_back(static_cast<float>(items[i].second));
224+
}
225+
}
226+
valid_counts.push_back(static_cast<int64_t>(k));
227+
}
228+
229+
result[nt] = {torch::tensor(flat_ids, torch::kLong),
230+
torch::tensor(flat_weights, torch::kFloat),
231+
torch::tensor(valid_counts, torch::kLong)};
232+
}
233+
return result;
234+
}
235+
236+
int32_t PPRForwardPushState::get_total_degree(int32_t node_id, int32_t ntype_id) const {
237+
if (ntype_id >= static_cast<int32_t>(degree_tensors_.size()))
238+
return 0;
239+
const auto& t = degree_tensors_[ntype_id];
240+
if (t.numel() == 0)
241+
return 0;
242+
TORCH_CHECK(node_id < static_cast<int32_t>(t.size(0)), "Node ID ", node_id,
243+
" out of range for degree tensor of ntype_id ", ntype_id, " (size=", t.size(0),
244+
"). This indicates corrupted graph data or a sampler bug.");
245+
// data_ptr<int32_t>() returns a raw C pointer to the tensor's int32 data buffer.
246+
return t.data_ptr<int32_t>()[node_id];
247+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
#include <algorithm> // std::partial_sort, std::min
6+
#include <cstdint> // Fixed-width integer types: int32_t, int64_t, uint32_t, uint64_t
7+
#include <optional> // std::optional for nullable return values
8+
#include <tuple> // std::tuple for multi-value returns
9+
#include <unordered_map> // std::unordered_map — like Python dict, O(1) average lookup
10+
#include <unordered_set> // std::unordered_set — like Python set, O(1) average lookup
11+
#include <vector> // std::vector — like Python list, contiguous in memory
12+
13+
// Combine (node_id, etype_id) into a single 64-bit integer for use as a hash
14+
// map key. A single 64-bit integer is cheaper to hash than a pair of two
15+
// integers (std::unordered_map has no built-in pair hash).
16+
//
17+
// Bit layout:
18+
// bits 63–32: node_id (upper half)
19+
// bits 31– 0: etype_id (lower half)
20+
//
21+
// Both inputs are cast through uint32_t before packing. Without this, a
22+
// negative int32_t (e.g. -1 = 0xFFFFFFFF) would be sign-extended to a full
23+
// 64-bit value, corrupting the upper bits when shifted. Reinterpreting as
24+
// uint32_t first treats the bit pattern as-is (no sign extension).
25+
static inline uint64_t pack_key(int32_t node_id, int32_t etype_id) {
26+
return (static_cast<uint64_t>(static_cast<uint32_t>(node_id)) << 32) |
27+
static_cast<uint32_t>(etype_id);
28+
}
29+
30+
// C++ kernel for the PPR Forward Push algorithm (Andersen et al., 2006).
31+
//
32+
// All hot-loop state (scores, residuals, queue, neighbor cache) lives inside
33+
// this object. The distributed neighbor fetch is kept in Python because it
34+
// involves async RPC calls that C++ cannot drive directly.
35+
//
36+
// Owned state: ppr_scores, residuals, queue, queued_nodes, neighbor_cache.
37+
// Python retains ownership of: the distributed neighbor fetch (_batch_fetch_neighbors).
38+
//
39+
// Typical call sequence per batch:
40+
// 1. PPRForwardPushState(seed_nodes, ...) — init per-seed residuals / queue
41+
// while True:
42+
// 2. drain_queue() — drain queue → nodes needing lookup
43+
// 3. <Python: _batch_fetch_neighbors(...)> — distributed RPC fetch (stays in Python)
44+
// 4. push_residuals(fetched_by_etype_id) — push residuals, update queue
45+
// 5. extract_top_k(max_ppr_nodes) — top-k selection per seed per node type
46+
class PPRForwardPushState {
47+
public:
48+
PPRForwardPushState(torch::Tensor seed_nodes, int32_t seed_node_type_id, double alpha,
49+
double requeue_threshold_factor,
50+
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids,
51+
std::vector<int32_t> edge_type_to_dst_ntype_id,
52+
std::vector<torch::Tensor> degree_tensors);
53+
54+
// Drain all queued nodes and return {etype_id: tensor[node_ids]} for batch
55+
// neighbor lookup. Also snapshots the drained nodes into queued_nodes_ for
56+
// use by push_residuals().
57+
//
58+
// Return value semantics:
59+
// - std::nullopt → queue was already empty; convergence achieved; stop the loop.
60+
// - empty map → nodes were drained but all were cached; call push_residuals({}).
61+
// - non-empty map → {etype_id → 1-D int64 tensor of node IDs} needing neighbor lookup.
62+
std::optional<std::unordered_map<int32_t, torch::Tensor>> drain_queue();
63+
64+
// Push residuals to neighbors given the fetched neighbor data.
65+
//
66+
// fetched_by_etype_id: {etype_id: (node_ids_tensor, flat_nbrs_tensor, counts_tensor)}
67+
// - node_ids_tensor: [N] int64 — source node IDs fetched for this edge type
68+
// - flat_nbrs_tensor: [sum(counts)] int64 — all neighbor lists concatenated flat
69+
// - counts_tensor: [N] int64 — neighbor count for each source node
70+
void push_residuals(const std::unordered_map<
71+
int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>&
72+
fetched_by_etype_id);
73+
74+
// Extract top-k PPR nodes per seed per node type.
75+
//
76+
// Returns {ntype_id: (flat_ids_tensor, flat_weights_tensor, valid_counts_tensor)}.
77+
// Only node types that received any PPR score are included in the output.
78+
//
79+
// Output layout for a batch of B seeds:
80+
// flat_ids[0 : valid_counts[0]] → top-k nodes for seed 0
81+
// flat_ids[valid_counts[0] : valid_counts[0]+valid_counts[1]] → top-k for seed 1
82+
// ...
83+
std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
84+
extract_top_k(int32_t max_ppr_nodes);
85+
86+
private:
87+
// Look up the total (across all edge types) out-degree of a node.
88+
// Returns 0 for destination-only node types (no outgoing edges).
89+
int32_t get_total_degree(int32_t node_id, int32_t ntype_id) const;
90+
91+
// -------------------------------------------------------------------------
92+
// Scalar algorithm parameters
93+
// -------------------------------------------------------------------------
94+
double alpha_; // Restart probability
95+
double one_minus_alpha_; // 1 - alpha, precomputed to avoid repeated subtraction
96+
double requeue_threshold_factor_; // alpha * eps; multiplied by degree to get per-node threshold
97+
98+
int32_t batch_size_; // Number of seeds in the current batch
99+
int32_t num_node_types_; // Total number of node types (homo + hetero)
100+
int32_t num_nodes_in_queue_{0}; // Running count of nodes across all seeds / types
101+
102+
// -------------------------------------------------------------------------
103+
// Graph structure (read-only after construction)
104+
// -------------------------------------------------------------------------
105+
std::vector<std::vector<int32_t>> node_type_to_edge_type_ids_;
106+
std::vector<int32_t> edge_type_to_dst_ntype_id_;
107+
std::vector<torch::Tensor> degree_tensors_;
108+
109+
// -------------------------------------------------------------------------
110+
// Per-seed, per-node-type PPR state (indexed [seed_idx][ntype_id])
111+
// -------------------------------------------------------------------------
112+
std::vector<std::vector<std::unordered_map<int32_t, double>>> ppr_scores_;
113+
std::vector<std::vector<std::unordered_map<int32_t, double>>> residuals_;
114+
std::vector<std::vector<std::unordered_set<int32_t>>> queue_;
115+
std::vector<std::vector<std::unordered_set<int32_t>>> queued_nodes_;
116+
117+
// -------------------------------------------------------------------------
118+
// Neighbor cache
119+
// -------------------------------------------------------------------------
120+
std::unordered_map<uint64_t, std::vector<int32_t>> neighbor_cache_;
121+
};
File renamed without changes.

0 commit comments

Comments
 (0)