Source code for eplace_lib.blast_analysis

"""
BLAST analysis module for sequence comparison.

This module provides functionality for running BLAST searches and filtering results
based on sequence identity and coverage criteria.
"""

import os
import re
import subprocess
import logging
from typing import Optional, Dict, Tuple
from pathlib import Path
from dataclasses import dataclass

# Configure module logger
logger = logging.getLogger(__name__)


[docs] @dataclass class BlastHit: """ Represents a single BLAST hit result. Attributes: query_id: Query sequence identifier subject_id: Subject (database) sequence identifier percent_identity: Percentage of identical matches alignment_length: Length of alignment query_length: Length of query sequence subject_length: Length of subject sequence query_start: Start position in query query_end: End position in query subject_start: Start position in subject subject_end: End position in subject evalue: Expectation value bit_score: Bit score query_coverage: Percentage of query covered by alignment subject_taxonomy: The subjects taxonomy information. A dictionary with rank as key and a tuple of (taxid, name) as value. """ query_id: str subject_id: str percent_identity: float alignment_length: int query_length: int subject_length: int query_start: int query_end: int subject_start: int subject_end: int evalue: float bit_score: float query_coverage: float subject_taxid: str subject_taxids: str subject_taxonomy: Optional[Dict[str, Tuple[str, str]]] = None
[docs] def get_subject_taxonomy(self, rank: str) -> Optional[tuple[str, str]]: """ Return the taxonomy information as a tuple of (taxid, name) for the given rank. If the rank is not found, return None. Args: rank: The rank to return the taxonomy information for. Returns: tuple of (taxid, name) for the given rank, or None if the rank is not found. """ if not self.subject_taxonomy: logger.warning("Taxonomy information is not available. Please parse it before calling this function.") return None if rank in self.subject_taxonomy: return self.subject_taxonomy[rank] else: return None
[docs] def get_accession(self) -> str: """ Extract the accession number from the subject_id. BLAST IDs can be in various formats: - gi|2273658778|gb|MZ387488.1| -> MZ387488.1 - ref|NZ_CP123456.1| -> NZ_CP123456.1 - gb|MZ387488.1| -> MZ387488.1 - MZ387488.1 -> MZ387488.1 (already in accession format) Note: gnl|database|identifier format is handled by returning the identifier, but these may not be standard accessions. Returns: The accession number extracted from subject_id, or the full subject_id if no standard format is detected """ return _extract_accession_from_pipe_id(self.subject_id)
def _extract_accession_from_pipe_id(seq_id: str) -> str: """ Extract the accession number from a pipe-delimited NCBI-style sequence ID. Handles formats such as: - gi|2273658778|gb|MZ387488.1| -> MZ387488.1 - ref|NZ_CP123456.1| -> NZ_CP123456.1 - gb|MZ387488.1| -> MZ387488.1 - gnl|BL_ORD_ID|12345 -> 12345 - MZ387488.1 -> MZ387488.1 (returned unchanged) - sampleA|42 -> sampleA|42 (no known pattern, returned unchanged) No generic fallback is applied: if the ID contains pipes but does not match a known NCBI format (gnl or a standard db-prefix like gb/ref/emb/…), the original string is returned unchanged. This prevents unrelated IDs that share the same trailing pipe-segment (e.g. ``sampleA|42`` and ``sampleB|42``) from being incorrectly considered equivalent. Args: seq_id: Sequence identifier, which may or may not be pipe-delimited. Returns: Extracted accession, or *seq_id* itself if no pipe-delimited structure is recognised. """ if '|' in seq_id: parts = seq_id.split('|') # gnl|database|identifier format if len(parts) >= 3 and parts[0] == 'gnl': return parts[2] # Known database prefixes: gi|123|gb|ACC.1|, ref|ACC.1|, gb|ACC.1|, etc. db_identifiers = ['gb', 'ref', 'emb', 'dbj', 'pdb', 'prf', 'sp', 'tr'] for i, part in enumerate(parts): if part in db_identifiers: if i + 1 < len(parts) and parts[i + 1]: return parts[i + 1] return seq_id
[docs] def normalize_sequence_id(seq_id: str) -> str: """ Normalize an arbitrary sequence or tree label to a canonical accession-like identifier. This is used to compare IDs from different sources (BLAST subject IDs, FASTA headers, tree leaf labels) that may be formatted differently but refer to the same sequence. Normalization steps: 1. Strip a leading '>' (FASTA header prefix). 2. Take only the first whitespace-delimited token. 3. Remove MAFFT reverse-complement markers: a leading '_R_' prefix or a trailing '_R_' suffix. 4. If the token contains pipes ('|'), extract the accession via _extract_accession_from_pipe_id() (gi|...|gb|ACC|, ref|ACC|, gb|ACC|, etc.). 5. Otherwise return the token unchanged. Args: seq_id: Raw sequence identifier from any source. Returns: Canonical accession string suitable for exact comparison. """ if not seq_id: return seq_id # Step 1: strip leading '>' token = seq_id.lstrip('>') # Step 2: take first whitespace-delimited token parts = token.split() token = parts[0] if parts else token # Step 3: remove MAFFT reverse-complement markers if token.startswith('_R_'): token = token[3:] if token.endswith('_R_'): token = token[:-3] # Step 4: parse pipe-delimited NCBI-style identifiers return _extract_accession_from_pipe_id(token)
[docs] class FastaReader: """ Class for reading FASTA files. """
[docs] @staticmethod def read_fasta(fasta_path: Path) -> dict[str, str]: """ Read sequences from a FASTA file. Args: fasta_path: Path to the FASTA file Returns: dictionary mapping sequence IDs to sequences Raises: FileNotFoundError: If FASTA file doesn't exist ValueError: If FASTA file is malformed """ if not fasta_path.exists(): raise FileNotFoundError(f"FASTA file not found: {fasta_path}") sequences = {} current_id = None current_seq = [] with open(fasta_path, 'r') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): # Save previous sequence if exists if current_id is not None: sequences[current_id] = ''.join(current_seq) # Start new sequence # Remove '>' and take first word as ID current_id = line[1:].split()[0] current_seq = [] else: current_seq.append(line) # Save last sequence if current_id is not None: sequences[current_id] = ''.join(current_seq) if not sequences: raise ValueError(f"No sequences found in FASTA file: {fasta_path}") return sequences
[docs] @staticmethod def get_sequence_lengths(fasta_path: Path) -> dict[str, int]: """ Get the length of each sequence in a FASTA file. Args: fasta_path: Path to the FASTA file Returns: dictionary mapping sequence IDs to their lengths """ sequences = FastaReader.read_fasta(fasta_path) return {seq_id: len(seq) for seq_id, seq in sequences.items()}
[docs] class BlastRunner: """ Class for running BLAST searches and parsing results. """
[docs] def __init__(self, blastdb_path: Optional[Path] = None): """ Initialize the BlastRunner. Args: blastdb_path: Path to BLAST database directory. If None, uses BLASTDB env var. """ self.blastdb_path = blastdb_path if self.blastdb_path is None: blastdb_env = os.environ.get('BLASTDB') if blastdb_env: self.blastdb_path = Path(blastdb_env) else: self.blastdb_path = Path.home() / "blastdb"
[docs] def check_blastn_available(self) -> bool: """ Check if blastn is available in the system. Returns: True if blastn is available, False otherwise """ try: result = subprocess.run( ['blastn', '-version'], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 except (subprocess.SubprocessError, FileNotFoundError): return False
[docs] def run_blastn( self, query_fasta: Path, output_file: Path, database: str = "core_nt", num_threads: int = 1, max_target_seqs: int = 100, evalue: float = 1e-5, outfmt: str = "6 qseqid sseqid pident length qlen slen qstart qend sstart send evalue bitscore staxid staxids" ) -> bool: """ Run blastn search. Args: query_fasta: Path to query FASTA file output_file: Path to output file database: Name of BLAST database (default: "core_nt") num_threads: Number of threads to use max_target_seqs: Maximum number of target sequences to report evalue: E-value threshold outfmt: Output format string Returns: True if BLAST ran successfully, False otherwise Raises: FileNotFoundError: If query file doesn't exist RuntimeError: If blastn is not available """ if not query_fasta.exists(): raise FileNotFoundError(f"Query FASTA file not found: {query_fasta}") if not self.check_blastn_available(): raise RuntimeError("blastn is not available. Please install BLAST+ tools.") # Build database path db_path = self.blastdb_path / database # Build blastn command cmd = [ 'blastn', '-query', str(query_fasta), '-db', str(db_path), '-out', str(output_file), '-outfmt', outfmt, '-num_threads', str(num_threads), '-max_target_seqs', str(max_target_seqs), '-evalue', str(evalue) ] logger.info(f"Running BLAST search: {' '.join(cmd)}") try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=3600 # 1 hour timeout ) if result.returncode != 0: logger.error(f"BLAST failed with error: {result.stderr}") return False logger.info(f"BLAST search completed successfully. Output: {output_file}") return True except subprocess.TimeoutExpired: logger.error("BLAST search timed out") return False except Exception as e: logger.error(f"Error running BLAST: {e}") return False
[docs] def parse_blast_results( self, blast_output: Path, query_lengths: Optional[dict[str, int]] = None ) -> list[BlastHit]: """ Parse BLAST tabular output. Args: blast_output: Path to BLAST output file (tabular format) query_lengths: dictionary of query sequence lengths. If None, uses qlen from results. Returns: list of BlastHit objects Raises: FileNotFoundError: If BLAST output file doesn't exist ValueError: If BLAST output is malformed """ if not blast_output.exists(): raise FileNotFoundError(f"BLAST output file not found: {blast_output}") hits = [] with open(blast_output, 'r') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line or line.startswith('#'): continue fields = line.split('\t') if len(fields) < 12: raise ValueError( f"Invalid BLAST output format at line {line_num}: " f"expected at least 12 fields, got {len(fields)}" ) try: query_id = fields[0] subject_id = fields[1] percent_identity = float(fields[2]) alignment_length = int(fields[3]) query_length = int(fields[4]) subject_length = int(fields[5]) query_start = int(fields[6]) query_end = int(fields[7]) subject_start = int(fields[8]) subject_end = int(fields[9]) evalue = float(fields[10]) bit_score = float(fields[11]) staxid = fields[12] staxids = fields[13] # Calculate query coverage query_coverage = (abs(query_end - query_start) + 1) / query_length * 100 hit = BlastHit( query_id=query_id, subject_id=subject_id, percent_identity=percent_identity, alignment_length=alignment_length, query_length=query_length, subject_length=subject_length, query_start=query_start, query_end=query_end, subject_start=subject_start, subject_end=subject_end, evalue=evalue, bit_score=bit_score, query_coverage=query_coverage, subject_taxid=staxid, subject_taxids=staxids ) hits.append(hit) except (ValueError, IndexError) as e: raise ValueError( f"Error parsing BLAST output at line {line_num}: {e}" ) return hits
[docs] def filter_blast_hits( self, hits: list[BlastHit], min_identity: float = 90.0, min_coverage: float = 80.0, min_alignment_length: Optional[int] = None ) -> list[BlastHit]: """ Filter BLAST hits based on identity and coverage thresholds. Args: hits: list of BlastHit objects min_identity: Minimum percent identity (default: 90.0) min_coverage: Minimum query coverage percentage (default: 80.0) min_alignment_length: Minimum alignment length (optional) Returns: Filtered list of BlastHit objects """ filtered_hits = [] for hit in hits: if hit.percent_identity < min_identity: continue if hit.query_coverage < min_coverage: continue if min_alignment_length and hit.alignment_length < min_alignment_length: continue filtered_hits.append(hit) logger.info( f"Filtered {len(hits)} hits to {len(filtered_hits)} hits " f"(min_identity={min_identity}%, min_coverage={min_coverage}%)" ) return filtered_hits
[docs] class MMseqs2Runner: """ Class for running MMseqs2 searches and parsing results. MMseqs2 (Many-against-Many sequence searching) is an alternative to BLAST for sequence similarity searching, offering improved speed and sensitivity. Results are parsed into BlastHit objects for compatibility with the rest of the ePLACE pipeline. The target database can be either a pre-built MMseqs2 database (created with ``mmseqs createdb``) or a FASTA file that MMseqs2 indexes automatically. Taxonomy fields (taxid) are populated only when the database was built with taxonomy information (``mmseqs createtaxdb``); otherwise they default to "0". **Database selection**: To keep results comparable with the BLAST workflow (which uses NCBI ``core_nt``), the recommended MMseqs2 database should be built from the same underlying sequence collection as ``core_nt``. This means creating an MMseqs2 database from the FASTA sequences that make up NCBI ``core_nt`` (e.g. by exporting them with ``blastdbcmd -db core_nt -entry all`` and then running ``mmseqs createdb``). Using a different nucleotide collection will change the search space and may produce classification differences that reflect the database rather than the search algorithm. There is no official pre-built MMseqs2 ``core_nt`` database; users must provide their own. """
[docs] def __init__(self, db_path: Optional[Path] = None): """ Initialize the MMseqs2Runner. Args: db_path: Path to the MMseqs2 database directory. If None the ``MMSEQS_DB_DIR`` environment variable is used; if unset, ``MMSEQS2DB`` is used as a legacy fallback; if both are unset, the directory ``~/mmseqs2db`` is used. """ self.db_path = db_path if self.db_path is None: mmseqs_env = os.environ.get('MMSEQS_DB_DIR') or os.environ.get('MMSEQS2DB') if mmseqs_env: self.db_path = Path(mmseqs_env) else: self.db_path = Path.home() / "mmseqs2db"
[docs] def check_mmseqs_available(self) -> bool: """ Check if mmseqs is available in the system PATH. Returns: True if mmseqs is available, False otherwise """ try: result = subprocess.run( ['mmseqs', 'version'], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 except (subprocess.SubprocessError, FileNotFoundError): return False
[docs] def parse_mmseqs_results( self, mmseqs_output: Path, query_lengths: Optional[dict[str, int]] = None ) -> list[BlastHit]: """ Parse MMseqs2 tabular output into BlastHit objects. Expects output generated with ``--format-output`` set to: ``query,target,pident,alnlen,qlen,tlen,qstart,qend,tstart,tend,evalue,bits,taxid,taxlineage`` The ``taxid`` and ``taxlineage`` columns are optional; if absent or set to "N/A" / "0", ``subject_taxid`` will be stored as "0". Args: mmseqs_output: Path to MMseqs2 output file query_lengths: Unused; kept for API compatibility with BlastRunner.parse_blast_results. Returns: list of BlastHit objects Raises: FileNotFoundError: If the output file doesn't exist ValueError: If the output is malformed """ if not mmseqs_output.exists(): raise FileNotFoundError(f"MMseqs2 output file not found: {mmseqs_output}") hits = [] with open(mmseqs_output, 'r') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line or line.startswith('#'): continue fields = line.split('\t') if len(fields) < 12: raise ValueError( f"Invalid MMseqs2 output format at line {line_num}: " f"expected at least 12 fields, got {len(fields)}" ) try: query_id = fields[0] subject_id = fields[1] percent_identity = float(fields[2]) alignment_length = int(fields[3]) query_length = int(fields[4]) subject_length = int(fields[5]) query_start = int(fields[6]) query_end = int(fields[7]) subject_start = int(fields[8]) subject_end = int(fields[9]) evalue = float(fields[10]) bit_score = float(fields[11]) # taxid field (may be absent or "N/A" when database has no taxonomy) taxid = "0" if len(fields) > 12 and fields[12] not in ("", "N/A", "0"): taxid = fields[12] # Calculate query coverage query_coverage = (abs(query_end - query_start) + 1) / query_length * 100 hit = BlastHit( query_id=query_id, subject_id=subject_id, percent_identity=percent_identity, alignment_length=alignment_length, query_length=query_length, subject_length=subject_length, query_start=query_start, query_end=query_end, subject_start=subject_start, subject_end=subject_end, evalue=evalue, bit_score=bit_score, query_coverage=query_coverage, subject_taxid=taxid, subject_taxids=taxid ) hits.append(hit) except (ValueError, IndexError) as e: raise ValueError( f"Error parsing MMseqs2 output at line {line_num}: {e}" ) return hits
[docs] def filter_hits( self, hits: list[BlastHit], min_identity: float = 90.0, min_coverage: float = 80.0, min_alignment_length: Optional[int] = None ) -> list[BlastHit]: """ Filter MMseqs2 hits based on identity and coverage thresholds. Args: hits: list of BlastHit objects min_identity: Minimum percent identity (default: 90.0) min_coverage: Minimum query coverage percentage (default: 80.0) min_alignment_length: Minimum alignment length (optional) Returns: Filtered list of BlastHit objects """ filtered_hits = [] for hit in hits: if hit.percent_identity < min_identity: continue if hit.query_coverage < min_coverage: continue if min_alignment_length and hit.alignment_length < min_alignment_length: continue filtered_hits.append(hit) logger.info( f"Filtered {len(hits)} hits to {len(filtered_hits)} hits " f"(min_identity={min_identity}%, min_coverage={min_coverage}%)" ) return filtered_hits
[docs] def validate_mmseqs_memory_limit(value: str) -> str: """Validate a MMseqs2-style memory limit string. Accepts a positive integer (no leading zeros) followed by a single unit suffix ``K``, ``M``, ``G``, or ``T`` (case-sensitive, no space). Examples of valid values:: 64G 128G 400G 1T 512M Args: value: The memory limit string to validate. Returns: The validated string unchanged. Raises: ValueError: If the string is empty, missing units, has an invalid unit suffix, or is otherwise malformed. """ if not value: raise ValueError("MMseqs2 memory limit must not be empty.") if not re.fullmatch(r'[1-9][0-9]*[KMGT]', value): raise ValueError( f"Invalid MMseqs2 memory limit: '{value}'. " "Expected a positive integer (no leading zeros) followed by a unit (K, M, G, or T), " "e.g. '400G' or '1T'." ) return value