from __future__ import annotations
import re
import warnings
from collections.abc import Callable, Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Literal, TypeGuard, TypeVar, cast, overload
import cyvcf2
from hirola import HashTable
import numpy as np
import oxbow
import polars as pl
import pyranges as pr
from loguru import logger
from more_itertools import mark_ends
from natsort import natsorted
from numpy.typing import ArrayLike, NDArray
from phantom import Phantom
from seqpro.rag import OFFSET_TYPE
from tqdm.auto import tqdm
from typing_extensions import Self, assert_never
from ._types import POS_MAX, POS_TYPE
from ._utils import ContigNormalizer, format_memory, hap_ilens, parse_memory
from ._var_ranges import var_counts, var_indices
from .exprs import ILEN, symbolic_ilen
"""Dtype for VCF range indices. This determines the maximum size of a contig in genoray.
We have to use int64 because this is what htslib uses for CSI indexes."""
V_IDX_TYPE = np.uint32
"""Dtype for VCF variant indices (uint32). This determines the maximum number of unique variants in a file."""
class DosageFieldError(RuntimeError): ...
class MultiallelicDosageError(RuntimeError): ...
GDTYPE = TypeVar("GDTYPE", np.int8, np.int16)
def _is_genos8(obj: Any) -> TypeGuard[NDArray[np.int8]]:
return (
isinstance(obj, np.ndarray)
and obj.dtype.type == np.int8
and obj.ndim == 3
and obj.shape[1] in (2, 3) # diploid with/without phasing
)
class Genos8(NDArray[np.int8], Phantom, predicate=_is_genos8):
_gdtype = np.int8
@classmethod
def empty(
cls, n_samples: int, ploidy: int, n_variants: int, phasing: bool
) -> Genos8:
return cls.parse(np.empty((n_samples, ploidy + phasing, n_variants), np.int8))
def _is_genos16(obj: Any) -> TypeGuard[NDArray[np.int16]]:
return (
isinstance(obj, np.ndarray)
and obj.dtype.type == np.int16
and obj.ndim == 3
and obj.shape[1] in (2, 3) # diploid with/without phasing
)
class Genos16(NDArray[np.int16], Phantom, predicate=_is_genos16):
_gdtype = np.int16
@classmethod
def empty(
cls, n_samples: int, ploidy: int, n_variants: int, phasing: bool
) -> Genos16:
return cls.parse(np.empty((n_samples, ploidy + phasing, n_variants), np.int16))
def _is_dosages(obj: Any) -> TypeGuard[NDArray[np.float32]]:
return (
isinstance(obj, np.ndarray) and obj.dtype.type == np.float32 and obj.ndim == 2
)
class Dosages(NDArray[np.float32], Phantom, predicate=_is_dosages):
@classmethod
def empty(
cls, n_samples: int, ploidy: int, n_variants: int, phasing: bool
) -> Dosages:
return cls.parse(np.empty((n_samples, n_variants), np.float32))
def _is_genos8_dosages(obj: Any) -> TypeGuard[tuple[Genos8, Dosages]]:
"""Check if the object is a tuple of genotypes and dosages.
Parameters
----------
obj
Object to check.
Returns
-------
bool
True if the object is a tuple of genotypes and dosages, False otherwise.
"""
return (
isinstance(obj, tuple)
and len(obj) == 2
and isinstance(obj[0], Genos8)
and isinstance(obj[1], Dosages)
)
class Genos8Dosages(tuple[Genos8, Dosages], Phantom, predicate=_is_genos8_dosages):
_gdtype = np.int8
@classmethod
def empty(
cls, n_samples: int, ploidy: int, n_variants: int, phasing: bool
) -> Genos8Dosages:
return cls.parse(
(
Genos8.empty(n_samples, ploidy, n_variants, phasing),
Dosages.empty(n_samples, ploidy, n_variants, phasing),
)
)
def _is_genos16_dosages(obj) -> TypeGuard[tuple[Genos8, Dosages]]:
"""Check if the object is a tuple of genotypes and dosages.
Parameters
----------
obj
Object to check.
Returns
-------
bool
True if the object is a tuple of genotypes and dosages, False otherwise.
"""
return (
isinstance(obj, tuple)
and len(obj) == 2
and isinstance(obj[0], Genos16)
and isinstance(obj[1], Dosages)
)
class Genos16Dosages(tuple[Genos16, Dosages], Phantom, predicate=_is_genos16_dosages):
_gdtype = np.int16
@classmethod
def empty(
cls, n_samples: int, ploidy: int, n_variants: int, phasing: bool
) -> Genos16Dosages:
return cls.parse(
(
Genos16.empty(n_samples, ploidy, n_variants, phasing),
Dosages.empty(n_samples, ploidy, n_variants, phasing),
)
)
T = TypeVar("T", Genos8, Genos16, Dosages, Genos8Dosages, Genos16Dosages)
L = TypeVar("L", Genos8, Genos16, Genos8Dosages, Genos16Dosages)
G = TypeVar("G", Genos8, Genos16)
GD = TypeVar("GD", Genos8Dosages, Genos16Dosages)
class _Index:
gr: pr.PyRanges
"""PyRanges for range queries, just has Chromosome, Start, End, and index columns."""
df: pl.DataFrame
"""All the other columns in the index that aren't #CHROM, start, end, or index. Facilitates
index -> attribute lookups."""
def __init__(self, gr: pr.PyRanges, df: pl.DataFrame):
self.gr = gr
self.df = df
[docs]
class VCF:
"""Create a VCF reader.
Parameters
----------
path
Path to the VCF file.
filter
Function to filter variants. Should return True for variants to keep.
.. note::
To avoid KeyErrors, this function needs to be tolerant to missing fields. For example, if you
access an INFO or FORMAT field, not all variants are guaranteed to have the same fields.
The `cyvcf2.Variant <https://brentp.github.io/cyvcf2/docstrings.html#cyvcf2.cyvcf2.Variant>`_
API provides the :meth:`.get <dict.get>` method on the INFO and FORMAT attributes. For example,
:code:`lambda v: v.INFO.get("AF", 0) > 0.05` will skip any variants with an AF <= 0.05 or a
missing AF by treating missing AFs as 0.
pl_filter
Polars expression to filter variants. Should return True for variants to keep. Must match the filter function.
.. note::
This expression will be applied to the polars DataFrame returned by :meth:`get_record_info`.
It is not applied to the VCF file itself, so it will not be able to use the cyvcf2.Variant API.
For example, if you want to filter variants by INFO field, you can use:
:code:`pl.col("AF") > 0.05`
but you can not use:
:code:`lambda v: v.INFO.get("AF", 0) > 0.05`
because the expression will be applied to the polars DataFrame, not the VCF file.
read_as
Type of data to read from the VCF file. Can be VCF.Genos, VCF.Dosages, or VCF.GenosDosages.
phasing
Whether to include phasing information on genotypes. If True, the ploidy axis will be length 3 such that
phasing is indicated by the 3rd value: 0 = unphased, 1 = phased. If False, the ploidy axis will be length 2.
dosage_field
Name of the dosage field to read from the VCF file. Required if read_as is VCF.Dosages, VCF.Genos8Dosages,
or VCF.Genos16Dosages.
progress
Whether to show a progress bar while reading the VCF file.
"""
path: Path
"""Path to the VCF file."""
available_samples: list[str]
"""List of available samples in the VCF file."""
contigs: list[str]
"""Naturally sorted list of available contigs in the VCF file."""
ploidy: int = 2
"""Ploidy of the VCF file. This is currently always 2 since we use cyvcf2."""
_filter: Callable[[cyvcf2.Variant], bool] | None
"""Function to filter variants. Should return True for variants to keep."""
_pl_filter: pl.Expr | None
"""Polars expression to filter variants. Should return True for variants to keep. Must match the filter function."""
phasing: bool
"""Whether to include phasing information on genotypes. If True, the ploidy axis will be length 3 such that
phasing is indicated by the 3rd value: 0 = unphased, 1 = phased. If False, the ploidy axis will be length 2."""
dosage_field: str | None
"""Name of the dosage field to read from the VCF file. Required if you want to use modes that include dosages."""
_pbar: tqdm | None
"""A progress bar to use while reading variants. This will be incremented per variant
during any calls to a read function."""
_s_sorter: NDArray[np.intp] | slice
_samples: list[str]
_c_norm: ContigNormalizer
_index: pl.DataFrame | None
_vcf: cyvcf2.VCF
Genos8 = Genos8
"""Mode for :code:`int8` genotypes :code:`(samples ploidy variants)`"""
Genos16 = Genos16
"""Mode for :code:`int16` genotypes :code:`(samples ploidy variants)`"""
Dosages = Dosages
"""Mode for dosages :code:`(samples variants) float32`"""
Genos8Dosages = Genos8Dosages
"""Mode for :code:`int8` genotypes :code:`(samples ploidy variants) int8` and dosages :code:`(samples variants) float32`"""
Genos16Dosages = Genos16Dosages
"""Mode for :code:`int16` genotypes :code:`(samples ploidy variants) int16` and dosages :code:`(samples variants) float32`"""
def __init__(
self,
path: str | Path,
filter: Callable[[cyvcf2.Variant], bool] | None = None,
pl_filter: pl.Expr | None = None,
phasing: bool = False,
dosage_field: str | None = None,
progress: bool = False,
with_gvi_index: bool = True,
):
self._check_filter_pair(filter, pl_filter)
self.path = Path(path)
if not self.path.exists():
raise FileNotFoundError(f"VCF file {self.path} does not exist.")
self._filter = filter
self._pl_filter = pl_filter
self.phasing = phasing
self.dosage_field = dosage_field
self.progress = progress
self._pbar = None
self._index = None
vcf = cyvcf2.VCF(path)
self.available_samples = vcf.samples
self.contigs = natsorted(vcf.seqnames)
self._c_norm = ContigNormalizer(vcf.seqnames)
avail = np.asarray(self.available_samples)
self._s2i = HashTable(max=len(avail) * 2, dtype=avail.dtype)
self._s2i.add(avail)
self.set_samples(None)
if with_gvi_index and self._valid_index() and self._filter is None:
self._load_index()
def _open(self) -> cyvcf2.VCF:
return cyvcf2.VCF(self.path, samples=self._samples, lazy=True)
@staticmethod
def _check_filter_pair(
filter: Callable[[cyvcf2.Variant], bool] | None,
pl_filter: pl.Expr | None,
) -> None:
"""Enforce the both-or-neither invariant on a (filter, pl_filter) pair."""
if (filter is not None and pl_filter is None) or (
filter is None and pl_filter is not None
):
raise ValueError(
"If a filter function is provided, a polars expression must also be provided, and vice versa."
)
@property
def filter(self) -> Callable[[cyvcf2.Variant], bool] | None:
"""Function to filter variants. Should return True for variants to keep."""
return self._filter
def _index_path(self) -> Path:
"""Path to the index file."""
base = Path(f"{self.path}.gvi")
if base.exists():
return base
else:
return base.with_suffix(".gvi.zst")
@filter.setter
def filter(
self,
value: tuple[Callable[[cyvcf2.Variant], bool] | None, pl.Expr | None] | None,
):
"""Set the record filter and its matching polars expression together.
Assign a ``(filter, pl_filter)`` pair, or ``None`` to clear both. The VCF
path requires both a cyvcf2 record callable (for the genotype scan) and a
matching polars expression (for the ``.gvi`` index); they must be set
together, mirroring the constructor's both-or-neither invariant. Changing
the filter invalidates the in-memory index.
"""
if value is None:
filter = pl_filter = None
elif isinstance(value, tuple) and len(value) == 2:
filter, pl_filter = value
else:
raise TypeError(
"VCF.filter must be assigned a (filter, pl_filter) tuple or None; "
f"got {type(value).__name__}."
)
self._check_filter_pair(filter, pl_filter)
self._index = None
self._filter = filter
self._pl_filter = pl_filter
@property
def nbytes(self) -> int:
"""Total in-memory footprint, in bytes, of resident (non-mmap'd) data
structures held by this reader. Currently this is the gvi variant
index (CHROM/POS/REF/ALT/ILEN). Returns 0 before the index is loaded.
"""
if self._index is None:
return 0
return self._index.estimated_size()
@property
def current_samples(self) -> list[str]:
"""List of samples currently being read from the VCF file."""
return self._samples
@property
def n_samples(self) -> int:
"""Number of samples currently selected."""
return len(self._samples)
[docs]
def set_samples(self, samples: ArrayLike | None) -> Self:
"""Set the samples to read from the VCF file. Modifies the VCF reader in place and returns it.
Parameters
----------
samples
List of sample names to read from the VCF file.
Returns
-------
The VCF reader with the specified samples.
"""
if samples is not None:
samples = cast(list[str], np.atleast_1d(samples).tolist())
if samples is None or samples == self.available_samples:
self._samples = self.available_samples
self._s_sorter = slice(None)
self._vcf = self._open()
return self
if missing := set(samples).difference(self.available_samples):
raise ValueError(
f"Samples {missing} not found in the VCF file. "
f"Available samples: {self.available_samples}"
)
self._samples = samples
avail_indices = self._s2i.get(np.asarray(samples))
vcf_order = np.argsort(avail_indices, kind="stable")
if np.all(vcf_order == np.arange(len(samples))):
self._s_sorter = slice(None)
else:
self._s_sorter = np.argsort(vcf_order, kind="stable")
self._vcf = self._open()
return self
[docs]
@contextmanager
def using_pbar(self, pbar: tqdm):
"""Create a context where the given progress bar will be incremented by any calls to a read method.
Parameters
----------
pbar
Progress bar to use while reading variants. This will be incremented per variant
during any calls to a read function.
"""
self._pbar = pbar
try:
yield self
finally:
self._pbar = None
[docs]
def n_vars_in_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> NDArray[np.uint32]:
"""Return the start and end indices of the variants in the given ranges.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
Shape: :code:`(ranges)`. Number of variants in the given ranges.
"""
if self._index is None:
return self._n_vars_no_index(contig, starts, ends)
else:
return self._n_vars_with_index(contig, starts, ends)
def _n_vars_no_index(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> NDArray[np.uint32]:
"""Return the start and end indices of the variants in the given ranges.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
Shape: :code:`(ranges)`. Number of variants in the given ranges.
"""
starts = np.atleast_1d(np.asarray(starts, POS_TYPE)).clip(min=0)
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
c = self._c_norm.norm(contig)
if c is None:
return np.zeros_like(starts, np.uint32)
out = np.empty_like(starts, np.uint32)
starts = starts + 1 # 1-based
for i, (s, e) in enumerate(zip(starts, ends)):
coord = f"{c}:{s}-{e}"
if self._filter is None:
out[i] = sum(1 for _ in self._vcf(coord))
else:
out[i] = sum(self._filter(v) for v in self._vcf(coord))
return out
def _n_vars_with_index(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> NDArray[np.uint32]:
"""Return the start and end indices of the variants in the given ranges.
Parameters
----------
index
Index to use for counting variants.
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
n_variants
Shape: :code:`(ranges)`. Number of variants in the given ranges.
"""
if self._index is None:
raise RuntimeError(
"Index not loaded. Call `_load_index()` before using this method."
)
return var_counts(self._c_norm, self._index, contig, starts, ends)
def _var_idxs(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> tuple[NDArray[np.integer], NDArray[OFFSET_TYPE]]:
"""Get variant indices and the number of indices per range.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
Shape: (tot_variants). Variant indices for the given ranges.
Shape: (ranges+1). Offsets to get variant indices for each range.
"""
if self._index is None:
raise RuntimeError(
"Index not loaded. Call `_load_index()` before using this method."
)
return var_indices(V_IDX_TYPE, self._c_norm, self._index, contig, starts, ends)
[docs]
def read(
self,
contig: str,
start: int | np.integer = 0,
end: int | np.integer = POS_MAX,
mode: type[T] = Genos16,
out: T | None = None,
) -> T:
"""Read genotypes and/or dosages for a range.
Parameters
----------
contig
Contig name.
start
0-based start position.
end
0-based, exclusive end position.
mode
Type of data to read.
out
Output array to fill with genotypes and/or dosages. If None, a new array will be created.
Returns
-------
Genotypes and/or dosages. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
"""
if (
issubclass(mode, (Dosages, Genos8Dosages, Genos16Dosages))
and self.dosage_field is None
):
raise ValueError(
"Dosage field not specified. Set the VCF reader's `dosage_field` parameter."
)
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
return mode.empty(self.n_samples, self.ploidy, 0, self.phasing)
start = max(0, start) # type: ignore
vcf = self._vcf(f"{c}:{int(start + 1)}-{end}") # range string is 1-based
if out is None:
if self._index is not None:
n_variants = self.n_vars_in_ranges(c, start, end)[0]
if n_variants == 0:
return mode.empty(self.n_samples, self.ploidy, 0, self.phasing)
else:
n_variants = None
if issubclass(mode, (Genos8, Genos16)):
if n_variants is None:
data = None
else:
data = mode.empty(
self.n_samples, self.ploidy, n_variants, self.phasing
)
data, _ = self._fill_genos(vcf, data, mode=mode)
elif issubclass(mode, Dosages):
assert self.dosage_field is not None
if n_variants is None:
data = None
else:
data = mode.empty(
self.n_samples, self.ploidy, n_variants, self.phasing
)
data, _ = self._fill_dosages(vcf, data, self.dosage_field)
elif issubclass(mode, (Genos8Dosages, Genos16Dosages)):
assert self.dosage_field is not None
if n_variants is None:
data = None
else:
data = mode.empty(
self.n_samples, self.ploidy, n_variants, self.phasing
)
data, _ = self._fill_genos_and_dosages(
vcf, data, self.dosage_field, mode=mode
)
else:
assert_never(mode)
out = cast(T, data)
else:
if isinstance(out, (Genos8, Genos16)):
self._fill_genos(vcf, out)
elif isinstance(out, Dosages):
assert self.dosage_field is not None
self._fill_dosages(vcf, out, self.dosage_field)
elif isinstance(out, (Genos8Dosages, Genos16Dosages)):
assert self.dosage_field is not None
self._fill_genos_and_dosages(vcf, out, self.dosage_field)
else:
assert_never(mode)
return out
[docs]
def chunk(
self,
contig: str,
start: int | np.integer = 0,
end: int | np.integer = POS_MAX,
max_mem: int | str = "4g",
mode: type[T] = Genos16,
) -> Generator[T]:
"""Iterate over genotypes and/or dosages for a range in chunks limited by :code:`max_mem`.
Parameters
----------
contig
Contig name.
start
0-based start position.
end
0-based, exclusive end position.
max_mem
Maximum memory to use for each chunk. Can be an integer or a string with a suffix
(e.g. "4g", "2 MB").
mode
Type of data to read.
Returns
-------
Generator of genotypes and/or dosages. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise
will be a tuple of arrays.
"""
if (
issubclass(mode, (Dosages, Genos8Dosages, Genos16Dosages))
and self.dosage_field is None
):
raise ValueError(
"Dosage field not specified. Set the VCF reader's `dosage_field` parameter."
)
max_mem = parse_memory(max_mem)
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
yield mode.empty(self.n_samples, self.ploidy, 0, self.phasing)
return
start = max(0, start) # type: ignore
mem_per_v = self._mem_per_variant(mode)
vars_per_chunk = max_mem // mem_per_v
if vars_per_chunk == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
+ f" Memory per variant: {format_memory(mem_per_v)}."
)
buffer = mode.empty(self.n_samples, self.ploidy, vars_per_chunk, self.phasing)
if isinstance(buffer, (Genos8, Genos16)):
gt_buffer = buffer
ds_buffer = None
elif isinstance(buffer, Dosages):
gt_buffer = None
ds_buffer = buffer
else:
gt_buffer, ds_buffer = buffer
vcf = self._vcf(f"{c}:{int(start + 1)}-{end}") # range string is 1-based
if self._filter is not None:
vcf = filter(self._filter, vcf)
if self.progress and self._pbar is None:
vcf = tqdm(vcf, desc="Reading VCF", unit=" variant")
i = 0
for v in vcf:
if gt_buffer is not None:
if self.phasing:
# (s p+1) np.int16
gt_buffer[..., i] = v.genotype.array()[self._s_sorter]
else:
gt_buffer[..., i] = v.genotype.array()[
self._s_sorter, : self.ploidy
]
if ds_buffer is not None:
d = v.format(self.dosage_field)
if d is None:
raise DosageFieldError(
f"Dosage field '{self.dosage_field}' not found for record {v!r}"
)
if d.shape[1] > 1:
raise MultiallelicDosageError(
f"Multiallelic dosages are not supported, encountered in VCF record {v!r}"
)
ds_buffer[..., i] = d.squeeze(1)[self._s_sorter]
i += 1
if self._pbar is not None:
self._pbar.update()
if i == vars_per_chunk:
yield buffer
i = 0
if i != 0:
buffer = []
if gt_buffer is not None:
gt_buffer = gt_buffer[..., :i]
buffer.append(gt_buffer)
if ds_buffer is not None:
ds_buffer = ds_buffer[..., :i]
buffer.append(ds_buffer)
buffer = tuple(buffer)
if len(buffer) == 1:
yield buffer[0]
else:
yield buffer # type: ignore
def _chunk_ranges_with_length(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
max_mem: int | str = "4g",
mode: type[L] = Genos16,
) -> Generator[
Generator[
tuple[L, int, int] # data, end, n_extension_vars
]
]:
"""Read genotypes and/or dosages in chunks approximately limited by :code:`max_mem`.
Will extend the range so that the returned data corresponds to haplotypes that have at least as much
length as the original range.
Parameters
----------
contig
Contig name.
start
0-based start positions.
end
0-based, exclusive end positions.
max_mem
Maximum memory to use for each chunk. Can be an integer or a string with a suffix
(e.g. "4g", "2 MB").
mode
Type of data to read.
Returns
-------
Generator of chunks of genotypes and/or dosages and the 0-based end position of the final variant
in the chunk. Genotypes have shape :code:`(samples ploidy variants)` and
dosages have shape :code:`(samples variants)`. Missing genotypes have value -1 and missing dosages
have value np.nan. If just using genotypes or dosages, will be a single array, otherwise will be a
tuple of arrays.
"""
if (
issubclass(mode, (Genos8Dosages, Genos16Dosages))
and self.dosage_field is None
):
raise ValueError(
"Dosage field not specified. Set the VCF reader's `dosage_field` parameter."
)
mode = cast(type[L], mode)
max_mem = parse_memory(max_mem)
starts = np.atleast_1d(np.asarray(starts, POS_TYPE)).clip(min=0)
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
c = self._c_norm.norm(contig)
if c is None:
logger.warning(
f"Query contig {contig} not found in VCF file, even after normalizing for UCSC/Ensembl nomenclature."
)
for e in ends:
yield (
(mode.empty(self.n_samples, self.ploidy, 0, self.phasing), e, 0)
for _ in range(1)
)
return
n_variants = self.n_vars_in_ranges(c, starts, ends)
tot_variants = n_variants.sum()
if tot_variants == 0:
for e in ends:
yield (
(mode.empty(self.n_samples, self.ploidy, 0, self.phasing), e, 0)
for _ in range(1)
)
return
mem_per_v = self._mem_per_variant(mode)
vars_per_chunk = min(max_mem // mem_per_v, tot_variants)
if vars_per_chunk == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
f" Memory per variant: {format_memory(mem_per_v)}."
)
starts = starts + 1 # cyvcf2 queries are 1-based
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
for s, e, n in zip(starts, ends, n_variants):
if n == 0:
yield (
(mode.empty(self.n_samples, self.ploidy, 0, self.phasing), e, 0)
for _ in range(1)
)
continue
yield self._chunk_with_length_helper(n, vars_per_chunk, c, s, e, mode)
def _chunk_with_length_helper(
self,
n: int,
vars_per_chunk: int,
contig: str,
start: POS_TYPE,
end: POS_TYPE,
mode: type[L],
) -> Generator[tuple[L, int, int]]:
if (
issubclass(mode, (Genos8Dosages, Genos16Dosages))
and self.dosage_field is None
):
raise ValueError(
"Dosage field not specified. Set the VCF reader's `dosage_field` parameter."
)
n_chunks, final_chunk = divmod(n, vars_per_chunk)
if final_chunk == 0:
# perfectly divisible so there is no final chunk
chunk_sizes = np.full(n_chunks, vars_per_chunk)
elif n_chunks == 0:
# n_vars < vars_per_chunk, so we just use the remainder
chunk_sizes = np.array([final_chunk])
else:
# have a final chunk that is smaller than the rest
chunk_sizes = np.full(n_chunks + 1, vars_per_chunk)
chunk_sizes[-1] = final_chunk
vcf = self._vcf(f"{contig}:{start}-{end}")
hap_lens = np.full((self.n_samples, self.ploidy), end - start, dtype=np.int32)
for _, is_last, chunk_size in mark_ends(chunk_sizes):
ilens = np.empty(chunk_size, dtype=np.int32)
if issubclass(mode, (Genos8, Genos16)):
out = cast(
Genos8 | Genos16,
mode.empty(self.n_samples, self.ploidy, chunk_size, self.phasing),
)
out, last_end = self._fill_genos(vcf, out, ilens)
hap_lens += hap_ilens(out[:, : self.ploidy], ilens)
elif issubclass(mode, (Genos8Dosages, Genos16Dosages)):
self.dosage_field = cast(str, self.dosage_field)
out = mode.empty(self.n_samples, self.ploidy, chunk_size, self.phasing)
out, last_end = self._fill_genos_and_dosages(
vcf, out, self.dosage_field, ilens
)
hap_lens += hap_ilens(out[0][:, : self.ploidy], ilens)
else:
assert_never(mode)
if not is_last:
yield cast(L, out), last_end, 0
continue
if issubclass(mode, (Genos8, Genos16)):
ls_ext, last_end = self._ext_genos_with_length(
contig, start, end, hap_lens, mode, last_end
)
elif issubclass(mode, (Genos8Dosages, Genos16Dosages)):
self.dosage_field = cast(str, self.dosage_field)
ls_ext, last_end = self._ext_genos_dosages_with_length(
contig,
start,
end,
hap_lens,
mode,
self.dosage_field,
last_end,
)
else:
assert_never(mode)
if len(ls_ext) > 0:
if issubclass(mode, (Genos8, Genos16)):
out = np.concatenate([out, *ls_ext], axis=-1)
else:
out = tuple(
np.concatenate([o, *ls], axis=-1)
for o, ls in zip(out, zip(*ls_ext))
)
yield (
cast(L, out),
last_end,
len(ls_ext),
)
@overload
def get_record_info(
self,
contig: str | None = None,
start: int | np.integer | None = None,
end: int | np.integer | None = None,
fields: list[str] | None = None,
info: list[str] | None = None,
lazy: Literal[False] = ...,
) -> pl.DataFrame: ...
@overload
def get_record_info(
self,
contig: str | None = None,
start: int | np.integer | None = None,
end: int | np.integer | None = None,
fields: list[str] | None = None,
info: list[str] | None = None,
*,
lazy: Literal[True],
) -> pl.LazyFrame: ...
@overload
def get_record_info(
self,
contig: str | None = None,
start: int | np.integer | None = None,
end: int | np.integer | None = None,
fields: list[str] | None = None,
info: list[str] | None = None,
lazy: bool = False,
) -> pl.DataFrame | pl.LazyFrame: ...
def _oxbow_reader(self) -> Callable:
"""Return the oxbow reader callable appropriate for this file's extension."""
if self.path.suffix == ".bcf":
return oxbow.from_bcf
elif re.search(r"\.vcf(\.gz)?$", self.path.name) is not None:
return oxbow.from_vcf
else:
raise ValueError(f"Unsupported file extension: {self.path.suffix}")
[docs]
def get_record_info(
self,
contig: str | None = None,
start: int | np.integer | None = None,
end: int | np.integer | None = None,
fields: list[str] | None = None,
info: list[str] | None = None,
lazy: bool = False,
) -> pl.DataFrame | pl.LazyFrame:
"""Get a DataFrame of any non-FORMAT fields in the VCF for a given range or the entire VCF.
Will filter variants if the VCF instance has a filter function.
Parameters
----------
contig
Contig name. If None, will read the entire VCF.
start
0-based start position.
end
0-based, exclusive end position.
fields
List of non-FORMAT, non-INFO fields to include. Returns all by default.
info
List of INFO fields to include. Returns all by default.
"""
if (start is not None or end is not None) and contig is None:
raise ValueError("start and end must be None if no contig is specified.")
if start is None:
start = 0
if end is None:
end = POS_MAX
if contig is not None:
region = f"{contig}:{start + 1}-{end}"
else:
region = None
if fields is not None:
fields = [f.lower() for f in fields]
if info is not None:
info = [f.lower() for f in info]
reader = self._oxbow_reader()
df = (
cast(
pl.LazyFrame,
reader(
self.path,
samples=[],
fields=fields,
info_fields=info,
regions=region,
).pl(lazy=True),
)
.rename(lambda c: c.upper())
.with_columns(pl.col("CHROM").cast(pl.Enum(self.contigs)))
)
if self._pl_filter is not None:
df = df.filter(self._pl_filter)
if not lazy:
df = df.collect()
return df
def _declared_info_fields(self, candidates: tuple[str, ...]) -> list[str]:
"""Return which of ``candidates`` are declared as INFO fields in the VCF header.
Uses ``header_iter()`` rather than ``get_header_type()`` because the latter
matches both INFO and FORMAT declarations; a FORMAT-only field must NOT be
treated as an INFO field (it would error when passed to oxbow's info_fields=).
"""
info_ids: set[str] = {
h.info()["ID"]
for h in self._vcf.header_iter()
if h.info().get("HeaderType") == "INFO"
}
return [c for c in candidates if c in info_ids]
def _fetch_info_cols(self, info_names: list[str]) -> pl.LazyFrame:
"""Fetch a set of uppercase INFO field names directly from oxbow and unnest the
returned struct, returning a LazyFrame with POS and those INFO columns as
top-level columns. POS is retained so the caller can cross-check alignment
against the base frame before positional concat.
Returns a LazyFrame with columns: POS, <info_names...>
"""
reader = self._oxbow_reader()
# oxbow requires uppercase INFO field names; returns them nested in an 'info' struct
raw = (
cast(
pl.LazyFrame,
reader(
self.path,
samples=[],
fields=["pos"],
info_fields=info_names,
).pl(lazy=True),
)
.with_columns(pl.col("pos").alias("POS"))
.drop("pos")
.unnest("info")
)
return raw
def _write_gvi_index(
self,
fields: list[str] | None = None,
info: list[str] | None = None,
overwrite: bool = True,
only_biallelic: bool = False,
) -> None:
"""Writes record information to disk, ignoring any filtering. At a minimum this index will
include columns `CHROM`, `POS` (1-based), `REF`, `ALT`, and `ILEN`.
Parameters
----------
fields
List of non-FORMAT, non-INFO fields to include. At a minimum this index will include
columns `CHROM`, `POS` (1-based), `REF`, `ALT`, and `ILEN`.
info
List of INFO fields to include.
overwrite
Whether to overwrite the index file if it exists.
only_biallelic
Whether to only use the first ALT alleles for each variant (i.e. assume all variants are biallelic). Better compression if True.
"""
if self._valid_index() and not overwrite:
raise FileExistsError(
f"A valid index file {self._index_path()} already exists. Use overwrite=True to overwrite."
)
_fields: set[str] = {"CHROM", "POS", "REF", "ALT"}
if fields is not None:
_fields.update(fields)
# Pull SVLEN/END/IMPRECISE when the header declares them so symbolic SVs
# can be sized. Requesting an undeclared INFO field can error in oxbow.
sv_info = self._declared_info_fields(("SVLEN", "END", "IMPRECISE"))
user_info_upper = {i.upper() for i in info} if info else set()
extra_sv = [f for f in sv_info if f.upper() not in user_info_upper]
filt = self._pl_filter
self._pl_filter = None
try:
index = self.get_record_info(fields=list(_fields), info=info, lazy=True)
finally:
self._pl_filter = filt
# Fetch SV helper columns directly (oxbow requires uppercase and returns a struct).
# Both oxbow reads cover the identical full record set (no region, no filter, same
# file order), which is what makes the positional horizontal concat correct.
# WARNING: region-scoping or pre-concat filtering either call would silently
# misalign SVLEN/END/IMPRECISE to wrong variants, corrupting ILEN.
# POS is cross-checked element-wise to confirm the two reads are in identical order.
if extra_sv:
index_df = index.collect()
sv_cols_df = self._fetch_info_cols(extra_sv).collect()
if index_df.height != sv_cols_df.height:
raise ValueError(
f"Row count mismatch between base index ({index_df.height}) and SV INFO "
f"columns ({sv_cols_df.height}); positional concat would misalign ILEN."
)
base_pos = index_df.get_column("POS")
sv_pos = sv_cols_df.get_column("POS")
if not base_pos.equals(sv_pos):
raise ValueError(
"POS mismatch between base index and SV INFO columns; "
"positional concat would misalign ILEN. This is a bug — please report it."
)
# Drop POS from sv_cols_df to avoid duplicate column before horizontal concat
sv_cols_df = sv_cols_df.drop("POS")
index = pl.concat([index_df, sv_cols_df], how="horizontal").lazy()
# Ensure the columns symbolic_ilen references exist (nulls when absent).
schema = index.collect_schema()
for col in ("SVLEN", "END", "IMPRECISE"):
if col not in schema.names():
dtype = pl.Boolean if col == "IMPRECISE" else pl.Int64
index = index.with_columns(pl.lit(None, dtype=dtype).alias(col))
# SVLEN from oxbow is List(Int32) (Number=A); coerce to scalar via list.first()
schema = index.collect_schema()
coerce: list[pl.Expr] = []
for col in ("SVLEN", "END"):
if col in schema.names() and isinstance(schema[col], pl.List):
coerce.append(pl.col(col).list.first().alias(col))
if coerce:
index = index.with_columns(coerce)
index = index.with_columns(ILEN=symbolic_ilen())
# Drop ALL of {SVLEN, END, IMPRECISE} that the user did not explicitly request
# via info=. This covers both fetched helper cols AND null-placeholder cols added
# above, so non-SV indexes don't gain stray all-null columns.
sv_cols_in_frame = {"SVLEN", "END", "IMPRECISE"} & set(
index.collect_schema().names()
)
drop_cols = [c for c in sv_cols_in_frame if c not in user_info_upper]
if drop_cols:
index = index.drop(drop_cols)
index.with_columns(pl.col("ALT").list.join(",")).collect().write_ipc(
self._index_path(), compression="zstd"
)
def _load_index(self) -> Self:
"""Load the index from disk, applying the filter expression if provided. You must
ensure that the filter expression is exactly equivalent to the vcf.filter function.
If a filter expression is not given and the VCF has a filter function, then one pass
over the VCF will be made to infer what records should be filtered.
Parameters
----------
filter
Filter expression to apply to the index. This should be a pl.Expr object that
is equivalent to the VCF filter function. If None, the filter function will be
used to filter the index.
"""
if not self._valid_index():
raise FileNotFoundError(
f"Index file {self._index_path()} does not exist or is out-of-date. "
"Please (re)create the index using `_write_gvi_index()`."
)
logger.info("Loading genoray index.")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
index = pl.scan_ipc(
self._index_path(), row_index_name="index"
).with_columns(pl.col("CHROM").cast(pl.Enum(self.contigs)))
# Normalize ALT (on-disk comma-Utf8) to list[str] BEFORE applying the
# filter so the in-memory schema documented in genoray.exprs holds and
# list-typed expressions (is_symbolic, is_biallelic) work on this path.
schema = index.collect_schema()
if schema["ALT"] == pl.Utf8:
index = index.with_columns(pl.col("ALT").str.split(","))
if self._pl_filter is not None:
index = index.filter(self._pl_filter)
if "ILEN" not in schema:
index = index.with_columns(ILEN=ILEN)
self._index = index.collect()
return self
def _valid_index(self) -> bool:
"""Check if the index is valid. Needs to exist and have a modified time greater than
the VCF file."""
if not self._index_path().exists():
return False
vcf_mtime = self.path.stat().st_mtime_ns
index_mtime = self._index_path().stat().st_mtime_ns
return index_mtime > vcf_mtime
def _fill_genos(
self,
vcf: cyvcf2.VCF,
out: Genos8 | Genos16 | None,
ilens: NDArray[np.int32] | None = None,
mode: type[Genos8 | Genos16] | None = None,
) -> tuple[Genos8 | Genos16, int]:
if self._filter is not None:
vcf = filter(self._filter, vcf)
if out is None:
assert mode is not None
assert ilens is None, "caller should not provide ilens if out is None"
out_ls = []
for i, v in enumerate(vcf):
if self.phasing:
# (s p+1) np.int16
out_ls.append(v.genotype.array())
else:
# (s p) np.int16
out_ls.append(v.genotype.array()[:, : self.ploidy])
if self._pbar is not None:
self._pbar.update()
if len(out_ls) == 0:
return mode.empty(self.n_samples, self.ploidy, 0, self.phasing), 0
# (s p v)
out = cast(
Genos8 | Genos16,
np.stack(out_ls, axis=-1, dtype=mode._gdtype)[self._s_sorter],
)
return out, v.end # type: ignore
#! assumes n_variants > 0
n_variants = out.shape[-1]
if self.progress and self._pbar is None:
vcf = tqdm(vcf, total=n_variants, desc="Reading VCF", unit=" variant")
elif self._pbar is not None and self._pbar.total is None:
self._pbar.total = n_variants
self._pbar.refresh()
i = 0
for i, v in enumerate(vcf):
if self.phasing:
# (s p+1) np.int16
out[..., i] = v.genotype.array()[self._s_sorter]
else:
# (s p) np.int16
out[..., i] = v.genotype.array()[self._s_sorter, : self.ploidy]
if ilens is not None:
ilens[i] = len(v.ALT[0]) - len(v.REF)
if self._pbar is not None:
self._pbar.update()
if i == n_variants - 1:
break
if i != n_variants - 1:
raise ValueError("Not enough variants found in the given range.")
return out, v.end # type: ignore
def _fill_dosages(
self, vcf: cyvcf2.VCF, out: Dosages | None, dosage_field: str
) -> tuple[Dosages, int]:
if self._filter is not None:
vcf = filter(self._filter, vcf)
if out is None:
out_ls = []
for v in vcf:
d = v.format(dosage_field)
if d is None:
raise DosageFieldError(
f"Dosage field '{dosage_field}' not found for record {v!r}"
)
if d.shape[1] > 1:
raise MultiallelicDosageError(
f"Multiallelic dosages are not supported, encountered in VCF record {v!r}"
)
out_ls.append(d.squeeze(1))
if self._pbar is not None:
self._pbar.update()
if len(out_ls) == 0:
return Dosages.empty(self.n_samples, self.ploidy, 0, self.phasing), 0
_out = cast(
Dosages, np.stack(out_ls, axis=-1, dtype=np.float32)[self._s_sorter]
)
return _out, v.end # type: ignore
#! assumes n_variants > 0
n_variants = out.shape[-1]
if self.progress and self._pbar is None:
vcf = tqdm(vcf, total=n_variants, desc="Reading VCF", unit=" variant")
elif self._pbar is not None and self._pbar.total is None:
self._pbar.total = n_variants
self._pbar.refresh()
i = 0
for i, v in enumerate(vcf):
# (samples alts)
d = v.format(dosage_field)
if d is None:
raise DosageFieldError(
f"Dosage field '{dosage_field}' not found for record {v!r}"
)
if d.shape[1] > 1:
raise MultiallelicDosageError(
f"Multiallelic dosages are not supported, encountered in VCF record {v!r}"
)
out[..., i] = d.squeeze(1)[self._s_sorter]
if self._pbar is not None:
self._pbar.update()
if i == n_variants - 1:
break
if i != n_variants - 1:
raise ValueError("Not enough variants found in the given range.")
return out, v.end # type: ignore
def _fill_genos_and_dosages(
self,
vcf: cyvcf2.VCF,
out: Genos8Dosages | Genos16Dosages | None,
dosage_field: str,
ilens: NDArray[np.int32] | None = None,
mode: type[Genos8Dosages | Genos16Dosages] | None = None,
) -> tuple[Genos8Dosages | Genos16Dosages, int]:
if out is None:
assert mode is not None
assert ilens is None, "caller should not provide ilens if out is None"
geno_ls = []
dosage_ls = []
for i, v in enumerate(vcf):
if self.phasing:
# (s p+1) np.int16
geno_ls.append(v.genotype.array())
else:
# (s p) np.int16
geno_ls.append(v.genotype.array()[:, : self.ploidy])
d = v.format(dosage_field)
if d is None:
raise DosageFieldError(
f"Dosage field '{dosage_field}' not found for record {v!r}"
)
if d.shape[1] > 1:
raise MultiallelicDosageError(
f"Multiallelic dosages are not supported, encountered in VCF record {v!r}"
)
dosage_ls.append(d.squeeze(1))
if self._pbar is not None:
self._pbar.update()
if len(geno_ls) == 0:
out = mode.empty(self.n_samples, self.ploidy, 0, self.phasing)
return out, 0
genos = cast(
Genos8 | Genos16,
np.stack(geno_ls, axis=-1, dtype=mode._gdtype)[self._s_sorter],
)
dosages = cast(
Dosages, np.stack(dosage_ls, axis=-1, dtype=np.float32)[self._s_sorter]
)
out = cast(Genos8Dosages | Genos16Dosages, (genos, dosages))
return out, v.end # type: ignore
#! assumes n_variants > 0
n_variants = out[0].shape[-1]
if self._filter is not None:
vcf = filter(self._filter, vcf)
if self.progress and self._pbar is None:
vcf = tqdm(vcf, total=n_variants, desc="Reading VCF", unit=" variant")
elif self._pbar is not None and self._pbar.total is None:
self._pbar.total = n_variants
self._pbar.refresh()
i = 0
for i, v in enumerate(vcf):
if self.phasing:
# (s p+1) np.int16
out[0][..., i] = v.genotype.array()[self._s_sorter]
else:
out[0][..., i] = v.genotype.array()[self._s_sorter, : self.ploidy]
d = v.format(dosage_field)
if d is None:
raise DosageFieldError(
f"Dosage field '{dosage_field}' not found for record {v!r}"
)
if d.shape[1] > 1:
raise MultiallelicDosageError(
f"Multiallelic dosages are not supported, encountered in VCF record {v!r}"
)
out[1][..., i] = d.squeeze(1)[self._s_sorter]
if ilens is not None:
ilens[i] = len(v.ALT[0]) - len(v.REF)
if self._pbar is not None:
self._pbar.update()
if i == n_variants - 1:
break
if i != n_variants - 1:
raise ValueError("Not enough variants found in the given range.")
return out, v.end # type: ignore
def _mem_per_variant(self, mode: type[T]) -> int:
"""Calculate the memory required per variant for the given genotypes and dosages.
Parameters
----------
genotypes
Whether to include genotypes.
dosages
Whether to include dosages.
Returns
-------
int
Memory required per variant in bytes.
"""
mem = 0
ploidy = self.ploidy + self.phasing
if issubclass(mode, (Genos8, Genos16)):
mem += self.n_samples * ploidy * mode._gdtype().itemsize
elif issubclass(mode, Dosages):
mem += self.n_samples * np.float32().itemsize
elif issubclass(mode, (Genos8Dosages, Genos16Dosages)):
mem += self.n_samples * ploidy * mode._gdtype().itemsize
mem += self.n_samples * np.float32().itemsize
else:
assert_never(mode)
return mem
def _ext_genos_with_length(
self,
contig: str,
start: int | np.integer,
end: int | np.integer,
hap_lens: NDArray[np.int32],
mode: type[G],
last_end: int,
) -> tuple[list[G], int]:
ploidy = self.ploidy + self.phasing
length = end - start
ext_start = end
coord = f"{contig}:{ext_start + 1}"
_CHECK_LEN_EVERY_N = 20
ls_genos: list[G] = []
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="no intervals found for", category=UserWarning
)
for i, v in enumerate(self._vcf(coord)):
if v.start < ext_start or (
self._filter is not None and not self._filter(v)
):
continue
# (s p, 1)
genos = v.genotype.array()[:, :ploidy, None]
genos = genos.astype(mode._gdtype)
ls_genos.append(genos)
if v.is_indel:
ilen = len(v.ALT[0]) - len(v.REF)
dist = v.start - last_end
hap_lens += dist + np.where(
genos[:, : self.ploidy] == 1, ilen, 0
).squeeze(-1)
last_end = cast(int, v.end)
if i % _CHECK_LEN_EVERY_N == 0 and (hap_lens >= length).all():
break
if len(ls_genos) > 0:
last_end = cast(int, v.end) # type: ignore | guaranteed bound by len(ls) > 0
return ls_genos, last_end
def _ext_genos_dosages_with_length(
self,
contig: str,
start: int | np.integer,
end: int | np.integer,
hap_lens: NDArray[np.int32],
mode: type[GD],
dosage_field: str,
last_end: int,
) -> tuple[list[GD], int]:
ploidy = self.ploidy + self.phasing
length = end - start
ext_start = end
coord = f"{contig}:{ext_start + 1}"
_CHECK_LEN_EVERY_N = 20
ls_geno_dosages: list[GD] = []
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="no intervals found for", category=UserWarning
)
for i, v in enumerate(self._vcf(coord)):
if v.start < ext_start or (
self._filter is not None and not self._filter(v)
):
continue
# (s p 1)
genos = v.genotype.array()[:, :ploidy, None]
genos = genos.astype(mode._gdtype)
dosages = v.format(dosage_field)
if dosages is None:
raise DosageFieldError(
f"Dosage field '{dosage_field}' not found for record {v!r}"
)
# (s, 1, 1) or (s, 1)? -> (s)
dosages = dosages.squeeze(1)[self._s_sorter, None]
ls_geno_dosages.append((genos, dosages)) # type: ignore
if v.is_indel:
ilen = len(v.ALT[0]) - len(v.REF)
dist = v.start - last_end
# (s p 1)
hap_lens += dist + np.where(
genos[:, : self.ploidy] == 1, ilen, 0
).squeeze(-1)
last_end = cast(int, v.end)
if i % _CHECK_LEN_EVERY_N == 0 and (hap_lens >= length).all():
break
if len(ls_geno_dosages) > 0:
last_end = cast(int, v.end) # type: ignore | guaranteed bound by len(ls) > 0
return ls_geno_dosages, last_end