Spaces:
Runtime error
Runtime error
Witold Wydmański
commited on
Commit
·
8ee7dbf
1
Parent(s):
137a7d5
feat: Add get_esm2_embeddings function
Browse files
app.py
CHANGED
|
@@ -38,6 +38,22 @@ def fold_prot_locally(sequence):
|
|
| 38 |
pdb = convert_outputs_to_pdb(output)
|
| 39 |
return pdb
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def get_esmfold_embeddings(sequence):
|
| 42 |
logger.info("Getting embeddings for: " + sequence)
|
| 43 |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
|
@@ -165,11 +181,12 @@ with gr.Blocks() as demo:
|
|
| 165 |
with gr.Row(visible=False):
|
| 166 |
with gr.Column():
|
| 167 |
gr.Markdown("## Embeddings")
|
| 168 |
-
embs = gr.JSON(label="Embeddings"
|
| 169 |
|
| 170 |
name.change(fn=suggest, inputs=name, outputs=inp)
|
| 171 |
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
|
| 172 |
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
|
|
|
|
| 173 |
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
|
| 174 |
|
| 175 |
demo.launch()
|
|
|
|
| 38 |
pdb = convert_outputs_to_pdb(output)
|
| 39 |
return pdb
|
| 40 |
|
| 41 |
+
def get_esm2_embeddings(sequence):
|
| 42 |
+
logger.info("Getting embeddings for: " + sequence)
|
| 43 |
+
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
aa = tokenized_input
|
| 47 |
+
L = aa.shape[1]
|
| 48 |
+
device = tokenized_input.device
|
| 49 |
+
attention_mask = torch.ones_like(aa, device=device)
|
| 50 |
+
|
| 51 |
+
# === ESM ===
|
| 52 |
+
esmaa = model.af2_idx_to_esm_idx(aa, attention_mask)
|
| 53 |
+
esm_s = model.compute_language_model_representations(esmaa)
|
| 54 |
+
|
| 55 |
+
return {"res": esm_s.cpu().tolist()}
|
| 56 |
+
|
| 57 |
def get_esmfold_embeddings(sequence):
|
| 58 |
logger.info("Getting embeddings for: " + sequence)
|
| 59 |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda()
|
|
|
|
| 181 |
with gr.Row(visible=False):
|
| 182 |
with gr.Column():
|
| 183 |
gr.Markdown("## Embeddings")
|
| 184 |
+
embs = gr.JSON(label="Embeddings")
|
| 185 |
|
| 186 |
name.change(fn=suggest, inputs=name, outputs=inp)
|
| 187 |
btn.click(fold_prot_locally, inputs=[inp], outputs=[out], api_name="pdb")
|
| 188 |
btn.click(get_esmfold_embeddings, inputs=[inp], outputs=[embs], api_name="embeddings")
|
| 189 |
+
btn.click(get_esm2_embeddings, inputs=[inp], outputs=[embs], api_name="esm2_embeddings")
|
| 190 |
out.change(fn=molecule, inputs=[out], outputs=[out_mol], api_name="3d_fold")
|
| 191 |
|
| 192 |
demo.launch()
|
client.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
from gradio_client import Client
|
| 3 |
+
|
| 4 |
+
#%%
|
| 5 |
+
# client = Client("https://wwydmanski-esmfold.hf.space/")
|
| 6 |
+
client = Client("http://localhost:7860")
|
| 7 |
+
|
| 8 |
+
# %%
|
| 9 |
+
result = client.predict(
|
| 10 |
+
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", # str in 'sequence' Textbox component
|
| 11 |
+
api_name="/esm2_embeddings")
|
| 12 |
+
|
| 13 |
+
# %%
|
| 14 |
+
result
|