Skip to content

Commit 08903d4

Browse files
committed
Add LD pruning support for biallelic SNP calls
1 parent 772c135 commit 08903d4

4 files changed

Lines changed: 306 additions & 0 deletions

File tree

malariagen_data/anoph/ld.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Optional
2+
3+
import allel # type: ignore
4+
import xarray as xr
5+
from numpydoc_decorator import doc # type: ignore
6+
7+
from ..util import _check_types, _dask_compress_dataset
8+
from . import base_params, ld_params, pca_params
9+
from .snp_data import AnophelesSnpData
10+
11+
12+
class AnophelesLdAnalysis(
13+
AnophelesSnpData,
14+
):
15+
def __init__(
16+
self,
17+
**kwargs,
18+
):
19+
# N.B., this class is designed to work cooperatively, and
20+
# so it's important that any remaining parameters are passed
21+
# to the superclass constructor.
22+
super().__init__(**kwargs)
23+
24+
@_check_types
25+
@doc(
26+
summary="""
27+
Access biallelic SNP calls after LD pruning.
28+
""",
29+
extended_summary="""
30+
This function obtains biallelic SNP calls, then performs LD pruning
31+
using scikit-allel's `locate_unlinked` function. The resulting dataset
32+
can be used as input to ADMIXTURE workflows or exported to PLINK format.
33+
34+
Note that `n_snps` is required to control memory usage. Without
35+
pre-thinning, LD pruning could attempt to materialise millions of
36+
variants and run out of memory.
37+
""",
38+
returns="""
39+
A dataset of LD-pruned biallelic SNP calls with the same structure as
40+
the output of `biallelic_snp_calls`.
41+
""",
42+
)
43+
def biallelic_snp_calls_ld_pruned(
44+
self,
45+
region: base_params.regions,
46+
n_snps: base_params.n_snps,
47+
ld_window_size: ld_params.ld_window_size = ld_params.ld_window_size_default,
48+
ld_window_step: ld_params.ld_window_step = ld_params.ld_window_step_default,
49+
ld_threshold: ld_params.ld_threshold = ld_params.ld_threshold_default,
50+
thin_offset: base_params.thin_offset = 0,
51+
sample_sets: Optional[base_params.sample_sets] = None,
52+
sample_query: Optional[base_params.sample_query] = None,
53+
sample_query_options: Optional[base_params.sample_query_options] = None,
54+
sample_indices: Optional[base_params.sample_indices] = None,
55+
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
56+
min_minor_ac: Optional[
57+
base_params.min_minor_ac
58+
] = pca_params.min_minor_ac_default,
59+
max_missing_an: Optional[
60+
base_params.max_missing_an
61+
] = pca_params.max_missing_an_default,
62+
random_seed: base_params.random_seed = 42,
63+
inline_array: base_params.inline_array = base_params.inline_array_default,
64+
chunks: base_params.chunks = base_params.native_chunks,
65+
) -> xr.Dataset:
66+
# Check that either sample_query xor sample_indices are provided.
67+
base_params._validate_sample_selection_params(
68+
sample_query=sample_query, sample_indices=sample_indices
69+
)
70+
71+
# Obtain biallelic SNP calls with thinning applied first.
72+
ds_snps = self.biallelic_snp_calls(
73+
region=region,
74+
sample_sets=sample_sets,
75+
sample_query=sample_query,
76+
sample_query_options=sample_query_options,
77+
sample_indices=sample_indices,
78+
site_mask=site_mask,
79+
min_minor_ac=min_minor_ac,
80+
max_missing_an=max_missing_an,
81+
n_snps=n_snps,
82+
thin_offset=thin_offset,
83+
random_seed=random_seed,
84+
inline_array=inline_array,
85+
chunks=chunks,
86+
)
87+
88+
# Compute genotype reference counts.
89+
with self._dask_progress(desc="Computing genotype ref counts"):
90+
gt = ds_snps["call_genotype"].data
91+
gn = allel.GenotypeDaskArray(gt).to_n_ref(fill=-127).compute()
92+
93+
# Perform LD pruning.
94+
with self._spinner(desc="LD pruning"):
95+
loc_unlinked = allel.locate_unlinked(
96+
gn,
97+
size=ld_window_size,
98+
step=ld_window_step,
99+
threshold=ld_threshold,
100+
)
101+
102+
# Apply the pruning mask.
103+
ds_pruned = _dask_compress_dataset(
104+
ds_snps, indexer=loc_unlinked, dim="variants"
105+
)
106+
107+
return ds_pruned

