Skip to content

Commit f220d36

Browse files
committed
Implement basic .where method
1 parent 9aaeb2c commit f220d36

2 files changed

Lines changed: 102 additions & 119 deletions

File tree

mdio/dataset.h

Lines changed: 48 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -984,52 +984,21 @@ class Dataset {
984984
return absl::OkStatus();
985985
}
986986

987-
// nlohmann::json transform_to_json(const tensorstore::IndexTransform<>& transform) {
988-
// nlohmann::json j;
989-
// j["input_inclusive_min"] = std::vector<Index>(transform.input_origin().begin(), transform.input_origin().end());
990-
// j["input_shape"] = std::vector<Index>(transform.input_shape().begin(), transform.input_shape().end());
991-
992-
// // Serialize each output‐index map:
993-
// j["output"] = nlohmann::json::array();
994-
// for (DimensionIndex d = 0; d < transform.output_rank(); ++d) {
995-
// const auto& map = transform.output_index_map(d);
996-
// nlohmann::json mo;
997-
// switch (map.method()) {
998-
// case tensorstore::OutputIndexMethod::constant:
999-
// mo["method"] = "constant";
1000-
// mo["offset"] = map.offset();
1001-
// break;
1002-
1003-
// case tensorstore::OutputIndexMethod::single_input_dimension:
1004-
// mo["method"] = "single_input_dimension";
1005-
// mo["input_dimension"] = map.input_dimension(); // no *
1006-
// mo["stride"] = map.stride(); // no *
1007-
// mo["offset"] = map.offset();
1008-
// break;
1009-
1010-
// case tensorstore::OutputIndexMethod::array: {
1011-
// mo["method"] = "array";
1012-
// mo["offset"] = map.offset();
1013-
// mo["stride"] = map.stride();
1014-
// auto arr = map.index_array(); // ArrayView<const Index,1>
1015-
// std::size_t n = static_cast<std::size_t>(arr.shape()[0]);
1016-
// const auto* ptr = arr.data();
1017-
// std::vector<Index> vec(ptr, ptr + n);
1018-
// mo["index_array"] = std::move(vec);
1019-
// break;
1020-
// }
1021-
1022-
// default:
1023-
// // e.g. index_range, if you need it
1024-
// break;
1025-
// }
1026-
// j["output"].push_back(std::move(mo));
1027-
// }
1028-
// return j;
1029-
// }
987+
template <typename T>
988+
void _current_position_increment(std::vector<typename Variable<T>::Interval>& positionInterval, const std::vector<typename Variable<T>::Interval>& interval) {
989+
for (std::size_t d = positionInterval.size(); d-- > 0; ) {
990+
if (positionInterval[d].inclusive_min + 1 < interval[d].exclusive_max) {
991+
++positionInterval[d].inclusive_min;
992+
return;
993+
}
994+
positionInterval[d].inclusive_min = interval[d].inclusive_min;
995+
}
996+
997+
// Should be unreachable.
998+
}
1030999

