# app.py — Startup recommender + Unlike + AI name (optional tagline/description) import os, re, numpy as np, pandas as pd from pathlib import Path import gradio as gr import torch, faiss from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ---------- Paths / artifacts ---------- OUT_DIR = Path("./emb_index_e5") FAISS_PATH = OUT_DIR / "faiss.index" DATA_PATH = OUT_DIR / "data.parquet" assert FAISS_PATH.exists(), f"Missing {FAISS_PATH}. Build & upload embeddings/index." assert DATA_PATH.exists(), f"Missing {DATA_PATH}. Build & upload data parquet." # ---------- Devices ---------- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" DEVICE_EMBED = "cuda" if torch.cuda.is_available() else "cpu" # e5 on GPU if available DEVICE_GEN = "cpu" # FLAN on CPU (avoid OOM) print(f"Embed device: {DEVICE_EMBED} | Gen device: {DEVICE_GEN}") # ---------- Load artifacts ---------- index = faiss.read_index(str(FAISS_PATH)) df_local = pd.read_parquet(DATA_PATH) for c in ["name","tagline","description"]: if c in df_local.columns: df_local[c] = df_local[c].astype(str).fillna("") # ---------- Load models ---------- EMBED_MODEL = "intfloat/e5-base-v2" embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE_EMBED) MODEL_BASE = "google/flan-t5-base" MODEL_LARGE = "google/flan-t5-large" USE_LARGE_FOR_DESCRIPTION = False # keep False on Spaces unless you switch GEN to "cuda" tok_base = AutoTokenizer.from_pretrained(MODEL_BASE) base_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {} mod_base = AutoModelForSeq2SeqLM.from_pretrained(MODEL_BASE, **base_kwargs).to(DEVICE_GEN) if USE_LARGE_FOR_DESCRIPTION: tok_large = AutoTokenizer.from_pretrained(MODEL_LARGE) large_kwargs = {"torch_dtype": torch.float16} if DEVICE_GEN == "cuda" else {} mod_large = AutoModelForSeq2SeqLM.from_pretrained(MODEL_LARGE, **large_kwargs).to(DEVICE_GEN) else: tok_large, mod_large = tok_base, mod_base # ---------- Helpers (embedding + generation) ---------- def _generate_text(model, tokenizer, prompt, max_new_tokens=30, temperature=0.9, top_p=0.95, num_return_sequences=1): inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_GEN) outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, num_return_sequences=num_return_sequences ) return [tokenizer.decode(o, skip_special_tokens=True).strip() for o in outputs] def _embed_query(q: str) -> np.ndarray: return embed_model.encode([f"query: {q}"], convert_to_numpy=True, normalize_embeddings=True).astype("float32")[0] def _embed_passages(texts) -> np.ndarray: texts = [f"passage: {t}" for t in texts] return embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True).astype("float32") # ---------- Search with per-session unlikes ---------- def search_topk_filtered_session(query: str, k: int, unliked_ids: set): qv = _embed_query(query) fetch = min(index.ntotal, max(k * 20, 50, k + len(unliked_ids))) scores, inds = index.search(qv[None, :], fetch) inds = inds[0].tolist(); scores = scores[0].tolist() res = df_local.iloc[inds][["name","tagline","description"]].copy() res.insert(0, "row_idx", df_local.iloc[inds].index) res.insert(1, "score", [float(s) for s in scores]) res = res[~res["row_idx"].isin(unliked_ids)].head(k).reset_index(drop=True) res.insert(0, "rank", range(1, len(res)+1)) return res # ---------- Synthetic generation (length-aware) ---------- _STOPWORDS = { "the","a","an","for","and","or","to","of","in","on","with","by","from", "my","our","your","their","at","as","about","into","over","under","this","that", "idea","startup","company","product","service","app","platform","factory","labs","tech" } def _words(s: str): return [w for w in re.findall(r"[a-z]+", str(s).lower()) if w] def _content_words(s: str): return [w for w in _words(s) if len(w) >= 3 and w not in _STOPWORDS] def _normalize_name(s: str) -> str: return re.sub(r"[^a-z0-9]+", "", str(s).lower()) def _has_vowel(s: str) -> bool: return bool(re.search(r"[aeiou]", str(s).lower())) def _overlap_ratio(name_tokens, banned): if not name_tokens or not banned: return 0.0 inter = len(set(name_tokens) & set(banned)); union = len(set(name_tokens) | set(banned)) return inter / max(union, 1) NAME_CHAR_TARGET, NAME_CHAR_TOL = 12, 3 NAME_WORDS_MIN, NAME_WORDS_MAX = 1, 3 def _len_ok(text: str, target_chars: int, tol: int, min_words: int, max_words: int): c = len(text); w = len(text.split()) return (target_chars - tol) <= c <= (target_chars + tol) and (min_words <= w <= max_words) def _theme_hints(query: str, k: int = 6): kws = _content_words(query); seen, hints = set(), [] for t in kws: if t not in seen: hints.append(t); seen.add(t) return ", ".join(hints[:k]) if hints else "education, learning, students, AI" def generate_names(base_idea: str, n: int = 10, oversample: int = 80, max_retries: int = 3): banned = sorted(set(_content_words(base_idea))) avoid_str = ", ".join(banned[:12]) if banned else "previous words" hints = _theme_hints(base_idea) all_candidates = [] def _prompt(osz): return ( f"Create {osz} brandable startup names for this idea:\n" f"\"{base_idea}\"\n\n" f"Guidance:\n" f"- Evoke these themes (without literally using the words): {hints}\n" f"- 1 or 2 words; aim ~{NAME_CHAR_TARGET} characters (±{NAME_CHAR_TOL})\n" f"- Portmanteau/blends welcome (e.g., Coursera, Udacity, Grammarly)\n" f"- Do NOT use: {avoid_str}\n" f"- Avoid generic phrases (e.g., 'Plastic Bottles', 'Online Store')\n" f"- Output one name per line; no numbering, no quotes." ) for attempt in range(max_retries): raw = _generate_text(mod_base, tok_base, _prompt(oversample), num_return_sequences=1, max_new_tokens=240, temperature=1.0 + 0.05*attempt, top_p=0.95)[0] # collect for line in raw.splitlines(): nm = line.strip().lstrip("-•*0123456789. ").strip() if nm: nm = re.sub(r"[^\w\s-]+$", "", nm).strip() all_candidates.append(nm) # dedup uniq, seen = [], set() for nm in all_candidates: key = _normalize_name(nm) if key and key not in seen: seen.add(key); uniq.append(nm) all_candidates = uniq # progressive filter def ok(nm: str, overlap_cap: float, tol_boost: int): if not _has_vowel(nm): return False if not _len_ok(nm, NAME_CHAR_TARGET, NAME_CHAR_TOL+tol_boost, NAME_WORDS_MIN, NAME_WORDS_MAX): return False toks = _content_words(nm) if _overlap_ratio(toks, banned) > overlap_cap: return False if " ".join(toks) in {"plastic bottles","bottles plastic"}: return False return True overlap_caps = [0.25, 0.35, 0.5]; tol_boosts = [0, 1, 2] filtered = [nm for nm in all_candidates if ok(nm, overlap_caps[min(attempt,2)], tol_boosts[min(attempt,2)])] if len(filtered) >= n: return filtered[:n] return all_candidates[:n] if all_candidates else [] # Tagline/description length targets (from your EDA) TAG_CHAR_TARGET, TAG_CHAR_TOL = 40, 6 TAG_WORD_TARGET, TAG_WORD_TOL = 6, 2 DESC_CHAR_MIN, DESC_CHAR_MAX = 170, 230 DESC_WORD_MIN, DESC_WORD_MAX = 27, 35 def _trim_to_words(text: str, max_words: int) -> str: toks = text.split() return text.strip() if len(toks) <= max_words else " ".join(toks[:max_words]).rstrip(",;:") + "." def _snap_sentence_boundary(text: str, min_chars: int, max_chars: int): text = text.strip() if len(text) <= max_chars and len(text) >= min_chars: return text cutoff = min(max_chars, len(text)); candidate = text[:cutoff] m = re.search(r"[\.!\?](?!.*[\.!\?])", candidate) if m and (len(candidate[:m.end()].strip()) >= min_chars): return candidate[:m.end()].strip() return candidate.rstrip(",;: ").strip() + ("." if not candidate.endswith((".", "!", "?")) else "") def _within_ranges(text: str, cmin: int, cmax: int, wmin: int, wmax: int) -> bool: c = len(text); w = len(text.split()); return (cmin <= c <= cmax) and (wmin <= w <= wmax) def generate_tagline_and_desc(name: str, query_context: str): tag_prompt = ( f"Write a short, benefit-driven tagline for a startup called '{name}'. " f"Audience & domain: {query_context}. " f"Target ~{TAG_CHAR_TARGET} characters and ~{TAG_WORD_TARGET} words. Avoid clichés." ) tagline = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=28, temperature=0.9, top_p=0.95)[0] tagline = re.sub(r"\s+", " ", tagline).strip() tagline = _trim_to_words(tagline, TAG_WORD_TARGET + TAG_WORD_TOL) if len(tagline) > TAG_CHAR_TARGET + TAG_CHAR_TOL: tagline = tagline[:TAG_CHAR_TARGET + TAG_CHAR_TOL].rstrip(",;: -") + "…" if not _within_ranges(tagline, TAG_CHAR_TARGET - TAG_CHAR_TOL, TAG_CHAR_TARGET + TAG_CHAR_TOL, TAG_WORD_TARGET - TAG_WORD_TOL, TAG_WORD_TARGET + TAG_WORD_TOL): tagline2 = _generate_text(mod_base, tok_base, tag_prompt, max_new_tokens=30, temperature=1.0, top_p=0.9)[0] tagline2 = _trim_to_words(re.sub(r"\s+", " ", tagline2).strip(), TAG_WORD_TARGET + TAG_WORD_TOL) if abs(len(tagline2) - TAG_CHAR_TARGET) < abs(len(tagline) - TAG_CHAR_TARGET): tagline = tagline2 desc_prompt = ( f"Write a concise product description for the startup '{name}'. " f"Context: {query_context}. " f"Explain who it's for, what it does, and the main benefit. " f"Target {DESC_CHAR_MIN}–{DESC_CHAR_MAX} characters and {DESC_WORD_MIN}–{DESC_WORD_MAX} words. " f"Avoid fluff; keep it clear." ) model, tok = (mod_large, tok_large) if USE_LARGE_FOR_DESCRIPTION else (mod_base, tok_base) description = _generate_text(model, tok, desc_prompt, max_new_tokens=110, temperature=1.05, top_p=0.95)[0] description = re.sub(r"\s+", " ", description).strip() if len(description.split()) > DESC_WORD_MAX: description = _trim_to_words(description, DESC_WORD_MAX) description = _snap_sentence_boundary(description, DESC_CHAR_MIN, DESC_CHAR_MAX) if not _within_ranges(description, DESC_CHAR_MIN, DESC_CHAR_MAX, DESC_WORD_MIN, DESC_WORD_MAX): description2 = _generate_text(model, tok, desc_prompt, max_new_tokens=120, temperature=1.05, top_p=0.9)[0] description2 = re.sub(r"\s+", " ", description2).strip() if len(description2.split()) > DESC_WORD_MAX: description2 = _trim_to_words(description2, DESC_WORD_MAX) description2 = _snap_sentence_boundary(description2, DESC_CHAR_MIN, DESC_CHAR_MAX) target_mid = (DESC_CHAR_MIN + DESC_CHAR_MAX) / 2 if abs(len(description2) - target_mid) < abs(len(description) - target_mid): description = description2 return tagline, description def pick_best_synthetic_name(query: str, n_candidates: int = 10, include_copy=False): names = generate_names(query, n=n_candidates, oversample=max(80, 8*n_candidates), max_retries=3) if len(names) == 0: names = generate_names(query, n=n_candidates, oversample=140, max_retries=1) if len(names) == 0: toks = _content_words(query) or ["nova","learn","edu","mento"] seeds = list({t[:4]+"ify" for t in toks} | {t[:3]+"ora" for t in toks} | {t[:4]+"io" for t in toks}) names = seeds[:n_candidates] qv = _embed_query(query); embs = _embed_passages(names); cos = embs @ qv banned = sorted(set(_content_words(query))) final_scores = [] for nm, s in zip(names, cos): toks = _content_words(nm); overlap = _overlap_ratio(toks, banned) length_pen = 0.0; L = len(_normalize_name(nm)) if L < 4: length_pen += 0.3 if L > 16: length_pen += 0.2 final_scores.append(float(s) - 0.35*overlap - length_pen) best_idx = int(np.argmax(final_scores)); best_name = names[best_idx]; best_score = float(final_scores[best_idx]) tagline, description = ("","") if include_copy: tagline, description = generate_tagline_and_desc(best_name, query_context=query) row = pd.DataFrame([{"rank":4,"score":best_score,"name":best_name,"tagline":tagline,"description":description}]) return row # ---------- UI glue ---------- EXAMPLES = [ "AI tool to analyze customer feedback", "Social network for jobs", "Mobile fintech app for cross-border payments", "AI learning tool for students", "Marketplace for eco-friendly products", ] def ui_search(query, state_unlikes): query = (query or "").strip() if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Please enter a short idea." state_unlikes = [] # reset for new query res = search_topk_filtered_session(query, k=3, unliked_ids=set()) return res, state_unlikes, "Found 3 similar items. You can unlike by row_idx, then Refresh." def ui_unlike(query, unlike_ids_csv, state_unlikes): query = (query or "").strip() if not query: return gr.update(value=pd.DataFrame()), state_unlikes, "Enter a query first." add_ids = set() for tok in (unlike_ids_csv or "").split(","): tok = tok.strip() if tok.isdigit(): add_ids.add(int(tok)) cur = set(state_unlikes) | add_ids res = search_topk_filtered_session(query, k=3, unliked_ids=cur) return res, list(cur), f"Excluded {sorted(add_ids)}. Currently unliked: {sorted(cur)}" def ui_clear_unlikes(query): query = (query or "").strip() if not query: return gr.update(value=pd.DataFrame()), [], "Enter a query first." res = search_topk_filtered_session(query, k=3, unliked_ids=set()) return res, [], "Cleared unlikes." def ui_generate_synth(query, include_copy): query = (query or "").strip() if not query: return gr.update(value=pd.DataFrame()), "Enter a query first." synth = pick_best_synthetic_name(query, n_candidates=10, include_copy=include_copy) return synth, "Generated AI option as #4. Combine it with your top-3." def _apply_example(example_text, state_unlikes): results, state_unlikes, msg = ui_search(example_text, state_unlikes) return example_text, results, state_unlikes, f"Example selected: “{example_text}”. {msg}" with gr.Blocks(title="Startup Recommender + AI Name") as app: gr.Markdown("## Startup Recommender → Unlike → AI Name\nEnter a short idea. Get 3 similar startups, unlike what doesn’t fit, then generate an AI name (and optional tagline & description).") query = gr.Textbox(label="Your idea (short description)", placeholder="e.g., AI tool to analyze student essays and give feedback") with gr.Row(): gr.Markdown("**Try an example:**") example_buttons = [gr.Button(ex, variant="secondary") for ex in EXAMPLES] with gr.Row(): btn_search = gr.Button("Search Top-3") unlike_ids = gr.Textbox(label="Unlike by row_idx (comma-separated)", placeholder="e.g., 123, 456") btn_unlike = gr.Button("Refresh after Unlike") btn_clear = gr.Button("Clear Unlikes") results_tbl = gr.Dataframe(label="Top-3 Similar (after excludes)", interactive=False, wrap=True) gr.Markdown("### AI-Generated Option (#4)") include_copy = gr.Checkbox(label="Also generate tagline & description", value=True) btn_synth = gr.Button("Generate #4 (AI)") synth_tbl = gr.Dataframe(label="Synthetic #4", interactive=False, wrap=True) status = gr.Markdown("") state_unlikes = gr.State([]) # wiring btn_search.click(ui_search, inputs=[query, state_unlikes], outputs=[results_tbl, state_unlikes, status]) btn_unlike.click(ui_unlike, inputs=[query, unlike_ids, state_unlikes], outputs=[results_tbl, state_unlikes, status]) btn_clear.click(ui_clear_unlikes, inputs=[query], outputs=[results_tbl, state_unlikes, status]) for btn, ex in zip(example_buttons, EXAMPLES): btn.click(lambda st, ex_=ex: _apply_example(ex_, st), inputs=[state_unlikes], outputs=[query, results_tbl, state_unlikes, status]) btn_synth.click(ui_generate_synth, inputs=[query, include_copy], outputs=[synth_tbl, status]) # On Spaces, just calling launch() is fine; no explicit port. if __name__ == "__main__": app.launch()