malariagen_data/anoph/ld_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Parameters for LD pruning functions."""
2+
3+
from typing_extensions import Annotated, TypeAlias
4+
5+
ld_window_size: TypeAlias = Annotated[
6+
int,
7+
"Window size in number of SNPs for LD pruning.",
8+
]
9+
10+
ld_window_size_default: ld_window_size = 500
11+
12+
ld_window_step: TypeAlias = Annotated[
13+
int,
14+
"Step size in number of SNPs for LD pruning.",
15+
]
16+
17+
ld_window_step_default: ld_window_step = 200
18+
19+
ld_threshold: TypeAlias = Annotated[
20+
float,
21+
"r-squared threshold for LD pruning. SNP pairs with r-squared above "
22+
"this threshold will be considered linked.",
23+
]
24+
25+
ld_threshold_default: ld_threshold = 0.1

malariagen_data/anopheles.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .anoph.sample_metadata import AnophelesSampleMetadata
3737
from .anoph.snp_data import AnophelesSnpData
3838
from .anoph.to_plink import PlinkConverter
39+
from .anoph.ld import AnophelesLdAnalysis
3940
from .anoph.g123 import AnophelesG123Analysis
4041
from .anoph.fst import AnophelesFstAnalysis
4142
from .anoph.h12 import AnophelesH12Analysis
@@ -88,6 +89,7 @@ class AnophelesDataResource(
8889
AnophelesDistanceAnalysis,
8990
AnophelesPca,
9091
PlinkConverter,
92+
AnophelesLdAnalysis,
9193
AnophelesIgv,
9294
AnophelesKaryotypeAnalysis,
9395
AnophelesAimData,

tests/anoph/test_ld.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import random
2+
3+
import pytest
4+
from pytest_cases import parametrize_with_cases
5+
6+
from malariagen_data import af1 as _af1
7+
from malariagen_data import ag3 as _ag3
8+
9+
from malariagen_data.anoph.ld import AnophelesLdAnalysis
10+
11+
12+
@pytest.fixture
13+
def ag3_sim_api(ag3_sim_fixture):
14+
return AnophelesLdAnalysis(
15+
url=ag3_sim_fixture.url,
16+
public_url=ag3_sim_fixture.url,
17+
config_path=_ag3.CONFIG_PATH,
18+
major_version_number=_ag3.MAJOR_VERSION_NUMBER,
19+
major_version_path=_ag3.MAJOR_VERSION_PATH,
20+
pre=True,
21+
aim_metadata_dtype={
22+
"aim_species_fraction_arab": "float64",
23+
"aim_species_fraction_colu": "float64",
24+
"aim_species_fraction_colu_no2l": "float64",
25+
"aim_species_gambcolu_arabiensis": object,
26+
"aim_species_gambiae_coluzzii": object,
27+
"aim_species": object,
28+
},
29+
gff_gene_type="gene",
30+
gff_gene_name_attribute="Name",
31+
gff_default_attributes=("ID", "Parent", "Name", "description"),
32+
default_site_mask="gamb_colu_arab",
33+
results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
34+
taxon_colors=_ag3.TAXON_COLORS,
35+
virtual_contigs=_ag3.VIRTUAL_CONTIGS,
36+
)
37+
38+
39+
@pytest.fixture
40+
def af1_sim_api(af1_sim_fixture):
41+
return AnophelesLdAnalysis(
42+
url=af1_sim_fixture.url,
43+
public_url=af1_sim_fixture.url,
44+
config_path=_af1.CONFIG_PATH,
45+
major_version_number=_af1.MAJOR_VERSION_NUMBER,
46+
major_version_path=_af1.MAJOR_VERSION_PATH,
47+
pre=False,
48+
gff_gene_type="protein_coding_gene",
49+
gff_gene_name_attribute="Note",
50+
gff_default_attributes=("ID", "Parent", "Note", "description"),
51+
default_site_mask="funestus",
52+
results_cache=af1_sim_fixture.results_cache_path.as_posix(),
53+
taxon_colors=_af1.TAXON_COLORS,
54+
)
55+
56+
57+
def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
58+
return ag3_sim_fixture, ag3_sim_api
59+
60+
61+
def case_af1_sim(af1_sim_fixture, af1_sim_api):
62+
return af1_sim_fixture, af1_sim_api
63+
64+
65+
@parametrize_with_cases("fixture,api", cases=".")
66+
def test_ld_pruning_returns_fewer_snps(fixture, api: AnophelesLdAnalysis):
67+
region = random.choice(api.contigs)
68+
site_mask = random.choice(api.site_mask_ids)
69+
ds_full = api.biallelic_snp_calls(
70+
region=region,
71+
site_mask=site_mask,
72+
min_minor_ac=1,
73+
max_missing_an=0,
74+
)
75+
n_available = ds_full.sizes["variants"]
76+
if n_available < 10:
77+
pytest.skip("Not enough variants for LD pruning test")
78+
79+
n_snps = min(n_available, 200)
80+
81+
ds_pruned = api.biallelic_snp_calls_ld_pruned(
82+
region=region,
83+
n_snps=n_snps,
84+
site_mask=site_mask,
85+
min_minor_ac=1,
86+
max_missing_an=0,
87+
)
88+
89+
# Pruned dataset should have fewer or equal variants.
90+
assert ds_pruned.sizes["variants"] <= n_snps
91+
assert ds_pruned.sizes["variants"] > 0
92+
93+
94+
@parametrize_with_cases("fixture,api", cases=".")
95+
def test_ld_pruned_dataset_structure(fixture, api: AnophelesLdAnalysis):
96+
region = random.choice(api.contigs)
97+
site_mask = random.choice(api.site_mask_ids)
98+
ds_full = api.biallelic_snp_calls(
99+
region=region,
100+
site_mask=site_mask,
101+
min_minor_ac=1,
102+
max_missing_an=0,
103+
)
104+
n_available = ds_full.sizes["variants"]
105+
if n_available < 10:
106+
pytest.skip("Not enough variants for LD pruning test")
107+
108+
n_snps = min(n_available, 200)
109+
110+
ds_pruned = api.biallelic_snp_calls_ld_pruned(
111+
region=region,
112+
n_snps=n_snps,
113+
site_mask=site_mask,
114+
min_minor_ac=1,
115+
max_missing_an=0,
116+
)
117+
118+
# Check expected coordinates.
119+
assert "sample_id" in ds_pruned.coords
120+
assert "variant_position" in ds_pruned.coords
121+
assert "variant_contig" in ds_pruned.coords
122+
123+
# Check expected data variables.
124+
assert "variant_allele" in ds_pruned.data_vars
125+
assert "call_genotype" in ds_pruned.data_vars
126+
127+
# Check dimensions.
128+
assert "variants" in ds_pruned.dims
129+
assert "samples" in ds_pruned.dims
130+
assert "ploidy" in ds_pruned.dims
131+
assert "alleles" in ds_pruned.dims
132+
133+
# Check alleles are biallelic.
134+
assert ds_pruned.sizes["alleles"] == 2
135+
136+
137+
@parametrize_with_cases("fixture,api", cases=".")
138+
def test_ld_pruned_plink_compatibility(fixture, api: AnophelesLdAnalysis):
139+
region = random.choice(api.contigs)
140+
site_mask = random.choice(api.site_mask_ids)
141+
ds_full = api.biallelic_snp_calls(
142+
region=region,
143+
site_mask=site_mask,
144+
min_minor_ac=1,
145+
max_missing_an=0,
146+
)
147+
n_available = ds_full.sizes["variants"]
148+
if n_available < 10:
149+
pytest.skip("Not enough variants for LD pruning test")
150+
151+
n_snps = min(n_available, 200)
152+
153+
ds_pruned = api.biallelic_snp_calls_ld_pruned(
154+
region=region,
155+
n_snps=n_snps,
156+
site_mask=site_mask,
157+
min_minor_ac=1,
158+
max_missing_an=0,
159+
)
160+
161+
# Verify the pruned dataset has all variables required by PlinkConverter.
162+
assert "call_genotype" in ds_pruned
163+
assert "variant_allele" in ds_pruned
164+
assert "variant_contig" in ds_pruned.coords
165+
assert "variant_position" in ds_pruned.coords
166+
assert "sample_id" in ds_pruned.coords
167+
168+
# Verify shapes are internally consistent.
169+
n_variants = ds_pruned.sizes["variants"]
170+
n_samples = ds_pruned.sizes["samples"]
171+
assert ds_pruned["call_genotype"].shape == (n_variants, n_samples, 2)
172+
assert ds_pruned["variant_allele"].shape == (n_variants, 2)

0 commit comments

Comments
 (0)