Source code for eplace_lib.taxonomy

"""
Taxonomy extraction and sequence retrieval module.

This module provides functionality for extracting taxonomic information from BLAST results,
selecting representative sequences per taxonomic rank, and extracting sequences from databases.
"""

import os
import re
import subprocess
import logging
import sys
from pathlib import Path
from typing import Optional, List, Dict
from collections import defaultdict

from .blast_analysis import BlastHit, normalize_sequence_id

import pytaxonkit


def _subject_id_matches(subject_id: str, target_id: str) -> bool:
    """Return True if *subject_id* refers to the same sequence as *target_id*.

    Exact equality is checked first so that non-NCBI pipe-delimited labels
    (e.g. ``sampleA|42``) are never conflated with an unrelated sequence that
    happens to share the same trailing segment.  Normalized comparison is used
    only as a fallback to handle cases where the same accession appears in
    different formats (e.g. ``gi|...|gb|HQ641676.1|`` vs ``HQ641676.1``, or a
    MAFFT ``_R_`` reverse-complement marker).
    """
    if subject_id == target_id:
        return True
    return normalize_sequence_id(subject_id) == normalize_sequence_id(target_id)

# Configure module logger
logger = logging.getLogger(__name__)

# Valid taxonomic ranks supported by the library
VALID_RANKS = ['domain', 'phylum', 'class', 'order', 'family', 'genus', 'species']


