from __future__ import annotations
from collections.abc import Callable, Generator
from functools import partial
from io import TextIOWrapper
from pathlib import Path
from types import TracebackType
from typing import Any, TypeGuard, TypeVar, cast
import numpy as np
import pgenlib
import polars as pl
from hirola import HashTable
from loguru import logger
from more_itertools import mark_ends, windowed
from numpy.typing import ArrayLike, NDArray
from phantom import Phantom
from seqpro.rag import OFFSET_TYPE
from typing_extensions import Self, assert_never
from zstandard import ZstdDecompressor
from ._types import POS_MAX, POS_TYPE
from ._utils import ContigNormalizer, format_memory, hap_ilens, parse_memory
from ._var_ranges import var_counts, var_indices
from .exprs import ILEN, is_biallelic, symbolic_ilen
V_IDX_TYPE = np.uint32
"""Dtype for PGEN variant indices (uint32). This determines the maximum number of unique variants in a file."""
def _is_genos(obj: Any) -> TypeGuard[Genos]:
return (
isinstance(obj, np.ndarray)
and obj.dtype.type == np.int32
and obj.ndim == 3
and obj.shape[1] == 2
)
class Genos(NDArray[np.int32], Phantom, predicate=_is_genos):
_dtype = np.int32
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(np.empty((n_samples, ploidy, n_variants), dtype=cls._dtype))
def _is_dosages(obj: Any) -> TypeGuard[Dosages]:
return (
isinstance(obj, np.ndarray) and obj.dtype.type == np.float32 and obj.ndim == 2
)
class Dosages(NDArray[np.float32], Phantom, predicate=_is_dosages):
_dtype = np.float32
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(np.empty((n_samples, n_variants), dtype=cls._dtype))
def _is_phasing(obj: Any) -> TypeGuard[Phasing]:
return isinstance(obj, np.ndarray) and obj.dtype.type == np.bool_ and obj.ndim == 2
class Phasing(NDArray[np.bool_], Phantom, predicate=_is_phasing):
_dtype = np.bool_
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(np.empty((n_samples, n_variants), dtype=cls._dtype))
def _is_genos_phasing(obj) -> TypeGuard[GenosPhasing]:
return (
isinstance(obj, tuple)
and len(obj) == 2
and isinstance(obj[0], Genos)
and isinstance(obj[1], Phasing)
)
class GenosPhasing(tuple[Genos, Phasing], Phantom, predicate=_is_genos_phasing):
_dtypes = (np.int32, np.bool_)
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(
(
Genos.empty(n_samples, ploidy, n_variants),
Phasing.empty(n_samples, ploidy, n_variants),
)
)
def _is_genos_dosages(obj) -> TypeGuard[GenosDosages]:
return (
isinstance(obj, tuple)
and len(obj) == 2
and isinstance(obj[0], Genos)
and isinstance(obj[1], Dosages)
)
class GenosDosages(tuple[Genos, Dosages], Phantom, predicate=_is_genos_dosages):
_dtypes = (np.int32, np.float32)
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(
(
Genos.empty(n_samples, ploidy, n_variants),
Dosages.empty(n_samples, ploidy, n_variants),
)
)
def _is_genos_phasing_dosages(obj) -> TypeGuard[GenosPhasingDosages]:
return (
isinstance(obj, tuple)
and len(obj) == 3
and isinstance(obj[0], Genos)
and isinstance(obj[1], Phasing)
and isinstance(obj[2], Dosages)
)
class GenosPhasingDosages(
tuple[Genos, Phasing, Dosages], Phantom, predicate=_is_genos_phasing_dosages
):
_dtypes = (np.int32, np.bool_, np.float32)
@classmethod
def empty(cls, n_samples: int, ploidy: int, n_variants: int) -> Self:
return cls.parse(
(
Genos.empty(n_samples, ploidy, n_variants),
Phasing.empty(n_samples, ploidy, n_variants),
Dosages.empty(n_samples, ploidy, n_variants),
)
)
T = TypeVar("T", Genos, Dosages, GenosPhasing, GenosDosages, GenosPhasingDosages)
L = TypeVar("L", Genos, GenosPhasing, GenosDosages, GenosPhasingDosages)
[docs]
class PGEN:
"""Create a PGEN reader.
Parameters
----------
path
Path to the PGEN file. Only used for genotypes if a dosage path is provided as well.
filter
Polars expression to filter variants. Should return True for variants to keep. Will have at least the columns
`CHROM`, `POS` (1-based), `REF`, `ALT`, and `ILEN` available to use.
dosage_path
Path to a dosage PGEN file. If None, the genotype PGEN file will be used for both genotypes and dosages.
"""
available_samples: list[str]
"""List of available samples in the PGEN file."""
_filter: pl.Expr | None
"""Polars expression to filter variants. Should return True for variants to keep."""
ploidy = 2
"""Ploidy of the samples. The PGEN format currently only supports diploid (2)."""
contigs: list[str] | None = None
"""Naturally sorted list of contig names in the PGEN file."""
_index: pl.DataFrame | None = None
_geno_pgen: pgenlib.PgenReader
_dose_pgen: pgenlib.PgenReader
_s_idx: NDArray[np.uint32] | slice
_s_unsorter: NDArray[np.intp] | slice
"""pgenlib always returns samples in the order they exist in the file. This array re-orders them to the order specified by the user."""
_geno_path: Path
_dose_path: Path | None
_sei: StartsEndsIlens | None = None # unfiltered so that var_idxs map correctly
"""Variant 0-based starts, ends, ILEN, and ALT alleles if the PGEN with filters is bi-allelic."""
_s2i: HashTable
_c_norm: ContigNormalizer | None = None
_c_max_idxs: dict[str, int] | None = None
Genos = Genos
""":code:`(samples ploidy variants) int32`"""
Dosages = Dosages
""":code:`(samples variants) float32`
.. note::
PGEN does not support multi-allelic dosages. If you attempt to write one, you will get an
error from PLINK 2.0.
"""
GenosPhasing = GenosPhasing
""":code:`(samples ploidy variants) int32` and :code:`(samples variants) bool`"""
GenosDosages = GenosDosages
""":code:`(samples ploidy variants) int32` and :code:`(samples variants) float32`"""
GenosPhasingDosages = GenosPhasingDosages
""":code:`(samples ploidy variants) int32`, :code:`(samples variants) bool`, and :code:`(samples variants) float32`"""
def __init__(
self,
geno_path: str | Path,
filter: pl.Expr | None = None,
dosage_path: str | Path | None = None,
load_index: bool = True,
):
self._filter = filter
geno_path = Path(geno_path)
if geno_path.suffix != ".pgen":
geno_path = geno_path.with_suffix(".pgen")
self._geno_path = geno_path
if not self._geno_path.exists():
raise FileNotFoundError(f"PGEN file {self._geno_path} does not exist.")
samples = _read_psam(geno_path.with_suffix(".psam"))
self.available_samples = cast(list[str], samples.tolist())
self._s2i = HashTable(
max=len(samples) * 2, # type: ignore
dtype=samples.dtype,
)
self._s2i.add(samples)
self._s_idx = slice(None)
self._s_unsorter = slice(None)
self._geno_pgen = pgenlib.PgenReader(bytes(geno_path), len(samples))
if dosage_path is not None:
dosage_path = Path(dosage_path)
dose_samples = _read_psam(dosage_path.with_suffix(".psam"))
if (samples != dose_samples).any():
raise ValueError(
"Samples in dosage file do not match those in genotype file."
)
self._dose_pgen = pgenlib.PgenReader(bytes(Path(dosage_path)))
else:
self._dose_pgen = self._geno_pgen
self._dose_path = dosage_path
if load_index:
self._init_index()
def _init_index(self):
"""Initialize the index and all derived data structures if they have not been yet."""
if self._index is not None:
return
self._index, self._sei, self.contigs = _load_index(
self._index_path(), self._filter
)
self._c_norm = ContigNormalizer(self.contigs)
# what variant index does each contig start at
contig_var_offsets = (
self._index.group_by("CHROM", maintain_order=True)
.agg(offset=pl.col("index").max())["offset"]
.to_numpy()
)
self._c_max_idxs = {c: v for c, v in zip(self.contigs, contig_var_offsets)}
def _free_index(self):
"""Free large allocations from the index."""
self._index = None
self._sei = None
@property
def current_samples(self) -> list[str]:
"""List of samples that are currently being used, in order."""
if isinstance(self._s_unsorter, slice):
return self.available_samples
return cast(list[str], self._s2i.keys[self._s_idx].tolist())
@property
def n_samples(self) -> int:
"""Number of samples in the file."""
if isinstance(self._s_unsorter, slice):
return len(self.available_samples)
return len(self._s_unsorter)
@property
def nbytes(self) -> int:
"""Total in-memory footprint, in bytes, of resident (non-mmap'd) data
structures held by this reader. Sums the gvi variant index dataframe
and the StartsEndsIlens cache. Returns 0 after `_free_index()`.
"""
n = 0
if self._index is not None:
n += self._index.estimated_size()
if self._sei is not None:
n += (
self._sei.v_starts.nbytes
+ self._sei.v_ends.nbytes
+ self._sei.ilens.nbytes
+ self._sei.alt.estimated_size()
)
return n
@property
def filter(self) -> pl.Expr | None:
"""Polars expression to filter variants. Should return True for variants to keep."""
return self._filter
@filter.setter
def filter(self, filter: pl.Expr | None):
"""Set the Polars expression to filter variants. Should return True for variants to keep."""
self._index, self._sei, _ = _load_index(self._index_path(), filter)
self._filter = filter
def _index_path(self) -> Path:
"""Path to the index file."""
# check whether pvar or pvar.zst
index = self._geno_path.with_suffix(".pvar")
if not index.exists():
index = self._geno_path.with_suffix(".pvar.zst")
if not index.exists():
raise FileNotFoundError("No index file found.")
return index.with_suffix(f"{index.suffix}.gvi")
[docs]
def set_samples(self, samples: ArrayLike | None) -> Self:
"""Set the samples to use.
Parameters
----------
samples
List of sample names to use. If None, all samples will be used.
"""
if samples is not None:
samples = np.atleast_1d(samples)
if (
samples is None
or len(samples) == len(self.available_samples)
and (samples == np.asarray(self.available_samples)).all()
):
self._s_idx = slice(None)
self._s_unsorter = slice(None)
return self
s_idx = self._s2i.get(samples)
if (s_idx == -1).any():
missing = samples[s_idx == -1]
raise ValueError(f"Samples {missing} not found in the file.")
s_idx = s_idx.astype(np.uint32)
if len(np.unique(s_idx)) != len(s_idx):
raise ValueError("Samples must be unique.")
self._s_idx = s_idx
sorter = np.argsort(s_idx, kind="stable")
sorted_s_idx = s_idx[sorter]
self._s_unsorter = np.argsort(sorter, kind="stable")
# if dose path is None, then dose pgen is just a reference to geno pgen so
# we're also (somewhat unsafely) mutating the dose pgen here
self._geno_pgen.change_sample_subset(sorted_s_idx)
if self._dose_path is not None:
self._dose_pgen.change_sample_subset(sorted_s_idx)
return self
@property
def geno_path(self) -> Path:
"""Path to the genotype file."""
return self._geno_path
@property
def dosage_path(self) -> Path | None:
"""Path to the dosage file."""
return self._dose_path
@dosage_path.setter
def dosage_path(self, dosage_path: str | Path | None):
"""Set the path to the dosage file."""
if dosage_path is not None:
dosage_path = Path(dosage_path)
dose_samples = _read_psam(dosage_path.with_suffix(".psam"))
if (np.asarray(self.available_samples) != dose_samples).any():
raise ValueError(
"Samples in dosage file do not match those in genotype file."
)
self._dose_pgen = pgenlib.PgenReader(bytes(Path(dosage_path)))
else:
self._dose_pgen = self._geno_pgen
self._dose_path = dosage_path
def __del__(self):
if hasattr(self, "_geno_pgen"):
self._geno_pgen.close()
if hasattr(self, "_dose_pgen") and self._dose_pgen is not None:
self._dose_pgen.close()
[docs]
def n_vars_in_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> NDArray[np.uint32]:
"""Return the start and end indices of the variants in the given ranges.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
n_variants
Shape: :code:`(ranges)`. Number of variants in the given ranges.
"""
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
return var_counts(self._c_norm, self._index, contig, starts, ends)
[docs]
def var_idxs(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> tuple[NDArray[V_IDX_TYPE], NDArray[OFFSET_TYPE]]:
"""Get variant indices and the number of indices per range.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
Shape: (tot_variants). Variant indices for the given ranges.
Shape: (ranges+1). Offsets to get variant indices for each range.
"""
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
return var_indices(V_IDX_TYPE, self._c_norm, self._index, contig, starts, ends)
[docs]
def read(
self,
contig: str,
start: int | np.integer = 0,
end: int | np.integer = POS_MAX,
mode: type[T] = Genos,
) -> T:
"""Read genotypes and/or dosages for a range.
Parameters
----------
contig
Contig name.
start
0-based start position.
end
0-based, exclusive end position.
mode
Type of data to read. Can be :code:`Genos`, :code:`Dosages`, :code:`GenosPhasing`,
:code:`GenosDosages`, or :code:`GenosPhasingDosages`.
Returns
-------
Genotypes and/or dosages. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
"""
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
c = self._c_norm.norm(contig)
if c is None:
return mode.empty(self.n_samples, self.ploidy, 0)
var_idxs, _ = self.var_idxs(c, start, end)
n_variants = len(var_idxs)
if n_variants == 0:
return mode.empty(self.n_samples, self.ploidy, 0)
if issubclass(mode, Genos):
out = self._read_genos(var_idxs)
elif issubclass(mode, Dosages):
out = self._read_dosages(var_idxs)
elif issubclass(mode, GenosPhasing):
out = self._read_genos_phasing(var_idxs)
elif issubclass(mode, GenosDosages):
out = self._read_genos_dosages(var_idxs)
elif issubclass(mode, GenosPhasingDosages):
out = self._read_genos_phasing_dosages(var_idxs)
else:
assert_never(mode)
return cast(T, out)
[docs]
def chunk(
self,
contig: str,
start: int | np.integer = 0,
end: int | np.integer = POS_MAX,
max_mem: int | str = "4g",
mode: type[T] = Genos,
) -> Generator[T]:
"""Iterate over genotypes and/or dosages for a range in chunks limited by :code:`max_mem`.
Parameters
----------
contig
Contig name.
start
0-based start position.
end
0-based, exclusive end position.
max_mem
Maximum memory to use for each chunk. Can be an integer or a string with a suffix
(e.g. "4g", "2 MB").
mode
Type of data to read. Can be :code:`Genos`, :code:`Dosages`, :code:`GenosPhasing`,
:code:`GenosDosages`, or :code:`GenosPhasingDosages`.
Returns
-------
Generator of genotypes and/or dosages. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
"""
max_mem = parse_memory(max_mem)
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
yield mode.empty(self.n_samples, self.ploidy, 0)
return
var_idxs, _ = self.var_idxs(c, start, end)
n_variants = len(var_idxs)
if n_variants == 0:
yield mode.empty(self.n_samples, self.ploidy, 0)
return
mem_per_v = self._mem_per_variant(mode)
vars_per_chunk = min(max_mem // mem_per_v, n_variants)
if vars_per_chunk == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
f" Memory per variant: {format_memory(mem_per_v)}."
)
n_chunks = -(-n_variants // vars_per_chunk)
v_chunks = np.array_split(var_idxs, n_chunks)
for var_idx in v_chunks:
if issubclass(mode, Genos):
_out = self._read_genos(var_idx)
elif issubclass(mode, Dosages):
_out = self._read_dosages(var_idx)
elif issubclass(mode, GenosPhasing):
_out = self._read_genos_phasing(var_idx)
elif issubclass(mode, GenosDosages):
_out = self._read_genos_dosages(var_idx)
elif issubclass(mode, GenosPhasingDosages):
_out = self._read_genos_phasing_dosages(var_idx)
else:
assert_never(mode)
yield mode.parse(_out)
[docs]
def read_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
mode: type[T] = Genos,
) -> tuple[T, NDArray[OFFSET_TYPE]]:
"""Read genotypes and/or dosages for multiple ranges.
Parameters
----------
contig
Contig name.
starts
0-based start positions.
ends
0-based, exclusive end positions.
mode
Type of data to read. Can be :code:`Genos`, :code:`Dosages`, :code:`GenosPhasing`,
:code:`GenosDosages`, or :code:`GenosPhasingDosages`.
Returns
-------
Genotypes and/or dosages. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
Shape: (ranges+1). Offsets to slice out data for each range from the variants axis like so:
Examples
--------
.. code-block:: python
data, offsets = reader.read_ranges(...)
data[..., offsets[i] : offsets[i + 1]] # data for range i
Note that the number of variants for range :code:`i` is :code:`np.diff(offsets)[i]`.
"""
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
n_ranges = len(starts)
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
return mode.empty(self.n_samples, self.ploidy, 0), np.zeros(
n_ranges + 1, OFFSET_TYPE
)
var_idxs, offsets = self.var_idxs(c, starts, ends)
n_variants = len(var_idxs)
if n_variants == 0:
return mode.empty(self.n_samples, self.ploidy, 0), np.zeros(
n_ranges + 1, OFFSET_TYPE
)
if issubclass(mode, Genos):
out = self._read_genos(var_idxs)
elif issubclass(mode, Dosages):
out = self._read_dosages(var_idxs)
elif issubclass(mode, GenosPhasing):
out = self._read_genos_phasing(var_idxs)
elif issubclass(mode, GenosDosages):
out = self._read_genos_dosages(var_idxs)
elif issubclass(mode, GenosPhasingDosages):
out = self._read_genos_phasing_dosages(var_idxs)
else:
assert_never(mode)
return cast(T, out), offsets
[docs]
def chunk_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
max_mem: int | str = "4g",
mode: type[T] = Genos,
) -> Generator[Generator[T]]:
"""Read genotypes and/or dosages for multiple ranges in chunks limited by :code:`max_mem`.
Parameters
----------
contig
Contig name.
starts
0-based start positions.
ends
0-based, exclusive end positions.
max_mem
Maximum memory to use for each chunk. Can be an integer or a string with a suffix
(e.g. "4g", "2 MB").
mode
Type of data to read. Can be :code:`Genos`, :code:`Dosages`, :code:`GenosPhasing`,
:code:`GenosDosages`, or :code:`GenosPhasingDosages`.
Returns
-------
Generator of generators of genotypes and/or dosages of each ranges' data. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
Examples
--------
.. code-block:: python
gen = reader.read_ranges_chunks(...)
for range_ in gen:
if range_ is None:
continue
for chunk in range_:
# do something with chunk
pass
"""
max_mem = parse_memory(max_mem)
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
if self._c_norm is None or self._index is None:
self._init_index()
assert self._c_norm is not None and self._index is not None
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
for _ in range(len(starts)):
yield (mode.empty(self.n_samples, self.ploidy, 0) for _ in range(1))
return
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
var_idxs, offsets = self.var_idxs(c, starts, ends)
vars_per_range = np.diff(offsets)
tot_variants = len(var_idxs)
if tot_variants == 0:
for _ in range(len(starts)):
yield (mode.empty(self.n_samples, self.ploidy, 0) for _ in range(1))
return
mem_per_v = self._mem_per_variant(mode)
vars_per_chunk = np.minimum(max_mem // mem_per_v, vars_per_range)
if vars_per_chunk.min() == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
f" Memory per variant: {format_memory(mem_per_v)}."
)
chunks_per_range = -(-vars_per_range // vars_per_chunk)
for (o_s, o_e), n_chunks in zip(windowed(offsets, 2), chunks_per_range):
if o_s == o_e:
yield (mode.empty(self.n_samples, self.ploidy, 0) for _ in range(1))
continue
range_idxs = var_idxs[o_s:o_e]
v_chunks = np.array_split(range_idxs, n_chunks)
if issubclass(mode, Genos):
read = self._read_genos
elif issubclass(mode, Dosages):
read = self._read_dosages
elif issubclass(mode, GenosPhasing):
read = self._read_genos_phasing
elif issubclass(mode, GenosDosages):
read = self._read_genos_dosages
elif issubclass(mode, GenosPhasingDosages):
read = self._read_genos_phasing_dosages
else:
assert_never(mode)
yield (cast(T, read(var_idx)) for var_idx in v_chunks)
def _chunk_ranges_with_length(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
max_mem: int | str = "4g",
mode: type[L] = Genos,
) -> Generator[
Generator[
tuple[L, POS_TYPE, NDArray[V_IDX_TYPE]] # data, end, chunk_idxs
]
]:
"""Read genotypes and/or dosages for multiple ranges in chunks approximately limited by :code:`max_mem`.
Will extend the ranges so that the returned data corresponds to haplotypes that have at least as much
length as the original ranges.
.. note::
Even if the reader is set to only return dosages, this method must read in genotypes to compute
haplotype lengths so there is no performance difference between reading with/without genotypes.
Parameters
----------
contig
Contig name.
starts
0-based start positions.
ends
0-based, exclusive end positions.
max_mem
Maximum memory to use for each chunk. Can be an integer or a string with a suffix
(e.g. "4g", "2 MB").
mode
Type of data to read. Can be :code:`Genos`, :code:`Dosages`, :code:`GenosPhasing`,
:code:`GenosDosages`, or :code:`GenosPhasingDosages`.
Returns
-------
Generator of generators of genotypes and/or dosages of each ranges' data, plus an integer indicating
the 0-based end position of the final variant in the chunk. Genotypes have shape
:code:`(samples ploidy variants)` and dosages have shape :code:`(samples variants)`. Missing genotypes
have value -1 and missing dosages have value np.nan. If just using genotypes or dosages, will be a
single array, otherwise will be a tuple of arrays.
Examples
--------
.. code-block:: python
gen = reader.read_ranges_chunks(...)
for range_ in gen:
if range_ is None:
continue
for chunk in range_:
# do something with chunk
pass
"""
if self._sei is None:
raise ValueError(
"Cannot use chunk_ranges_with_length without variant start, end, and ilen info, which usually happens when multi-allelic"
" variants are present."
)
max_mem = parse_memory(max_mem)
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
if self._c_norm is None or self._index is None or self._c_max_idxs is None:
self._init_index()
assert (
self._c_norm is not None
and self._index is not None
and self._c_max_idxs is not None
)
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
for e in ends:
yield (
(
mode.empty(self.n_samples, self.ploidy, 0),
e,
np.empty(0, dtype=V_IDX_TYPE),
)
for _ in range(1)
)
# we have full length, no deletions in any of the ranges
return
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
var_idxs, offsets = self.var_idxs(c, starts, ends)
tot_variants = len(var_idxs)
if tot_variants == 0:
for e in ends:
yield (
(
mode.empty(self.n_samples, self.ploidy, 0),
e,
np.empty(0, dtype=V_IDX_TYPE),
)
for _ in range(1)
)
# we have full length, no deletions in any of the ranges
return
mem_per_v = self._mem_per_variant(mode)
vars_per_chunk = min(max_mem // mem_per_v, tot_variants)
if vars_per_chunk == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
f" Memory per variant: {format_memory(mem_per_v)}."
)
if issubclass(mode, Genos):
read = self._read_genos
elif issubclass(mode, GenosPhasing):
read = self._read_genos_phasing
elif issubclass(mode, GenosDosages):
read = self._read_genos_dosages
elif issubclass(mode, GenosPhasingDosages):
read = self._read_genos_phasing_dosages
else:
assert_never(mode)
read = cast(Callable[[NDArray[np.uint32]], L], read)
for i, (s, e) in enumerate(zip(starts, ends)):
o_s, o_e = offsets[i], offsets[i + 1]
range_idxs = var_idxs[o_s:o_e]
n_variants = len(range_idxs)
if n_variants == 0:
# we have full length, no deletions in any of the ranges
yield (
(
mode.empty(self.n_samples, self.ploidy, 0),
e,
np.empty(0, dtype=V_IDX_TYPE),
)
for _ in range(1)
)
continue
n_chunks = -(-n_variants // vars_per_chunk)
v_chunks = np.array_split(range_idxs, n_chunks)
yield _gen_with_length(
v_chunks=v_chunks,
q_start=s,
q_end=e,
read=read,
v_starts=self._sei.v_starts,
v_ends=self._sei.v_ends,
ilens=self._sei.ilens,
contig_max_idx=self._c_max_idxs[c],
)
def _mem_per_variant(self, mode: type[T]) -> int:
mem = 0
if issubclass(mode, Genos):
mem += self.n_samples * self.ploidy * mode._dtype().itemsize
elif issubclass(mode, Dosages):
mem += self.n_samples * mode._dtype().itemsize
elif issubclass(mode, GenosPhasing) or issubclass(mode, GenosDosages):
mem += self.n_samples * self.ploidy * mode._dtypes[0]().itemsize
mem += self.n_samples * mode._dtypes[1]().itemsize
elif issubclass(mode, GenosPhasingDosages):
mem += self.n_samples * self.ploidy * mode._dtypes[0]().itemsize
mem += self.n_samples * mode._dtypes[1]().itemsize
mem += self.n_samples * mode._dtypes[2]().itemsize
else:
assert_never(mode)
if isinstance(self._s_unsorter, np.ndarray):
mem *= 2 # have to make a copy to sort by samples
return mem
def _read_genos(self, var_idxs: NDArray[V_IDX_TYPE]) -> Genos:
out = np.empty(
(len(var_idxs), self.n_samples * self.ploidy), dtype=Genos._dtype
)
self._geno_pgen.read_alleles_list(var_idxs, out)
out = out.reshape(len(var_idxs), self.n_samples, self.ploidy).transpose(
1, 2, 0
)[self._s_unsorter]
out[out == -9] = -1
return cast(Genos, out)
def _read_dosages(self, var_idxs: NDArray[V_IDX_TYPE]) -> Dosages:
out = np.empty((len(var_idxs), self.n_samples), dtype=Dosages._dtype)
self._dose_pgen.read_dosages_list(var_idxs, out)
out = out.transpose(1, 0)[self._s_unsorter]
out[out == -9] = np.nan
return cast(Dosages, out)
def _read_genos_dosages(self, var_idxs: NDArray[V_IDX_TYPE]) -> GenosDosages:
genos = self._read_genos(var_idxs)
dosages = self._read_dosages(var_idxs)
return cast(GenosDosages, (genos, dosages))
def _read_genos_phasing(self, var_idxs: NDArray[V_IDX_TYPE]) -> GenosPhasing:
genos = np.empty(
(len(var_idxs), self.n_samples * self.ploidy), dtype=Genos._dtype
)
phasing = np.empty((len(var_idxs), self.n_samples), dtype=Phasing._dtype)
self._geno_pgen.read_alleles_and_phasepresent_list(var_idxs, genos, phasing)
genos = genos.reshape(len(var_idxs), self.n_samples, self.ploidy).transpose(
1, 2, 0
)[self._s_unsorter]
genos[genos == -9] = -1
phasing = phasing.transpose(1, 0)[self._s_unsorter]
return cast(GenosPhasing, (genos, phasing))
def _read_genos_phasing_dosages(
self, var_idxs: NDArray[V_IDX_TYPE]
) -> GenosPhasingDosages:
genos_phasing = self._read_genos_phasing(var_idxs)
dosages = self._read_dosages(var_idxs)
return cast(GenosPhasingDosages, (*genos_phasing, dosages))
def _gen_with_length(
v_chunks: list[NDArray[V_IDX_TYPE]],
q_start: int,
q_end: int,
read: Callable[[NDArray[V_IDX_TYPE]], L],
v_starts: NDArray[POS_TYPE], # full dataset v_starts
v_ends: NDArray[POS_TYPE], # full dataset v_ends
ilens: NDArray[np.int32], # full dataset ilens
contig_max_idx: int,
) -> Generator[tuple[L, POS_TYPE, NDArray[V_IDX_TYPE]]]:
# * This implementation computes haplotype lengths as shorter than they actually are if a spanning deletion is present
# * This will result in including more variants than needed, which is fine since we're extending var_idx by more than we
# * need to anyway.
#! Assume len(v_chunks) > 0 and all len(var_idx) > 0 is guaranteed by caller
length = q_end - q_start
_idx_extension = 20
for _, is_last, var_idx in mark_ends(v_chunks):
last_end = cast(POS_TYPE, v_ends[var_idx[-1]])
if not is_last:
yield read(var_idx), last_end, var_idx
continue
ext_s_idx: int = min(var_idx[-1] + 1, contig_max_idx)
# end idx is 0-based inclusive
ext_e_idx = min(ext_s_idx + _idx_extension - 1, contig_max_idx)
_idx_extension *= 2
if ext_s_idx == ext_e_idx:
# no extension needed
yield read(var_idx), last_end, var_idx
return
var_idx = np.concatenate(
[var_idx, np.arange(ext_s_idx, ext_e_idx + 1, dtype=V_IDX_TYPE)]
)
last_idx: V_IDX_TYPE = var_idx[-1]
last_end = cast(POS_TYPE, v_ends[var_idx[-1]])
# (s p v)
out = read(var_idx)
if ext_s_idx == ext_e_idx:
yield out, last_end, var_idx
return
initial_len = max(length, last_end - q_start) # type: ignore
if isinstance(out, Genos):
# (s p)
hap_lens = np.full(out.shape[:-1], initial_len, dtype=np.int32)
hap_lens += hap_ilens(out, ilens[var_idx])
else:
# (s p)
hap_lens = np.full(out[0].shape[:-1], initial_len, dtype=np.int32)
hap_lens += hap_ilens(out[0], ilens[var_idx])
ls_ext: list[L] = []
while (hap_lens < length).any():
ext_s_idx = min(last_idx + 1, contig_max_idx)
# end idx is 0-based inclusive
ext_e_idx = min(ext_s_idx + _idx_extension - 1, contig_max_idx)
_idx_extension *= 2
if ext_s_idx == ext_e_idx:
break
ext_idx = np.arange(ext_s_idx, ext_e_idx + 1, dtype=V_IDX_TYPE)
last_idx = ext_idx[-1]
ext_out = read(ext_idx)
ls_ext.append(ext_out)
if isinstance(ext_out, Genos):
ext_genos = ext_out
else:
ext_genos = ext_out[0]
dist = v_starts[ext_idx[-1]] - last_end
hap_lens += dist + hap_ilens(ext_genos, ilens[ext_idx])
last_end = cast(POS_TYPE, v_ends[ext_idx[-1]])
if len(ls_ext) == 0:
yield out, last_end, var_idx
return
if isinstance(out, Genos):
out = np.concatenate([out, *ls_ext], axis=-1)
else:
out = tuple(
np.concatenate([o, *ls], axis=-1) for o, ls in zip(out, zip(*ls_ext))
)
var_idx = np.arange(var_idx[0], last_idx + 1, dtype=V_IDX_TYPE)
yield (
out, # type: ignore
last_end,
var_idx,
)
def _read_psam(path: Path) -> NDArray[np.str_]:
with open(path.with_suffix(".psam")) as f:
cols = [c.strip("#") for c in f.readline().strip().split()]
psam = pl.read_csv(
path.with_suffix(".psam"),
separator="\t",
has_header=False,
skip_rows=1,
new_columns=cols,
schema_overrides={
"FID": pl.Utf8,
"IID": pl.Utf8,
"SID": pl.Utf8,
"PAT": pl.Utf8,
"MAT": pl.Utf8,
"SEX": pl.Utf8,
},
)
samples = psam["IID"].to_numpy().astype(str)
return samples
class StartsEndsIlens:
v_starts: NDArray[POS_TYPE]
"""0-based starts, sorted."""
v_ends: NDArray[POS_TYPE]
"""0-based exclusive ends, sorted by start."""
ilens: NDArray[np.int32]
"""Indel lengths, sorted by start."""
alt: pl.Series
"""Alternate alleles, sorted by start."""
def __init__(
self,
v_starts: NDArray[POS_TYPE],
v_ends: NDArray[POS_TYPE],
ilens: NDArray[np.int32],
alt: pl.Series,
):
self.v_starts = v_starts
self.v_ends = v_ends
self.ilens = ilens
self.alt = alt
def _valid_index(index_path: Path) -> bool:
"""Check if the index is valid. Needs to exist and have a modified time greater than
the PVAR file."""
if not index_path.exists():
return False
pvar_mtime = index_path.with_suffix("").stat().st_mtime_ns
index_mtime = index_path.stat().st_mtime_ns
return index_mtime > pvar_mtime
def _load_index(
index_path: Path, filter: pl.Expr | None
) -> tuple[pl.DataFrame, StartsEndsIlens | None, list[str]]:
if not _valid_index(index_path):
logger.info("Genoray PVAR index not found or out-of-date, creating index.")
_write_index(index_path)
logger.info("Loading genoray index.")
index = pl.scan_ipc(
index_path, row_index_name="index", memory_map=False
).with_columns(
pl.col("index").cast(pl.UInt32),
start=pl.col("POS") - 1,
end=pl.col("POS") + pl.col("REF").str.len_bytes() - 1,
)
schema = index.collect_schema()
if schema["ALT"] == pl.Utf8:
index = index.with_columns(pl.col("ALT").str.split(","))
if "ILEN" not in schema:
if "INFO" in schema.names():
# ILEN is intentionally recomputed from the persisted INFO string on
# each _load_index call. PGEN does NOT persist the computed ILEN in
# the .gvi file (unlike the VCF path, which writes ILEN at index-build
# time). Do NOT "optimise" this into persistence without also updating
# _write_index to store ILEN — otherwise old .gvi files would silently
# miss symbolic-SV ILEN corrections.
#
# Regex-extract SVLEN/END/IMPRECISE from the PVAR INFO string, then
# use the shared symbolic_ilen() helper so symbolic SVs get correct
# sign-adjusted lengths (DEL→-|len|, INS/DUP→+|len|, others→null).
info_col = pl.col("INFO").fill_null("")
index = index.with_columns(
info_col.str.extract(r"(?:^|;)SVLEN=(-?\d+)", 1)
.cast(pl.Int64)
.alias("SVLEN"),
info_col.str.extract(r"(?:^|;)END=(\d+)", 1)
.cast(pl.Int64)
.alias("END"),
# IMPRECISE is a VCF Flag (Number=0): match the bare token only.
# Do NOT broaden to IMPRECISE=… — that would wrongly treat
# IMPRECISE=0 as set.
info_col.str.contains(r"(?:^|;)IMPRECISE(?:;|$)").alias("IMPRECISE"),
)
index = index.with_columns(ILEN=symbolic_ilen())
index = index.drop("SVLEN", "END", "IMPRECISE")
else:
index = index.with_columns(ILEN=ILEN)
if filter is None:
has_multiallelics = index.select((~is_biallelic).any()).collect().item()
else:
has_multiallelics = (
index.filter(filter).select((~is_biallelic).any()).collect().item()
)
if has_multiallelics:
sei = None
else:
# can keep the first alt for multiallelic sites since they're getting filtered out
# anyway, so they won't be accessed
# if the filter is changed, the index is invalidated and re-read (see filter setter)
if filter is None:
data = index.select(
"start",
"end",
pl.col("ILEN").list.first().fill_null(0).alias("ILEN"),
pl.col("ALT").list.first(),
)
else:
data = index.with_columns(
ILEN=pl.when(filter)
.then(pl.col("ILEN").list.first().fill_null(0))
.otherwise(pl.lit(0))
)
data = data.select("start", "end", "ILEN", pl.col("ALT").list.first())
data = data.collect()
v_starts = data["start"].to_numpy()
v_ends = data["end"].to_numpy()
ilens = data["ILEN"].to_numpy()
alt = data["ALT"]
sei = StartsEndsIlens(v_starts, v_ends, ilens, alt)
if filter is not None:
index = index.filter(filter)
index = index.select("index", "CHROM", "POS", "REF", "ALT", "ILEN").collect()
# PVAR contigs are not necessarily sorted, only guaranteed to be sorted within a contig
contigs = index["CHROM"].unique(maintain_order=True).to_list()
return index, sei, contigs
def _write_index(index_path: Path):
"""Write PVAR index."""
(
_scan_pvar(index_path.with_suffix(""))
.rename({"#CHROM": "CHROM"})
.sink_ipc(index_path)
)
def _scan_pvar(pvar: Path):
pvar_schema = {
"#CHROM": pl.Utf8,
"POS": pl.Int64,
"ID": pl.Utf8,
"REF": pl.Utf8,
"ALT": pl.Utf8,
"QUAL": pl.Float64,
"FILTER": pl.Utf8,
"INFO": pl.Utf8,
"CM": pl.Float64,
}
cols = None
is_pvar = False
if pvar.suffix == ".zst":
opener = ZstdFile
else:
opener = partial(open, mode="r")
with opener(pvar) as f:
for line in f:
if line.startswith("##"):
is_pvar = True
continue
if line.startswith("#"):
is_pvar = True
cols = [c for c in line.strip().split("\t")]
break
if not is_pvar:
return _scan_bim(pvar)
if cols is None:
raise ValueError(f"No non-comment lines in PVAR file: {pvar}")
if "FORMAT" in cols:
raise RuntimeError("PVAR does not support the FORMAT column.")
return pl.scan_csv(
pvar,
separator="\t",
comment_prefix="##",
schema={c: pvar_schema[c] for c in cols},
null_values=".",
)
class ZstdFile(TextIOWrapper):
def __init__(self, path: Path):
self.path = path
self.reader = ZstdDecompressor().stream_reader(open(path, "rb"))
super().__init__(self.reader, newline="\n", encoding="utf-8")
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.reader.close()
return super().__exit__(exc_type, exc_val, exc_tb)
def _scan_bim(bim: Path):
with open(bim, "r") as f:
n_cols = len(f.readline().strip().split("\t"))
schema = {
"#CHROM": pl.Categorical,
"ID": pl.Utf8,
"CM": pl.Float64,
"POS": pl.Int32,
"ALT": pl.Utf8,
"REF": pl.Utf8,
}
if n_cols == 5:
del schema["CM"]
return pl.scan_csv(
bim,
separator="\t",
has_header=False,
schema=schema,
null_values=".",
).filter(pl.col("POS") > 0)