from __future__ import annotations
import copy
import re
import shutil
import warnings
from collections.abc import Iterable, Sequence
from os import PathLike
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generic, Literal, TypeVar, cast, overload
import awkward as ak
import joblib
import numba as nb
import numpy as np
import polars as pl
import polars_bio as pb
import polars_config_meta # noqa: F401
import seqpro as sp
from awkward.contents import Content, NumpyArray
from hirola import HashTable
from joblib_progress import joblib_progress
from loguru import logger
from numpy.typing import ArrayLike, NDArray
from pgenlib import PgenReader
from polars._typing import IntoExpr
from pydantic import BaseModel
from rich.progress import MofNCompleteColumn, Progress
from seqpro.rag import OFFSET_TYPE, Ragged, lengths_to_offsets
from tqdm.auto import tqdm
from ._pgen import PGEN
from ._types import DOSAGE_TYPE, DTYPE, POS_MAX, POS_TYPE, V_IDX_TYPE
from ._utils import ContigNormalizer, format_memory, parse_memory
from ._var_ranges import var_ranges
from ._vcf import VCF
from .exprs import ILEN
NUMERIC = TypeVar("NUMERIC", bound=np.number)
_REGION_STR_RE = re.compile(r"^(?P<chrom>[^:]+):(?P<start>\d+)-(?P<end>\d+)$")
def _coerce_bed_schema(df: pl.DataFrame) -> pl.DataFrame:
"""Coerce a BED-like frame to columns chrom (Utf8), start (Int32), end (Int32).
Handles both the seqpro convention (chromStart/chromEnd) and the
polars-bio convention (start/end), as well as PyRanges-style
(Chromosome/Start/End) via sp.bed.from_pyr.
"""
rename: dict[str, str] = {}
cols = set(df.columns)
for src, dst in (
("Chromosome", "chrom"),
("CHROM", "chrom"),
("chromStart", "start"),
("chromEnd", "end"),
("Start", "start"),
("End", "end"),
):
if src in cols and dst not in cols:
rename[src] = dst
if rename:
df = df.rename(rename)
return df.select(
pl.col("chrom").cast(pl.Utf8),
pl.col("start").cast(pl.Int32),
pl.col("end").cast(pl.Int32),
)
def _normalize_regions(
regions: "str | tuple[str, int, int] | PathLike | object",
cnorm: ContigNormalizer,
) -> pl.DataFrame:
"""Normalize *regions* to a DataFrame with columns chrom (Utf8), start (Int32),
end (Int32) using 0-based, end-exclusive coordinates.
Accepted input types:
* ``str`` — ``"chrom:start-end"`` (1-based inclusive, converted to 0-based half-open).
* ``tuple[str, int, int]`` — ``(chrom, start, end)`` already 0-based half-open.
* ``Path`` / ``PathLike`` — path to a BED3+ file; read via ``sp.bed.read``.
* ``pl.DataFrame`` (or any frame-like) — must already have ``chrom``, ``start``,
and ``end`` columns (or common aliases).
Rows whose contig is not recognised by *cnorm* are dropped with a ``UserWarning``.
"""
if isinstance(regions, str):
m = _REGION_STR_RE.match(regions)
if m is None:
raise ValueError(
f"Region string {regions!r} does not match 'chrom:start-end'"
)
chrom = m["chrom"]
start = int(m["start"]) - 1 # 1-based inclusive → 0-based
end = int(m["end"])
df = pl.DataFrame(
{"chrom": [chrom], "start": [start], "end": [end]},
schema={"chrom": pl.Utf8, "start": pl.Int32, "end": pl.Int32},
)
elif (
isinstance(regions, tuple) and len(regions) == 3 and isinstance(regions[0], str)
):
chrom, start, end = regions
df = pl.DataFrame(
{"chrom": [chrom], "start": [int(start)], "end": [int(end)]},
schema={"chrom": pl.Utf8, "start": pl.Int32, "end": pl.Int32},
)
elif isinstance(regions, PathLike):
raw = sp.bed.read(Path(regions))
df = _coerce_bed_schema(raw)
else:
# Frame-like
if isinstance(regions, pl.DataFrame):
df = regions
else:
# Try pandas
try:
import pandas as pd
except ImportError:
pd = None
if pd is not None and isinstance(regions, pd.DataFrame):
df = pl.from_pandas(regions)
else:
# Try pyranges (v0 or v1)
pyr_df = None
for mod_name in ("pyranges", "pyranges1"):
try:
pyr_mod = __import__(mod_name)
except ImportError:
continue
pr_cls = getattr(pyr_mod, "PyRanges", None)
if pr_cls is not None and isinstance(regions, pr_cls):
# pyranges0 exposes .df; pyranges1 exposes .to_pandas()
if hasattr(regions, "df"):
pyr_df = pl.from_pandas(regions.df)
else:
pyr_df = pl.from_pandas(regions.to_pandas())
break
if pyr_df is None:
raise TypeError(
f"Unsupported regions type: {type(regions).__name__}. "
"Expected str, tuple, PathLike, or a polars/pandas/pyranges frame."
)
df = pyr_df
df = _coerce_bed_schema(df)
normed = [cnorm.norm(c) for c in df["chrom"].to_list()]
normed_chroms = [n for n in normed if n is not None]
keep_mask = [n is not None for n in normed]
if not all(keep_mask):
n_dropped = sum(1 for k in keep_mask if not k)
warnings.warn(
f"{n_dropped} region(s) dropped: contig not in dataset.", stacklevel=2
)
df = df.filter(pl.Series(keep_mask))
df = df.with_columns(pl.Series("chrom", normed_chroms))
return df
def _normalize_samples(
samples: "str | Sequence[str] | PathLike",
available: Sequence[str],
) -> list[str]:
"""Normalize `samples` to a list of valid sample names, preserving caller order
and deduping by first occurrence. Raises ValueError on unknown samples."""
if isinstance(samples, str):
candidates: list[str] = [samples]
elif isinstance(samples, PathLike) or hasattr(samples, "__fspath__"):
candidates = Path(samples).read_text().splitlines()
candidates = [s for s in candidates if s.strip()]
else:
candidates = list(samples)
avail_set = set(available)
missing = [s for s in candidates if s not in avail_set]
if missing:
raise ValueError(f"Samples not found in dataset: {missing}")
seen: set[str] = set()
deduped: list[str] = []
for s in candidates:
if s not in seen:
seen.add(s)
deduped.append(s)
return deduped
def _validate_fields(
fields: "Sequence[str] | None",
available: "dict[str, np.dtype]",
) -> list[str]:
"""Validate field selection. `None` returns all available fields; a sequence is
validated as a subset of `available`. Raises ValueError on unknown fields."""
if fields is None:
return list(available)
fields = list(fields)
missing = [f for f in fields if f not in available]
if missing:
raise ValueError(f"Fields not found in dataset: {missing}")
return fields
def _resolve_kept_var_idxs(
sv: "SparseVar",
regions: pl.DataFrame,
mode: Literal["pos", "record", "variant"],
merge_overlapping: bool,
) -> "NDArray[V_IDX_TYPE]":
"""Return a sorted, deduplicated array of source variant indices to keep.
Parameters
----------
sv
The SparseVar to query.
regions
Normalised BED-like frame with columns ``chrom`` (Utf8), ``start`` (Int32),
``end`` (Int32). Coordinates are 0-based, half-open.
mode
``"variant"`` — any variant whose span (accounting for ILEN) overlaps the
region is kept, as returned by ``var_ranges``.
``"pos"`` — keep only variants whose POS-1 (0-based) falls strictly inside
``[start, end)``.
``"record"`` — like ``"pos"`` but widen the end by 1, i.e. POS-1 in
``[start, end + 1)``.
merge_overlapping
If *True*, overlapping regions are silently merged before querying.
If *False* and overlapping regions are detected, raise ``ValueError``.
Returns
-------
NDArray[V_IDX_TYPE]
Sorted, deduplicated 1-D array of variant indices.
"""
if regions.height == 0:
return np.empty(0, dtype=V_IDX_TYPE)
# --- overlap detection / optional merge ---
# sp.bed.to_pyr requires chromStart/chromEnd column names.
pyr_input = regions.rename({"start": "chromStart", "end": "chromEnd"})
pyr = sp.bed.to_pyr(pyr_input)
mod = type(pyr).__module__.split(".")[0]
if mod == "pyranges":
merged = pyr.merge()
elif mod == "pyranges1":
merged = pyr.merge_overlaps()
else:
raise RuntimeError(f"Unexpected PyRanges module: {type(pyr)!r}")
if len(merged) != regions.height:
if not merge_overlapping:
raise ValueError("regions overlap; pass merge_overlapping=True to dedupe")
regions = _coerce_bed_schema(sp.bed.from_pyr(merged))
# --- collect candidate variant indices via var_ranges ---
kept_chunks: list[NDArray[V_IDX_TYPE]] = []
sentinel = np.iinfo(V_IDX_TYPE).max
for contig_key, sub in regions.group_by("chrom", maintain_order=False):
c = contig_key[0] if isinstance(contig_key, tuple) else contig_key
starts = sub["start"].to_numpy()
ends = sub["end"].to_numpy()
vr = sv.var_ranges(c, starts, ends) # shape (n_ranges, 2)
valid = vr[:, 0] != sentinel
for s, e in vr[valid]:
kept_chunks.append(np.arange(s, e, dtype=V_IDX_TYPE))
if not kept_chunks:
return np.empty(0, dtype=V_IDX_TYPE)
candidates = np.unique(np.concatenate(kept_chunks))
# "variant" mode: var_ranges already does ILEN-aware overlap — return as-is.
if mode == "variant":
return candidates
# --- pos / record mode: filter by POS membership ---
region_by_contig: dict[str, tuple[NDArray, NDArray]] = {}
for contig_key, sub in regions.group_by("chrom", maintain_order=False):
c = contig_key[0] if isinstance(contig_key, tuple) else contig_key
region_by_contig[c] = (sub["start"].to_numpy(), sub["end"].to_numpy())
idx_slice = sv.index[candidates.tolist()]
cand_pos0 = idx_slice["POS"].to_numpy() - 1 # 1-based POS → 0-based
cand_chrom = idx_slice["CHROM"].to_list()
end_offset = 0 if mode == "pos" else 1 # "record" widens end by 1
keep_mask = np.zeros(len(candidates), dtype=bool)
for i in range(len(candidates)):
pair = region_by_contig.get(cand_chrom[i])
if pair is None:
continue
r_starts, r_ends = pair
p = cand_pos0[i]
if np.any((r_starts <= p) & (p < r_ends + end_offset)):
keep_mask[i] = True
return candidates[keep_mask]
@overload
def dense2sparse(
genos: NDArray[np.int8],
var_idxs: NDArray[V_IDX_TYPE],
dosages: None = None,
) -> Ragged[V_IDX_TYPE]: ...
@overload
def dense2sparse(
genos: NDArray[np.int8],
var_idxs: NDArray[V_IDX_TYPE],
dosages: NDArray[DOSAGE_TYPE],
) -> tuple[Ragged[V_IDX_TYPE], Ragged[DOSAGE_TYPE]]: ...
def dense2sparse(
genos: NDArray[np.int8],
var_idxs: NDArray[V_IDX_TYPE],
dosages: NDArray[DOSAGE_TYPE] | None = None,
) -> Ragged[V_IDX_TYPE] | tuple[Ragged[V_IDX_TYPE], Ragged[DOSAGE_TYPE]]:
"""Convert dense genotypes (and dosages) to sparse genotypes."""
# (s p v)
if genos.ndim < 3:
raise ValueError(
"Sparse genotypes must have at least 3 dimensions, with the final three dimensions corresponding"
+ " to (samples, ploidy, variants)"
)
if dosages is not None:
if dosages.ndim < 2:
raise ValueError(
"Sparse dosages must have at least 2 dimensions, with the final two dimensions corresponding"
+ " to (samples, variants)"
)
if dosages.shape[-1] != genos.shape[-1]:
raise ValueError(
"Sparse dosages must have the same number of variants as the genotypes"
)
if dosages.shape[-2] != genos.shape[-3]:
raise ValueError(
"Sparse dosages must have the same number of samples as the genotypes"
)
keep = genos == 1
data = var_idxs[keep.nonzero()[-1]]
lengths = keep.sum(-1)
shape = (*lengths.shape, None)
offsets = lengths_to_offsets(lengths)
rag = Ragged[V_IDX_TYPE].from_offsets(data, shape, offsets)
if dosages is not None:
# (s v) -> (s p v)
dosage_data = np.broadcast_to(dosages[:, None], genos.shape)[keep]
_dosages = Ragged[DOSAGE_TYPE].from_offsets(dosage_data, shape, offsets)
return rag, _dosages
return rag
def _dense2sparse_with_length(
genos: NDArray[np.integer],
var_idxs: NDArray[V_IDX_TYPE],
q_start: int,
q_end: int,
v_starts: NDArray[np.int32],
ilens: NDArray[np.int32],
dosages: NDArray[DOSAGE_TYPE] | None = None,
) -> Ragged[V_IDX_TYPE] | tuple[Ragged[V_IDX_TYPE], Ragged[DOSAGE_TYPE]]:
"""Convert a dense ``with_length`` window (shared, over-extended across all
samples/haplotypes) into per-haplotype-minimal sparse output, identical to
``SparseVar.read_ranges_with_length`` for the same query.
Parameters
----------
genos
Dense genotypes for the window. Shape: (samples, ploidy, variants).
var_idxs
Global variant indices of the window, used only to populate the sparse
output. Shape: (variants,).
q_start, q_end
0-based, half-open original query span (before extension).
v_starts
0-based start positions of the window's variants (i.e. POS - 1).
Window-aligned: same length as the ``genos`` variant axis and
positionally aligned with ``var_idxs`` (NOT a global per-dataset array).
Shape: (variants,).
ilens
ILEN of the window's variants (ALT - REF length). Window-aligned, like
``v_starts``. Shape: (variants,).
dosages
Optional dense dosages. Shape: (samples, variants).
Returns
-------
``Ragged[V_IDX_TYPE]`` of shape (samples, ploidy, ~variants), or a tuple
with a matching ``Ragged[DOSAGE_TYPE]`` when ``dosages`` is given.
"""
# single-range only: exactly (samples, ploidy, variants), no batch dimension
if genos.ndim != 3:
raise ValueError("Dense genotypes must have shape (samples, ploidy, variants).")
n_samples, ploidy, _ = genos.shape
v_starts = np.ascontiguousarray(v_starts, dtype=np.int32)
ilens = np.ascontiguousarray(ilens, dtype=np.int32)
var_idxs = np.ascontiguousarray(var_idxs, dtype=V_IDX_TYPE)
# Pass 1: per-haplotype kept counts (reuses the sparse path's length walk).
lengths = np.empty((n_samples, ploidy), dtype=np.int64)
_dense2sparse_count(
genos, v_starts, ilens, POS_TYPE(q_start), POS_TYPE(q_end), lengths
)
flat_offsets = lengths_to_offsets(lengths)
total = int(flat_offsets[-1])
shape = (n_samples, ploidy, None)
# Pass 2: fill the (and optionally the dosage) output in disjoint ranges.
out_data = np.empty(total, dtype=V_IDX_TYPE)
has_dose = dosages is not None
dose_in = (
np.ascontiguousarray(dosages, dtype=DOSAGE_TYPE)
if has_dose
else np.empty((0, 0), dtype=DOSAGE_TYPE)
)
out_dose = np.empty(total if has_dose else 0, dtype=DOSAGE_TYPE)
_dense2sparse_fill(
genos, var_idxs, dose_in, lengths, flat_offsets, out_data, out_dose, has_dose
)
rag = Ragged[V_IDX_TYPE].from_offsets(out_data, shape, flat_offsets)
if has_dose:
drag = Ragged[DOSAGE_TYPE].from_offsets(out_dose, shape, flat_offsets)
return rag, drag
return rag
CURRENT_VERSION = 1
class SparseVarMetadata(BaseModel):
version: int | None = None
samples: list[str]
ploidy: int
contigs: list[str]
fields: dict[str, str] = {} # field_name -> numpy dtype name (e.g. "float32")
_SRT = TypeVar("_SRT")
[docs]
class SparseVar(Generic[_SRT]):
"""Open a Sparse Variant (SVAR) directory.
Parameters
----------
path
Path to the SVAR directory.
attrs
Expression of attributes to load in addition to the ALT and ILEN columns.
fields
Names of fields to load from the SVAR directory. Must be keys of
:attr:`available_fields`. Only VCF FORMAT fields with ``Number=G`` are currently
supported as custom fields.
"""
path: Path
version: int | None
available_samples: list[str]
ploidy: int
contigs: list[str]
"""Contigs in the order they appear in the dataset. Variants are only sorted within each contig."""
genos: Ragged[V_IDX_TYPE]
available_fields: dict[str, np.dtype[np.number]]
fields: dict[str, Ragged[np.number]]
index: pl.DataFrame
"""Table of variants with columns: `CHROM`, `POS`, `REF`, `ALT`, `ILEN`, and any additional
attributes specified in `attrs` on construction."""
_c_norm: ContigNormalizer
_s2i: HashTable
_c_max_idxs: dict[str, int]
_is_biallelic: bool
@property
def n_samples(self) -> int:
"""Number of samples in the dataset."""
return len(self.available_samples)
@property
def n_variants(self) -> int:
"""Number of variants in the dataset."""
return self.index.height
@property
def nbytes(self) -> int:
"""Total in-memory footprint, in bytes, of resident (non-mmap'd) data
held by this reader. Only the polars variant index counts; `genos`
and `fields` are memory-mapped and excluded.
"""
return self.index.estimated_size()
@overload
def __init__(
self: SparseVar[Ragged[V_IDX_TYPE]],
path: str | Path,
attrs: IntoExpr | None = None,
fields: None = None,
) -> None: ...
@overload
def __init__(
self: SparseVar[Ragged[np.void]],
path: str | Path,
attrs: IntoExpr | None = None,
fields: Sequence[str] = ...,
) -> None: ...
def __init__(
self,
path: str | Path,
attrs: IntoExpr | None = None,
fields: Sequence[str] | None = None,
):
path = Path(path)
self.path = path
if not self.path.exists():
raise FileNotFoundError(f"SVAR directory {self.path} does not exist.")
with open(path / "metadata.json", "rb") as f:
metadata = SparseVarMetadata.model_validate_json(f.read())
contigs = metadata.contigs
self.version = metadata.version
self.contigs = contigs
self.available_samples = metadata.samples
self.ploidy = metadata.ploidy
self.available_fields = {
name: np.dtype(dtype_str) for name, dtype_str in metadata.fields.items()
}
if fields is not None and (missing := set(fields) - set(self.available_fields)):
raise ValueError(f"Fields {missing} not found in the dataset.")
samples = np.array(self.available_samples)
self._s2i = HashTable(
len(samples) * 2, # type: ignore
dtype=samples.dtype,
)
self._s2i.add(samples)
self._c_norm = ContigNormalizer(contigs)
shape = (self.n_samples, self.ploidy, None)
self.genos = _open_genos(path, shape, "r")
self.fields = {
name: _open_fmt(name, self.available_fields[name], path, shape, "r")
for name in (fields or [])
}
logger.info("Loading genoray index")
self.index = self._load_index(attrs)
self._is_biallelic = (self.index["ALT"].list.len() == 1).all()
vars_per_contig = self.index.group_by("CHROM", maintain_order=True).agg(
n_variants=pl.len()
)
self._c_max_idxs = {
c: v - 1
for c, v in zip(
vars_per_contig["CHROM"], vars_per_contig["n_variants"].cum_sum()
)
}
self._c_max_idxs |= {c: 0 for c in self.contigs if c not in self._c_max_idxs}
[docs]
def var_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
) -> NDArray[V_IDX_TYPE]:
"""Get variant index ranges for each query range. i.e.
For each query range, return the minimum and maximum variant that overlaps.
Note that this means some variants within those ranges may not actually overlap with
the query range if there is a deletion that spans the start of the query.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
Returns
-------
Shape: :code:`(ranges, 2)`. The first column is the start index of the variant
and the second column is the end index of the variant.
"""
return var_ranges(self._c_norm, self.index, contig, starts, ends)
def _find_starts_ends(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
samples: ArrayLike | None = None,
out: NDArray[OFFSET_TYPE] | None = None,
) -> NDArray[OFFSET_TYPE]:
"""Find the start and end offsets of the sparse genotypes for each range.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
samples
List of sample names to read. If None, read all samples.
out
Output array to write to. If None, a new array will be created.
Returns
-------
Shape: (2, ranges, samples, ploidy). The first column is the start index of the variant
and the second column is the end index of the variant.
"""
if samples is None:
samples = np.atleast_1d(np.array(self.available_samples))
else:
samples = np.atleast_1d(np.array(samples))
if missing := set(samples) - set(self.available_samples): # type: ignore
raise ValueError(f"Samples {missing} not found in the dataset.")
s_idxs = cast(NDArray[np.int64], self._s2i[samples])
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
n_ranges = len(starts)
c = self._c_norm.norm(contig)
if c is None:
if out is None:
return np.full(
(n_ranges, len(samples), self.ploidy, 2), -1, OFFSET_TYPE
)
else:
out[:] = -1
return out
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
# (r 2)
var_ranges = self.var_ranges(contig, starts, ends)
if out is None:
# (2 r s p)
out = np.empty((2, n_ranges, len(samples), self.ploidy), dtype=OFFSET_TYPE)
_find_starts_ends(
self.genos.data,
self.genos.offsets,
var_ranges,
s_idxs,
self.ploidy,
out_offsets=out,
)
return out
def _find_starts_ends_with_length(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
samples: ArrayLike | None = None,
out: NDArray[OFFSET_TYPE] | None = None,
) -> NDArray[OFFSET_TYPE]:
"""Find the start and end offsets of the sparse genotypes for each range.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
samples
List of sample names to read. If None, read all samples.
out
Output array to write to. If None, a new array will be created.
Returns
-------
Shape: (2, ranges, samples, ploidy). The first column is the start index of the variant
and the second column is the end index of the variant.
"""
if not self._is_biallelic:
raise ValueError(
"Cannot use with_length operations with multiallelic variants."
)
if samples is None:
samples = np.atleast_1d(np.array(self.available_samples))
else:
samples = np.atleast_1d(np.array(samples))
if missing := set(samples) - set(self.available_samples): # type: ignore
raise ValueError(f"Samples {missing} not found in the dataset.")
s_idxs = cast(NDArray[np.int64], self._s2i[samples])
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
n_ranges = len(starts)
c = self._c_norm.norm(contig)
if c is None:
return np.full((n_ranges, len(samples), self.ploidy, 2), -1, OFFSET_TYPE)
ends = np.atleast_1d(np.asarray(ends, POS_TYPE))
# (r 2)
var_ranges = self.var_ranges(contig, starts, ends)
v_starts = (self.index["POS"] - 1).to_numpy()
# (2 r s p)
out = _find_starts_ends_with_length(
self.genos.data,
self.genos.offsets,
starts,
ends,
var_ranges,
v_starts,
self.index["ILEN"].list.first().fill_null(0).to_numpy(),
s_idxs,
self.ploidy,
self._c_max_idxs[c],
out,
)
return out
[docs]
def read_ranges(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
samples: ArrayLike | None = None,
) -> _SRT:
"""Read the genotypes for the given ranges.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
samples
List of sample names to read. If None, read all samples.
Returns
-------
When no fields are loaded: ``Ragged[V_IDX_TYPE]`` with shape
``(ranges, samples, ploidy, ~variants)``. When fields are loaded: an awkward
record array of the same outer shape where ``result.genos`` is
``Ragged[V_IDX_TYPE]`` and each additional field (e.g. ``result.dosages``) is
a ``Ragged`` of its respective dtype. All arrays are backed by memory-mapped
data so only the offsets reside in RAM.
"""
if samples is None:
samples = np.atleast_1d(np.array(self.available_samples))
else:
samples = np.atleast_1d(np.array(samples))
if missing := set(samples) - set(self.available_samples): # type: ignore
raise ValueError(f"Samples {missing} not found in the dataset.")
n_samples = len(samples)
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
n_ranges = len(starts)
# (2 r s p)
starts_ends = self._find_starts_ends(contig, starts, ends, samples)
shape = (n_ranges, n_samples, self.ploidy, None)
flat_offsets = starts_ends.reshape(2, -1)
genos_result = Ragged[V_IDX_TYPE].from_offsets(
self.genos.data, shape, flat_offsets
)
if not self.fields:
return genos_result # type: ignore[return-value]
field_results = {
name: Ragged.from_offsets(field.data, shape, flat_offsets)
for name, field in self.fields.items()
}
return ak.zip({"genos": genos_result, **field_results}) # type: ignore[return-value]
[docs]
def read_ranges_with_length(
self,
contig: str,
starts: ArrayLike = 0,
ends: ArrayLike = POS_MAX,
samples: ArrayLike | None = None,
) -> _SRT:
"""Read the genotypes for the given ranges such that each entry of variants is guaranteed to have
the minimum amount of variants to reach the query length. This can mean either fewer or more variants
than would be returned than by :code:`read_ranges`, depending on the presence of indels.
Parameters
----------
contig
Contig name.
starts
0-based start positions of the ranges.
ends
0-based, exclusive end positions of the ranges.
samples
List of sample names to read. If None, read all samples.
Returns
-------
Same return structure as :meth:`read_ranges`.
"""
if samples is None:
samples = np.atleast_1d(np.array(self.available_samples))
else:
if missing := set(samples) - set(self.available_samples): # type: ignore
raise ValueError(f"Samples {missing} not found in the dataset.")
samples = np.atleast_1d(np.array(samples))
n_samples = len(samples)
starts = np.atleast_1d(np.asarray(starts, POS_TYPE))
n_ranges = len(starts)
# (2 r s p)
starts_ends = self._find_starts_ends_with_length(contig, starts, ends, samples)
shape = (n_ranges, n_samples, self.ploidy, None)
flat_offsets = starts_ends.reshape(2, -1)
genos_result = Ragged[V_IDX_TYPE].from_offsets(
self.genos.data, shape, flat_offsets
)
if not self.fields:
return genos_result # type: ignore[return-value]
field_results = {
name: Ragged.from_offsets(field.data, shape, flat_offsets)
for name, field in self.fields.items()
}
return ak.zip({"genos": genos_result, **field_results}) # type: ignore[return-value]
@overload
def with_fields(self, fields: Sequence[str]) -> SparseVar[Ragged[np.void]]: ...
@overload
def with_fields(self, fields: Literal[False]) -> SparseVar[Ragged[V_IDX_TYPE]]: ...
@overload
def with_fields(self, fields: None = None) -> SparseVar[_SRT]: ...
[docs]
def with_fields(
self,
fields: Sequence[str] | Literal[False] | None = None,
) -> SparseVar:
"""Return a shallow copy of this ``SparseVar`` with updated fields.
Parameters
----------
fields
- ``None``: leave fields unchanged (returns shallow copy).
- ``Sequence[str]``: names of fields to load from the SVAR directory.
Must be keys of :attr:`available_fields`.
- ``False``: drop all fields, returning a ``SparseVar[Ragged[V_IDX_TYPE]]``.
"""
new = copy.copy(self)
if fields is None:
return new
if fields is False:
new.fields = {}
return new
if missing := set(fields) - set(self.available_fields):
raise ValueError(f"Fields {missing} not found in the dataset.")
shape = (self.n_samples, self.ploidy, None)
new.fields = {
name: _open_fmt(name, self.available_fields[name], self.path, shape, "r")
for name in fields
}
return new
[docs]
@classmethod
def from_vcf(
cls,
out: str | Path,
vcf: VCF,
max_mem: int | str,
overwrite: bool = False,
with_dosages: bool = False,
n_jobs: int = -1,
):
"""Create a Sparse Variant (.svar) from a VCF/BCF.
Parameters
----------
out
Path to the output directory.
vcf
VCF file to write from.
max_mem
Maximum memory to use while writing.
overwrite
Whether to overwrite the output directory if it exists.
with_dosages
Whether to write dosages.
n_jobs
Number of jobs to use for parallel processing.
"""
out = Path(out)
if with_dosages and vcf.dosage_field is None:
raise ValueError("VCF does not have a dosage field specified.")
if out.exists() and not overwrite:
raise FileExistsError(
f"Output path {out} already exists. Use overwrite=True to overwrite."
)
out.mkdir(parents=True, exist_ok=True)
if not vcf._index_path().exists():
logger.info("Genoray VCF index not found, creating index.")
vcf._write_gvi_index()
_write_filtered_index(vcf._index_path(), cls._index_path(out), vcf._pl_filter)
contigs = vcf.contigs
with open(out / "metadata.json", "w") as f:
json = SparseVarMetadata(
version=CURRENT_VERSION,
contigs=contigs,
samples=vcf.available_samples,
ploidy=vcf.ploidy,
fields={"dosages": np.dtype(DOSAGE_TYPE).name} if with_dosages else {},
).model_dump_json()
f.write(json)
max_mem = parse_memory(max_mem)
effective_n_jobs = joblib.cpu_count() if n_jobs == -1 else n_jobs
effective_n_jobs = min(effective_n_jobs, len(contigs))
job_mem = max_mem // effective_n_jobs
with TemporaryDirectory() as chunk_dir:
chunk_dir = Path(chunk_dir)
shape = (vcf.n_samples, vcf.ploidy)
tasks = []
for chunk_idx, c in enumerate(contigs):
task = joblib.delayed(_process_contig_vcf)(
vcf.path,
dosage_field=vcf.dosage_field if with_dosages else None,
max_mem=job_mem,
contig=c,
chunk_dir=chunk_dir,
chunk_idx=chunk_idx,
cyvcf2_filter=vcf._filter,
pl_filter=vcf._pl_filter,
)
tasks.append(task)
with (
joblib_progress(
description=f"Processing contigs using {effective_n_jobs} jobs",
total=len(tasks),
),
joblib.Parallel(n_jobs=effective_n_jobs) as parallel,
):
results: list[tuple[int, int]] = list(parallel(tasks)) # type: ignore
logger.info("Concatenating intermediate chunks")
_concat_data(out, chunk_dir, shape, results, with_dosages=with_dosages)
[docs]
@classmethod
def from_pgen(
cls,
out: str | Path,
pgen: PGEN,
max_mem: int | str,
overwrite: bool = False,
with_dosages: bool = False,
n_jobs: int = -1,
):
"""Create a Sparse Variant (.svar) from a PGEN.
Parameters
----------
out
Path to the output directory.
pgen
PGEN file to write from.
max_mem
Maximum memory to use while writing.
overwrite
Whether to overwrite the output directory if it exists.
with_dosages
Whether to write dosages.
n_jobs
Number of jobs to use for parallel processing.
"""
out = Path(out)
if with_dosages and pgen.dosage_path is None:
raise ValueError("PGEN does not have a dosage file specified.")
if out.exists() and not overwrite:
raise FileExistsError(
f"Output path {out} already exists. Use overwrite=True to overwrite."
)
out.mkdir(parents=True, exist_ok=True)
pgen._init_index()
assert pgen.contigs is not None
assert pgen._c_max_idxs is not None
contigs = pgen.contigs
with open(out / "metadata.json", "w") as f:
json = SparseVarMetadata(
version=CURRENT_VERSION,
contigs=contigs,
samples=pgen.available_samples,
ploidy=pgen.ploidy,
fields={"dosages": np.dtype(DOSAGE_TYPE).name} if with_dosages else {},
).model_dump_json()
f.write(json)
_write_filtered_index(pgen._index_path(), cls._index_path(out), pgen._filter)
if with_dosages and pgen._sei is None:
raise ValueError("PGEN must be bi-allelic with filters applied")
max_mem = parse_memory(max_mem)
effective_n_jobs = joblib.cpu_count() if n_jobs == -1 else n_jobs
effective_n_jobs = min(effective_n_jobs, len(contigs))
job_mem = max_mem // effective_n_jobs
mem_per_var = pgen._mem_per_variant(
pgen.GenosDosages if with_dosages else pgen.Genos # type: ignore
)
shape = (pgen.n_samples, pgen.ploidy)
with TemporaryDirectory() as contig_dir:
contig_dir = Path(contig_dir)
keep_by_contig = {
chrom: np.asarray(idxs, dtype=np.uint32)
for chrom, idxs in (
pgen._index.group_by("CHROM", maintain_order=True)
.agg(pl.col("index"))
.iter_rows()
)
}
tasks = []
for c in contigs:
keep_idxs = keep_by_contig.get(c)
if keep_idxs is None or len(keep_idxs) == 0:
continue
task = joblib.delayed(_process_contig_pgen)(
geno_path=pgen.geno_path,
dosage_path=pgen.dosage_path if with_dosages else None,
max_mem=job_mem,
keep_idxs=keep_idxs,
mem_per_var=mem_per_var,
n_samples=pgen.n_samples,
ploidy=pgen.ploidy,
chunk_dir=contig_dir,
chunk_idx=len(tasks),
)
tasks.append(task)
pgen._free_index()
# PgenReaders can be multi-GB allocations, close them to free memory
pgen._geno_pgen.close()
if pgen.dosage_path is not None:
pgen._dose_pgen.close()
with (
joblib_progress(
description=f"Processing contigs using {effective_n_jobs} jobs",
total=len(tasks),
),
joblib.Parallel(n_jobs=effective_n_jobs) as parallel,
):
results: list[tuple[int, int]] = list(parallel(tasks)) # type: ignore
logger.info("Concatenating intermediate chunks")
_concat_data(out, contig_dir, shape, results, with_dosages=with_dosages)
@classmethod
def _index_path(cls, root: Path):
"""Path to the index file."""
return root / "index.arrow"
def _load_index(self, attrs: IntoExpr | None = None) -> pl.DataFrame:
"""Load the .gvi index."""
with warnings.catch_warnings():
warnings.simplefilter("ignore")
index = pl.scan_ipc(self._index_path(self.path), row_index_name="index")
schema = index.collect_schema()
if schema["ALT"] == pl.Utf8:
index = index.with_columns(pl.col("ALT").cast(pl.List(pl.Utf8)))
_attrs: set[IntoExpr] = {"ALT"}
if attrs is not None:
if not isinstance(attrs, str) and isinstance(attrs, Iterable):
_attrs.update(attrs)
else:
_attrs.add(attrs)
_attrs.discard("ILEN")
user_attr_names = [a for a in _attrs - {"ALT"} if isinstance(a, str)]
if non_numeric := [
a for a in user_attr_names if not schema[a].is_numeric()
]:
raise ValueError(f"Attrs {non_numeric} must be numeric.")
attrs = list(_attrs)
if "ILEN" in schema:
attrs.append("ILEN")
elif "ILEN" not in schema:
attrs.append(ILEN.alias("ILEN"))
index = index.select("index", "CHROM", "POS", "REF", *attrs).collect()
return index
[docs]
def annotate_with_gtf(
self,
gtf: str | pl.DataFrame,
level_filter: int | None = 1,
write_back: bool = True,
*,
strand_encoding: dict[str | None, int] | None = None,
codon_null_token: int | None = None,
) -> pl.DataFrame:
"""
Annotate variants with gene_id, strand, and codon_pos from GTF CDS features.
Computes codon position for SNVs only; indels receive strand but null codon_pos.
Parameters
----------
gtf : str or pl.DataFrame
Path to GTF file (.gtf or .gtf.gz) or pre-loaded Polars DataFrame.
level_filter : int or None, default 1
If set, keep rows with GTF 'level' <= level_filter (1 = highest quality).
write_back : bool, default True
If True, update self.var_table in-place and write to index.arrow file.
strand_encoding : dict or None, optional
Encode strand as integers. Example: {'+': 0, '-': 1, None: 2}
codon_null_token : int or None, optional
Replace null codon_pos with this integer for ML models.
Returns
-------
pl.DataFrame
Columns: varID (UInt32), gene_id (Utf8), strand (Utf8/Int16), codon_pos (Int8/Int16)
Examples
--------
>>> svar = SparseVar("data.svar")
>>> annot = svar.annotate_with_gtf("gencode.v45.gtf.gz")
>>> annot.head()
"""
# Validate inputs
if level_filter is not None and not isinstance(level_filter, int):
raise TypeError(
f"level_filter must be int or None, got {type(level_filter)}"
)
if strand_encoding is not None and not isinstance(strand_encoding, dict):
raise TypeError("strand_encoding must be dict or None")
if codon_null_token is not None and not isinstance(codon_null_token, int):
raise TypeError("codon_null_token must be int or None")
logger.info("Loading GTF for CDS annotation")
with tqdm(total=3, desc="GTF annotation", unit="step") as pbar:
# Load GTF
pbar.set_description("Loading GTF")
gtf_df = _load_gtf(gtf)
if level_filter is not None and "level" in gtf_df.columns:
gtf_df = gtf_df.filter(pl.col("level").cast(pl.Int32) <= level_filter)
pbar.update(1)
# CDS Annotation
pbar.set_description("CDS annotation")
# Extract CDS features with gene_biotype
cds_df = gtf_df.filter(pl.col("feature") == "CDS").select(
"chrom",
"start",
"end",
"strand",
"frame",
"gene_id",
"transcript_id",
"gene_biotype",
"transcript_support_level",
"tag",
)
if len(cds_df) == 0:
annot = _empty_annot()
else:
annot = _get_strand_and_codon_pos(cds_df, self.index, self._c_norm)
pbar.update()
# Apply encoding
pbar.set_description("Finalizing")
if strand_encoding is not None:
str_map = {k: v for k, v in strand_encoding.items() if k is not None}
null_val = strand_encoding.get(None)
strand_expr = pl.col("strand").replace_strict(str_map, default=null_val)
annot = annot.with_columns(strand_expr.cast(pl.Int16).alias("strand"))
if codon_null_token is not None:
annot = annot.with_columns(
pl.col("codon_pos").fill_null(codon_null_token).cast(pl.Int16)
)
# Write back if requested
if write_back:
self._load_all_attrs()
self.index = (
self.index.lazy()
.with_row_index("varID")
.join(annot.lazy(), on="varID", how="left")
.drop("varID")
.collect()
)
df = self._to_df()
df.write_ipc(self._index_path(self.path))
logger.info("Wrote gene_id, strand, codon_pos to index.arrow")
pbar.update(1)
return annot
[docs]
def cache_afs(self):
"""Cache the allele frequencies on disk. Will also load all possible attributes and add the AF column in-memory."""
self._load_all_attrs()
afs = self._compute_afs()
self.index = self.index.with_columns(AF=pl.Series(afs))
self._write_afs()
def _load_all_attrs(self):
idx_df = pl.scan_ipc(self._index_path(self.path))
schema = idx_df.collect_schema()
missing = set(schema) - set(self.index.columns)
missing_attrs = idx_df.select(*missing).collect()
self.index = self.index.hstack(missing_attrs)
def _compute_afs(self) -> NDArray[np.float32]:
n_samples, ploidy, _ = cast(tuple[int, int, None], self.genos.shape)
max_count = n_samples * ploidy
afs = np.zeros(self.n_variants, np.float32)
_nb_af_helper(afs, self.genos.data, self.genos.offsets, max_count)
return afs
def _write_afs(self):
df = self._to_df()
df.write_ipc(self._index_path(self.path))
def _to_df(self) -> pl.DataFrame:
return self.index.drop("index")
def _load_genos(self):
def memmap2array(layout: Content, **kwargs):
if isinstance(layout, NumpyArray):
data = layout.data
if isinstance(data, np.memmap):
data = data[:]
return NumpyArray(data)
self.genos = ak.transform(memmap2array, self.genos) # type: ignore
[docs]
def write_view(
self,
regions: str | tuple[str, int, int] | Path,
samples: str | Sequence[str] | Path,
output: str | Path,
fields: Sequence[str] | None = None,
merge_overlapping: bool = False,
regions_overlap: Literal["pos", "record", "variant"] = "pos",
overwrite: bool = False,
threads: int | None = None,
) -> None:
"""Write a subset of this SparseVar to a new directory.
Parameters
----------
regions
Region(s) to include. Accepts the same input types as
:func:`_normalize_regions`: a ``"chrom:start-end"`` string, a
``(chrom, start, end)`` tuple, a BED file path, or a
polars/pandas/pyranges frame.
samples
Samples to include. Accepts a single sample name, a list, or a
path to a file of newline-separated names.
output
Destination directory for the new SparseVar.
fields
Fields to carry over (``None`` = all available; ``[]`` = none).
merge_overlapping
If ``True`` silently merge overlapping regions; if ``False``
raise ``ValueError`` when overlaps are detected.
regions_overlap
How variants are matched to regions — ``"pos"``, ``"record"``, or
``"variant"``. See :func:`_resolve_kept_var_idxs`.
overwrite
Whether to overwrite *output* if it already exists.
threads
Number of Numba threads to use. ``None`` uses all available CPUs.
Notes
-----
Variants whose minor allele count is 0 in the chosen sample subset are
dropped from the output. If every candidate variant drops, a
:class:`ValueError` is raised — the same code path that fires when
``regions`` itself selects no variants.
"""
from ._utils import _resolve_threads, numba_threads
output = Path(output)
# --- 1. Normalize inputs ---
regions_df = _normalize_regions(regions, self._c_norm)
caller_samples = _normalize_samples(samples, self.available_samples)
fields_to_write = _validate_fields(fields, self.available_fields)
if not caller_samples:
raise ValueError("write_view requires at least one sample")
# --- 2. Resolve kept variant indices ---
kept_var_idxs = _resolve_kept_var_idxs(
self, regions_df, regions_overlap, merge_overlapping
)
if len(kept_var_idxs) == 0:
raise ValueError("no variants selected by `regions`")
# --- 3. Output directory (after all validation, so no partial dir on error) ---
if output.exists():
if not overwrite:
raise FileExistsError(
f"Output path {output} already exists. Use overwrite=True to overwrite."
)
shutil.rmtree(output)
output.mkdir(parents=True)
# --- 4. Setup ---
n_out = len(caller_samples)
ploidy = self.ploidy
threads_resolved = _resolve_threads(threads)
src_sample_idxs = self._s2i[np.array(caller_samples)].astype(np.int64)
# --- 4.5. Pre-pass: drop variants whose MAC across kept samples is 0 ---
mac_per_kept = np.zeros(len(kept_var_idxs), dtype=np.int64)
with numba_threads(threads_resolved):
_nb_count_mac_per_kept(
self.genos.data,
self.genos.offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
mac_per_kept,
)
keep_mask = mac_per_kept > 0
n_dropped = int((~keep_mask).sum())
if n_dropped:
warnings.warn(
f"write_view: dropping {n_dropped} variant(s) with MAC=0 in the output sample set",
stacklevel=2,
)
kept_var_idxs = kept_var_idxs[keep_mask]
if len(kept_var_idxs) == 0:
raise ValueError(
"all variants in the selected regions have MAC=0 in the "
"chosen sample subset; nothing to write"
)
# --- 5. Pass 1: count kept entries per output slot ---
out_lengths = np.zeros(n_out * ploidy, dtype=np.int64)
with numba_threads(threads_resolved):
_nb_count_kept(
self.genos.data,
self.genos.offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
out_lengths,
)
new_offsets = lengths_to_offsets(out_lengths.reshape(n_out, ploidy))
# --- 6. Write offsets.npy ---
offsets_mm = np.memmap(
output / "offsets.npy",
dtype=np.int64,
mode="w+",
shape=new_offsets.shape,
)
offsets_mm[:] = new_offsets
offsets_mm.flush()
# Allocate output variant_idxs memmap
n_entries = int(new_offsets[-1])
out_var_idxs_mm = np.memmap(
output / "variant_idxs.npy",
dtype=V_IDX_TYPE,
mode="w+",
shape=(n_entries,),
)
# --- 7. Pass 2 (genos): write remapped variant indices ---
with numba_threads(threads_resolved):
_nb_write_var_idxs(
self.genos.data,
self.genos.offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
new_offsets.ravel(),
out_var_idxs_mm,
)
out_var_idxs_mm.flush()
# --- 8. Pass 2 (fields): write each field ---
for name in fields_to_write:
dtype = self.available_fields[name]
src_field_rag = _open_fmt(
name, dtype, self.path, (self.n_samples, ploidy, None), "r"
)
out_field_mm = np.memmap(
output / f"{name}.npy",
dtype=dtype,
mode="w+",
shape=(n_entries,),
)
with numba_threads(threads_resolved):
_nb_write_field(
src_field_rag.data,
self.genos.data,
self.genos.offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
new_offsets.ravel(),
out_field_mm,
)
out_field_mm.flush()
del src_field_rag
# --- 9. Build new index ---
new_index = self.index[kept_var_idxs.tolist()]
# Drop existing AF and row-index columns if present
cols_to_drop = [c for c in ("AF", "index") if c in new_index.columns]
if cols_to_drop:
new_index = new_index.drop(cols_to_drop)
# Compute AFs over the written genos
n_alleles = n_out * ploidy
afs = np.zeros(len(kept_var_idxs), dtype=np.float32)
_nb_af_helper(afs, out_var_idxs_mm, new_offsets.ravel(), n_alleles)
new_index = new_index.with_columns(AF=pl.Series(afs))
new_index.write_ipc(SparseVar._index_path(output))
# --- 10. Write metadata.json ---
with open(output / "metadata.json", "w") as f:
json_str = SparseVarMetadata(
version=CURRENT_VERSION,
samples=caller_samples,
ploidy=ploidy,
contigs=self.contigs,
fields={n: self.available_fields[n].name for n in fields_to_write},
).model_dump_json()
f.write(json_str)
@nb.njit(nogil=True, cache=True)
def _nb_af_helper(afs, v_idxs, offsets, max_count):
for i in range(len(offsets) - 1):
o_s, o_e = offsets[i], offsets[i + 1]
v_slice = v_idxs[o_s:o_e]
afs[v_slice] += 1
afs /= max_count
@nb.njit(parallel=True, nogil=True, cache=True)
def _nb_count_kept(
src_data, src_offsets, src_sample_idxs, ploidy, kept_var_idxs, out_lengths
):
"""Pass 1: count, per output (sample, ploidy) slot, how many source variant
indices fall in `kept_var_idxs`."""
n_out = src_sample_idxs.shape[0]
n_kept = kept_var_idxs.shape[0]
for i in nb.prange(n_out):
s = src_sample_idxs[i]
for p in range(ploidy):
src_slot = s * ploidy + p
count = 0
lo = src_offsets[src_slot]
hi = src_offsets[src_slot + 1]
for j in range(lo, hi):
v = src_data[j]
k = np.searchsorted(kept_var_idxs, v)
if k < n_kept and kept_var_idxs[k] == v:
count += 1
out_lengths[i * ploidy + p] = count
@nb.njit(parallel=True, nogil=True, cache=True)
def _nb_count_mac_per_kept(
src_data, src_offsets, src_sample_idxs, ploidy, kept_var_idxs, mac_out
):
"""Count, per kept variant, the number of non-ref entries across (sample, ploidy)
in the output. Outer prange is over kept variants so each writes its own slot —
no atomics needed."""
n_kept = kept_var_idxs.shape[0]
n_samples = src_sample_idxs.shape[0]
for k in nb.prange(n_kept):
v = kept_var_idxs[k]
count = 0
for i in range(n_samples):
s = src_sample_idxs[i]
for p in range(ploidy):
src_slot = s * ploidy + p
lo = src_offsets[src_slot]
hi = src_offsets[src_slot + 1]
idx = np.searchsorted(src_data[lo:hi], v)
if idx < (hi - lo) and src_data[lo + idx] == v:
count += 1
mac_out[k] = count
@nb.njit(parallel=True, nogil=True, cache=True)
def _nb_write_var_idxs(
src_data,
src_offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
new_offsets,
out_var_idxs,
):
"""Pass 2: write remapped variant indices."""
n_out = src_sample_idxs.shape[0]
n_kept = kept_var_idxs.shape[0]
for i in nb.prange(n_out):
s = src_sample_idxs[i]
for p in range(ploidy):
src_slot = s * ploidy + p
out_slot = i * ploidy + p
wp = new_offsets[out_slot]
lo = src_offsets[src_slot]
hi = src_offsets[src_slot + 1]
for j in range(lo, hi):
v = src_data[j]
k = np.searchsorted(kept_var_idxs, v)
if k < n_kept and kept_var_idxs[k] == v:
out_var_idxs[wp] = k
wp += 1
@nb.njit(parallel=True, nogil=True, cache=True)
def _nb_write_field(
src_field,
src_data,
src_offsets,
src_sample_idxs,
ploidy,
kept_var_idxs,
new_offsets,
out_field,
):
"""Pass 2 (field variant): writes src_field values at filter-kept positions."""
n_out = src_sample_idxs.shape[0]
n_kept = kept_var_idxs.shape[0]
for i in nb.prange(n_out):
s = src_sample_idxs[i]
for p in range(ploidy):
src_slot = s * ploidy + p
out_slot = i * ploidy + p
wp = new_offsets[out_slot]
lo = src_offsets[src_slot]
hi = src_offsets[src_slot + 1]
for j in range(lo, hi):
v = src_data[j]
k = np.searchsorted(kept_var_idxs, v)
if k < n_kept and kept_var_idxs[k] == v:
out_field[wp] = src_field[j]
wp += 1
def _write_filtered_index(src: Path, dst: Path, pl_filter: pl.Expr | None) -> None:
"""Stream a (possibly filtered) genoray index from ``src`` to ``dst``.
When ``pl_filter`` is None this is byte-equivalent to copying. Otherwise the
filter is applied lazily; ALT is normalized to list[str] for the filter and
re-joined to the on-disk comma-Utf8 form so the SVAR index format is
unchanged. ILEN is computed on-the-fly if absent so ILEN-dependent
expressions (e.g. ``is_snp``) work correctly, then dropped from the output
to preserve the original on-disk schema. ILEN is always computed when absent
from the on-disk schema — even if the filter doesn't reference it — to avoid
introspecting the opaque Polars expression.
"""
if pl_filter is None:
shutil.copy(src, dst)
return
lf = pl.scan_ipc(src)
schema = lf.collect_schema()
alt_is_utf8 = schema["ALT"] == pl.Utf8
ilen_added = "ILEN" not in schema
if alt_is_utf8:
lf = lf.with_columns(pl.col("ALT").str.split(","))
if ilen_added:
lf = lf.with_columns(ILEN=ILEN)
lf = lf.filter(pl_filter)
if ilen_added:
lf = lf.drop("ILEN")
if alt_is_utf8:
lf = lf.with_columns(pl.col("ALT").list.join(","))
lf.sink_ipc(dst, compression="zstd")
def _process_contig_vcf(
path: str | Path,
dosage_field: str | None,
max_mem: int | str,
contig: str,
chunk_dir: Path,
chunk_idx: int,
cyvcf2_filter=None,
pl_filter=None,
) -> tuple[int, int]:
vcf = VCF(
path,
filter=cyvcf2_filter,
pl_filter=pl_filter,
dosage_field=dosage_field,
with_gvi_index=False,
)
if dosage_field is not None:
chunker = vcf.chunk(contig, max_mem=max_mem, mode=VCF.Genos8Dosages)
else:
chunker = vcf.chunk(contig, max_mem=max_mem, mode=VCF.Genos8)
total_vars = 0
n_chunks = 0
# Create a subdirectory for this contig to avoid collision
contig_dir = chunk_dir / f"c{chunk_idx}"
contig_dir.mkdir(parents=True, exist_ok=True)
for i, data in enumerate(chunker):
out_path = contig_dir / str(i)
out_path.mkdir(parents=True, exist_ok=True)
n_chunks += 1
if isinstance(data, tuple):
genos, dosages = data
else:
genos = data
dosages = None
n_vars = genos.shape[-1]
if n_vars == 0:
continue
var_idxs = np.arange(total_vars, total_vars + n_vars, dtype=np.int32)
if dosages is not None:
sp_genos, sp_dosages = dense2sparse(genos, var_idxs, dosages)
_write_genos(out_path, sp_genos)
_write_dosages(out_path, sp_dosages.data)
else:
sp_genos = dense2sparse(genos, var_idxs)
_write_genos(out_path, sp_genos)
total_vars += n_vars
return total_vars, n_chunks
def _process_contig_pgen(
geno_path: str | Path,
dosage_path: str | Path | None,
max_mem: int,
keep_idxs: np.ndarray,
mem_per_var: int,
n_samples: int,
ploidy: int,
chunk_dir: Path,
chunk_idx: int,
) -> tuple[int, int]:
geno_reader = PgenReader(bytes(Path(geno_path)), n_samples)
dose_reader = (
PgenReader(bytes(Path(dosage_path))) if dosage_path is not None else None
)
keep_idxs = np.ascontiguousarray(keep_idxs, dtype=np.uint32)
n_total = int(len(keep_idxs))
vars_per_chunk = min(max_mem // mem_per_var, n_total) if n_total else 0
if n_total and vars_per_chunk == 0:
raise ValueError(
f"Maximum memory {format_memory(max_mem)} insufficient to read a single variant."
+ f" Memory per variant: {format_memory(mem_per_var)}."
)
# Create a subdirectory for this contig to avoid collision
contig_dir = chunk_dir / f"c{chunk_idx}"
contig_dir.mkdir(parents=True, exist_ok=True)
total_vars = 0
n_chunks = 0
for i, c0 in enumerate(range(0, n_total, vars_per_chunk) if n_total else []):
idxs = keep_idxs[c0 : c0 + vars_per_chunk]
n_vars = int(len(idxs))
if n_vars == 0:
continue
n_chunks += 1
out_path = contig_dir / str(i)
out_path.mkdir(parents=True, exist_ok=True)
# Read genotypes for exactly the kept variant indices.
# (v, s*p)
genos = np.empty((n_vars, n_samples * ploidy), dtype=np.int32)
geno_reader.read_alleles_list(idxs, genos)
genos = genos.astype(np.int8)
# (v, s, p) -> (s, p, v)
genos = genos.reshape(n_vars, n_samples, ploidy).transpose(1, 2, 0)
genos[genos == -9] = -1
dosages = None
if dose_reader is not None:
dosages = np.empty((n_vars, n_samples), dtype=np.float32)
dose_reader.read_dosages_list(idxs, dosages)
dosages = dosages.transpose(1, 0)
dosages[dosages == -9] = np.nan
# Convert to sparse
var_idxs = np.arange(total_vars, total_vars + n_vars, dtype=np.int32)
if dosages is not None:
sp_genos, sp_dosages = dense2sparse(genos, var_idxs, dosages)
_write_genos(out_path, sp_genos)
_write_dosages(out_path, sp_dosages.data)
else:
sp_genos = dense2sparse(genos, var_idxs)
_write_genos(out_path, sp_genos)
total_vars += n_vars
return total_vars, n_chunks
def _open_genos(path: Path, shape: tuple[int | None, ...], mode: Literal["r", "r+"]):
# Load the memory-mapped files
var_idxs = np.memmap(path / "variant_idxs.npy", dtype=V_IDX_TYPE, mode=mode)
offsets = np.memmap(path / "offsets.npy", dtype=np.int64, mode=mode)
sp_genos = Ragged[V_IDX_TYPE].from_offsets(var_idxs, shape, offsets)
return sp_genos
def _open_fmt(
name: str,
type_: NUMERIC | np.dtype[NUMERIC] | type[NUMERIC],
path: Path,
shape: tuple[int | None, ...],
mode: Literal["r", "r+"],
) -> Ragged[NUMERIC]:
# Load the memory-mapped files
data = np.memmap(path / f"{name}.npy", dtype=type_, mode=mode)
offsets = np.memmap(path / "offsets.npy", dtype=np.int64, mode=mode)
sp_genos = Ragged.from_offsets(data, shape, offsets)
return sp_genos
def _write_genos(path: Path, sp_genos: Ragged[V_IDX_TYPE]):
path.mkdir(parents=True, exist_ok=True)
var_idxs = np.memmap(
path / "variant_idxs.npy",
shape=sp_genos.data.shape,
dtype=sp_genos.data.dtype,
mode="w+",
)
var_idxs[:] = sp_genos.data
var_idxs.flush()
offsets = np.memmap(
path / "offsets.npy",
shape=sp_genos.offsets.shape,
dtype=sp_genos.offsets.dtype,
mode="w+",
)
offsets[:] = sp_genos.offsets
offsets.flush()
def _write_dosages(path: Path, dosages: NDArray[DOSAGE_TYPE]):
path.mkdir(parents=True, exist_ok=True)
dosages_memmap = np.memmap(
path / "dosages.npy",
shape=dosages.shape,
dtype=dosages.dtype,
mode="w+",
)
dosages_memmap[:] = dosages
dosages_memmap.flush()
def _concat_data(
out_path: Path,
chunk_dir: Path,
shape: tuple[int, int],
contig_results: list[tuple[int, int]],
with_dosages: bool = False,
):
out_path.mkdir(parents=True, exist_ok=True)
# Flatten chunk directories and calculate offsets
chunk_offsets: list[int] = []
contig_offset = 0
global_chunk_idx = 0
for chunk_idx, (n_vars, n_chunks) in enumerate(contig_results):
contig_subdir = chunk_dir / f"c{chunk_idx}"
if n_chunks > 0:
for i in range(n_chunks):
src = contig_subdir / str(i)
dest = chunk_dir / str(global_chunk_idx)
src.rename(dest)
chunk_offsets.append(contig_offset)
global_chunk_idx += 1
if contig_subdir.exists():
shutil.rmtree(contig_subdir)
contig_offset += n_vars
# [1, 2, 3, ...]
chunk_dirs = [chunk_dir / str(i) for i in range(len(chunk_offsets))]
vars_per_sp = np.zeros(shape, dtype=np.int32)
# Pass 1: Compute lengths
# We explicitly map only the offsets to avoid mapping the potentially large variant_idxs
for c_dir in chunk_dirs:
# Load offsets
chunk_offsets_arr = np.memmap(c_dir / "offsets.npy", dtype=np.int64, mode="r")
# Compute lengths: (n_samples * ploidy,) -> (n_samples, ploidy)
chunk_lengths = np.diff(chunk_offsets_arr).reshape(shape)
vars_per_sp += chunk_lengths
# Close memmap
del chunk_offsets_arr
# offsets should be relatively small even for ultra-large datasets
# scales O(n_samples * ploidy)
offsets = lengths_to_offsets(vars_per_sp)
offsets_memmap = np.memmap(
out_path / "offsets.npy", dtype=offsets.dtype, mode="w+", shape=offsets.shape
)
offsets_memmap[:] = offsets
offsets_memmap.flush()
var_idxs_memmap = np.memmap(
out_path / "variant_idxs.npy", dtype=V_IDX_TYPE, mode="w+", shape=offsets[-1]
)
# Use in-memory array for write offsets to avoid disk I/O
write_offsets = offsets[:-1].copy()
pbar = Progress(*Progress.get_default_columns(), MofNCompleteColumn())
pbar.start()
# Pass 2: Copy Genotypes
for offset, c_dir in pbar.track(
zip(chunk_offsets, chunk_dirs),
total=len(chunk_dirs),
description="Copying genotypes",
):
# We process chunks sequentially to minimize memory usage
sp_genos = _open_genos(c_dir, (*shape, None), mode="r")
_copy_chunk_helper(
var_idxs_memmap,
write_offsets,
sp_genos.data,
sp_genos.offsets,
offset,
shape[0],
shape[1],
)
# Close memmaps
del sp_genos
var_idxs_memmap.flush()
if with_dosages:
# Reset write offsets
write_offsets = offsets[:-1].copy()
dosages_memmap = np.memmap(
out_path / "dosages.npy", dtype=DOSAGE_TYPE, mode="w+", shape=offsets[-1]
)
for c_dir in pbar.track(
chunk_dirs, total=len(chunk_dirs), description="Copying dosages"
):
sp_dosages = _open_fmt(
"dosages", DOSAGE_TYPE, c_dir, (*shape, None), mode="r"
)
_copy_chunk_dosages_helper(
dosages_memmap,
write_offsets,
sp_dosages.data,
sp_dosages.offsets,
shape[0],
shape[1],
)
del sp_dosages
dosages_memmap.flush()
pbar.stop()
@nb.njit(parallel=True, nogil=True, cache=True)
def _copy_chunk_helper(
out_data: NDArray[DTYPE],
write_offsets: NDArray[OFFSET_TYPE],
in_data: NDArray[DTYPE],
in_offsets: NDArray[OFFSET_TYPE],
variant_offset: int,
n_samples: int,
ploidy: int,
):
for s in nb.prange(n_samples):
for p in range(ploidy):
sp = s * ploidy + p
i_s, i_e = in_offsets[sp], in_offsets[sp + 1]
length = i_e - i_s
o_s = write_offsets[sp]
# Copy and add offset
for i in range(length):
out_data[o_s + i] = in_data[i_s + i] + variant_offset # type: ignore
write_offsets[sp] += length
@nb.njit(parallel=True, nogil=True, cache=True)
def _copy_chunk_dosages_helper(
out_data: NDArray[DOSAGE_TYPE],
write_offsets: NDArray[OFFSET_TYPE],
in_data: NDArray[DOSAGE_TYPE],
in_offsets: NDArray[OFFSET_TYPE],
n_samples: int,
ploidy: int,
):
for s in nb.prange(n_samples):
for p in range(ploidy):
sp = s * ploidy + p
i_s, i_e = in_offsets[sp], in_offsets[sp + 1]
length = i_e - i_s
o_s = write_offsets[sp]
out_data[o_s : o_s + length] = in_data[i_s:i_e]
write_offsets[sp] += length
@nb.njit(parallel=True, nogil=True, cache=True)
def _find_starts_ends(
genos: NDArray[V_IDX_TYPE],
geno_offsets: NDArray[OFFSET_TYPE],
var_ranges: NDArray[V_IDX_TYPE],
sample_idxs: NDArray[np.int64],
ploidy: int,
out_offsets: NDArray[OFFSET_TYPE] | None = None,
):
"""Find the start and end offsets of the sparse genotypes for each range.
Parameters
----------
genos
Sparse genotypes
geno_offsets
Genotype offsets
var_ranges
Shape = (ranges 2) Variant index ranges.
sample_idxs
Sample indices
ploidy
Ploidy
out_offsets
Output array to write to. If None, a new array will be created.
Returns
-------
Shape: (ranges samples ploidy 2). The first column is the start index of the variant
and the second column is the end index of the variant.
"""
n_ranges = len(var_ranges)
n_samples = len(sample_idxs)
if out_offsets is None:
out_offsets = np.empty((2, n_ranges, n_samples, ploidy), dtype=OFFSET_TYPE)
sorter = np.argsort(var_ranges[:, 0])
var_ranges = var_ranges[sorter]
for s in nb.prange(n_samples):
for p in nb.prange(ploidy):
s_idx = sample_idxs[s]
sp = s_idx * ploidy + p
o_s, o_e = geno_offsets[sp], geno_offsets[sp + 1]
sp_genos = genos[o_s:o_e]
# add o_s to make indices relative to whole array
out_offsets[..., s, p] = np.searchsorted(sp_genos, var_ranges).T + o_s
no_vars = var_ranges[:, 0] == var_ranges[:, 1]
out_offsets[:, no_vars] = np.iinfo(OFFSET_TYPE).max
unsorter = np.argsort(sorter)
out_offsets[:] = out_offsets[:, unsorter]
return out_offsets
@nb.njit(nogil=True, cache=True)
def _length_walk_n_keep(
sp_genos: NDArray[V_IDX_TYPE],
v_starts: NDArray[np.int32],
ilens: NDArray[np.int32],
start_idx: int,
max_idx: int,
q_start: POS_TYPE,
q_end: POS_TYPE,
) -> int:
"""Number of leading variants in ``sp_genos[start_idx:max_idx]`` to include
so one haplotype reaches ``q_end - q_start`` in length, extending past
``q_end`` only as needed. Variants strictly inside ``[q_start, q_end)`` are
always included; the length budget only gates extension past ``q_end``.
Returns a count in ``[0, max_idx - start_idx]``."""
q_len = q_end - q_start
last_v_end = q_start
written_len = 0
for j in range(start_idx, max_idx):
v_idx = sp_genos[j]
v_start = v_starts[v_idx]
ilen = ilens[v_idx]
maybe_add_one = POS_TYPE(v_start >= q_start)
if v_start >= q_start:
past_query = v_start >= q_end
written_len += v_start - last_v_end
if past_query and written_len >= q_len:
return j - start_idx # exclude this variant
written_len += max(0, ilen) + maybe_add_one
if past_query and written_len >= q_len:
return j - start_idx + 1 # include this variant
v_end = v_start - min(0, ilen) + maybe_add_one
last_v_end = max(last_v_end, v_end)
return max_idx - start_idx
@nb.njit(parallel=True, nogil=True, cache=True)
def _dense2sparse_count(
genos: NDArray[np.integer],
v_starts: NDArray[np.int32],
ilens: NDArray[np.int32],
q_start: POS_TYPE,
q_end: POS_TYPE,
out_lengths: NDArray[np.int64],
) -> None:
"""Pass 1: per (sample, haplotype), count the carried ALT calls to keep.
Gathers each haplotype's carried (``== 1``) window-local positions in order
and routes them through :func:`_length_walk_n_keep` (the SAME walk the sparse
path uses, so the two cannot drift). Writes the kept count to ``out_lengths``.
"""
n_samples, ploidy, n_var = genos.shape
for s in nb.prange(n_samples):
carriers = np.empty(n_var, dtype=V_IDX_TYPE)
for p in range(ploidy):
nc = 0
for v in range(n_var):
if genos[s, p, v] == 1:
carriers[nc] = v
nc += 1
out_lengths[s, p] = _length_walk_n_keep(
carriers, v_starts, ilens, 0, nc, q_start, q_end
)
@nb.njit(parallel=True, nogil=True, cache=True)
def _dense2sparse_fill(
genos: NDArray[np.integer],
var_idxs: NDArray[V_IDX_TYPE],
dosages: NDArray[DOSAGE_TYPE],
out_lengths: NDArray[np.int64],
flat_offsets: NDArray[OFFSET_TYPE],
out_data: NDArray[V_IDX_TYPE],
out_dose: NDArray[DOSAGE_TYPE],
has_dose: bool,
) -> None:
"""Pass 2: emit the first ``out_lengths[s, p]`` carried ALT calls per
haplotype into the disjoint output range ``[flat_offsets[slot], ...)``."""
n_samples, ploidy, n_var = genos.shape
for s in nb.prange(n_samples):
for p in range(ploidy):
slot = s * ploidy + p
n_keep = out_lengths[s, p]
w = flat_offsets[slot]
kept = 0
for v in range(n_var):
if kept >= n_keep:
break
if genos[s, p, v] == 1:
out_data[w] = var_idxs[v]
if has_dose:
out_dose[w] = dosages[s, v]
w += 1
kept += 1
@nb.njit(parallel=False, nogil=True, cache=True)
def _find_starts_ends_with_length(
genos: NDArray[V_IDX_TYPE],
geno_offsets: NDArray[OFFSET_TYPE],
q_starts: NDArray[POS_TYPE],
q_ends: NDArray[POS_TYPE],
var_ranges: NDArray[V_IDX_TYPE],
v_starts: NDArray[np.int32],
ilens: NDArray[np.int32],
sample_idxs: NDArray[np.int64],
ploidy: int,
contig_max_idx: int,
out: NDArray[OFFSET_TYPE] | None = None,
):
"""Find the start and end offsets of the sparse genotypes for each range.
Parameters
----------
genos
Sparse genotypes
geno_offsets
Genotype offsets
var_ranges
Shape = (ranges 2) Variant index ranges.
Notes
-----
Correctness requires that ``argsort(q_starts) == argsort(var_ranges[:, 0])``,
i.e. that the per-range query positions and variant-index ranges are
co-monotone in input order. This holds whenever ``var_ranges`` is derived
from ``(q_starts, q_ends)`` (e.g. via ``SparseVar.var_ranges``). The
function sorts ``var_ranges`` internally but indexes ``q_starts`` /
``q_ends`` by the same sorted position, so violating this invariant will
produce results aligned to the wrong query.
Returns
-------
Shape: (2 ranges samples ploidy). The first column is the start index of the variant
and the second column is the end index of the variant.
"""
n_ranges = len(q_starts)
n_samples = len(sample_idxs)
if out is None:
out = np.empty((2, n_ranges, n_samples, ploidy), dtype=OFFSET_TYPE)
sorter = np.argsort(var_ranges[:, 0])
var_ranges = var_ranges[sorter]
for s in nb.prange(n_samples):
for p in nb.prange(ploidy):
s_idx = sample_idxs[s]
sp = s_idx * ploidy + p
o_s, o_e = geno_offsets[sp], geno_offsets[sp + 1]
sp_genos = genos[o_s:o_e]
max_idx = np.searchsorted(sp_genos, contig_max_idx + 1)
start_idxs = np.searchsorted(sp_genos, var_ranges[:, 0])
for r in range(n_ranges):
start_idx: np.intp = start_idxs[r]
if var_ranges[r, 0] == var_ranges[r, 1]:
out[:, r, s, p] = np.iinfo(OFFSET_TYPE).max
continue
# add o_s to make indices relative to whole array
out[0, r, s, p] = start_idx + o_s
if start_idx == max_idx:
# no variants in this range
out[1, r, s, p] = start_idx + o_s
continue
n_keep = _length_walk_n_keep(
sp_genos,
v_starts,
ilens,
start_idx,
max_idx,
q_starts[r],
q_ends[r],
)
out[1, r, s, p] = start_idx + o_s + n_keep
unsorter = np.argsort(sorter)
out[:] = out[:, unsorter]
return out
def _empty_annot() -> pl.DataFrame:
"""Return an empty annotation DataFrame with the correct schema."""
return pl.DataFrame(
{"varID": [], "gene_id": [], "strand": [], "codon_pos": []},
schema={
"varID": pl.UInt32,
"gene_id": pl.Utf8,
"strand": pl.Utf8,
"codon_pos": pl.Int8,
},
)
def _get_strand_and_codon_pos(
cds_df: pl.DataFrame, var_table: pl.DataFrame, contig_normalizer: ContigNormalizer
) -> pl.DataFrame:
"""
Calculate strand and codon position for variants overlapping CDS regions.
Parameters
----------
cds_df : pl.DataFrame
CDS features from GTF with columns: chrom, start, end, strand, frame,
gene_id, transcript_id, gene_biotype, transcript_support_level, tag
coordinates should be 1-based
var_table : pl.DataFrame
Variant table with columns: index, CHROM, POS, ILEN, ...
POS should be 1-based
contig_normalizer : ContigNormalizer
Normalizer to match chromosome names between CDS and granges
Returns
-------
pl.DataFrame
Annotation with varID, gene_id, strand, codon_pos
"""
# Normalize CDS chromosome names to match granges
# Cast to string first to avoid categorical comparison issues
cds_df = cds_df.with_columns(
pl.col("chrom").cast(pl.Utf8).replace(contig_normalizer.contig_map)
)
# Filter out CDS features with chromosomes not in granges
cds_df = cds_df.filter(pl.col("chrom").is_in(contig_normalizer.contigs))
cds_df.config_meta.set(coordinate_system_zero_based=False) # type: ignore
# Prepare var_table for pb.overlap by creating interval columns
var_intervals = var_table.select(
pl.col("ILEN").list.first(),
var_id="index",
chrom="CHROM",
start=pl.col("POS"),
end=pl.col("POS")
- pl.col("ILEN").list.first().clip(upper_bound=0).fill_null(0),
)
var_intervals.config_meta.set(coordinate_system_zero_based=False) # type: ignore
# Check if CDS or var_table is empty
if cds_df.is_empty() or var_table.is_empty():
return _empty_annot()
joined_cds = (
cast(
pl.LazyFrame,
pb.overlap(var_intervals, cds_df, projection_pushdown=True),
)
.rename(
{
"start_1": "pos",
"start_2": "cds_start",
"end_2": "cds_end",
}
)
.drop("end_1", "chrom_1", "chrom_2")
.rename(lambda c: c.replace("_2", "").replace("_1", ""))
.collect()
)
if joined_cds.height == 0:
return _empty_annot()
annot = (
joined_cds
# Positive strand: (rel_pos - frame) % 3
# Negative strand: (2 * (rel_pos - frame)) % 3 (reverse complement pattern)
.with_columns(
pl.when(
pl.col("frame").is_not_null()
& (pl.col("frame") <= 2)
& (pl.col("ILEN") == 0)
)
.then(
pl.when(pl.col("strand") == "+")
.then((pl.col("pos") - pl.col("cds_start") - pl.col("frame")) % 3)
.otherwise(
(2 * (pl.col("pos") - pl.col("cds_start") - pl.col("frame"))) % 3
)
)
.cast(pl.Int8)
.alias("codon_pos")
)
# Get the gene_id, strand, and codon_pos.
# If there are any duplicates, choose the one with the best rank, breaking ties by choosing the first seen.
.with_columns(
# Rank 0 is best, higher ranks are worse
pl.when(pl.col("gene_biotype") == "protein_coding")
.then(0)
.otherwise(1)
.alias("rank_pc"),
pl.when(
pl.col("tag").is_not_null()
& pl.col("tag").str.contains(r"^(canonical|appris_principal)")
)
.then(0)
.otherwise(1)
.alias("rank_canonical"),
pl.when(pl.col("transcript_support_level").is_not_null())
.then(
pl.col("transcript_support_level")
.str.extract(r"(\d+)", 1)
.cast(pl.Int16, strict=False)
)
.otherwise(9999)
.alias("rank_tsl"),
# Negative span so larger spans get rank 0 (best)
-(pl.col("cds_end") - pl.col("cds_start")).alias("rank_span"),
)
.sort(
[
"var_id",
"rank_pc",
"rank_canonical",
"rank_tsl",
"rank_span",
"transcript_id",
],
descending=[
False,
False,
False,
False,
False,
False,
], # kept this for code clarity (default also the same)
)
.group_by("var_id", maintain_order=True)
.agg(pl.col("gene_id", "strand", "codon_pos").first())
)
return annot
def _load_gtf(gtf: str | pl.DataFrame) -> pl.DataFrame:
"""Load GTF file as a 1-based polars DataFrame."""
if isinstance(gtf, pl.DataFrame):
return gtf.rename({"seqname": "chrom"}, strict=False)
return (
sp.gtf.scan(str(gtf))
.with_columns(
sp.gtf.attr("gene_id"),
sp.gtf.attr("transcript_id"),
sp.gtf.attr("gene_name"),
sp.gtf.attr("gene_biotype"),
sp.gtf.attr("transcript_support_level"),
sp.gtf.attr("level"),
sp.gtf.attr("tag"),
)
.collect()
.rename({"seqname": "chrom"}, strict=False)
)