[docs] class TaxonomyExtractor: """ Class for extracting taxonomic information from sequence IDs. """
[docs] def parse_taxids(self, tax_ids: list[str]) -> dict[str, dict[str, tuple[str, str]]]: """ Parse taxonomic information from the taxonomy IDs from the BLAST hits Args: tax_ids: the taxonomy IDs reported by BLAST Returns: dictionary containing the rank and a tuple of the taxonomy ID and the name """ # make sure that duplicate taxids are removed before we look them up tax_ids = list(set(tax_ids)) taxonomy_dict = {} # we need to get the whole lineage, and then convert it to a dict try: df = pytaxonkit.lineage(tax_ids) except Exception: logger.exception("Error retrieving taxonomic lineages") sys.exit(1) df['names'] = df['FullLineage'].str.split(';') df['taxids'] = df['FullLineageTaxIDs'].str.split(';') df['ranks'] = df['FullLineageRanks'].str.split(';') long_df = df.explode(['names', 'taxids', 'ranks']) filtered = long_df[long_df['ranks'].isin(VALID_RANKS)] for tid, rank, taxid, name in ( filtered[['TaxID', 'ranks', 'taxids', 'names']] .drop_duplicates() .itertuples(index=False, name=None) ): tid = str(tid) taxid = str(taxid) taxonomy_dict.setdefault(tid, {})[rank] = (taxid, name) return taxonomy_dict
[docs] def group_hits_by_query( self, hits: list[BlastHit] ) -> dict[str, list[BlastHit]]: """ Group BLAST hits by query sequence. Args: hits: list of BlastHit objects Returns: dictionary mapping query IDs to lists of hits """ grouped = defaultdict(list) for hit in hits: grouped[hit.query_id].append(hit) return dict(grouped)
[docs] def select_representatives_by_rank( self, hits: list[BlastHit], rank: str, max_per_rank: int = 1, preferred_representatives: Optional[Dict[str, str]] = None ) -> list[BlastHit]: """ Select representative sequences per taxonomic rank. Args: hits: list of BlastHit objects for a single query rank: Taxonomic rank for representative selection max_per_rank: Maximum number of representatives per rank (default: 1) preferred_representatives: Optional dictionary mapping rank_tid to preferred subject_id to ensure consistent representatives across queries Returns: list of representative BlastHit objects """ if preferred_representatives is None: preferred_representatives = {} # Group hits by taxonomic rank (using subject_id as proxy) rank_groups = defaultdict(list) reported_hits = set() for hit in hits: if not hit.subject_taxonomy: # No taxonomy available (e.g. MMseqs2 database without taxonomy). # Fall back to grouping by subject_id so the hit still contributes # a representative rather than being silently dropped. logger.info( f"No taxonomy for hit {hit.subject_id} (query {hit.query_id}); " f"using subject_id as fallback group key" ) rank_groups[hit.subject_id].append(hit) continue if rank not in hit.subject_taxonomy: logger.info( f"We did not find {rank} in the taxonomy of {hit.query_id} which has subject taxid of {hit.subject_taxid}") continue if not hit.subject_taxonomy[rank]: logger.warning( f"Hit {hit.subject_id} for query {hit.query_id} has no taxonomic information at rank {rank}") continue if isinstance(hit.subject_taxonomy[rank], tuple): # Log the first time we see each rank name if hit.subject_taxonomy[rank][1] not in reported_hits: logger.info(f"Found a hit for {hit.query_id} at rank {rank}: {hit.subject_taxonomy[rank][1]} ({hit.subject_taxonomy[rank][0]})") reported_hits.add(hit.subject_taxonomy[rank][1]) # Add all hits with taxonomic information to rank_groups rank_groups[hit.subject_taxonomy[rank][1]].append(hit) else: logger.warning(f"Not really sure what {hit.subject_taxonomy[rank]} of type {type(hit.subject_taxonomy[rank])} is supposed to be") # Select best representative from each rank representatives = [] for rank_key, rank_hits in rank_groups.items(): # Check if we have a preferred representative for this rank preferred_subject_id = preferred_representatives.get(rank_key) if preferred_subject_id: # Look for the preferred representative in the current hits. # Try exact match first; fall back to normalized comparison to handle # NCBI format differences (e.g. gi|...|gb|ACC| vs ACC). preferred_hit = next( (hit for hit in rank_hits if _subject_id_matches(hit.subject_id, preferred_subject_id)), None ) if preferred_hit: # Use the preferred representative logger.info(f"Reusing previously selected representative {preferred_subject_id} for rank {rank_key}") representatives.append(preferred_hit) continue # No preferred representative or it's not in current hits # Sort by bit score (best first) and select new representative rank_hits.sort(key=lambda h: h.bit_score, reverse=True) # Take top N representatives representatives.extend(rank_hits[:max_per_rank]) logger.info( f"Selected {len(representatives)} representative sequences from {len(hits)} hits at rank '{rank}'" ) return representatives
[docs] class SequenceExtractor: """ Class for extracting sequences from BLAST databases. """
[docs] def __init__(self, blastdb_path: Optional[Path] = None): """ Initialize the SequenceExtractor. Args: blastdb_path: Path to BLAST database directory. If None, uses BLASTDB env var. """ self.blastdb_path = blastdb_path if self.blastdb_path is None: blastdb_env = os.environ.get('BLASTDB') if blastdb_env: self.blastdb_path = Path(blastdb_env) else: self.blastdb_path = Path.home() / "blastdb"
[docs] def check_blastdbcmd_available(self) -> bool: """ Check if blastdbcmd is available in the system. Returns: True if blastdbcmd is available, False otherwise """ try: result = subprocess.run( ['blastdbcmd', '-version'], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 except (subprocess.SubprocessError, FileNotFoundError): return False
[docs] def extract_sequences( self, sequence_ids: list[str], output_fasta: Path, database: str = "core_nt" ) -> bool: """ Extract sequences from BLAST database using blastdbcmd. Args: sequence_ids: list of sequence IDs to extract output_fasta: Path to output FASTA file database: Name of BLAST database (default: "core_nt") Returns: True if extraction was successful, False otherwise Raises: RuntimeError: If blastdbcmd is not available """ if not self.check_blastdbcmd_available(): raise RuntimeError("blastdbcmd is not available. Please install BLAST+ tools.") if not sequence_ids: logger.warning("No sequence IDs provided for extraction") return False # Build database path db_path = self.blastdb_path / database # Create a temporary file with sequence IDs id_file = output_fasta.parent / f"{output_fasta.stem}_ids.txt" try: with open(id_file, 'w') as f: for seq_id in sequence_ids: f.write(f"{seq_id}\n") # Run blastdbcmd cmd = [ 'blastdbcmd', '-db', str(db_path), '-entry_batch', str(id_file), '-out', str(output_fasta) ] logger.info(f"Extracting {len(sequence_ids)} sequences from database") result = subprocess.run( cmd, capture_output=True, text=True, timeout=600 # 10 minute timeout ) if result.returncode != 0: logger.error(f"blastdbcmd failed with error: {result.stderr}") return False logger.info(f"Sequences extracted successfully to {output_fasta}") return True except subprocess.TimeoutExpired: logger.error("Sequence extraction timed out") return False except Exception as e: logger.error(f"Error extracting sequences (taxonomy): {e}") return False finally: # Clean up temporary ID file if id_file.exists(): id_file.unlink()
[docs] def extract_representatives_for_query( self, query_id: str, representative_hits: list[BlastHit], output_dir: Path, database: str = "core_nt" ) -> Optional[Path]: """ Extract representative sequences for a single query to a FASTA file. Args: query_id: Query sequence identifier representative_hits: list of representative BlastHit objects output_dir: Output directory for FASTA files database: Name of BLAST database Returns: Path to output FASTA file if successful, None otherwise """ if not representative_hits: logger.warning(f"No representative hits for query {query_id}") return None # Create output directory if it doesn't exist output_dir.mkdir(parents=True, exist_ok=True) # Generate output filename safe_query_id = query_id.replace('|', '_').replace('/', '_') output_fasta = output_dir / f"{safe_query_id}_representatives.fasta" # Extract sequence IDs sequence_ids = [hit.subject_id for hit in representative_hits] # Extract sequences success = self.extract_sequences( sequence_ids=sequence_ids, output_fasta=output_fasta, database=database ) if success: return output_fasta else: return None
[docs] def rewrite_blast_hits( blast_hits: List[BlastHit], output_file: Path, header: bool = True) -> bool: """ Rewrite the blast hits when we have annotated them Args: blast_hits: list of BlastHit objects output_file: the file to write to header: whether to include a header line in the file Returns: True on success """ fields = [ "query_id", "subject_id", "percent_identity", "alignment_length", "query_length", "subject_length", "query_start", "query_end", "subject_start", "subject_end", "evalue", "bit_score", "query_coverage", "subject_taxid", "subject_taxids", "subject_taxonomy" ] with open(output_file, 'w') as out: if header: print("\t".join(fields), file=out) for hit in blast_hits: print( "\t".join( "" if getattr(hit, f) is None else str(getattr(hit, f)) for f in fields ), file=out ) return True
[docs] def process_blast_results_for_taxonomy( blast_hits: List[BlastHit], output_dir: Path, rank: str = "genus", database: str = "core_nt", blastdb_path: Optional[Path] = None ) -> Dict[str, Optional[Path]]: """ Process BLAST hits to extract representative sequences per taxonomic rank. Args: blast_hits: list of BlastHit objects output_dir: Output directory for FASTA files rank: Taxonomic rank for representative selection database: Name of BLAST database blastdb_path: Path to BLAST database directory Returns: dictionary mapping query IDs to output FASTA file paths """ if rank not in VALID_RANKS: raise ValueError(f"Rank: {rank} is not a valid rank. It must be one of: {VALID_RANKS}") tax_extractor = TaxonomyExtractor() seq_extractor = SequenceExtractor(blastdb_path) # get all the taxonomies subject_taxids = {hit.subject_taxid for hit in blast_hits} tax_dict = tax_extractor.parse_taxids(list(subject_taxids)) # add all the ranks to all the hits for h in blast_hits: h.subject_taxonomy = tax_dict.get(h.subject_taxid) # Group hits by query grouped_hits = tax_extractor.group_hits_by_query(blast_hits) # Track selected representatives across queries to ensure consistency # Maps rank_tid -> subject_id of the selected representative preferred_representatives = {} # Process each query results = {} for query_id, query_hits in grouped_hits.items(): logger.info(f"Processing query {query_id} with {len(query_hits)} hits") # Select representatives, preferring previously selected ones representatives = tax_extractor.select_representatives_by_rank( hits=query_hits, rank=rank, preferred_representatives=preferred_representatives ) if len(representatives) == 0: logger.warning(f"Error: No representative sequences for {query_id} at rank {rank}") continue # Update the preferred representatives with newly selected ones for rep in representatives: if ( rep.subject_taxonomy and rank in rep.subject_taxonomy and isinstance(rep.subject_taxonomy[rank], tuple) and rep.subject_taxonomy[rank][1] not in preferred_representatives ): preferred_representatives[rep.subject_taxonomy[rank][1]] = rep.subject_id logger.info(f"Recording {rep.subject_id} as representative for rank {rep.subject_taxonomy[rank][1]}") # Create query-specific output directory query_output_dir = output_dir / query_id.replace('|', '_').replace('/', '_') # Extract sequences output_fasta = seq_extractor.extract_representatives_for_query( query_id=query_id, representative_hits=representatives, output_dir=query_output_dir, database=database ) results[query_id] = output_fasta return results
[docs] def sort_strings_and_numbers(s: str): """ Extract text and numbers from strings for proper sorting. Args: s: string to extract the number from Returns: Returns: A tuple ``(text_part, num_part)`` that can be used as a sort key. For strings matching the pattern ``<non-digits><digits>``, this is the non-digit prefix and the trailing integer. For non-matching strings, returns ``(s, 0)``. """ match = re.match(r'(\D+)(\d+)', s) if match: text_part = match.group(1) num_part = int(match.group(2)) return (text_part, num_part) return (s, 0)
[docs] def generate_classification_summary( sequences: dict[str, str], blast_hits: List[BlastHit], output_file: Path, rank: str = "genus", group_rank: str = "class", tree_label_rank: str = "genus", tree_files: Optional[dict[str, Path]] = None ) -> bool: """ Generate a classification summary TSV file for each query sequence. This function creates a TSV file that reports: - Query sequence ID - Closest organism at the classification rank (--rank) - Closest organism at the grouping rank (--group-rank) - Closest organism at the tree labeling rank (--tree-label-rank) - Whether the sequence appears in multiple groups - Whether the sequence has no appropriate classification The classification is based on the phylogenetically nearest neighbor in the tree (if available), otherwise falls back to the best BLAST hit by bit score. Args: sequences: dictionary of sequences that we read from the fasta file blast_hits: List of BlastHit objects with taxonomy information output_file: Path to output TSV file rank: Taxonomic rank for classification (default: genus) group_rank: Taxonomic rank for grouping (default: class) tree_label_rank: Taxonomic rank for tree labeling (default: genus) tree_files: Optional dict mapping query_id to tree file paths for finding nearest neighbors Returns: True if successful, False otherwise """ logger.info(f"Generating classification summary TSV to {output_file}") # Validate ranks for r, r_name in [(rank, 'rank'), (group_rank, 'group_rank'), (tree_label_rank, 'tree_label_rank')]: if r not in VALID_RANKS: logger.error(f"{r_name}: {r} is not a valid rank. It must be one of: {VALID_RANKS}") return False # Group hits by query query_hits_map = defaultdict(list) for hit in blast_hits: query_hits_map[hit.query_id].append(hit) # Collect all query IDs that were searched all_query_ids = set(sequences.keys()) # Prepare data for each query summary_data = [] for query_id in sorted(all_query_ids, key=sort_strings_and_numbers): query_hits = query_hits_map.get(query_id, []) # Initialize classification info classification = { 'query_id': query_id, 'blast_hits': 0, 'taxonomy_blast': ';;;;;', 'blast_classification_rank': rank, 'blast_classification_taxid': 'N/A', 'blast_classification_name': 'N/A', 'blast_group_rank': group_rank, 'blast_group_taxid': 'N/A', 'blast_group_name': 'N/A', 'blast_tree_label_rank': tree_label_rank, 'blast_tree_label_taxid': 'N/A', 'blast_tree_label_name': 'N/A', 'taxonomy_tree': ';;;;;', 'tree_classification_rank': rank, 'tree_classification_taxid': 'N/A', 'tree_classification_name': 'N/A', 'tree_group_rank': group_rank, 'tree_group_taxid': 'N/A', 'tree_group_name': 'N/A', 'tree_tree_label_rank': tree_label_rank, 'tree_tree_label_taxid': 'N/A', 'tree_tree_label_name': 'N/A', 'tree_based_classification': 'No', 'appears_in_multiple_groups': 'No', 'has_classification': 'Yes' } if not query_hits: # No hits for this query classification['has_classification'] = 'No' summary_data.append(classification) continue classification['blast_hits'] = len(query_hits) # Get the best BLAST hit (highest bit score) for BLAST-based classification blast_best_hit = max(query_hits, key=lambda h: h.bit_score) # Populate BLAST-based classification if blast_best_hit.subject_taxonomy: classification['taxonomy_blast'] = ';'.join([blast_best_hit.subject_taxonomy[r][1] if r in blast_best_hit.subject_taxonomy else "" for r in VALID_RANKS]) # Extract BLAST-based classification at different ranks blast_missing_ranks = [] if blast_best_hit.subject_taxonomy and rank in blast_best_hit.subject_taxonomy: taxid, name = blast_best_hit.subject_taxonomy[rank] classification['blast_classification_taxid'] = taxid classification['blast_classification_name'] = name else: blast_missing_ranks.append(rank) if blast_best_hit.subject_taxonomy and group_rank in blast_best_hit.subject_taxonomy: taxid, name = blast_best_hit.subject_taxonomy[group_rank] classification['blast_group_taxid'] = taxid classification['blast_group_name'] = name else: blast_missing_ranks.append(group_rank) if blast_best_hit.subject_taxonomy and tree_label_rank in blast_best_hit.subject_taxonomy: taxid, name = blast_best_hit.subject_taxonomy[tree_label_rank] classification['blast_tree_label_taxid'] = taxid classification['blast_tree_label_name'] = name else: blast_missing_ranks.append(tree_label_rank) # Try to get tree-based classification tree_best_hit = None if tree_files and query_id in tree_files: # Try to find the nearest neighbor in the phylogenetic tree tree_file = tree_files[query_id] if tree_file and tree_file.exists(): # Import here to avoid circular dependency from .alignment import find_nearest_neighbor_in_tree nearest_neighbor = find_nearest_neighbor_in_tree(tree_file, query_id) if nearest_neighbor: # Find the BLAST hit corresponding to the nearest neighbor. # Try exact match first; fall back to normalized comparison to handle # NCBI format differences and MAFFT _R_ markers. for hit in query_hits: if _subject_id_matches(hit.subject_id, nearest_neighbor): tree_best_hit = hit classification['tree_based_classification'] = 'Yes' break if tree_best_hit: logger.info(f"Tree-based nearest neighbor for {query_id}: {nearest_neighbor}") else: logger.debug(f"Tree nearest neighbor {nearest_neighbor} not found in BLAST hits for {query_id}") # Populate tree-based classification if available if tree_best_hit: if tree_best_hit.subject_taxonomy: classification['taxonomy_tree'] = ';'.join([tree_best_hit.subject_taxonomy[r][1] if r in tree_best_hit.subject_taxonomy else "" for r in VALID_RANKS]) tree_missing_ranks = [] if tree_best_hit.subject_taxonomy and rank in tree_best_hit.subject_taxonomy: taxid, name = tree_best_hit.subject_taxonomy[rank] classification['tree_classification_taxid'] = taxid classification['tree_classification_name'] = name else: tree_missing_ranks.append(rank) if tree_best_hit.subject_taxonomy and group_rank in tree_best_hit.subject_taxonomy: taxid, name = tree_best_hit.subject_taxonomy[group_rank] classification['tree_group_taxid'] = taxid classification['tree_group_name'] = name else: tree_missing_ranks.append(group_rank) if tree_best_hit.subject_taxonomy and tree_label_rank in tree_best_hit.subject_taxonomy: taxid, name = tree_best_hit.subject_taxonomy[tree_label_rank] classification['tree_tree_label_taxid'] = taxid classification['tree_tree_label_name'] = name else: tree_missing_ranks.append(tree_label_rank) # Set classification status based on BLAST missing ranks if blast_missing_ranks: if len(blast_missing_ranks) == 3: # All three ranks missing classification['has_classification'] = 'No' else: classification['has_classification'] = 'Partial' # Check if sequence appears in multiple groups at the group_rank level group_names = set() group_taxids = set() for hit in query_hits: if hit.subject_taxonomy and group_rank in hit.subject_taxonomy: taxid, name = hit.subject_taxonomy[group_rank] group_names.add(name) group_taxids.add(str(taxid)) if len(group_names) > 1: classification['appears_in_multiple_groups'] = 'Yes' # Update BLAST group names/taxids to show all groups classification['blast_group_name'] = '; '.join(sorted(group_names)) classification['blast_group_taxid'] = '; '.join(sorted(group_taxids)) summary_data.append(classification) # Write TSV file try: with open(output_file, 'w') as f: # Write header with both BLAST and tree-based columns headers = [ 'query_id', 'blast_hits', 'taxonomy_blast', 'blast_classification_rank', 'blast_classification_taxid', 'blast_classification_name', 'blast_group_rank', 'blast_group_taxid', 'blast_group_name', 'blast_tree_label_rank', 'blast_tree_label_taxid', 'blast_tree_label_name', 'tree_based_classification', 'taxonomy_tree', 'tree_classification_rank', 'tree_classification_taxid', 'tree_classification_name', 'tree_group_rank', 'tree_group_taxid', 'tree_group_name', 'tree_tree_label_rank', 'tree_tree_label_taxid', 'tree_tree_label_name', 'appears_in_multiple_groups', 'has_classification' ] f.write('\t'.join(headers) + '\n') # Write data for entry in summary_data: row = [ entry['query_id'], str(entry['blast_hits']), entry['taxonomy_blast'], entry['blast_classification_rank'], entry['blast_classification_taxid'], entry['blast_classification_name'], entry['blast_group_rank'], entry['blast_group_taxid'], entry['blast_group_name'], entry['blast_tree_label_rank'], entry['blast_tree_label_taxid'], entry['blast_tree_label_name'], entry['tree_based_classification'], entry['taxonomy_tree'], entry['tree_classification_rank'], entry['tree_classification_taxid'], entry['tree_classification_name'], entry['tree_group_rank'], entry['tree_group_taxid'], entry['tree_group_name'], entry['tree_tree_label_rank'], entry['tree_tree_label_taxid'], entry['tree_tree_label_name'], entry['appears_in_multiple_groups'], entry['has_classification'] ] f.write('\t'.join(row) + '\n') logger.info(f"Successfully wrote classification summary for {len(summary_data)} queries to {output_file}") return True except Exception as e: logger.error(f"Error writing classification summary TSV: {e}") return False