Skip to content
229 changes: 171 additions & 58 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,85 +63,198 @@ def jackknife_cls(
returns:
cls (dict): Dictionary of data Cls
"""
if nd < 0 or nd > 2:
raise ValueError("number of deletions must be 0, 1, or 2")
"""Alms calculated and save then cls calculated from saved alms."""

if progress is None:
progress = NoProgress()

# calc save alms if don't exist
compute_jk_alms(
data_maps,
vis_maps,
jk_map,
fields,
dir=dir,
progress=progress,
)

# calculate cls from saved alms
return compute_jk_cls_from_alms(
jk_map,
fields,
mask_correction=mask_correction,
unmixed=unmixed,
nd=nd,
dir=dir,
progress=progress,
)

def compute_jk_alms(
data_maps,
vis_maps,
jk_map,
fields,
dir="./dices",
progress=None,
):
"""Compute and save ALMs each JK region."""

if progress is None:
progress = NoProgress()

cls = {}
njk = len(np.unique(jk_map)[np.unique(jk_map) != 0])
os.makedirs(dir, exist_ok=True)

all_regions = list(combinations(range(1, njk + 1), nd))
total = (njk + 1) + len(all_regions)
njk = len(np.unique(jk_map)[np.unique(jk_map) != 0])

total = njk + 1
current = 0
progress.update(current, total)

# Compute ALMs
for k in range(0, njk + 1):
data_path = os.path.join(dir, f"data_alms_{k}.fits")
vis_path = os.path.join(dir, f"vis_alms_{k}.fits")
with progress.task(f"ALMs {k}"):
if not (os.path.exists(data_path) and os.path.exists(vis_path)):
if k == 0:
data_alms_k = transform(fields, data_maps)
vis_alms_k = transform(fields, vis_maps)
else:
data_alms_k = transform(
fields, _get_region_maps(data_maps, jk_map, k)
)
vis_alms_k = transform(
fields, _get_region_maps(vis_maps, jk_map, k)
)
write_alms(data_path, data_alms_k, clobber=True)
write_alms(vis_path, vis_alms_k, clobber=True)
_compute_single_jk_alm(
k,
data_maps,
vis_maps,
jk_map,
fields,
dir,
)

current += 1
progress.update(current, total)

data_alms_full = read_alms(os.path.join(dir, "data_alms_0.fits"))
vis_alms_full = read_alms(os.path.join(dir, "vis_alms_0.fits"))
mls0 = angular_power_spectra(vis_alms_full)

# Compute Cls
def _compute_single_jk_alm(
k,
data_maps,
vis_maps,
jk_map,
fields,
dir="./dices",
):
data_path = os.path.join(dir, f"data_alms_{k}.fits")
vis_path = os.path.join(dir, f"vis_alms_{k}.fits")

if os.path.exists(data_path) and os.path.exists(vis_path):
return k, False # nothing done

if k == 0:
data_alms_k = transform(fields, data_maps)
vis_alms_k = transform(fields, vis_maps)
else:
data_alms_k = transform(
fields, _get_region_maps(data_maps, jk_map, k)
)
vis_alms_k = transform(
fields, _get_region_maps(vis_maps, jk_map, k)
)

write_alms(data_path, data_alms_k, clobber=True)
write_alms(vis_path, vis_alms_k, clobber=True)

return k, True # processed


def compute_jk_cls_from_alms(
jk_map,
fields,
mask_correction="Fast",
unmixed=False,
nd=1,
dir="./dices",
progress=None,
):
if nd == 0:
data_alms_full = read_alms(os.path.join(dir, "data_alms_0.fits"))
cls0 = angular_power_spectra(data_alms_full)
return {(): cls0}

if nd < 1 or nd > 2:
raise ValueError("number of deletions must be 1 or 2")

if progress is None:
progress = NoProgress()

cls = {}

njk = len(np.unique(jk_map)[np.unique(jk_map) != 0])
all_regions = list(combinations(range(1, njk + 1), nd))

total = len(all_regions)
current = 0
progress.update(current, total)

for regions in all_regions:
regions_tag = "_".join(map(str, regions))
cls_path = os.path.join(dir, f"cls_{regions_tag}_unmixed_{unmixed}.fits")
with progress.task(f"Cls {regions}"):
if os.path.exists(cls_path):
cls[regions] = read(cls_path)
else:
alms_jk = _subtract_alms(
data_alms_full,
_accumulate_alms(
os.path.join(dir, f"data_alms_{r}.fits") for r in regions
),
)
_cls = angular_power_spectra(alms_jk)
_cls = correct_bias(_cls, jk_map, fields, *regions)
if mask_correction == "Full":
vis_alms_jk = _subtract_alms(
vis_alms_full,
_accumulate_alms(
os.path.join(dir, f"vis_alms_{r}.fits") for r in regions
),
)
_cls_mm = angular_power_spectra(vis_alms_jk)
_cls = correct_footprint_naturalspice(
_cls, _cls_mm, mls0, fields, unmixed=unmixed
)
elif mask_correction == "Fast":
_cls = correct_footprint_fsky(
_cls, jk_map, *regions, unmixed=unmixed
)
else:
raise ValueError("mask_correction must be 'Fast' or 'Full'")
write(cls_path, _cls, clobber=True)
cls[regions] = _cls

cls[regions] = _compute_single_jk_cls(
regions,
jk_map,
fields,
mask_correction,
unmixed,
dir,
)

current += 1
progress.update(current, total)

return cls

def _compute_single_jk_cls(
regions,
jk_map,
fields,
mask_correction="Fast",
unmixed=False,
dir="./dices",
):
"""Compute Cls for a single jackknife region combination."""

regions_tag = "_".join(map(str, regions))
cls_path = os.path.join(dir, f"cls_{regions_tag}_unmixed_{unmixed}.fits")

if os.path.exists(cls_path):
return read(cls_path)

data_alms_full = read_alms(os.path.join(dir, "data_alms_0.fits"))
vis_alms_full = read_alms(os.path.join(dir, "vis_alms_0.fits"))
mls0 = angular_power_spectra(vis_alms_full)

alms_jk = _subtract_alms(
data_alms_full,
_accumulate_alms(
os.path.join(dir, f"data_alms_{r}.fits") for r in regions
),
)

_cls = angular_power_spectra(alms_jk)
_cls = correct_bias(_cls, jk_map, fields, *regions)

if mask_correction == "Full":
vis_alms_jk = _subtract_alms(
vis_alms_full,
_accumulate_alms(
os.path.join(dir, f"vis_alms_{r}.fits") for r in regions
),
)
_cls_mm = angular_power_spectra(vis_alms_jk)
_cls = correct_footprint_naturalspice(
_cls, _cls_mm, mls0, fields, unmixed=unmixed
)

elif mask_correction == "Fast":
_cls = correct_footprint_fsky(
_cls, jk_map, *regions, unmixed=unmixed
)

else:
raise ValueError("mask_correction must be 'Fast' or 'Full'")

write(cls_path, _cls, clobber=True)

return _cls

def _get_region_maps(maps, jk_map, jk):
"""
Expand Down
Loading