Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -82,7 +82,7 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
| 82 |
return vec
|
| 83 |
|
| 84 |
###############################################################################
|
| 85 |
-
# 3.
|
| 86 |
###############################################################################
|
| 87 |
|
| 88 |
def calculate_shap_values(model, x_tensor):
|
|
@@ -105,7 +105,7 @@ def calculate_shap_values(model, x_tensor):
|
|
| 105 |
|
| 106 |
|
| 107 |
###############################################################################
|
| 108 |
-
# 4. PER-BASE
|
| 109 |
###############################################################################
|
| 110 |
|
| 111 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
@@ -125,7 +125,7 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
| 125 |
return shap_means
|
| 126 |
|
| 127 |
###############################################################################
|
| 128 |
-
# 5. FIND EXTREME
|
| 129 |
###############################################################################
|
| 130 |
|
| 131 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
@@ -166,7 +166,7 @@ def get_zero_centered_cmap():
|
|
| 166 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
| 167 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
| 168 |
|
| 169 |
-
def plot_linear_heatmap(shap_means, title="Per-base
|
| 170 |
if start is not None and end is not None:
|
| 171 |
local_shap = shap_means[start:end]
|
| 172 |
subtitle = f" (positions {start}-{end})"
|
|
@@ -184,7 +184,7 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
| 184 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
| 185 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 186 |
cbar.ax.tick_params(labelsize=8)
|
| 187 |
-
cbar.set_label('
|
| 188 |
ax.set_yticks([])
|
| 189 |
ax.set_xlabel('Position in Sequence', fontsize=10)
|
| 190 |
ax.set_title(f"{title}{subtitle}", pad=10)
|
|
@@ -200,17 +200,17 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
|
|
| 200 |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
|
| 201 |
plt.barh(range(len(values)), values, color=colors)
|
| 202 |
plt.yticks(range(len(values)), features)
|
| 203 |
-
plt.xlabel('
|
| 204 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
| 205 |
plt.gca().invert_yaxis()
|
| 206 |
plt.tight_layout()
|
| 207 |
return fig
|
| 208 |
|
| 209 |
-
def plot_shap_histogram(shap_array, title="
|
| 210 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 211 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
| 212 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
| 213 |
-
ax.set_xlabel("
|
| 214 |
ax.set_ylabel("Count")
|
| 215 |
ax.set_title(title)
|
| 216 |
ax.legend()
|
|
@@ -227,23 +227,23 @@ def compute_gc_content(sequence):
|
|
| 227 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
| 228 |
###############################################################################
|
| 229 |
def create_kmer_shap_csv(kmers, shap_values):
|
| 230 |
-
"""Create a CSV file with k-mer
|
| 231 |
-
# Create DataFrame with k-mers and
|
| 232 |
kmer_df = pd.DataFrame({
|
| 233 |
'kmer': kmers,
|
| 234 |
-
'
|
| 235 |
-
'
|
| 236 |
})
|
| 237 |
|
| 238 |
-
# Sort by absolute
|
| 239 |
-
kmer_df = kmer_df.sort_values('
|
| 240 |
|
| 241 |
-
# Drop the
|
| 242 |
-
kmer_df = kmer_df[['kmer', '
|
| 243 |
|
| 244 |
# Save to temporary file
|
| 245 |
temp_dir = tempfile.gettempdir()
|
| 246 |
-
temp_path = os.path.join(temp_dir, f"
|
| 247 |
kmer_df.to_csv(temp_path, index=False)
|
| 248 |
|
| 249 |
return temp_path # Return only the file path, not a tuple
|
|
@@ -296,19 +296,19 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
| 296 |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
|
| 297 |
f"---\n"
|
| 298 |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
|
| 299 |
-
f"Start: {max_start}, End: {max_end}, Avg
|
| 300 |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
|
| 301 |
-
f"Start: {min_start}, End: {min_end}, Avg
|
| 302 |
)
|
| 303 |
|
| 304 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
| 305 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
| 306 |
bar_img = fig_to_image(bar_fig)
|
| 307 |
|
| 308 |
-
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide
|
| 309 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 310 |
|
| 311 |
-
# Create CSV with k-mer
|
| 312 |
kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
|
| 313 |
|
| 314 |
# State dictionary for subregion analysis
|
|
@@ -347,14 +347,14 @@ def analyze_subregion(state, header, region_start, region_end):
|
|
| 347 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
| 348 |
f"Region length: {len(region_seq)} bases\n"
|
| 349 |
f"GC content: {gc_percent:.2f}%\n"
|
| 350 |
-
f"Average
|
| 351 |
-
f"Fraction with
|
| 352 |
-
f"Fraction with
|
| 353 |
f"Subregion interpretation: {region_classification}\n"
|
| 354 |
)
|
| 355 |
-
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion
|
| 356 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 357 |
-
hist_fig = plot_shap_histogram(region_shap, title="
|
| 358 |
hist_img = fig_to_image(hist_fig)
|
| 359 |
|
| 360 |
# For demonstration, returning None for the file download as well
|
|
@@ -370,10 +370,10 @@ def get_zero_centered_cmap():
|
|
| 370 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
| 371 |
|
| 372 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
| 373 |
-
"""Compute the
|
| 374 |
return shap2_norm - shap1_norm
|
| 375 |
|
| 376 |
-
def plot_comparative_heatmap(shap_diff, title="
|
| 377 |
"""
|
| 378 |
Plot heatmap using relative positions (0-100%)
|
| 379 |
"""
|
|
@@ -393,7 +393,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
|
| 393 |
|
| 394 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 395 |
cbar.ax.tick_params(labelsize=8)
|
| 396 |
-
cbar.set_label('
|
| 397 |
|
| 398 |
ax.set_yticks([])
|
| 399 |
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
|
|
@@ -402,14 +402,14 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
|
|
| 402 |
|
| 403 |
return fig
|
| 404 |
|
| 405 |
-
def plot_shap_histogram(shap_array, title="
|
| 406 |
"""
|
| 407 |
-
Plot histogram of
|
| 408 |
"""
|
| 409 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 410 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
| 411 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
| 412 |
-
ax.set_xlabel("
|
| 413 |
ax.set_ylabel("Count")
|
| 414 |
ax.set_title(title)
|
| 415 |
ax.legend()
|
|
@@ -483,7 +483,7 @@ def sliding_window_smooth(values, window_size=50):
|
|
| 483 |
|
| 484 |
def normalize_shap_lengths(shap1, shap2):
|
| 485 |
"""
|
| 486 |
-
Normalize and smooth
|
| 487 |
"""
|
| 488 |
# Calculate adaptive parameters
|
| 489 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
|
@@ -517,7 +517,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
| 517 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
| 518 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
| 519 |
|
| 520 |
-
# Extract
|
| 521 |
shap1 = res1[3]["shap_means"]
|
| 522 |
shap2 = res2[3]["shap_means"]
|
| 523 |
|
|
@@ -567,7 +567,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
| 567 |
f"Smoothing Window: {smooth_window} points\n"
|
| 568 |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
|
| 569 |
"Statistics:\n"
|
| 570 |
-
f"Average
|
| 571 |
f"Standard deviation: {std_diff:.4f}\n"
|
| 572 |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
|
| 573 |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
|
|
@@ -582,7 +582,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
| 582 |
# Generate visualizations
|
| 583 |
heatmap_fig = plot_comparative_heatmap(
|
| 584 |
shap_diff,
|
| 585 |
-
title=f"
|
| 586 |
)
|
| 587 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 588 |
|
|
@@ -590,7 +590,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
|
|
| 590 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
| 591 |
hist_fig = plot_shap_histogram(
|
| 592 |
shap_diff,
|
| 593 |
-
title="Distribution of
|
| 594 |
num_bins=num_bins
|
| 595 |
)
|
| 596 |
hist_img = fig_to_image(hist_fig)
|
|
@@ -680,7 +680,7 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
|
|
| 680 |
return None, None
|
| 681 |
|
| 682 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
| 683 |
-
"""Compute statistical measures for gene
|
| 684 |
return {
|
| 685 |
'avg_shap': float(np.mean(gene_shap)),
|
| 686 |
'median_shap': float(np.median(gene_shap)),
|
|
@@ -693,7 +693,7 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
|
| 693 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
| 694 |
"""
|
| 695 |
Create a simple genome diagram using PIL, forcing a minimum color intensity
|
| 696 |
-
so that small
|
| 697 |
"""
|
| 698 |
from PIL import Image, ImageDraw, ImageFont
|
| 699 |
|
|
@@ -730,7 +730,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
| 730 |
title_font = ImageFont.load_default()
|
| 731 |
|
| 732 |
# Draw title
|
| 733 |
-
draw.text((margin, margin // 2), "Genome
|
| 734 |
|
| 735 |
# Draw genome line
|
| 736 |
line_y = height // 2
|
|
@@ -755,7 +755,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
| 755 |
], fill='black', width=1)
|
| 756 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
| 757 |
|
| 758 |
-
# Sort genes by absolute
|
| 759 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
| 760 |
|
| 761 |
# Draw genes
|
|
@@ -764,10 +764,10 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
| 764 |
start_x = margin + int(gene['start'] * scale)
|
| 765 |
end_x = margin + int(gene['end'] * scale)
|
| 766 |
|
| 767 |
-
# Calculate color based on
|
| 768 |
avg_shap = gene['avg_shap']
|
| 769 |
|
| 770 |
-
# Convert
|
| 771 |
# Then clamp to a minimum intensity so it never ends up plain white
|
| 772 |
intensity = int(abs(avg_shap) * 500)
|
| 773 |
intensity = max(50, min(255, intensity)) # clamp between 50 and 255
|
|
@@ -813,7 +813,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
|
|
| 813 |
# Draw legend
|
| 814 |
legend_x = margin
|
| 815 |
legend_y = height - margin
|
| 816 |
-
draw.text((int(legend_x), int(legend_y - 60)), "
|
| 817 |
|
| 818 |
# Draw legend boxes
|
| 819 |
box_width = 20
|
|
@@ -858,13 +858,13 @@ def analyze_gene_features(sequence_file: str,
|
|
| 858 |
features_file: str,
|
| 859 |
fasta_text: str = "",
|
| 860 |
features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
|
| 861 |
-
"""Analyze
|
| 862 |
# First analyze whole sequence
|
| 863 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
| 864 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
| 865 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
| 866 |
|
| 867 |
-
# Get
|
| 868 |
shap_means = sequence_results[3]["shap_means"]
|
| 869 |
|
| 870 |
# Parse gene features
|
|
@@ -889,7 +889,7 @@ def analyze_gene_features(sequence_file: str,
|
|
| 889 |
if start is None or end is None:
|
| 890 |
continue
|
| 891 |
|
| 892 |
-
# Get
|
| 893 |
gene_shap = shap_means[start:end]
|
| 894 |
stats = compute_gene_statistics(gene_shap)
|
| 895 |
|
|
@@ -916,7 +916,7 @@ def analyze_gene_features(sequence_file: str,
|
|
| 916 |
if not gene_results:
|
| 917 |
return "No valid genes could be processed", None, None
|
| 918 |
|
| 919 |
-
# Sort genes by absolute
|
| 920 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
| 921 |
|
| 922 |
# Create results text
|
|
@@ -932,11 +932,11 @@ def analyze_gene_features(sequence_file: str,
|
|
| 932 |
f"Location: {gene['location']}\n"
|
| 933 |
f"Classification: {gene['classification']} "
|
| 934 |
f"(confidence: {gene['confidence']:.4f})\n"
|
| 935 |
-
f"Average
|
| 936 |
)
|
| 937 |
|
| 938 |
# Create CSV content
|
| 939 |
-
csv_content = "gene_name,location,
|
| 940 |
csv_content += "pos_fraction,classification,confidence,locus_tag\n"
|
| 941 |
|
| 942 |
for gene in gene_results:
|
|
@@ -1020,11 +1020,11 @@ with gr.Blocks(css=css) as iface:
|
|
| 1020 |
gr.Markdown("""
|
| 1021 |
# Virus Host Classifier
|
| 1022 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
| 1023 |
-
**Step 2**: Explore subregions to see local
|
| 1024 |
**Step 3**: Analyze gene features and their contributions.
|
| 1025 |
**Step 4**: Compare sequences and analyze differences.
|
| 1026 |
|
| 1027 |
-
**Color Scale**: Negative
|
| 1028 |
""")
|
| 1029 |
|
| 1030 |
with gr.Tab("1) Full-Sequence Analysis"):
|
|
@@ -1043,11 +1043,11 @@ with gr.Blocks(css=css) as iface:
|
|
| 1043 |
|
| 1044 |
with gr.Column(scale=2):
|
| 1045 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
| 1046 |
-
kmer_img = gr.Image(label="Top k-mer
|
| 1047 |
-
genome_img = gr.Image(label="Genome-wide
|
| 1048 |
|
| 1049 |
# File components with the correct type parameter
|
| 1050 |
-
download_kmer_shap = gr.File(label="Download k-mer
|
| 1051 |
download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
|
| 1052 |
|
| 1053 |
seq_state = gr.State()
|
|
@@ -1071,7 +1071,7 @@ with gr.Blocks(css=css) as iface:
|
|
| 1071 |
with gr.Tab("2) Subregion Exploration"):
|
| 1072 |
gr.Markdown("""
|
| 1073 |
**Subregion Analysis**
|
| 1074 |
-
Select start/end positions to view local
|
| 1075 |
The heatmap uses the same Blue-White-Red scale.
|
| 1076 |
""")
|
| 1077 |
with gr.Row():
|
|
@@ -1080,8 +1080,8 @@ with gr.Blocks(css=css) as iface:
|
|
| 1080 |
region_btn = gr.Button("Analyze Subregion")
|
| 1081 |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
|
| 1082 |
with gr.Row():
|
| 1083 |
-
subregion_img = gr.Image(label="Subregion
|
| 1084 |
-
subregion_hist_img = gr.Image(label="
|
| 1085 |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
|
| 1086 |
|
| 1087 |
region_btn.click(
|
|
@@ -1093,12 +1093,11 @@ with gr.Blocks(css=css) as iface:
|
|
| 1093 |
with gr.Tab("3) Gene Features Analysis"):
|
| 1094 |
gr.Markdown("""
|
| 1095 |
**Analyze Gene Features**
|
| 1096 |
-
Upload a FASTA file and corresponding gene features file to analyze
|
| 1097 |
Gene features should be in the format:
|
| 1098 |
|
| 1099 |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
|
| 1100 |
SEQUENCE
|
| 1101 |
-
|
| 1102 |
The genome viewer will show genes color-coded by their contribution:
|
| 1103 |
- Red: Genes pushing toward human origin
|
| 1104 |
- Blue: Genes pushing toward non-human origin
|
|
@@ -1126,7 +1125,7 @@ with gr.Blocks(css=css) as iface:
|
|
| 1126 |
with gr.Tab("4) Comparative Analysis"):
|
| 1127 |
gr.Markdown("""
|
| 1128 |
**Compare Two Sequences**
|
| 1129 |
-
Upload or paste two FASTA sequences to compare their
|
| 1130 |
The sequences will be normalized to the same length for comparison.
|
| 1131 |
|
| 1132 |
**Color Scale**:
|
|
@@ -1144,8 +1143,8 @@ with gr.Blocks(css=css) as iface:
|
|
| 1144 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
| 1145 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
| 1146 |
with gr.Row():
|
| 1147 |
-
diff_heatmap = gr.Image(label="
|
| 1148 |
-
diff_hist = gr.Image(label="Distribution of
|
| 1149 |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
|
| 1150 |
|
| 1151 |
compare_btn.click(
|
|
@@ -1157,8 +1156,8 @@ with gr.Blocks(css=css) as iface:
|
|
| 1157 |
gr.Markdown("""
|
| 1158 |
### Interface Features
|
| 1159 |
- **Overall Classification** (human vs non-human) using k-mer frequencies
|
| 1160 |
-
- **
|
| 1161 |
-
- **White-Centered
|
| 1162 |
- Negative (blue), 0 (white), Positive (red)
|
| 1163 |
- Symmetrical color range around 0
|
| 1164 |
- **Identify Subregions** with strongest push for human or non-human
|
|
@@ -1172,7 +1171,7 @@ with gr.Blocks(css=css) as iface:
|
|
| 1172 |
- Statistical summary of differences
|
| 1173 |
- **Data Export**:
|
| 1174 |
- Download results as CSV files
|
| 1175 |
-
- Download k-mer
|
| 1176 |
- Save analysis outputs for further processing
|
| 1177 |
""")
|
| 1178 |
|
|
|
|
| 82 |
return vec
|
| 83 |
|
| 84 |
###############################################################################
|
| 85 |
+
# 3. FEATURE IMPORTANCE (ABLATION) CALCULATION
|
| 86 |
###############################################################################
|
| 87 |
|
| 88 |
def calculate_shap_values(model, x_tensor):
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
###############################################################################
|
| 108 |
+
# 4. PER-BASE FEATURE IMPORTANCE AGGREGATION
|
| 109 |
###############################################################################
|
| 110 |
|
| 111 |
def compute_positionwise_scores(sequence, shap_values, k=4):
|
|
|
|
| 125 |
return shap_means
|
| 126 |
|
| 127 |
###############################################################################
|
| 128 |
+
# 5. FIND EXTREME IMPORTANCE REGIONS
|
| 129 |
###############################################################################
|
| 130 |
|
| 131 |
def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
|
|
| 166 |
colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
|
| 167 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
| 168 |
|
| 169 |
+
def plot_linear_heatmap(shap_means, title="Per-base Feature Importance Heatmap", start=None, end=None):
|
| 170 |
if start is not None and end is not None:
|
| 171 |
local_shap = shap_means[start:end]
|
| 172 |
subtitle = f" (positions {start}-{end})"
|
|
|
|
| 184 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap=cmap, vmin=-extent, vmax=extent)
|
| 185 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 186 |
cbar.ax.tick_params(labelsize=8)
|
| 187 |
+
cbar.set_label('Feature Importance', fontsize=9, labelpad=5)
|
| 188 |
ax.set_yticks([])
|
| 189 |
ax.set_xlabel('Position in Sequence', fontsize=10)
|
| 190 |
ax.set_title(f"{title}{subtitle}", pad=10)
|
|
|
|
| 200 |
colors = ['#99ccff' if v < 0 else '#ff9999' for v in values]
|
| 201 |
plt.barh(range(len(values)), values, color=colors)
|
| 202 |
plt.yticks(range(len(values)), features)
|
| 203 |
+
plt.xlabel('Feature Importance (impact on model output)')
|
| 204 |
plt.title(f'Top {top_k} Most Influential k-mers')
|
| 205 |
plt.gca().invert_yaxis()
|
| 206 |
plt.tight_layout()
|
| 207 |
return fig
|
| 208 |
|
| 209 |
+
def plot_shap_histogram(shap_array, title="Feature Importance Distribution in Region", num_bins=30):
|
| 210 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 211 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
|
| 212 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
| 213 |
+
ax.set_xlabel("Feature Importance Value")
|
| 214 |
ax.set_ylabel("Count")
|
| 215 |
ax.set_title(title)
|
| 216 |
ax.legend()
|
|
|
|
| 227 |
# 7. MAIN ANALYSIS STEP (Gradio Step 1)
|
| 228 |
###############################################################################
|
| 229 |
def create_kmer_shap_csv(kmers, shap_values):
|
| 230 |
+
"""Create a CSV file with k-mer importance values and return the filepath"""
|
| 231 |
+
# Create DataFrame with k-mers and importance values
|
| 232 |
kmer_df = pd.DataFrame({
|
| 233 |
'kmer': kmers,
|
| 234 |
+
'importance_value': shap_values,
|
| 235 |
+
'abs_importance': np.abs(shap_values)
|
| 236 |
})
|
| 237 |
|
| 238 |
+
# Sort by absolute importance value (most influential first)
|
| 239 |
+
kmer_df = kmer_df.sort_values('abs_importance', ascending=False)
|
| 240 |
|
| 241 |
+
# Drop the abs_importance column used for sorting
|
| 242 |
+
kmer_df = kmer_df[['kmer', 'importance_value']]
|
| 243 |
|
| 244 |
# Save to temporary file
|
| 245 |
temp_dir = tempfile.gettempdir()
|
| 246 |
+
temp_path = os.path.join(temp_dir, f"kmer_importance_values_{os.urandom(4).hex()}.csv")
|
| 247 |
kmer_df.to_csv(temp_path, index=False)
|
| 248 |
|
| 249 |
return temp_path # Return only the file path, not a tuple
|
|
|
|
| 296 |
f"(Human Probability: {prob_human:.3f}, Non-human Probability: {prob_nonhuman:.3f})\n\n"
|
| 297 |
f"---\n"
|
| 298 |
f"**Most Human-Pushing {window_size}-bp Subregion**:\n"
|
| 299 |
+
f"Start: {max_start}, End: {max_end}, Avg Importance: {max_avg:.4f}\n\n"
|
| 300 |
f"**Most Non-Human–Pushing {window_size}-bp Subregion**:\n"
|
| 301 |
+
f"Start: {min_start}, End: {min_end}, Avg Importance: {min_avg:.4f}"
|
| 302 |
)
|
| 303 |
|
| 304 |
kmers = [''.join(p) for p in product("ACGT", repeat=4)]
|
| 305 |
bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
|
| 306 |
bar_img = fig_to_image(bar_fig)
|
| 307 |
|
| 308 |
+
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide Feature Importance")
|
| 309 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 310 |
|
| 311 |
+
# Create CSV with k-mer importance values and return the file path
|
| 312 |
kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
|
| 313 |
|
| 314 |
# State dictionary for subregion analysis
|
|
|
|
| 347 |
f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
|
| 348 |
f"Region length: {len(region_seq)} bases\n"
|
| 349 |
f"GC content: {gc_percent:.2f}%\n"
|
| 350 |
+
f"Average importance in region: {avg_shap:.4f}\n"
|
| 351 |
+
f"Fraction with importance > 0 (toward human): {positive_fraction:.2f}\n"
|
| 352 |
+
f"Fraction with importance < 0 (toward non-human): {negative_fraction:.2f}\n"
|
| 353 |
f"Subregion interpretation: {region_classification}\n"
|
| 354 |
)
|
| 355 |
+
heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion Feature Importance", start=region_start, end=region_end)
|
| 356 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 357 |
+
hist_fig = plot_shap_histogram(region_shap, title="Feature Importance Distribution in Subregion")
|
| 358 |
hist_img = fig_to_image(hist_fig)
|
| 359 |
|
| 360 |
# For demonstration, returning None for the file download as well
|
|
|
|
| 370 |
return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
|
| 371 |
|
| 372 |
def compute_shap_difference(shap1_norm, shap2_norm):
|
| 373 |
+
"""Compute the feature importance difference between normalized sequences"""
|
| 374 |
return shap2_norm - shap1_norm
|
| 375 |
|
| 376 |
+
def plot_comparative_heatmap(shap_diff, title="Feature Importance Difference Heatmap"):
|
| 377 |
"""
|
| 378 |
Plot heatmap using relative positions (0-100%)
|
| 379 |
"""
|
|
|
|
| 393 |
|
| 394 |
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25, aspect=40, shrink=0.8)
|
| 395 |
cbar.ax.tick_params(labelsize=8)
|
| 396 |
+
cbar.set_label('Feature Importance Difference (Seq2 - Seq1)', fontsize=9, labelpad=5)
|
| 397 |
|
| 398 |
ax.set_yticks([])
|
| 399 |
ax.set_xlabel('Relative Position in Sequence', fontsize=10)
|
|
|
|
| 402 |
|
| 403 |
return fig
|
| 404 |
|
| 405 |
+
def plot_shap_histogram(shap_array, title="Feature Importance Distribution", num_bins=30):
|
| 406 |
"""
|
| 407 |
+
Plot histogram of feature importance values with configurable number of bins
|
| 408 |
"""
|
| 409 |
fig, ax = plt.subplots(figsize=(6, 4))
|
| 410 |
ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
|
| 411 |
ax.axvline(0, color='red', linestyle='--', label='0.0')
|
| 412 |
+
ax.set_xlabel("Feature Importance Value")
|
| 413 |
ax.set_ylabel("Count")
|
| 414 |
ax.set_title(title)
|
| 415 |
ax.legend()
|
|
|
|
| 483 |
|
| 484 |
def normalize_shap_lengths(shap1, shap2):
|
| 485 |
"""
|
| 486 |
+
Normalize and smooth feature importance values with dynamic adaptation
|
| 487 |
"""
|
| 488 |
# Calculate adaptive parameters
|
| 489 |
num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
|
|
|
|
| 517 |
if isinstance(res2[0], str) and "Error" in res2[0]:
|
| 518 |
return (f"Error in sequence 2: {res2[0]}", None, None, None)
|
| 519 |
|
| 520 |
+
# Extract feature importance values and sequence info
|
| 521 |
shap1 = res1[3]["shap_means"]
|
| 522 |
shap2 = res2[3]["shap_means"]
|
| 523 |
|
|
|
|
| 567 |
f"Smoothing Window: {smooth_window} points\n"
|
| 568 |
f"Adaptive Threshold: {adaptive_threshold:.3f}\n\n"
|
| 569 |
"Statistics:\n"
|
| 570 |
+
f"Average feature importance difference: {avg_diff:.4f}\n"
|
| 571 |
f"Standard deviation: {std_diff:.4f}\n"
|
| 572 |
f"Max difference: {max_diff:.4f} (Seq2 more human-like)\n"
|
| 573 |
f"Min difference: {min_diff:.4f} (Seq1 more human-like)\n"
|
|
|
|
| 582 |
# Generate visualizations
|
| 583 |
heatmap_fig = plot_comparative_heatmap(
|
| 584 |
shap_diff,
|
| 585 |
+
title=f"Feature Importance Difference Heatmap (window: {smooth_window})"
|
| 586 |
)
|
| 587 |
heatmap_img = fig_to_image(heatmap_fig)
|
| 588 |
|
|
|
|
| 590 |
num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
|
| 591 |
hist_fig = plot_shap_histogram(
|
| 592 |
shap_diff,
|
| 593 |
+
title="Distribution of Feature Importance Differences",
|
| 594 |
num_bins=num_bins
|
| 595 |
)
|
| 596 |
hist_img = fig_to_image(hist_fig)
|
|
|
|
| 680 |
return None, None
|
| 681 |
|
| 682 |
def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
|
| 683 |
+
"""Compute statistical measures for gene feature importance values"""
|
| 684 |
return {
|
| 685 |
'avg_shap': float(np.mean(gene_shap)),
|
| 686 |
'median_shap': float(np.median(gene_shap)),
|
|
|
|
| 693 |
def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
|
| 694 |
"""
|
| 695 |
Create a simple genome diagram using PIL, forcing a minimum color intensity
|
| 696 |
+
so that small feature importance values don't appear white.
|
| 697 |
"""
|
| 698 |
from PIL import Image, ImageDraw, ImageFont
|
| 699 |
|
|
|
|
| 730 |
title_font = ImageFont.load_default()
|
| 731 |
|
| 732 |
# Draw title
|
| 733 |
+
draw.text((margin, margin // 2), "Genome Feature Importance Analysis", fill='black', font=title_font or font)
|
| 734 |
|
| 735 |
# Draw genome line
|
| 736 |
line_y = height // 2
|
|
|
|
| 755 |
], fill='black', width=1)
|
| 756 |
draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
|
| 757 |
|
| 758 |
+
# Sort genes by absolute feature importance value for drawing
|
| 759 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
|
| 760 |
|
| 761 |
# Draw genes
|
|
|
|
| 764 |
start_x = margin + int(gene['start'] * scale)
|
| 765 |
end_x = margin + int(gene['end'] * scale)
|
| 766 |
|
| 767 |
+
# Calculate color based on feature importance value
|
| 768 |
avg_shap = gene['avg_shap']
|
| 769 |
|
| 770 |
+
# Convert importance -> color intensity (0 to 255)
|
| 771 |
# Then clamp to a minimum intensity so it never ends up plain white
|
| 772 |
intensity = int(abs(avg_shap) * 500)
|
| 773 |
intensity = max(50, min(255, intensity)) # clamp between 50 and 255
|
|
|
|
| 813 |
# Draw legend
|
| 814 |
legend_x = margin
|
| 815 |
legend_y = height - margin
|
| 816 |
+
draw.text((int(legend_x), int(legend_y - 60)), "Feature Importance Values:", fill='black', font=font)
|
| 817 |
|
| 818 |
# Draw legend boxes
|
| 819 |
box_width = 20
|
|
|
|
| 858 |
features_file: str,
|
| 859 |
fasta_text: str = "",
|
| 860 |
features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
|
| 861 |
+
"""Analyze feature importance values for each gene feature"""
|
| 862 |
# First analyze whole sequence
|
| 863 |
sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
|
| 864 |
if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
|
| 865 |
return f"Error in sequence analysis: {sequence_results[0]}", None, None
|
| 866 |
|
| 867 |
+
# Get feature importance values
|
| 868 |
shap_means = sequence_results[3]["shap_means"]
|
| 869 |
|
| 870 |
# Parse gene features
|
|
|
|
| 889 |
if start is None or end is None:
|
| 890 |
continue
|
| 891 |
|
| 892 |
+
# Get feature importance values for this region
|
| 893 |
gene_shap = shap_means[start:end]
|
| 894 |
stats = compute_gene_statistics(gene_shap)
|
| 895 |
|
|
|
|
| 916 |
if not gene_results:
|
| 917 |
return "No valid genes could be processed", None, None
|
| 918 |
|
| 919 |
+
# Sort genes by absolute feature importance value
|
| 920 |
sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
|
| 921 |
|
| 922 |
# Create results text
|
|
|
|
| 932 |
f"Location: {gene['location']}\n"
|
| 933 |
f"Classification: {gene['classification']} "
|
| 934 |
f"(confidence: {gene['confidence']:.4f})\n"
|
| 935 |
+
f"Average Feature Importance: {gene['avg_shap']:.4f}\n\n"
|
| 936 |
)
|
| 937 |
|
| 938 |
# Create CSV content
|
| 939 |
+
csv_content = "gene_name,location,avg_importance,median_importance,std_importance,max_importance,min_importance,"
|
| 940 |
csv_content += "pos_fraction,classification,confidence,locus_tag\n"
|
| 941 |
|
| 942 |
for gene in gene_results:
|
|
|
|
| 1020 |
gr.Markdown("""
|
| 1021 |
# Virus Host Classifier
|
| 1022 |
**Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
|
| 1023 |
+
**Step 2**: Explore subregions to see local feature influence, distribution, GC content, etc.
|
| 1024 |
**Step 3**: Analyze gene features and their contributions.
|
| 1025 |
**Step 4**: Compare sequences and analyze differences.
|
| 1026 |
|
| 1027 |
+
**Color Scale**: Negative values = Blue, Zero = White, Positive values = Red.
|
| 1028 |
""")
|
| 1029 |
|
| 1030 |
with gr.Tab("1) Full-Sequence Analysis"):
|
|
|
|
| 1043 |
|
| 1044 |
with gr.Column(scale=2):
|
| 1045 |
results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
|
| 1046 |
+
kmer_img = gr.Image(label="Top k-mer Importance")
|
| 1047 |
+
genome_img = gr.Image(label="Genome-wide Feature Importance Heatmap (Blue=neg, White=0, Red=pos)")
|
| 1048 |
|
| 1049 |
# File components with the correct type parameter
|
| 1050 |
+
download_kmer_shap = gr.File(label="Download k-mer Importance Values (CSV)", visible=True, type="filepath")
|
| 1051 |
download_results = gr.File(label="Download Results", visible=True, elem_classes="download-button")
|
| 1052 |
|
| 1053 |
seq_state = gr.State()
|
|
|
|
| 1071 |
with gr.Tab("2) Subregion Exploration"):
|
| 1072 |
gr.Markdown("""
|
| 1073 |
**Subregion Analysis**
|
| 1074 |
+
Select start/end positions to view local feature importance, distribution, GC content, etc.
|
| 1075 |
The heatmap uses the same Blue-White-Red scale.
|
| 1076 |
""")
|
| 1077 |
with gr.Row():
|
|
|
|
| 1080 |
region_btn = gr.Button("Analyze Subregion")
|
| 1081 |
subregion_info = gr.Textbox(label="Subregion Analysis", lines=7, interactive=False)
|
| 1082 |
with gr.Row():
|
| 1083 |
+
subregion_img = gr.Image(label="Subregion Feature Importance Heatmap (B-W-R)")
|
| 1084 |
+
subregion_hist_img = gr.Image(label="Feature Importance Distribution (Histogram)")
|
| 1085 |
download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
|
| 1086 |
|
| 1087 |
region_btn.click(
|
|
|
|
| 1093 |
with gr.Tab("3) Gene Features Analysis"):
|
| 1094 |
gr.Markdown("""
|
| 1095 |
**Analyze Gene Features**
|
| 1096 |
+
Upload a FASTA file and corresponding gene features file to analyze feature importance values per gene.
|
| 1097 |
Gene features should be in the format:
|
| 1098 |
|
| 1099 |
>gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
|
| 1100 |
SEQUENCE
|
|
|
|
| 1101 |
The genome viewer will show genes color-coded by their contribution:
|
| 1102 |
- Red: Genes pushing toward human origin
|
| 1103 |
- Blue: Genes pushing toward non-human origin
|
|
|
|
| 1125 |
with gr.Tab("4) Comparative Analysis"):
|
| 1126 |
gr.Markdown("""
|
| 1127 |
**Compare Two Sequences**
|
| 1128 |
+
Upload or paste two FASTA sequences to compare their feature importance patterns.
|
| 1129 |
The sequences will be normalized to the same length for comparison.
|
| 1130 |
|
| 1131 |
**Color Scale**:
|
|
|
|
| 1143 |
compare_btn = gr.Button("Compare Sequences", variant="primary")
|
| 1144 |
comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
|
| 1145 |
with gr.Row():
|
| 1146 |
+
diff_heatmap = gr.Image(label="Feature Importance Difference Heatmap")
|
| 1147 |
+
diff_hist = gr.Image(label="Distribution of Feature Importance Differences")
|
| 1148 |
download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
|
| 1149 |
|
| 1150 |
compare_btn.click(
|
|
|
|
| 1156 |
gr.Markdown("""
|
| 1157 |
### Interface Features
|
| 1158 |
- **Overall Classification** (human vs non-human) using k-mer frequencies
|
| 1159 |
+
- **Feature Importance Analysis** shows which k-mers push classification toward or away from human
|
| 1160 |
+
- **White-Centered Gradient**:
|
| 1161 |
- Negative (blue), 0 (white), Positive (red)
|
| 1162 |
- Symmetrical color range around 0
|
| 1163 |
- **Identify Subregions** with strongest push for human or non-human
|
|
|
|
| 1171 |
- Statistical summary of differences
|
| 1172 |
- **Data Export**:
|
| 1173 |
- Download results as CSV files
|
| 1174 |
+
- Download k-mer importance values
|
| 1175 |
- Save analysis outputs for further processing
|
| 1176 |
""")
|
| 1177 |
|