|
| 1 | +#include "ppr_forward_push.h" |
| 2 | + |
| 3 | +PPRForwardPushState::PPRForwardPushState( |
| 4 | + torch::Tensor seed_nodes, int32_t seed_node_type_id, double alpha, |
| 5 | + double requeue_threshold_factor, |
| 6 | + std::vector<std::vector<int32_t>> node_type_to_edge_type_ids, |
| 7 | + std::vector<int32_t> edge_type_to_dst_ntype_id, std::vector<torch::Tensor> degree_tensors) |
| 8 | + : alpha_(alpha), |
| 9 | + one_minus_alpha_(1.0 - alpha), |
| 10 | + requeue_threshold_factor_(requeue_threshold_factor), |
| 11 | + // std::move transfers ownership of each vector into the member variable |
| 12 | + // without copying its contents — equivalent to Python's list hand-off |
| 13 | + // when you no longer need the original. |
| 14 | + node_type_to_edge_type_ids_(std::move(node_type_to_edge_type_ids)), |
| 15 | + edge_type_to_dst_ntype_id_(std::move(edge_type_to_dst_ntype_id)), |
| 16 | + degree_tensors_(std::move(degree_tensors)) { |
| 17 | + TORCH_CHECK(seed_nodes.dim() == 1, "seed_nodes must be 1D"); |
| 18 | + batch_size_ = static_cast<int32_t>(seed_nodes.size(0)); |
| 19 | + num_node_types_ = static_cast<int32_t>(node_type_to_edge_type_ids_.size()); |
| 20 | + |
| 21 | + // Allocate per-seed, per-node-type tables. |
| 22 | + // .assign(n, val) fills a vector with n copies of val — like [val] * n in Python. |
| 23 | + ppr_scores_.assign(batch_size_, |
| 24 | + std::vector<std::unordered_map<int32_t, double>>(num_node_types_)); |
| 25 | + residuals_.assign(batch_size_, |
| 26 | + std::vector<std::unordered_map<int32_t, double>>(num_node_types_)); |
| 27 | + queue_.assign(batch_size_, std::vector<std::unordered_set<int32_t>>(num_node_types_)); |
| 28 | + queued_nodes_.assign(batch_size_, |
| 29 | + std::vector<std::unordered_set<int32_t>>(num_node_types_)); |
| 30 | + |
| 31 | + // accessor<dtype, ndim>() returns a typed view into the tensor's data that |
| 32 | + // supports [i] indexing with bounds checking in debug builds. |
| 33 | + auto acc = seed_nodes.accessor<int64_t, 1>(); |
| 34 | + num_nodes_in_queue_ = batch_size_; |
| 35 | + for (int32_t i = 0; i < batch_size_; ++i) { |
| 36 | + int32_t seed = static_cast<int32_t>(acc[i]); |
| 37 | + // PPR initialisation: each seed starts with residual = alpha (the |
| 38 | + // restart probability). The first push will move alpha into ppr_score |
| 39 | + // and distribute (1-alpha)*alpha to the seed's neighbors. |
| 40 | + residuals_[i][seed_node_type_id][seed] = alpha_; |
| 41 | + queue_[i][seed_node_type_id].insert(seed); |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +std::optional<std::unordered_map<int32_t, torch::Tensor>> PPRForwardPushState::drain_queue() { |
| 46 | + if (num_nodes_in_queue_ == 0) { |
| 47 | + return std::nullopt; |
| 48 | + } |
| 49 | + |
| 50 | + // Reset the snapshot from the previous iteration. |
| 51 | + for (int32_t s = 0; s < batch_size_; ++s) |
| 52 | + for (auto& qs : queued_nodes_[s]) |
| 53 | + qs.clear(); |
| 54 | + |
| 55 | + // nodes_to_lookup[eid] = set of node IDs that need a neighbor fetch for |
| 56 | + // edge type eid this round. Using a set deduplicates nodes that appear |
| 57 | + // in multiple seeds' queues: we only fetch each (node, etype) pair once. |
| 58 | + std::unordered_map<int32_t, std::unordered_set<int32_t>> nodes_to_lookup; |
| 59 | + |
| 60 | + for (int32_t s = 0; s < batch_size_; ++s) { |
| 61 | + for (int32_t nt = 0; nt < num_node_types_; ++nt) { |
| 62 | + if (queue_[s][nt].empty()) |
| 63 | + continue; |
| 64 | + |
| 65 | + // Move the live queue into the snapshot (no data copy — O(1)). |
| 66 | + queued_nodes_[s][nt] = std::move(queue_[s][nt]); |
| 67 | + queue_[s][nt].clear(); |
| 68 | + num_nodes_in_queue_ -= static_cast<int32_t>(queued_nodes_[s][nt].size()); |
| 69 | + |
| 70 | + for (int32_t node_id : queued_nodes_[s][nt]) { |
| 71 | + for (int32_t eid : node_type_to_edge_type_ids_[nt]) { |
| 72 | + if (neighbor_cache_.find(pack_key(node_id, eid)) == neighbor_cache_.end()) { |
| 73 | + nodes_to_lookup[eid].insert(node_id); |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + std::unordered_map<int32_t, torch::Tensor> result; |
| 81 | + for (auto& [eid, node_set] : nodes_to_lookup) { |
| 82 | + std::vector<int64_t> ids(node_set.begin(), node_set.end()); |
| 83 | + result[eid] = torch::tensor(ids, torch::kLong); |
| 84 | + } |
| 85 | + return result; |
| 86 | +} |
| 87 | + |
| 88 | +void PPRForwardPushState::push_residuals( |
| 89 | + const std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>& |
| 90 | + fetched_by_etype_id) { |
| 91 | + // Step 1: Unpack the input map into a C++ map keyed by pack_key(node_id, etype_id) |
| 92 | + // for fast lookup during the residual-push loop below. |
| 93 | + std::unordered_map<uint64_t, std::vector<int32_t>> fetched; |
| 94 | + for (const auto& [eid, tup] : fetched_by_etype_id) { |
| 95 | + const auto& node_ids_t = std::get<0>(tup); |
| 96 | + const auto& flat_nbrs_t = std::get<1>(tup); |
| 97 | + const auto& counts_t = std::get<2>(tup); |
| 98 | + |
| 99 | + // accessor<int64_t, 1>() gives a bounds-checked, typed 1-D view into |
| 100 | + // each tensor's data — equivalent to iterating over a NumPy array. |
| 101 | + auto node_acc = node_ids_t.accessor<int64_t, 1>(); |
| 102 | + auto nbr_acc = flat_nbrs_t.accessor<int64_t, 1>(); |
| 103 | + auto cnt_acc = counts_t.accessor<int64_t, 1>(); |
| 104 | + |
| 105 | + // Walk the flat neighbor list, slicing out each node's neighbors using |
| 106 | + // the running offset into the concatenated flat buffer. |
| 107 | + int64_t offset = 0; |
| 108 | + for (int64_t i = 0; i < node_ids_t.size(0); ++i) { |
| 109 | + int32_t nid = static_cast<int32_t>(node_acc[i]); |
| 110 | + int64_t count = cnt_acc[i]; |
| 111 | + std::vector<int32_t> nbrs(count); |
| 112 | + for (int64_t j = 0; j < count; ++j) |
| 113 | + nbrs[j] = static_cast<int32_t>(nbr_acc[offset + j]); |
| 114 | + fetched[pack_key(nid, eid)] = std::move(nbrs); |
| 115 | + offset += count; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + // Step 2: For every node that was in the queue (captured in queued_nodes_ |
| 120 | + // by drain_queue()), apply one PPR push step: |
| 121 | + // a. Absorb residual into the PPR score. |
| 122 | + // b. Distribute (1-alpha) * residual equally to each neighbor. |
| 123 | + // c. Enqueue any neighbor whose residual now exceeds the requeue threshold. |
| 124 | + for (int32_t s = 0; s < batch_size_; ++s) { |
| 125 | + for (int32_t nt = 0; nt < num_node_types_; ++nt) { |
| 126 | + if (queued_nodes_[s][nt].empty()) |
| 127 | + continue; |
| 128 | + |
| 129 | + for (int32_t src : queued_nodes_[s][nt]) { |
| 130 | + auto& src_res = residuals_[s][nt]; |
| 131 | + auto it = src_res.find(src); |
| 132 | + double res = (it != src_res.end()) ? it->second : 0.0; |
| 133 | + |
| 134 | + // a. Absorb: move residual into the PPR score. |
| 135 | + ppr_scores_[s][nt][src] += res; |
| 136 | + src_res[src] = 0.0; |
| 137 | + |
| 138 | + int32_t total_deg = get_total_degree(src, nt); |
| 139 | + // Destination-only nodes absorb residual but do not push further. |
| 140 | + if (total_deg == 0) |
| 141 | + continue; |
| 142 | + |
| 143 | + // b. Distribute: each neighbor receives an equal share. |
| 144 | + double res_per_nbr = one_minus_alpha_ * res / static_cast<double>(total_deg); |
| 145 | + |
| 146 | + for (int32_t eid : node_type_to_edge_type_ids_[nt]) { |
| 147 | + // Invariant: fetched and neighbor_cache_ are mutually exclusive for |
| 148 | + // any given (node, etype) key within one iteration. drain_queue() |
| 149 | + // only requests a fetch for nodes absent from neighbor_cache_, so a |
| 150 | + // key is in at most one of the two. |
| 151 | + const std::vector<int32_t>* nbr_list = nullptr; |
| 152 | + auto fi = fetched.find(pack_key(src, eid)); |
| 153 | + if (fi != fetched.end()) { |
| 154 | + nbr_list = &fi->second; |
| 155 | + } else { |
| 156 | + auto ci = neighbor_cache_.find(pack_key(src, eid)); |
| 157 | + if (ci != neighbor_cache_.end()) |
| 158 | + nbr_list = &ci->second; |
| 159 | + } |
| 160 | + if (!nbr_list || nbr_list->empty()) |
| 161 | + continue; |
| 162 | + |
| 163 | + int32_t dst_nt = edge_type_to_dst_ntype_id_[eid]; |
| 164 | + |
| 165 | + // c. Accumulate residual for each neighbor and re-enqueue if threshold |
| 166 | + // exceeded. |
| 167 | + for (int32_t nbr : *nbr_list) { |
| 168 | + residuals_[s][dst_nt][nbr] += res_per_nbr; |
| 169 | + |
| 170 | + double threshold = requeue_threshold_factor_ * |
| 171 | + static_cast<double>(get_total_degree(nbr, dst_nt)); |
| 172 | + |
| 173 | + if (queue_[s][dst_nt].find(nbr) == queue_[s][dst_nt].end() && |
| 174 | + residuals_[s][dst_nt][nbr] >= threshold) { |
| 175 | + queue_[s][dst_nt].insert(nbr); |
| 176 | + ++num_nodes_in_queue_; |
| 177 | + |
| 178 | + // Promote neighbor lists to the persistent cache: this node will |
| 179 | + // be processed next iteration, so caching avoids a re-fetch. |
| 180 | + for (int32_t peid : node_type_to_edge_type_ids_[dst_nt]) { |
| 181 | + uint64_t pk = pack_key(nbr, peid); |
| 182 | + if (neighbor_cache_.find(pk) == neighbor_cache_.end()) { |
| 183 | + auto pfi = fetched.find(pk); |
| 184 | + if (pfi != fetched.end()) |
| 185 | + neighbor_cache_[pk] = pfi->second; |
| 186 | + } |
| 187 | + } |
| 188 | + } |
| 189 | + } |
| 190 | + } |
| 191 | + } |
| 192 | + } |
| 193 | + } |
| 194 | +} |
| 195 | + |
| 196 | +std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> |
| 197 | +PPRForwardPushState::extract_top_k(int32_t max_ppr_nodes) { |
| 198 | + std::unordered_set<int32_t> active; |
| 199 | + for (int32_t s = 0; s < batch_size_; ++s) |
| 200 | + for (int32_t nt = 0; nt < num_node_types_; ++nt) |
| 201 | + if (!ppr_scores_[s][nt].empty()) |
| 202 | + active.insert(nt); |
| 203 | + |
| 204 | + std::unordered_map<int32_t, std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> result; |
| 205 | + for (int32_t nt : active) { |
| 206 | + std::vector<int64_t> flat_ids; |
| 207 | + std::vector<float> flat_weights; |
| 208 | + std::vector<int64_t> valid_counts; |
| 209 | + |
| 210 | + for (int32_t s = 0; s < batch_size_; ++s) { |
| 211 | + const auto& scores = ppr_scores_[s][nt]; |
| 212 | + int32_t k = std::min(max_ppr_nodes, static_cast<int32_t>(scores.size())); |
| 213 | + if (k > 0) { |
| 214 | + std::vector<std::pair<int32_t, double>> items(scores.begin(), scores.end()); |
| 215 | + std::partial_sort( |
| 216 | + items.begin(), items.begin() + k, items.end(), |
| 217 | + [](const auto& a, const auto& b) { return a.second > b.second; }); |
| 218 | + |
| 219 | + for (int32_t i = 0; i < k; ++i) { |
| 220 | + flat_ids.push_back(static_cast<int64_t>(items[i].first)); |
| 221 | + // Cast to float32 for output; internal scores stay double to |
| 222 | + // avoid accumulated rounding errors in the push loop. |
| 223 | + flat_weights.push_back(static_cast<float>(items[i].second)); |
| 224 | + } |
| 225 | + } |
| 226 | + valid_counts.push_back(static_cast<int64_t>(k)); |
| 227 | + } |
| 228 | + |
| 229 | + result[nt] = {torch::tensor(flat_ids, torch::kLong), |
| 230 | + torch::tensor(flat_weights, torch::kFloat), |
| 231 | + torch::tensor(valid_counts, torch::kLong)}; |
| 232 | + } |
| 233 | + return result; |
| 234 | +} |
| 235 | + |
| 236 | +int32_t PPRForwardPushState::get_total_degree(int32_t node_id, int32_t ntype_id) const { |
| 237 | + if (ntype_id >= static_cast<int32_t>(degree_tensors_.size())) |
| 238 | + return 0; |
| 239 | + const auto& t = degree_tensors_[ntype_id]; |
| 240 | + if (t.numel() == 0) |
| 241 | + return 0; |
| 242 | + TORCH_CHECK(node_id < static_cast<int32_t>(t.size(0)), "Node ID ", node_id, |
| 243 | + " out of range for degree tensor of ntype_id ", ntype_id, " (size=", t.size(0), |
| 244 | + "). This indicates corrupted graph data or a sampler bug."); |
| 245 | + // data_ptr<int32_t>() returns a raw C pointer to the tensor's int32 data buffer. |
| 246 | + return t.data_ptr<int32_t>()[node_id]; |
| 247 | +} |
0 commit comments