"""
Sequence alignment and phylogenetic tree building module.
This module provides functionality for trimming sequences based on BLAST alignments,
aligning sequences using MAFFT, and building phylogenetic trees using IQTree.
"""
import os
import sys
import subprocess
import logging
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from collections import defaultdict
from .blast_analysis import BlastHit, FastaReader
from .taxonomy import SequenceExtractor
# Configure module logger
logger = logging.getLogger(__name__)
[docs]
class SequenceTrimmer:
"""
Class for trimming sequences based on BLAST alignment coordinates.
"""
[docs]
@staticmethod
def trim_sequence_by_coordinates(
sequence: str,
start: int,
end: int
) -> str:
"""
Trim a sequence to extract the region between start and end coordinates.
BLAST coordinates are 1-indexed, so we need to adjust for Python's 0-indexing.
Args:
sequence: The full sequence string
start: Start position (1-indexed, inclusive)
end: End position (1-indexed, inclusive)
Returns:
Trimmed sequence string
"""
# Convert 1-indexed BLAST coordinates to 0-indexed Python
# Also handle reverse complement alignments (start > end)
if start > end:
# Reverse strand alignment
python_start = end - 1
python_end = start
else:
# Forward strand alignment
python_start = start - 1
python_end = end
# Ensure coordinates are within bounds
python_start = max(0, python_start)
python_end = min(len(sequence), python_end)
return sequence[python_start:python_end]
[docs]
@staticmethod
def trim_sequences_from_blast_hits(
fasta_path: Path,
blast_hits: List[BlastHit],
output_fasta: Path,
query_id: str,
taxonomic_rank: str
) -> bool:
"""
Trim sequences in a FASTA file based on BLAST hit coordinates.
This reads the representative sequences, trims them to the aligned regions,
and writes them to a new FASTA file along with the query sequence.
Args:
fasta_path: Path to input FASTA file with full-length sequences
blast_hits: List of BlastHit objects for this query
output_fasta: Path to output FASTA file with trimmed sequences
query_id: The query sequence ID to include in output
taxonomic_rank: the taxonomic rank to use for taxonomic labels (e.g., "genus")
Returns:
True if successful, False otherwise
"""
try:
# Read all sequences from the input FASTA
sequences = FastaReader.read_fasta(fasta_path)
# Create a mapping of subject accession to blast hits for quick lookup
# The FASTA file will have accessions (e.g., MZ387488.1) not full IDs
hit_map = {hit.get_accession(): hit for hit in blast_hits}
# Open output file
with open(output_fasta, 'w') as out:
# First, write the query sequence if it exists
if query_id in sequences:
query_seq = sequences[query_id]
out.write(f">{query_id}\n")
# Write sequence in lines of 60 characters
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
logger.info(f"Added query sequence {query_id} ({len(query_seq)} bp)")
# Now process subject sequences
for seq_id, sequence in sequences.items():
if seq_id == query_id:
continue # Skip query, already written
# Find the corresponding BLAST hit by accession
# The seq_id from FASTA should match the accession from the hit
hit = hit_map.get(seq_id)
if hit is None:
# This sequence doesn't have a BLAST hit, skip it
logger.warning(f"No BLAST hit found for sequence {seq_id}, skipping")
continue
# Trim the sequence based on subject coordinates
trimmed_seq = SequenceTrimmer.trim_sequence_by_coordinates(
sequence,
hit.subject_start,
hit.subject_end
)
# Write trimmed sequence with taxonomic information in header
header = seq_id
if isinstance(hit.subject_taxonomy, dict) and taxonomic_rank in hit.subject_taxonomy:
header = f"{seq_id} {hit.subject_taxonomy[taxonomic_rank][1]}"
out.write(f">{header}\n")
# Write sequence in lines of 60 characters
for i in range(0, len(trimmed_seq), 60):
out.write(trimmed_seq[i:i+60] + "\n")
logger.info(
f"Trimmed {seq_id} from {len(sequence)} bp to {len(trimmed_seq)} bp "
f"(coords: {hit.subject_start}-{hit.subject_end})"
)
return True
except Exception as e:
logger.error(f"Error trimming sequences: {e}")
return False
[docs]
class MAFFTAligner:
"""
Class for running MAFFT sequence alignments.
"""
[docs]
@staticmethod
def check_mafft_available() -> bool:
"""
Check if MAFFT is available in the system.
Returns:
True if MAFFT is available, False otherwise
"""
try:
result = subprocess.run(
['mafft', '--version'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
logger.info(f"Using mafft to build the alignment ({result.stdout})")
return result.returncode == 0
except (subprocess.SubprocessError, FileNotFoundError):
return False
[docs]
@staticmethod
def align_sequences(
input_fasta: Path,
output_fasta: Path,
auto_orient: bool = True,
num_threads: int = 1,
strategy: str = 'default'
) -> bool:
"""
Align sequences using MAFFT.
Args:
input_fasta: Path to input FASTA file with sequences to align
output_fasta: Path to output aligned FASTA file
auto_orient: Use MAFFT's auto-orient feature (default: True)
num_threads: Number of threads to use
strategy: MAFFT alignment strategy (default: 'default')
Options: 'default', 'auto', 'retree2', 'fftns'
'auto': Let MAFFT choose the best strategy automatically
'retree2': Fast progressive method, good for large datasets
'fftns': Fastest method for very large datasets
Returns:
True if alignment was successful, False otherwise
"""
if not MAFFTAligner.check_mafft_available():
logger.error("MAFFT is not available. Please install MAFFT.")
return False
if not input_fasta.exists():
logger.error(f"Input FASTA file not found: {input_fasta}")
return False
# Build MAFFT command
cmd = ['mafft']
# Add strategy-specific options
if strategy == 'auto':
cmd.append('--auto')
elif strategy == 'retree2':
cmd.append('--retree')
cmd.append('2')
elif strategy == 'fftns':
cmd.append('--retree')
cmd.append('1')
# 'default' uses MAFFT's default strategy (no special flag)
# Add auto-orient option
if auto_orient:
cmd.append('--adjustdirection')
# Add threading
cmd.extend(['--thread', str(num_threads)])
# Add input file
cmd.append(str(input_fasta))
logger.info(f"Running MAFFT alignment: {' '.join(cmd)}")
try:
with open(output_fasta, 'w') as out:
result = subprocess.run(
cmd,
stdout=out,
stderr=subprocess.PIPE,
text=True,
timeout=3600 # 1 hour timeout
)
if result.returncode != 0:
logger.error(f"MAFFT failed with error: {result.stderr}")
return False
logger.info(f"MAFFT alignment completed successfully. Output: {output_fasta}")
return True
except subprocess.TimeoutExpired:
logger.error("MAFFT alignment timed out")
return False
except Exception as e:
logger.error(f"Error running MAFFT: {e}")
return False
[docs]
class IQTreeBuilder:
"""
Class for building phylogenetic trees using IQTree.
"""
[docs]
@staticmethod
def check_iqtree_available() -> Tuple[bool, Optional[str]]:
"""
Check if IQTree is available in the system.
Returns:
Tuple of (available: bool, command: str or None)
"""
# Try 'iqtree', 'iqtree2', and 'iqtree3' commands
for cmd in ['iqtree3', 'iqtree2', 'iqtree']:
try:
result = subprocess.run(
[cmd, '--version'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
logger.info(f"Using {cmd} to build the trees ({result.stdout})")
return True, cmd
except (subprocess.SubprocessError, FileNotFoundError):
continue
return False, None
[docs]
@staticmethod
def build_tree(
alignment_fasta: Path,
output_prefix: Path,
model: str = "MFP",
num_threads: int = None
) -> bool:
"""
Build a phylogenetic tree using IQTree.
Args:
alignment_fasta: Path to aligned FASTA file
output_prefix: Prefix for output files
model: Substitution model (default: "MFP" for automatic ModelFinder Plus selection)
num_threads: Number of threads to use (default: None, which uses AUTO)
Returns:
True if tree building was successful, False otherwise
"""
available, iqtree_cmd = IQTreeBuilder.check_iqtree_available()
if not available:
logger.error("IQTree is not available. Please install IQTree or IQTree2.")
return False
if not alignment_fasta.exists():
logger.error(f"Alignment file not found: {alignment_fasta}")
return False
# Build IQTree command
cmd = [
iqtree_cmd,
'-s', str(alignment_fasta),
'-pre', str(output_prefix),
'-m', model,
'-T', str(num_threads) if num_threads else "AUTO"
]
logger.info(f"Running IQTree: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=14400 # 4 hour timeout (increased from 1 hour to handle larger datasets)
)
if result.returncode != 0:
logger.error(f"IQTree failed with error: {result.stderr}")
return False
# Check if tree file was created
tree_file = Path(str(output_prefix) + ".treefile")
if not tree_file.exists():
logger.error("IQTree did not produce a tree file")
return False
logger.info(f"IQTree completed successfully. Tree: {tree_file}")
return True
except subprocess.TimeoutExpired:
logger.error("IQTree timed out")
return False
except Exception as e:
logger.error(f"Error running IQTree: {e}")
return False
[docs]
@staticmethod
def build_tree_background(
alignment_fasta: Path,
output_prefix: Path,
model: str = "MFP"
) -> Optional[Dict]:
"""
Start building a phylogenetic tree using IQTree in the background.
This method starts IQTree as a background process and returns immediately,
allowing multiple trees to be built in parallel.
Args:
alignment_fasta: Path to aligned FASTA file
output_prefix: Prefix for output files
model: Substitution model (default: "MFP" for automatic ModelFinder Plus selection)
Returns:
Dictionary with process information if successful, None otherwise:
- 'process': subprocess.Popen object
- 'output_prefix': output prefix path
- 'alignment_fasta': input alignment file path
- 'tree_file': expected tree file path
"""
available, iqtree_cmd = IQTreeBuilder.check_iqtree_available()
if not available:
logger.error("IQTree is not available. Please install IQTree or IQTree2.")
return None
if not alignment_fasta.exists():
logger.error(f"Alignment file not found: {alignment_fasta}")
return None
# Build IQTree command
cmd = [
iqtree_cmd,
'-s', str(alignment_fasta),
'-pre', str(output_prefix),
'-m', model,
'-T', "AUTO"
]
logger.info(f"Starting IQTree in background: {' '.join(cmd)}")
try:
# Start process in background
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
tree_file = Path(str(output_prefix) + ".treefile")
return {
'process': process,
'output_prefix': output_prefix,
'alignment_fasta': alignment_fasta,
'tree_file': tree_file,
'cmd': ' '.join(cmd)
}
except Exception as e:
logger.error(f"Error starting IQTree: {e}")
return None
[docs]
@staticmethod
def wait_for_tree_jobs(
jobs: List[Dict],
timeout: int = 14400
) -> Dict[str, bool]:
"""
Wait for multiple IQTree jobs to complete.
This method polls all running processes and waits for them to complete.
Since the processes are already running in parallel (started with Popen),
this method just collects their results as they finish.
Args:
jobs: List of job dictionaries returned by build_tree_background()
timeout: Maximum time to wait for each individual job in seconds (default: 14400 = 4 hours)
Increased because of mega tree created at the end!
Returns:
Dictionary mapping tree_file path to success status (True/False)
"""
results = {}
logger.info(f"Waiting for {len(jobs)} IQTree jobs to complete...")
for job in jobs:
process = job['process']
output_prefix = job['output_prefix']
tree_file = job['tree_file']
cmd = job.get('cmd', 'IQTree')
try:
# Wait for process to complete
# Note: This waits for THIS process, but other processes continue running in parallel
stdout, stderr = process.communicate(timeout=timeout)
if process.returncode != 0:
logger.error(f"IQTree failed for {output_prefix} with error: {stderr}")
results[str(tree_file)] = False
continue
# Check if tree file was created
if not tree_file.exists():
logger.error(f"IQTree did not produce a tree file for {output_prefix}")
results[str(tree_file)] = False
continue
logger.info(f"IQTree completed successfully for {output_prefix}. Tree: {tree_file}")
results[str(tree_file)] = True
except subprocess.TimeoutExpired:
logger.error(f"IQTree timed out for {output_prefix} after {timeout} seconds")
process.kill()
# Clean up the killed process
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.error(f"Failed to terminate IQTree process for {output_prefix}")
results[str(tree_file)] = False
except Exception as e:
logger.error(f"Error waiting for IQTree job {output_prefix}: {e}")
# Ensure process is cleaned up
if process.poll() is None: # Process still running
process.kill()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
results[str(tree_file)] = False
successful = sum(1 for success in results.values() if success)
logger.info(f"Completed {successful}/{len(jobs)} IQTree jobs successfully")
return results
[docs]
@staticmethod
def relabel_tree_with_taxonomy(
tree_file: Path,
blast_hits: List[BlastHit],
output_tree: Path,
taxonomic_rank: str,
) -> bool:
"""
Relabel tree nodes with taxonomic names.
This reads a Newick tree file and replaces sequence IDs with taxonomic names
from the BLAST hits.
Args:
tree_file: Path to input tree file (Newick format)
blast_hits: List of BlastHit objects with taxonomic information
output_tree: Path to output tree file with relabeled nodes
taxonomic_rank: the taxonomic rank to use for relabeling (e.g., "genus")
Returns:
True if successful, False otherwise
"""
try:
# Create mapping of sequence accession to taxonomic name
# Trees will have accessions (e.g., MZ387488.1) not full IDs
label_map = {}
for hit in blast_hits:
label = "unknown"
if isinstance(hit.subject_taxonomy, dict) and taxonomic_rank in hit.subject_taxonomy:
label = hit.subject_taxonomy[taxonomic_rank][1]
label_map[hit.subject_id] = label
if label:
# Clean up the label for tree format (Newick format constraints)
# Replace spaces, colons, parentheses, commas, and semicolons
clean_label = (label.replace(' ', '_')
.replace(':', '_')
.replace('(', '_')
.replace(')', '_')
.replace(',', '_')
.replace(';', '_'))
# Use accession for mapping since that's what appears in trees
accession = hit.get_accession()
label_map[accession] = clean_label
# Read the tree file
with open(tree_file, 'r') as f:
tree_string = f.read()
# Replace sequence IDs with taxonomic names
for seq_id, tax_name in label_map.items():
# Handle normal sequences (not reversed)
tree_string = tree_string.replace(f"({seq_id}:", f"({tax_name}:")
tree_string = tree_string.replace(f",{seq_id}:", f",{tax_name}:")
tree_string = tree_string.replace(f" {seq_id}:", f" {tax_name}:")
# Handle sequences with _R_ prefix (reversed by MAFFT)
# MAFFT prepends _R_ to sequence IDs when it adjusts direction
# We need to remove _R_ to find the correct ID, then append "_R" to the label
# (using underscore to maintain Newick format compliance, representing " R")
reversed_seq_id = f"_R_{seq_id}"
reversed_label = f"{tax_name}_R"
tree_string = tree_string.replace(f"({reversed_seq_id}:", f"({reversed_label}:")
tree_string = tree_string.replace(f",{reversed_seq_id}:", f",{reversed_label}:")
tree_string = tree_string.replace(f" {reversed_seq_id}:", f" {reversed_label}:")
# Write the relabeled tree
with open(output_tree, 'w') as f:
f.write(tree_string)
logger.info(f"Tree relabeled with {len(label_map)} taxonomic names")
logger.info(f"Relabeled tree saved to: {output_tree}")
return True
except Exception as e:
logger.error(f"Error relabeling tree: {e}")
return False
[docs]
def process_query_alignment_and_tree(
query_id: str,
query_dir: Path,
blast_hits: List[BlastHit],
query_fasta: Path,
taxonomic_rank: str,
num_threads: int = 1
) -> Dict[str, Optional[Path]]:
"""
Complete pipeline for a single query: trim, align, and build tree.
Args:
query_id: Query sequence identifier
query_dir: Directory containing query-specific files
blast_hits: List of BlastHit objects for this query (with taxonomy info)
query_fasta: Path to original query FASTA file
taxonomic_rank: The taxonomic rank to use for relabeling the tree
num_threads: Number of threads to use
Returns:
Dictionary with paths to generated files:
- 'trimmed_fasta': Trimmed sequences
- 'alignment': Aligned sequences
- 'tree': Phylogenetic tree
- 'labeled_tree': Tree with taxonomic labels
"""
results = {
'trimmed_fasta': None,
'alignment': None,
'tree': None,
'labeled_tree': None
}
# File paths
safe_query_id = query_id.replace('|', '_').replace('/', '_')
representatives_fasta = query_dir / f"{safe_query_id}_representatives.fasta"
trimmed_fasta = query_dir / f"{safe_query_id}_trimmed.fasta"
alignment_fasta = query_dir / f"{safe_query_id}_aligned.fasta"
tree_prefix = query_dir / f"{safe_query_id}_tree"
tree_file = Path(str(tree_prefix) + ".treefile")
labeled_tree = query_dir / f"{safe_query_id}_tree_labeled.treefile"
# Step 1: Read query sequence and add it to the representatives file
try:
query_sequences = FastaReader.read_fasta(query_fasta)
if query_id not in query_sequences:
logger.error(f"Query {query_id} not found in {query_fasta}")
return results
# Read representatives and combine with query
if not representatives_fasta.exists():
logger.error(f"Representatives file not found: {representatives_fasta}")
return results
# Create combined FASTA with query + representatives
combined_fasta = query_dir / f"{safe_query_id}_with_query.fasta"
with open(combined_fasta, 'w') as out:
# Write query first
query_seq = query_sequences[query_id]
out.write(f">{query_id}\n")
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
# Append representatives
with open(representatives_fasta, 'r') as rep:
out.write(rep.read())
logger.info(f"Combined query with representatives: {combined_fasta}")
except Exception as e:
logger.error(f"Error preparing sequences: {e}")
return results
# Step 2: Trim sequences based on BLAST coordinates
logger.info(f"Trimming sequences for {query_id}...")
if SequenceTrimmer.trim_sequences_from_blast_hits(
fasta_path=combined_fasta,
blast_hits=blast_hits,
output_fasta=trimmed_fasta,
taxonomic_rank=taxonomic_rank,
query_id=query_id
):
results['trimmed_fasta'] = trimmed_fasta
logger.info(f"Trimmed sequences saved to: {trimmed_fasta}")
else:
logger.error(f"Failed to trim sequences for {query_id}")
return results
# Step 3: Align sequences with MAFFT
logger.info(f"Aligning sequences for {query_id}...")
if MAFFTAligner.align_sequences(
input_fasta=trimmed_fasta,
output_fasta=alignment_fasta,
auto_orient=True,
num_threads=num_threads
):
results['alignment'] = alignment_fasta
logger.info(f"Alignment saved to: {alignment_fasta}")
else:
logger.error(f"Failed to align sequences for {query_id}")
return results
# Step 4: Build phylogenetic tree with IQTree
logger.info(f"Building phylogenetic tree for {query_id}...")
if IQTreeBuilder.build_tree(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
):
results['tree'] = tree_file
logger.info(f"Tree saved to: {tree_file}")
else:
logger.error(f"Failed to build tree for {query_id}")
return results
# Step 5: Relabel tree with taxonomic names
logger.info(f"Relabeling tree with taxonomic names for {query_id}...")
if IQTreeBuilder.relabel_tree_with_taxonomy(
tree_file=tree_file,
blast_hits=blast_hits,
output_tree=labeled_tree,
taxonomic_rank=taxonomic_rank
):
results['labeled_tree'] = labeled_tree
logger.info(f"Labeled tree saved to: {labeled_tree}")
else:
logger.warning(f"Failed to relabel tree for {query_id}, but unlabeled tree is available")
return results
[docs]
def process_query_alignment_and_tree_parallel(
query_id: str,
query_dir: Path,
blast_hits: List[BlastHit],
query_fasta: Path,
taxonomic_rank: str,
num_threads: int = 1,
background_tree: bool = False
) -> Dict[str, Optional[Path]]:
"""
Complete pipeline for a single query: trim, align, and optionally build tree in background.
This is similar to process_query_alignment_and_tree, but with an option to start
tree building in the background and return immediately without waiting for completion.
Args:
query_id: Query sequence identifier
query_dir: Directory containing query-specific files
blast_hits: List of BlastHit objects for this query (with taxonomy info)
query_fasta: Path to original query FASTA file
taxonomic_rank: The taxonomic rank to use for relabeling the tree
num_threads: Number of threads to use
background_tree: If True, start tree building in background and return immediately
Returns:
Dictionary with paths to generated files:
- 'trimmed_fasta': Trimmed sequences
- 'alignment': Aligned sequences
- 'tree_job': Background job info if background_tree=True, None otherwise
- 'tree_file': Expected tree file path
- 'blast_hits': BLAST hits for later tree relabeling
- 'taxonomic_rank': Taxonomic rank for later tree relabeling
"""
results = {
'trimmed_fasta': None,
'alignment': None,
'tree_job': None,
'tree_file': None,
'labeled_tree_path': None,
'blast_hits': blast_hits,
'taxonomic_rank': taxonomic_rank
}
# File paths
safe_query_id = query_id.replace('|', '_').replace('/', '_')
representatives_fasta = query_dir / f"{safe_query_id}_representatives.fasta"
trimmed_fasta = query_dir / f"{safe_query_id}_trimmed.fasta"
alignment_fasta = query_dir / f"{safe_query_id}_aligned.fasta"
tree_prefix = query_dir / f"{safe_query_id}_tree"
tree_file = Path(str(tree_prefix) + ".treefile")
labeled_tree = query_dir / f"{safe_query_id}_tree_labeled.treefile"
# Step 1: Read query sequence and add it to the representatives file
try:
query_sequences = FastaReader.read_fasta(query_fasta)
if query_id not in query_sequences:
logger.error(f"Query {query_id} not found in {query_fasta}")
return results
# Read representatives and combine with query
if not representatives_fasta.exists():
logger.error(f"Representatives file not found: {representatives_fasta}")
return results
# Create combined FASTA with query + representatives
combined_fasta = query_dir / f"{safe_query_id}_with_query.fasta"
with open(combined_fasta, 'w') as out:
# Write query first
query_seq = query_sequences[query_id]
out.write(f">{query_id}\n")
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
# Append representatives
with open(representatives_fasta, 'r') as rep:
out.write(rep.read())
logger.info(f"Combined query with representatives: {combined_fasta}")
except Exception as e:
logger.error(f"Error preparing sequences: {e}")
return results
# Step 2: Trim sequences based on BLAST coordinates
logger.info(f"Trimming sequences for {query_id}...")
if SequenceTrimmer.trim_sequences_from_blast_hits(
fasta_path=combined_fasta,
blast_hits=blast_hits,
output_fasta=trimmed_fasta,
taxonomic_rank=taxonomic_rank,
query_id=query_id
):
results['trimmed_fasta'] = trimmed_fasta
logger.info(f"Trimmed sequences saved to: {trimmed_fasta}")
else:
logger.error(f"Failed to trim sequences for {query_id}")
return results
# Step 3: Align sequences with MAFFT
logger.info(f"Aligning sequences for {query_id}...")
if MAFFTAligner.align_sequences(
input_fasta=trimmed_fasta,
output_fasta=alignment_fasta,
auto_orient=True,
num_threads=num_threads
):
results['alignment'] = alignment_fasta
logger.info(f"Alignment saved to: {alignment_fasta}")
else:
logger.error(f"Failed to align sequences for {query_id}")
return results
# Step 4: Build phylogenetic tree with IQTree
if background_tree:
# Start tree building in background
logger.info(f"Starting phylogenetic tree building in background for {query_id}...")
tree_job = IQTreeBuilder.build_tree_background(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
)
if tree_job:
results['tree_job'] = tree_job
results['tree_file'] = tree_file
results['labeled_tree_path'] = labeled_tree
logger.info(f"Tree building started in background for {query_id}")
else:
logger.error(f"Failed to start tree building for {query_id}")
else:
# Build tree synchronously
logger.info(f"Building phylogenetic tree for {query_id}...")
if IQTreeBuilder.build_tree(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
):
results['tree_file'] = tree_file
logger.info(f"Tree saved to: {tree_file}")
# Step 5: Relabel tree with taxonomic names
logger.info(f"Relabeling tree with taxonomic names for {query_id}...")
if IQTreeBuilder.relabel_tree_with_taxonomy(
tree_file=tree_file,
blast_hits=blast_hits,
output_tree=labeled_tree,
taxonomic_rank=taxonomic_rank
):
results['labeled_tree_path'] = labeled_tree
logger.info(f"Labeled tree saved to: {labeled_tree}")
else:
logger.warning(f"Failed to relabel tree for {query_id}, but unlabeled tree is available")
else:
logger.error(f"Failed to build tree for {query_id}")
return results
[docs]
def check_alignment_consistency(blast_hits: List[BlastHit], tolerance: int = 50) -> Dict[str, bool]:
"""
Check if BLAST hits align to similar locations on reference sequences.
For each reference sequence that appears in multiple hits, check if the alignment
coordinates are consistent (within tolerance).
Args:
blast_hits: List of BlastHit objects to check
tolerance: Maximum allowed difference in coordinates (default: 50 bp)
Returns:
Dictionary mapping subject_id to consistency status (True if consistent)
"""
# Group hits by subject sequence
hits_by_subject = defaultdict(list)
for hit in blast_hits:
hits_by_subject[hit.subject_id].append(hit)
consistency_status = {}
for subject_id, subject_hits in hits_by_subject.items():
if len(subject_hits) == 1:
# Only one hit, so it's consistent by definition
consistency_status[subject_id] = True
continue
# Check if all hits have similar start and end coordinates
starts = [hit.subject_start for hit in subject_hits]
ends = [hit.subject_end for hit in subject_hits]
start_range = max(starts) - min(starts)
end_range = max(ends) - min(ends)
is_consistent = start_range <= tolerance and end_range <= tolerance
consistency_status[subject_id] = is_consistent
if not is_consistent:
logger.warning(
f"Subject {subject_id} has inconsistent alignments: "
f"start range={start_range}, end range={end_range}"
)
else:
logger.info(
f"Subject {subject_id} has consistent alignments across {len(subject_hits)} hits"
)
return consistency_status
[docs]
def group_hits_by_group_rank(
blast_hits: List[BlastHit],
group_rank: str,
) -> Dict[str, Dict[str, List[BlastHit]]]:
"""
Group BLAST hits by group_rank across all queries.
Args:
blast_hits: List of BlastHit objects with group taxonomy information
Returns:
Dictionary mapping group_rank_name (taxonomy name) to another dict mapping
query_id to list of hits.
Format: {group_rank_name: {query_id: [hits]}}
"""
grouped = defaultdict(lambda: defaultdict(list))
for hit in blast_hits:
if hit.subject_taxonomy and group_rank in hit.subject_taxonomy:
grouped[hit.subject_taxonomy[group_rank][1]][hit.query_id].append(hit)
else:
logger.warning(
f"Hit {hit.subject_id} for query {hit.query_id} has no group taxonomy information"
)
logger.info(f"Grouped hits into {len(grouped)} taxonomic groups")
for group_name, queries in grouped.items():
logger.info(
f" Group {group_name}: {len(queries)} queries, "
f"{sum(len(hits) for hits in queries.values())} total hits"
)
return dict(grouped)
[docs]
def create_grouped_fasta_with_queries(
group_tid: str,
group_name: str,
query_hits_map: Dict[str, List[BlastHit]],
labeling_rank: str,
query_fasta: Path,
output_fasta: Path,
database: str = "core_nt",
blastdb_path: Optional[Path] = None
) -> bool:
"""
Create a FASTA file for a taxonomic group containing all queries and unique references.
Args:
group_tid: Taxonomy ID of the group
group_name: Name of the taxonomic group
query_hits_map: Dictionary mapping query_id to list of BlastHit objects
labeling_rank: Taxonomic rank to use for labeling (e.g., "genus")
query_fasta: Path to original query FASTA file
output_fasta: Path to output grouped FASTA file
database: Name of BLAST database
blastdb_path: Path to BLAST database directory
Returns:
True if successful, False otherwise
"""
logger.info(f"Creating grouped FASTA for {group_name} ({group_tid})")
# Read all query sequences
try:
query_sequences = FastaReader.read_fasta(query_fasta)
except Exception as e:
logger.error(f"Error reading query FASTA: {e}")
return False
# Collect unique reference sequences (by labeling rank) so we only get one example
# For each unique reference, keep the hit with the best bit score
unique_references = {}
unique_labels = {}
for query_id, hits in query_hits_map.items():
for hit in hits:
label = hit.subject_id
if isinstance(hit.subject_taxonomy, dict) and labeling_rank in hit.subject_taxonomy:
label = hit.subject_taxonomy[labeling_rank][1]
if label not in unique_labels:
unique_labels[label] = hit
unique_references[hit.subject_id] = hit
else:
# Keep the hit with better bit score
if hit.bit_score > unique_labels[label].bit_score:
# Remove the old reference for this label, if present
old_hit = unique_labels[label]
old_subject_id = old_hit.subject_id
if old_subject_id in unique_references:
del unique_references[old_subject_id]
# Add the new, better-scoring hit
unique_references[hit.subject_id] = hit
unique_labels[label] = hit
logger.info(f"Found {len(unique_references)} unique reference sequences")
unique_label_keys = list(unique_labels.keys())
logger.info(
"Unique labels: count=%d, example_labels=%s",
len(unique_labels),
unique_label_keys[:10],
)
# Extract reference sequences
seq_extractor = SequenceExtractor(blastdb_path)
temp_ref_fasta = output_fasta.parent / f"{output_fasta.stem}_temp_refs.fasta"
try:
success = seq_extractor.extract_sequences(
sequence_ids=list(unique_references.keys()),
output_fasta=temp_ref_fasta,
database=database
)
if not success:
logger.error("Failed to extract reference sequences")
sys.exit(1)
# Read extracted references
ref_sequences = FastaReader.read_fasta(temp_ref_fasta)
# Write combined FASTA file
with open(output_fasta, 'w') as out:
# Write all query sequences first
for query_id in query_hits_map.keys():
if query_id in query_sequences:
query_seq = query_sequences[query_id]
out.write(f">{query_id}\n")
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
logger.info(f"Added query {query_id} ({len(query_seq)} bp)")
else:
logger.warning(f"Query {query_id} not found in query FASTA file, skipping")
# Write reference sequences with taxonomic labels
for subject_id, hit in unique_references.items():
# Get accession for lookup
accession = hit.get_accession()
if accession in ref_sequences:
ref_seq = ref_sequences[accession]
header = accession
if isinstance(hit.subject_taxonomy, dict) and labeling_rank in hit.subject_taxonomy:
header = f"{accession} {hit.subject_taxonomy[labeling_rank][1]}"
out.write(f">{header}\n")
for i in range(0, len(ref_seq), 60):
out.write(ref_seq[i:i+60] + "\n")
logger.info(f"Added reference {accession} ({len(ref_seq)} bp)")
else:
logger.warning(f"Reference {accession} not found in extracted sequences, skipping")
logger.info(f"Created grouped FASTA file: {output_fasta}")
return True
except Exception as e:
logger.error(f"Error creating grouped FASTA: {e}")
return False
finally:
# Clean up temporary file
if temp_ref_fasta.exists():
temp_ref_fasta.unlink()
[docs]
def trim_grouped_sequences(
input_fasta: Path,
blast_hits: List[BlastHit],
output_fasta: Path,
query_ids: List[str]
) -> bool:
"""
Trim sequences in a grouped FASTA file based on BLAST hit coordinates.
This is similar to trim_sequences_from_blast_hits but handles multiple queries.
Args:
input_fasta: Path to input FASTA file with full-length sequences
blast_hits: List of BlastHit objects for all queries in the group
output_fasta: Path to output FASTA file with trimmed sequences
query_ids: List of query sequence IDs to include (untrimmed)
Returns:
True if successful, False otherwise
"""
try:
# Read all sequences from the input FASTA
sequences = FastaReader.read_fasta(input_fasta)
# Create a mapping of subject accession to blast hits
hit_map = {}
for hit in blast_hits:
accession = hit.get_accession()
if accession not in hit_map:
hit_map[accession] = []
hit_map[accession].append(hit)
# For sequences with multiple hits, use the one with the best bit score
# This is consistent with the deduplication logic in create_grouped_fasta_with_queries
best_hits = {}
for accession, hits in hit_map.items():
if len(hits) == 1:
best_hits[accession] = hits[0]
else:
# Use the hit with the best bit score for consistency
best_hits[accession] = max(hits, key=lambda h: h.bit_score)
# Open output file
with open(output_fasta, 'w') as out:
# First, write all query sequences (untrimmed)
for query_id in query_ids:
if query_id in sequences:
query_seq = sequences[query_id]
out.write(f">{query_id}\n")
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
logger.info(f"Added query sequence {query_id} ({len(query_seq)} bp)")
# Now process subject sequences (trimmed)
for seq_id, sequence in sequences.items():
if seq_id in query_ids:
continue # Skip queries, already written
# Extract just the accession from the header (might have taxonomy info)
accession = seq_id.split()[0]
hit = best_hits.get(accession)
if hit is None:
logger.error(f"No BLAST hit found for sequence {accession}, data consistency issue")
continue
# Trim the sequence based on subject coordinates
trimmed_seq = SequenceTrimmer.trim_sequence_by_coordinates(
sequence,
hit.subject_start,
hit.subject_end
)
# Write trimmed sequence
out.write(f">{seq_id}\n")
for i in range(0, len(trimmed_seq), 60):
out.write(trimmed_seq[i:i+60] + "\n")
logger.info(
f"Trimmed {accession} from {len(sequence)} bp to {len(trimmed_seq)} bp "
f"(coords: {hit.subject_start}-{hit.subject_end})"
)
return True
except Exception as e:
logger.error(f"Error trimming grouped sequences: {e}")
return False
[docs]
def process_grouped_alignment_and_tree(
group_name: str,
group_dir: Path,
taxonomic_rank: str,
blast_hits: List[BlastHit],
query_ids: List[str],
num_threads: int = 1
) -> Dict[str, Optional[Path]]:
"""
Complete pipeline for a taxonomic group: trim, align, and build tree.
Args:
group_name: The name of the group, used for file naming
group_dir: Directory containing group-specific files
taxonomic_rank: Taxonomic rank to use for labeling the tree
blast_hits: List of BlastHit objects for all queries in the group
query_ids: List of query sequence IDs in this group
num_threads: Number of threads to use
Returns:
Dictionary with paths to generated files:
- 'combined_fasta': Combined sequences (queries + references)
- 'trimmed_fasta': Trimmed sequences
- 'alignment': Aligned sequences
- 'tree': Phylogenetic tree
- 'labeled_tree': Tree with taxonomic labels
"""
results = {
'combined_fasta': None,
'trimmed_fasta': None,
'alignment': None,
'tree': None,
'labeled_tree': None
}
# File paths
safe_group_name = group_name.replace(' ', '_').replace('/', '_').replace('|', '_')
combined_fasta = group_dir / f"{safe_group_name}_combined.fasta"
trimmed_fasta = group_dir / f"{safe_group_name}_trimmed.fasta"
alignment_fasta = group_dir / f"{safe_group_name}_aligned.fasta"
tree_prefix = group_dir / f"{safe_group_name}_tree"
tree_file = Path(str(tree_prefix) + ".treefile")
labeled_tree = group_dir / f"{safe_group_name}_tree_labeled.treefile"
# Check if combined FASTA exists
if not combined_fasta.exists():
logger.error(f"Combined FASTA file not found: {combined_fasta}")
return results
results['combined_fasta'] = combined_fasta
# Step 1: Trim sequences based on BLAST coordinates
logger.info(f"Trimming sequences for group {group_name}...")
if trim_grouped_sequences(
input_fasta=combined_fasta,
blast_hits=blast_hits,
output_fasta=trimmed_fasta,
query_ids=query_ids
):
results['trimmed_fasta'] = trimmed_fasta
logger.info(f"Trimmed sequences saved to: {trimmed_fasta}")
else:
logger.error(f"Failed to trim sequences for group {group_name}")
return results
# Step 2: Align sequences with MAFFT
logger.info(f"Aligning sequences for group {group_name}...")
if MAFFTAligner.align_sequences(
input_fasta=trimmed_fasta,
output_fasta=alignment_fasta,
auto_orient=True,
num_threads=num_threads
):
results['alignment'] = alignment_fasta
logger.info(f"Alignment saved to: {alignment_fasta}")
else:
logger.error(f"Failed to align sequences for group {group_name}")
return results
# Step 3: Build phylogenetic tree with IQTree
logger.info(f"Building phylogenetic tree for group {group_name}...")
if IQTreeBuilder.build_tree(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
):
results['tree'] = tree_file
logger.info(f"Tree saved to: {tree_file}")
else:
logger.error(f"Failed to build tree for group {group_name}")
return results
# Step 4: Relabel tree with taxonomic names
logger.info(f"Relabeling tree with taxonomic names for group {group_name}...")
if IQTreeBuilder.relabel_tree_with_taxonomy(
tree_file=tree_file,
blast_hits=blast_hits,
output_tree=labeled_tree,
taxonomic_rank=taxonomic_rank
):
results['labeled_tree'] = labeled_tree
logger.info(f"Labeled tree saved to: {labeled_tree}")
else:
logger.warning(f"Failed to relabel tree for group {group_name}, but unlabeled tree is available")
return results
[docs]
def process_grouped_alignment_and_tree_parallel(
group_name: str,
group_dir: Path,
taxonomic_rank: str,
blast_hits: List[BlastHit],
query_ids: List[str],
num_threads: int = 1,
background_tree: bool = False
) -> Dict[str, Optional[Path]]:
"""
Complete pipeline for a taxonomic group: trim, align, and optionally build tree in background.
This is similar to process_grouped_alignment_and_tree, but with an option to start
tree building in the background and return immediately without waiting for completion.
Args:
group_name: The name of the group, used for file naming
group_dir: Directory containing group-specific files
taxonomic_rank: Taxonomic rank to use for labeling the tree
blast_hits: List of BlastHit objects for all queries in the group
query_ids: List of query sequence IDs in this group
num_threads: Number of threads to use
background_tree: If True, start tree building in background and return immediately
Returns:
Dictionary with paths to generated files:
- 'combined_fasta': Combined sequences (queries + references)
- 'trimmed_fasta': Trimmed sequences
- 'alignment': Aligned sequences
- 'tree_job': Background job info if background_tree=True, None otherwise
- 'tree_file': Expected tree file path
- 'blast_hits': BLAST hits for later tree relabeling
- 'taxonomic_rank': Taxonomic rank for later tree relabeling
"""
results = {
'combined_fasta': None,
'trimmed_fasta': None,
'alignment': None,
'tree_job': None,
'tree_file': None,
'labeled_tree_path': None,
'blast_hits': blast_hits,
'taxonomic_rank': taxonomic_rank
}
# File paths
safe_group_name = group_name.replace(' ', '_').replace('/', '_').replace('|', '_')
combined_fasta = group_dir / f"{safe_group_name}_combined.fasta"
trimmed_fasta = group_dir / f"{safe_group_name}_trimmed.fasta"
alignment_fasta = group_dir / f"{safe_group_name}_aligned.fasta"
tree_prefix = group_dir / f"{safe_group_name}_tree"
tree_file = Path(str(tree_prefix) + ".treefile")
labeled_tree = group_dir / f"{safe_group_name}_tree_labeled.treefile"
# Check if combined FASTA exists
if not combined_fasta.exists():
logger.error(f"Combined FASTA file not found: {combined_fasta}")
return results
results['combined_fasta'] = combined_fasta
# Step 1: Trim sequences based on BLAST coordinates
logger.info(f"Trimming sequences for group {group_name}...")
if trim_grouped_sequences(
input_fasta=combined_fasta,
blast_hits=blast_hits,
output_fasta=trimmed_fasta,
query_ids=query_ids
):
results['trimmed_fasta'] = trimmed_fasta
logger.info(f"Trimmed sequences saved to: {trimmed_fasta}")
else:
logger.error(f"Failed to trim sequences for group {group_name}")
return results
# Step 2: Align sequences with MAFFT
logger.info(f"Aligning sequences for group {group_name}...")
if MAFFTAligner.align_sequences(
input_fasta=trimmed_fasta,
output_fasta=alignment_fasta,
auto_orient=True,
num_threads=num_threads
):
results['alignment'] = alignment_fasta
logger.info(f"Alignment saved to: {alignment_fasta}")
else:
logger.error(f"Failed to align sequences for group {group_name}")
return results
# Step 3: Build phylogenetic tree with IQTree
if background_tree:
# Start tree building in background
logger.info(f"Starting phylogenetic tree building in background for group {group_name}...")
tree_job = IQTreeBuilder.build_tree_background(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
)
if tree_job:
results['tree_job'] = tree_job
results['tree_file'] = tree_file
results['labeled_tree_path'] = labeled_tree
logger.info(f"Tree building started in background for {group_name}")
else:
logger.error(f"Failed to start tree building for group {group_name}")
else:
# Build tree synchronously
logger.info(f"Building phylogenetic tree for group {group_name}...")
if IQTreeBuilder.build_tree(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
):
results['tree_file'] = tree_file
logger.info(f"Tree saved to: {tree_file}")
# Step 4: Relabel tree with taxonomic names
logger.info(f"Relabeling tree with taxonomic names for group {group_name}...")
if IQTreeBuilder.relabel_tree_with_taxonomy(
tree_file=tree_file,
blast_hits=blast_hits,
output_tree=labeled_tree,
taxonomic_rank=taxonomic_rank
):
results['labeled_tree_path'] = labeled_tree
logger.info(f"Labeled tree saved to: {labeled_tree}")
else:
logger.warning(f"Failed to relabel tree for group {group_name}, but unlabeled tree is available")
else:
logger.error(f"Failed to build tree for group {group_name}")
return results
[docs]
def concatenate_all_groups_and_build_tree(
output_dir: Path,
query_fasta: Path,
classification_file: Path,
blast_hits: List[BlastHit],
combined_tree_label_rank: str = "genus",
num_threads: int = 1,
alignment_strategy: str = "auto"
) -> Dict[str, Optional[Path]]:
"""
Concatenate all group _trimmed.fasta files, add queries with 0 blast hits,
build a final alignment and tree.
This function:
1. Finds all *_trimmed.fasta files in group directories
2. Reads the classification file to identify queries with 0 blast hits
3. Concatenates all sequences into a single file
4. Uses MAFFT to build an alignment (with optimal parameters for many sequences)
5. Uses IQTree to build a phylogenetic tree
6. Relabels tree nodes with taxonomic names
Args:
output_dir: Output directory containing group subdirectories
query_fasta: Original query FASTA file
classification_file: Path to classifications.tsv file
blast_hits: List of all BlastHit objects with taxonomy information
combined_tree_label_rank: Taxonomic rank for tree labeling (default: genus)
num_threads: Number of threads for alignment and tree building (default: 1)
alignment_strategy: MAFFT alignment strategy (default: 'auto')
Options: 'default', 'auto', 'retree2', 'fftns'
Returns:
Dictionary with paths to generated files:
- 'combined_fasta': Combined sequences from all groups + zero-hit queries
- 'alignment': Aligned sequences
- 'tree': Phylogenetic tree
- 'labeled_tree': Tree with taxonomic labels
"""
results = {
'combined_fasta': None,
'alignment': None,
'tree': None,
'labeled_tree': None
}
# Define total steps for consistent logging
TOTAL_STEPS = 5
logger.info("\n" + "=" * 60)
logger.info("Building combined tree from all groups")
logger.info("=" * 60)
# Output file paths
combined_fasta = output_dir / "all_groups_combined.fasta"
alignment_fasta = output_dir / "all_groups_aligned.fasta"
tree_prefix = output_dir / "all_groups_tree"
tree_file = Path(str(tree_prefix) + ".treefile")
labeled_tree = output_dir / "all_groups_tree_labeled.treefile"
try:
# Step 1: Find all group directories and their _trimmed.fasta files
logger.info(f"\n[Step 1/{TOTAL_STEPS}] Finding all group trimmed FASTA files...")
trimmed_files = []
# Look for directories in output_dir
for item in output_dir.iterdir():
if item.is_dir():
# Look for *_trimmed.fasta files in this directory
for fasta_file in item.glob("*_trimmed.fasta"):
trimmed_files.append(fasta_file)
logger.info(f" Found: {fasta_file}")
if not trimmed_files:
logger.error("No _trimmed.fasta files found in group directories")
return results
logger.info(f"Found {len(trimmed_files)} trimmed FASTA files")
# Step 2: Read classification file to find queries with 0 blast hits
logger.info(f"\n[Step 2/{TOTAL_STEPS}] Identifying queries with 0 blast hits...")
zero_hit_queries = []
if classification_file.exists():
with open(classification_file, 'r') as f:
# Skip header
next(f, None)
for line in f:
parts = line.strip().split('\t')
if len(parts) >= 2:
query_id = parts[0]
try:
blast_hits_count = int(parts[1])
if blast_hits_count == 0:
zero_hit_queries.append(query_id)
logger.info(f" Query with 0 hits: {query_id}")
except ValueError:
# Skip lines where blast_hits_count is not a valid integer (e.g., 'N/A')
logger.debug(f"Skipping query {query_id} with non-numeric hit count: {parts[1]}")
continue
else:
logger.warning(f"Classification file not found: {classification_file}")
if zero_hit_queries:
logger.info(f"Found {len(zero_hit_queries)} queries with 0 blast hits")
else:
logger.info("No queries with 0 blast hits")
# Step 3: Concatenate all sequences
logger.info(f"\n[Step 3/{TOTAL_STEPS}] Concatenating all sequences...")
# Read original query sequences
query_sequences = FastaReader.read_fasta(query_fasta)
# Track which sequences we've already written to avoid duplicates
written_sequences = set()
with open(combined_fasta, 'w') as out:
# First, write all sequences from trimmed files
for trimmed_file in sorted(trimmed_files):
logger.info(f" Reading: {trimmed_file}")
sequences = FastaReader.read_fasta(trimmed_file)
for seq_id, sequence in sequences.items():
if seq_id not in written_sequences:
out.write(f">{seq_id}\n")
for i in range(0, len(sequence), 60):
out.write(sequence[i:i+60] + "\n")
written_sequences.add(seq_id)
# Then, add queries with 0 blast hits
for query_id in zero_hit_queries:
if query_id in query_sequences and query_id not in written_sequences:
query_seq = query_sequences[query_id]
out.write(f">{query_id}\n")
for i in range(0, len(query_seq), 60):
out.write(query_seq[i:i+60] + "\n")
written_sequences.add(query_id)
logger.info(f" Added zero-hit query: {query_id}")
results['combined_fasta'] = combined_fasta
logger.info(f"Combined {len(written_sequences)} sequences into: {combined_fasta}")
# Step 4: Align sequences with MAFFT (using optimal parameters for many sequences)
logger.info(f"\n[Step 4/{TOTAL_STEPS}] Aligning sequences with MAFFT...")
logger.info(f"Using '{alignment_strategy}' strategy for alignment")
if MAFFTAligner.align_sequences(
input_fasta=combined_fasta,
output_fasta=alignment_fasta,
auto_orient=True,
num_threads=num_threads,
strategy=alignment_strategy
):
results['alignment'] = alignment_fasta
logger.info(f"Alignment saved to: {alignment_fasta}")
else:
logger.error("Failed to align combined sequences")
return results
# Step 5: Build phylogenetic tree with IQTree
logger.info(f"\n[Step 5/{TOTAL_STEPS}] Building phylogenetic tree...")
if IQTreeBuilder.build_tree(
alignment_fasta=alignment_fasta,
output_prefix=tree_prefix,
num_threads=num_threads
):
results['tree'] = tree_file
logger.info(f"Tree saved to: {tree_file}")
# Relabel tree with taxonomic names
logger.info("Relabeling tree with taxonomic names...")
if IQTreeBuilder.relabel_tree_with_taxonomy(
tree_file=tree_file,
blast_hits=blast_hits,
output_tree=labeled_tree,
taxonomic_rank=combined_tree_label_rank
):
results['labeled_tree'] = labeled_tree
logger.info(f"Labeled tree saved to: {labeled_tree}")
else:
logger.warning("Failed to relabel tree, but unlabeled tree is available")
else:
logger.error("Failed to build phylogenetic tree")
return results
logger.info("\n" + "=" * 60)
logger.info("Combined tree building completed successfully!")
logger.info("=" * 60)
except Exception as e:
logger.error(f"Error in concatenate_all_groups_and_build_tree: {e}")
import traceback
logger.error(traceback.format_exc())
return results
return results
[docs]
class SimpleNewickNode:
"""
Simple Newick tree node representation for finding nearest neighbors.
"""
def __init__(self, name: str = "", distance: float = 0.0):
self.name = name
self.distance = distance
self.children: List['SimpleNewickNode'] = []
self.parent: Optional['SimpleNewickNode'] = None
[docs]
def is_leaf(self) -> bool:
"""Check if this node is a leaf."""
return len(self.children) == 0
[docs]
def get_leaves(self) -> List['SimpleNewickNode']:
"""Get all leaf nodes under this node."""
if self.is_leaf():
return [self]
leaves = []
for child in self.children:
leaves.extend(child.get_leaves())
return leaves
[docs]
def parse_simple_newick(newick_str: str) -> Optional[SimpleNewickNode]:
"""
Parse a simple Newick tree string into a tree structure.
This is a lightweight parser that handles basic Newick format with branch lengths.
Format: ((A:0.1,B:0.2):0.3,C:0.4);
Args:
newick_str: Newick format tree string
Returns:
Root node of the parsed tree, or None if parsing fails
"""
try:
# Remove trailing semicolon and whitespace
newick_str = newick_str.strip().rstrip(';').strip()
# Stack to track parent nodes
stack = []
current_node = None
i = 0
token = ""
while i < len(newick_str):
char = newick_str[i]
if char == '(':
# Start a new internal node
new_node = SimpleNewickNode()
if current_node is not None:
# Add as child to current node
current_node.children.append(new_node)
new_node.parent = current_node
stack.append(new_node)
current_node = new_node
token = ""
elif char == ',':
# Process accumulated token as a leaf
if token:
name, distance = _parse_node_info(token)
leaf = SimpleNewickNode(name=name, distance=distance)
leaf.parent = current_node
current_node.children.append(leaf)
token = ""
elif char == ')':
# Process any remaining token
if token:
name, distance = _parse_node_info(token)
leaf = SimpleNewickNode(name=name, distance=distance)
leaf.parent = current_node
current_node.children.append(leaf)
token = ""
# Look for node label and distance after ')'
i += 1
label_token = ""
while i < len(newick_str) and newick_str[i] not in '(),;':
label_token += newick_str[i]
i += 1
if label_token and current_node:
name, distance = _parse_node_info(label_token)
if name:
current_node.name = name
current_node.distance = distance
# Move up to parent
if len(stack) > 1:
stack.pop()
current_node = stack[-1]
i -= 1 # Adjust because we'll increment at the end
else:
# Accumulate characters for node name/distance
token += char
i += 1
# Process any final token
if token and current_node:
name, distance = _parse_node_info(token)
if name:
current_node.name = name
current_node.distance = distance
# Return the root (first node in stack)
return stack[0] if stack else None
except Exception as e:
logger.error(f"Error parsing Newick tree: {e}")
return None
def _parse_node_info(token: str) -> Tuple[str, float]:
"""
Parse node name and distance from a token like 'name:0.123' or just 'name' or ':0.123'.
Args:
token: Token string to parse
Returns:
Tuple of (name, distance)
"""
token = token.strip()
if ':' in token:
parts = token.split(':', 1)
name = parts[0].strip()
try:
distance = float(parts[1].strip())
except (ValueError, IndexError):
distance = 0.0
return name, distance
else:
return token, 0.0
[docs]
def find_nearest_neighbor_in_tree(
tree_file: Path,
query_id: str
) -> Optional[str]:
"""
Find the nearest neighbor (closest leaf) to a query sequence in a phylogenetic tree.
This function parses the Newick tree and finds the leaf node that is phylogenetically
closest to the query sequence based on tree topology and branch lengths.
Args:
tree_file: Path to the Newick tree file (.treefile)
query_id: Query sequence identifier to find neighbors for
Returns:
Name of the nearest neighbor leaf node, or None if not found or error
"""
try:
if not tree_file.exists():
logger.warning(f"Tree file not found: {tree_file}")
return None
# Read the tree file
tree_content = tree_file.read_text().strip()
if not tree_content:
logger.warning(f"Empty tree file: {tree_file}")
return None
# Parse the tree
root = parse_simple_newick(tree_content)
if not root:
logger.warning(f"Failed to parse tree file: {tree_file}")
return None
# Find the query node
query_node = _find_node_by_name(root, query_id)
if not query_node:
logger.warning(f"Query {query_id} not found in tree {tree_file}")
return None
# Get all leaves in the tree
all_leaves = root.get_leaves()
# Find the closest leaf (excluding the query itself)
min_distance = float('inf')
nearest_neighbor = None
for leaf in all_leaves:
# Skip the query itself
if leaf.name == query_id:
continue
# Calculate distance between query and this leaf
distance = _calculate_tree_distance(query_node, leaf)
if distance < min_distance:
min_distance = distance
nearest_neighbor = leaf.name
return nearest_neighbor
except Exception as e:
logger.error(f"Error finding nearest neighbor in tree {tree_file}: {e}")
return None
def _find_node_by_name(node: SimpleNewickNode, name: str) -> Optional[SimpleNewickNode]:
"""
Find a node with a given name in the tree.
Args:
node: Node to start search from
name: Name to search for
Returns:
Node with matching name, or None if not found
"""
if node.name == name:
return node
for child in node.children:
result = _find_node_by_name(child, name)
if result:
return result
return None
def _calculate_tree_distance(node1: SimpleNewickNode, node2: SimpleNewickNode) -> float:
"""
Calculate the phylogenetic distance between two nodes in the tree.
This finds the path from node1 to their most recent common ancestor (MRCA),
then from MRCA to node2, summing the branch lengths.
Args:
node1: First node
node2: Second node
Returns:
Total distance between the nodes
"""
# Get paths from root to each node
path1 = _get_path_to_root(node1)
path2 = _get_path_to_root(node2)
# Find the most recent common ancestor (MRCA)
# The paths are from node to root, so reverse them (without modifying originals)
path1_reversed = path1[::-1]
path2_reversed = path2[::-1]
# Find where paths diverge
mrca_index = 0
for i in range(min(len(path1_reversed), len(path2_reversed))):
if path1_reversed[i] is path2_reversed[i]:
mrca_index = i
else:
break
# Calculate distance: sum from node1 to MRCA + sum from MRCA to node2
distance = 0.0
# Distance from node1 to MRCA (excluding MRCA itself to avoid double counting)
for i in range(len(path1) - mrca_index - 1):
distance += path1[i].distance
# Distance from node2 to MRCA (excluding MRCA itself)
for i in range(len(path2) - mrca_index - 1):
distance += path2[i].distance
return distance
def _get_path_to_root(node: SimpleNewickNode) -> List[SimpleNewickNode]:
"""
Get the path from a node to the root.
Args:
node: Starting node
Returns:
List of nodes from the starting node to root
"""
path = []
current = node
while current:
path.append(current)
current = current.parent
return path