Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
Β·
39503cb
1
Parent(s):
8d3aacc
Adds linter and fixes linting
Browse files- app.py +1 -3
- core/pipelines.py +14 -4
- core/search_index.py +9 -5
- interface/components.py +23 -13
- interface/pages.py +25 -19
- linter.sh +1 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -5,9 +5,7 @@ st.set_page_config(
|
|
| 5 |
page_icon="π",
|
| 6 |
layout="wide",
|
| 7 |
initial_sidebar_state="expanded",
|
| 8 |
-
menu_items={
|
| 9 |
-
'About': "https://github.com/ugm2/neural-search-demo"
|
| 10 |
-
}
|
| 11 |
)
|
| 12 |
|
| 13 |
from streamlit_option_menu import option_menu
|
|
|
|
| 5 |
page_icon="π",
|
| 6 |
layout="wide",
|
| 7 |
initial_sidebar_state="expanded",
|
| 8 |
+
menu_items={"About": "https://github.com/ugm2/neural-search-demo"},
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
from streamlit_option_menu import option_menu
|
core/pipelines.py
CHANGED
|
@@ -9,9 +9,10 @@ from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
|
| 9 |
from haystack.nodes.preprocessor import PreProcessor
|
| 10 |
import streamlit as st
|
| 11 |
|
|
|
|
| 12 |
@st.cache(allow_output_mutation=True)
|
| 13 |
def keyword_search(
|
| 14 |
-
index=
|
| 15 |
):
|
| 16 |
document_store = InMemoryDocumentStore(index=index)
|
| 17 |
keyword_retriever = TfidfRetriever(document_store=(document_store))
|
|
@@ -31,16 +32,25 @@ def keyword_search(
|
|
| 31 |
# INDEXING PIPELINE
|
| 32 |
index_pipeline = Pipeline()
|
| 33 |
index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
|
| 34 |
-
index_pipeline.add_node(
|
|
|
|
|
|
|
| 35 |
index_pipeline.add_node(
|
| 36 |
document_store, name="DocumentStore", inputs=["TfidfRetriever"]
|
| 37 |
)
|
| 38 |
|
| 39 |
return search_pipeline, index_pipeline
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def dense_passage_retrieval(
|
| 43 |
-
index=
|
| 44 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 45 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
| 46 |
):
|
|
|
|
| 9 |
from haystack.nodes.preprocessor import PreProcessor
|
| 10 |
import streamlit as st
|
| 11 |
|
| 12 |
+
|
| 13 |
@st.cache(allow_output_mutation=True)
|
| 14 |
def keyword_search(
|
| 15 |
+
index="documents",
|
| 16 |
):
|
| 17 |
document_store = InMemoryDocumentStore(index=index)
|
| 18 |
keyword_retriever = TfidfRetriever(document_store=(document_store))
|
|
|
|
| 32 |
# INDEXING PIPELINE
|
| 33 |
index_pipeline = Pipeline()
|
| 34 |
index_pipeline.add_node(processor, name="Preprocessor", inputs=["File"])
|
| 35 |
+
index_pipeline.add_node(
|
| 36 |
+
keyword_retriever, name="TfidfRetriever", inputs=["Preprocessor"]
|
| 37 |
+
)
|
| 38 |
index_pipeline.add_node(
|
| 39 |
document_store, name="DocumentStore", inputs=["TfidfRetriever"]
|
| 40 |
)
|
| 41 |
|
| 42 |
return search_pipeline, index_pipeline
|
| 43 |
|
| 44 |
+
|
| 45 |
+
@st.cache(
|
| 46 |
+
hash_funcs={
|
| 47 |
+
tokenizers.Tokenizer: lambda _: None,
|
| 48 |
+
tokenizers.AddedToken: lambda _: None,
|
| 49 |
+
},
|
| 50 |
+
allow_output_mutation=True,
|
| 51 |
+
)
|
| 52 |
def dense_passage_retrieval(
|
| 53 |
+
index="documents",
|
| 54 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
| 55 |
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
| 56 |
):
|
core/search_index.py
CHANGED
|
@@ -6,9 +6,9 @@ def format_docs(documents):
|
|
| 6 |
"""Given a list of documents, format the documents and return the documents and doc ids."""
|
| 7 |
db_docs: list = []
|
| 8 |
for doc in documents:
|
| 9 |
-
doc_id = doc[
|
| 10 |
db_doc = {
|
| 11 |
-
"content": doc[
|
| 12 |
"content_type": "text",
|
| 13 |
"id": str(uuid.uuid4()),
|
| 14 |
"meta": {"id": doc_id},
|
|
@@ -16,11 +16,13 @@ def format_docs(documents):
|
|
| 16 |
db_docs.append(Document(**db_doc))
|
| 17 |
return db_docs, [doc.meta["id"] for doc in db_docs]
|
| 18 |
|
|
|
|
| 19 |
def index(documents, pipeline):
|
| 20 |
documents, doc_ids = format_docs(documents)
|
| 21 |
pipeline.run(documents=documents)
|
| 22 |
return doc_ids
|
| 23 |
|
|
|
|
| 24 |
def search(queries, pipeline):
|
| 25 |
results = []
|
| 26 |
matches_queries = pipeline.run_batch(queries=queries)
|
|
@@ -35,10 +37,12 @@ def search(queries, pipeline):
|
|
| 35 |
"text": res.content,
|
| 36 |
"score": res.score,
|
| 37 |
"id": res.meta["id"],
|
| 38 |
-
"fragment_id": res.id
|
| 39 |
}
|
| 40 |
)
|
| 41 |
if not score_is_empty:
|
| 42 |
-
query_results = sorted(
|
|
|
|
|
|
|
| 43 |
results.append(query_results)
|
| 44 |
-
return results
|
|
|
|
| 6 |
"""Given a list of documents, format the documents and return the documents and doc ids."""
|
| 7 |
db_docs: list = []
|
| 8 |
for doc in documents:
|
| 9 |
+
doc_id = doc["id"] if doc["id"] is not None else str(uuid.uuid4())
|
| 10 |
db_doc = {
|
| 11 |
+
"content": doc["text"],
|
| 12 |
"content_type": "text",
|
| 13 |
"id": str(uuid.uuid4()),
|
| 14 |
"meta": {"id": doc_id},
|
|
|
|
| 16 |
db_docs.append(Document(**db_doc))
|
| 17 |
return db_docs, [doc.meta["id"] for doc in db_docs]
|
| 18 |
|
| 19 |
+
|
| 20 |
def index(documents, pipeline):
|
| 21 |
documents, doc_ids = format_docs(documents)
|
| 22 |
pipeline.run(documents=documents)
|
| 23 |
return doc_ids
|
| 24 |
|
| 25 |
+
|
| 26 |
def search(queries, pipeline):
|
| 27 |
results = []
|
| 28 |
matches_queries = pipeline.run_batch(queries=queries)
|
|
|
|
| 37 |
"text": res.content,
|
| 38 |
"score": res.score,
|
| 39 |
"id": res.meta["id"],
|
| 40 |
+
"fragment_id": res.id,
|
| 41 |
}
|
| 42 |
)
|
| 43 |
if not score_is_empty:
|
| 44 |
+
query_results = sorted(
|
| 45 |
+
query_results, key=lambda x: x["score"], reverse=True
|
| 46 |
+
)
|
| 47 |
results.append(query_results)
|
| 48 |
+
return results
|
interface/components.py
CHANGED
|
@@ -3,36 +3,47 @@ import core.pipelines as pipelines_functions
|
|
| 3 |
from inspect import getmembers, isfunction
|
| 4 |
from networkx.drawing.nx_agraph import to_agraph
|
| 5 |
|
|
|
|
| 6 |
def component_select_pipeline(container):
|
| 7 |
-
pipeline_names, pipeline_funcs = list(
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
with container:
|
| 10 |
selected_pipeline = st.selectbox(
|
| 11 |
-
|
| 12 |
pipeline_names,
|
| 13 |
-
index=pipeline_names.index(
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
-
|
| 16 |
-
st.session_state[
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def component_show_pipeline(container, pipeline):
|
| 20 |
"""Draw the pipeline"""
|
| 21 |
-
with st.expander(
|
| 22 |
graphviz = to_agraph(pipeline.graph)
|
| 23 |
graphviz.layout("dot")
|
| 24 |
st.graphviz_chart(graphviz.string())
|
| 25 |
-
|
|
|
|
| 26 |
def component_show_search_result(container, results):
|
| 27 |
with container:
|
| 28 |
for idx, document in enumerate(results):
|
| 29 |
st.markdown(f"### Match {idx+1}")
|
| 30 |
st.markdown(f"**Text**: {document['text']}")
|
| 31 |
st.markdown(f"**Document**: {document['id']}")
|
| 32 |
-
if document[
|
| 33 |
st.markdown(f"**Score**: {document['score']:.3f}")
|
| 34 |
st.markdown("---")
|
| 35 |
|
|
|
|
| 36 |
def component_text_input(container):
|
| 37 |
"""Draw the Text Input widget"""
|
| 38 |
with container:
|
|
@@ -48,7 +59,6 @@ def component_text_input(container):
|
|
| 48 |
else:
|
| 49 |
break
|
| 50 |
corpus = [
|
| 51 |
-
{"text": doc["text"], "id": doc_id}
|
| 52 |
-
for doc_id, doc in enumerate(texts)
|
| 53 |
]
|
| 54 |
-
return corpus
|
|
|
|
| 3 |
from inspect import getmembers, isfunction
|
| 4 |
from networkx.drawing.nx_agraph import to_agraph
|
| 5 |
|
| 6 |
+
|
| 7 |
def component_select_pipeline(container):
|
| 8 |
+
pipeline_names, pipeline_funcs = list(
|
| 9 |
+
zip(*getmembers(pipelines_functions, isfunction))
|
| 10 |
+
)
|
| 11 |
+
pipeline_names = [
|
| 12 |
+
" ".join([n.capitalize() for n in name.split("_")]) for name in pipeline_names
|
| 13 |
+
]
|
| 14 |
with container:
|
| 15 |
selected_pipeline = st.selectbox(
|
| 16 |
+
"Select pipeline",
|
| 17 |
pipeline_names,
|
| 18 |
+
index=pipeline_names.index("Keyword Search")
|
| 19 |
+
if "Keyword Search" in pipeline_names
|
| 20 |
+
else 0,
|
| 21 |
)
|
| 22 |
+
(
|
| 23 |
+
st.session_state["search_pipeline"],
|
| 24 |
+
st.session_state["index_pipeline"],
|
| 25 |
+
) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
|
| 26 |
+
|
| 27 |
|
| 28 |
def component_show_pipeline(container, pipeline):
|
| 29 |
"""Draw the pipeline"""
|
| 30 |
+
with st.expander("Show pipeline"):
|
| 31 |
graphviz = to_agraph(pipeline.graph)
|
| 32 |
graphviz.layout("dot")
|
| 33 |
st.graphviz_chart(graphviz.string())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
def component_show_search_result(container, results):
|
| 37 |
with container:
|
| 38 |
for idx, document in enumerate(results):
|
| 39 |
st.markdown(f"### Match {idx+1}")
|
| 40 |
st.markdown(f"**Text**: {document['text']}")
|
| 41 |
st.markdown(f"**Document**: {document['id']}")
|
| 42 |
+
if document["score"] is not None:
|
| 43 |
st.markdown(f"**Score**: {document['score']:.3f}")
|
| 44 |
st.markdown("---")
|
| 45 |
|
| 46 |
+
|
| 47 |
def component_text_input(container):
|
| 48 |
"""Draw the Text Input widget"""
|
| 49 |
with container:
|
|
|
|
| 59 |
else:
|
| 60 |
break
|
| 61 |
corpus = [
|
| 62 |
+
{"text": doc["text"], "id": doc_id} for doc_id, doc in enumerate(texts)
|
|
|
|
| 63 |
]
|
| 64 |
+
return corpus
|
interface/pages.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit_option_menu import option_menu
|
| 3 |
from core.search_index import index, search
|
| 4 |
-
from interface.components import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def page_landing_page(container):
|
| 7 |
with container:
|
|
@@ -22,33 +27,34 @@ def page_landing_page(container):
|
|
| 22 |
"\n - Include file/url indexing"
|
| 23 |
"\n - [Optional] Include text to audio to read responses"
|
| 24 |
)
|
| 25 |
-
|
|
|
|
| 26 |
def page_search(container):
|
| 27 |
with container:
|
| 28 |
st.title("Query me!")
|
| 29 |
-
|
| 30 |
## SEARCH ##
|
| 31 |
query = st.text_input("Query")
|
| 32 |
-
|
| 33 |
-
component_show_pipeline(container, st.session_state[
|
| 34 |
-
|
| 35 |
if st.button("Search"):
|
| 36 |
-
st.session_state[
|
| 37 |
queries=[query],
|
| 38 |
-
pipeline=st.session_state[
|
| 39 |
)
|
| 40 |
-
if
|
| 41 |
component_show_search_result(
|
| 42 |
-
container=container,
|
| 43 |
-
results=st.session_state['search_results'][0]
|
| 44 |
)
|
| 45 |
-
|
|
|
|
| 46 |
def page_index(container):
|
| 47 |
with container:
|
| 48 |
st.title("Index time!")
|
| 49 |
-
|
| 50 |
-
component_show_pipeline(container, st.session_state[
|
| 51 |
-
|
| 52 |
input_funcs = {
|
| 53 |
"Raw Text": (component_text_input, "card-text"),
|
| 54 |
}
|
|
@@ -60,15 +66,15 @@ def page_index(container):
|
|
| 60 |
default_index=0,
|
| 61 |
orientation="horizontal",
|
| 62 |
)
|
| 63 |
-
|
| 64 |
corpus = input_funcs[selected_input][0](container)
|
| 65 |
-
|
| 66 |
if len(corpus) > 0:
|
| 67 |
index_results = None
|
| 68 |
if st.button("Index"):
|
| 69 |
index_results = index(
|
| 70 |
corpus,
|
| 71 |
-
st.session_state[
|
| 72 |
)
|
| 73 |
if index_results:
|
| 74 |
-
st.write(index_results)
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit_option_menu import option_menu
|
| 3 |
from core.search_index import index, search
|
| 4 |
+
from interface.components import (
|
| 5 |
+
component_show_pipeline,
|
| 6 |
+
component_show_search_result,
|
| 7 |
+
component_text_input,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
|
| 11 |
def page_landing_page(container):
|
| 12 |
with container:
|
|
|
|
| 27 |
"\n - Include file/url indexing"
|
| 28 |
"\n - [Optional] Include text to audio to read responses"
|
| 29 |
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
def page_search(container):
|
| 33 |
with container:
|
| 34 |
st.title("Query me!")
|
| 35 |
+
|
| 36 |
## SEARCH ##
|
| 37 |
query = st.text_input("Query")
|
| 38 |
+
|
| 39 |
+
component_show_pipeline(container, st.session_state["search_pipeline"])
|
| 40 |
+
|
| 41 |
if st.button("Search"):
|
| 42 |
+
st.session_state["search_results"] = search(
|
| 43 |
queries=[query],
|
| 44 |
+
pipeline=st.session_state["search_pipeline"],
|
| 45 |
)
|
| 46 |
+
if "search_results" in st.session_state:
|
| 47 |
component_show_search_result(
|
| 48 |
+
container=container, results=st.session_state["search_results"][0]
|
|
|
|
| 49 |
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def page_index(container):
|
| 53 |
with container:
|
| 54 |
st.title("Index time!")
|
| 55 |
+
|
| 56 |
+
component_show_pipeline(container, st.session_state["index_pipeline"])
|
| 57 |
+
|
| 58 |
input_funcs = {
|
| 59 |
"Raw Text": (component_text_input, "card-text"),
|
| 60 |
}
|
|
|
|
| 66 |
default_index=0,
|
| 67 |
orientation="horizontal",
|
| 68 |
)
|
| 69 |
+
|
| 70 |
corpus = input_funcs[selected_input][0](container)
|
| 71 |
+
|
| 72 |
if len(corpus) > 0:
|
| 73 |
index_results = None
|
| 74 |
if st.button("Index"):
|
| 75 |
index_results = index(
|
| 76 |
corpus,
|
| 77 |
+
st.session_state["index_pipeline"],
|
| 78 |
)
|
| 79 |
if index_results:
|
| 80 |
+
st.write(index_results)
|
linter.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python -m black app.py interface core
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
streamlit
|
| 2 |
streamlit_option_menu
|
| 3 |
farm-haystack
|
| 4 |
-
pygraphviz
|
|
|
|
|
|
| 1 |
streamlit
|
| 2 |
streamlit_option_menu
|
| 3 |
farm-haystack
|
| 4 |
+
pygraphviz
|
| 5 |
+
black
|