10311000
template <typename T>
1032-
mdio::Result<Dataset> where(const mdio::ValueDescriptor<T>& coord_desc) const {
1001+
Result<Dataset> where(const ValueDescriptor<T>& coord_desc) {
10331002
// 1) Lookup the coordinate Variable<T>
10341003
auto varRes =
10351004
variables.get<T>(std::string(coord_desc.label.label()));
@@ -1038,82 +1007,44 @@ class Dataset {
10381007
}
10391008
auto var = varRes.value();
10401009

1041-
// auto store = var.get_store();
1042-
1043-
// // MDIO_ASSIGN_OR_RETURN(auto transform, tensorstore::GetIndexTransform(store));
1044-
// auto transform_res = ApplyIndexTransform(
1045-
// [](tensorstore::IndexTransform<> t) { return t; },
1046-
// store);
1047-
// if (!transform_res.status().ok()) {
1048-
// return transform_res.status();
1049-
// }
1050-
// auto transform = transform_res.value();
1051-
1052-
// // std::cout << "transform: " << transform << std::endl;
1053-
// auto json = transform_to_json(transform);
1054-
// std::cout << json.dump(4) << std::endl;
1055-
1056-
MDIO_ASSIGN_OR_RETURN(auto spec, var.get_spec());
1010+
MDIO_ASSIGN_OR_RETURN(auto interval, var.get_intervals());
1011+
std::vector<typename Variable<T>::Interval> currentPos; // Hacky, we will use this to track our current position in the dataset.
1012+
for (const auto& i : interval) {
1013+
currentPos.push_back(i);
1014+
}
10571015

1058-
std::cout << spec.dump(4) << std::endl;
1016+
// 2) Read its data
1017+
auto varFut = var.Read();
1018+
if (!varFut.status().ok()) {
1019+
return varFut.status();
1020+
}
1021+
auto varDat = varFut.value();
1022+
1023+
// **Use the flattened data pointer + offset for N‑D arrays:**
1024+
auto* data_ptr = varDat.get_data_accessor().data();
1025+
Index offset = varDat.get_flattened_offset();
1026+
Index nSamples = var.num_samples();
1027+
1028+
// 3) Collect all flat indices where coord == target value
1029+
std::vector<Index> indices;
1030+
std::vector<RangeDescriptor<Index>> elementwiseSlices;
1031+
for (Index idx = offset; idx < offset + nSamples; ++idx) {
1032+
if (data_ptr[idx] == coord_desc.value) {
1033+
indices.push_back(idx);
1034+
for (const auto& pos : currentPos) {
1035+
elementwiseSlices.emplace_back(RangeDescriptor<Index>({coord_desc.label.label(), pos.inclusive_min, pos.inclusive_min+1, 1}));
1036+
}
1037+
}
1038+
this->_current_position_increment<T>(currentPos, interval);
1039+
}
10591040

1060-
return absl::UnimplementedError("where() is not yet fully implemented.");
1041+
if (indices.empty()) {
1042+
return absl::NotFoundError(
1043+
"where(): no entries match coordinate '" +
1044+
std::string(coord_desc.label.label()) + "'");
1045+
}
10611046

1062-
// 2) Read its data
1063-
// auto varFut = var.Read();
1064-
// if (!varFut.status().ok()) {
1065-
// return varFut.status();
1066-
// }
1067-
// auto varDat = varFut.value();
1068-
1069-
// // **Use the flattened data pointer + offset for N‑D arrays:**
1070-
// auto* data_ptr = varDat.get_data_accessor().data();
1071-
// Index offset = varDat.get_flattened_offset();
1072-
// Index nSamples = var.num_samples();
1073-
1074-
// // 3) Collect all flat indices where coord == target value
1075-
// std::vector<Index> indices;
1076-
// for (Index idx = offset; idx < offset + nSamples; ++idx) {
1077-
// if (data_ptr[idx] == coord_desc.value) {
1078-
// indices.push_back(idx);
1079-
// }
1080-
// }
1081-
// if (indices.empty()) {
1082-
// return absl::NotFoundError(
1083-
// "where(): no entries match coordinate '" +
1084-
// std::string(coord_desc.label.label()) + "'");
1085-
// }
1086-
1087-
// // 4) Collapse into contiguous [start,stop) runs
1088-
// std::vector<mdio::RangeDescriptor<Index>> runs;
1089-
// runs.reserve(4);
1090-
// Index run_start = indices[0];
1091-
// Index prev = indices[0];
1092-
1093-
// for (size_t j = 1; j < indices.size(); ++j) {
1094-
// Index cur = indices[j];
1095-
// if (cur == prev + 1) {
1096-
// prev = cur;
1097-
// } else {
1098-
// runs.push_back({
1099-
// /* label = */ coord_desc.label,
1100-
// /* start = */ run_start,
1101-
// /* stop = */ prev + 1,
1102-
// /* step = */ 1
1103-
// });
1104-
// run_start = prev = cur;
1105-
// }
1106-
// }
1107-
// // final run
1108-
// runs.push_back({
1109-
// coord_desc.label,
1110-
// run_start,
1111-
// prev + 1,
1112-
// 1
1113-
// });
1114-
1115-
// // 5) Apply via your existing isel(...)
1116-
// return this->isel(runs);
1047+
return isel(elementwiseSlices);
11171048
}
11181049

11191050
/**

mdio/dataset_test.cc

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,11 @@ TEST(Dataset, where) {
656656
auto ds = dsRes.value();
657657

658658
// mdio::ListDescriptor<mdio::Index> sliceIndices = {"inline", {1,3,7}};
659-
mdio::RangeDescriptor<mdio::Index> sliceIndices = {"inline", 1, 7, 2};
660-
auto sliceRes = ds.isel(sliceIndices);
659+
// mdio::RangeDescriptor<mdio::Index> sliceIndices = {"inline", 1, 7, 2};
660+
mdio::RangeDescriptor<mdio::Index> one = {"inline", 1, 2, 1};
661+
mdio::RangeDescriptor<mdio::Index> two = {"inline", 3, 4, 1};
662+
mdio::RangeDescriptor<mdio::Index> three = {"inline", 7, 8, 1};
663+
auto sliceRes = ds.isel(one, two, three);
661664
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
662665

663666

@@ -667,14 +670,63 @@ TEST(Dataset, where) {
667670

668671
std::cout << "=================Full inline spec=================" << std::endl;
669672
auto sliceRes1 = ds.where(ilValue);
673+
std::cout << ds.variables.at("data").value().get_spec().value()["transform"].dump(4) << std::endl;
674+
std::cout << ds.variables.at("inline").value() << std::endl;
675+
std::cout << ds.variables.at("data").value() << std::endl;
670676
std::cout << "=================Picked inline spec=================" << std::endl;
671677
auto sliceRes2 = sliceRes.value().where(ilValue);
678+
std::cout << sliceRes.value().variables.at("data").value().get_spec().value()["transform"].dump(4) << std::endl;
679+
auto il = sliceRes.value().variables.get<mdio::dtypes::int32_t>("inline").value();
680+
std::cout << il << std::endl;
681+
std::cout << sliceRes.value().variables.at("data").value() << std::endl;
682+
auto ilr = il.Read().value();
683+
for (auto i=0; i < il.num_samples(); i++) {
684+
std::cout << "[" << i << "]: " << ilr.get_data_accessor().data()[i+ilr.get_flattened_offset()] << std::endl;
685+
}
672686
std::cout << "=================Picked inline spec=================" << std::endl;
673687
// ASSERT_TRUE(sliceRes.ok()) << sliceRes.status();
674688
ASSERT_FALSE(sliceRes1.status().ok());
675689
ASSERT_FALSE(sliceRes2.status().ok()) << sliceRes2.status();
676690
}
677691

692+
TEST(Dataset, where2) {
693+
std::string path = "zarrs/selTester.mdio";
694+
auto dsRes = makePopulated(path);
695+
ASSERT_TRUE(dsRes.ok()) << dsRes.status();
696+
auto ds = dsRes.value();
697+
mdio::ValueDescriptor<mdio::dtypes::int32_t> ilValue = {"inline", 1};
698+
699+
mdio::RangeDescriptor<mdio::Index> ilRange = {"inline", 0, 10, 2};
700+
auto sliceRes = ds.isel(ilRange);
701+
ASSERT_TRUE(sliceRes.ok()) << sliceRes.status();
702+
auto slicedDs = sliceRes.value();
703+
std::cout << slicedDs.variables.at("data").value() << std::endl;
704+
auto sliceRes2 = slicedDs.where(ilValue);
705+
ilRange.start = 2;
706+
ilRange.stop = 4;
707+
ilRange.step = 1;
708+
auto sliceRes3 = slicedDs.isel(ilRange);
709+
std::cout << sliceRes3.value().variables.at("data").value() << std::endl;
710+
auto sliceRes4 = sliceRes3.value().where(ilValue);
711+
// std::cout << "===================Sliced Data=================" << std::endl;
712+
// std::cout << "===================Whered Data=================" << std::endl;
713+
// std::cout << sliceRes2.value().variables.at("data").value() << std::endl;
714+
}
715+
716+
TEST(Dataset, where3) {
717+
std::string path = "zarrs/selTester.mdio";
718+
auto dsRes = makePopulated(path);
719+
ASSERT_TRUE(dsRes.ok()) << dsRes.status();
720+
auto ds = dsRes.value();
721+
722+
mdio::ValueDescriptor<mdio::dtypes::int32_t> ilValue = {"inline", 3};
723+
auto sliceRes = ds.where(ilValue);
724+
ASSERT_TRUE(sliceRes.ok()) << sliceRes.status();
725+
auto slicedDs = sliceRes.value();
726+
std::cout << slicedDs << std::endl;
727+
728+
}
729+
678730
TEST(Dataset, selValue) {
679731
std::string path = "zarrs/selTester.mdio";
680732
auto dsRes = makePopulated(path);

0 commit comments

Comments
 (0)