HAL1993 commited on
Commit
d5dea30
·
verified ·
1 Parent(s): 2c4e21f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  from diffusers import DiffusionPipeline, AutoencoderTiny
9
  from custom_pipeline import FluxWithCFGPipeline
10
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
 
11
 
12
  # --- Torch Optimizations ---
13
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -33,13 +34,13 @@ pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", t
33
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
34
  pipe.to(device)
35
 
36
- # --- Load M2M100 Albanian-English Translator ---
37
  try:
38
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
39
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
40
  print("✅ Loaded M2M100 Albanian-English translator")
41
  except Exception as e:
42
- print(f"❌ Failed to load M2M100 translator: {e}")
43
  tokenizer = None
44
  model = None
45
 
@@ -48,37 +49,43 @@ def translate_sq_to_en(text):
48
  return text
49
  try:
50
  tokenizer.src_lang = "sq"
51
- encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
52
- generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id("en"))
53
- translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
54
  return translated
55
  except Exception as e:
56
  print(f"❌ Translation failed: {e}")
57
  return text
58
 
59
- # --- Main Inference Function ---
 
 
 
 
 
 
 
 
60
  @spaces.GPU
61
  def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", randomize_seed: bool = False):
62
  if pipe is None:
63
- raise gr.Error("Pipelinei nuk u ngarkua.")
64
 
65
  if not prompt or prompt.strip() == "":
66
  return None, seed, "Gabim: Plotësoni përshkrimin."
67
 
 
 
 
 
68
  if randomize_seed:
69
  seed = random.randint(0, MAX_SEED)
70
 
71
  width, height = ASPECT_RATIOS.get(aspect_ratio, (DEFAULT_WIDTH, DEFAULT_HEIGHT))
72
 
73
- prompt_final = prompt.strip()
74
-
75
- # Detect if prompt is probably Albanian by common Albanian letters or phrase starts
76
- if any(c in prompt_final for c in "ëçËÇ") or prompt_final.lower().startswith("një "):
77
- print(f"🌐 Detected likely Albanian. Translating prompt: {prompt_final}")
78
- prompt_final = translate_sq_to_en(prompt_final)
79
- print(f"✅ Translated prompt: {prompt_final}")
80
- else:
81
- print("🟢 Prompt seems English. Skipping translation.")
82
 
83
  prompt_final += ", ultra realistic, sharp, 8k resolution"
84
 
@@ -101,7 +108,7 @@ def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", rand
101
  except Exception as e:
102
  if torch.cuda.is_available():
103
  torch.cuda.empty_cache()
104
- raise gr.Error(f"Gabim gjatë gjenerimit: {e}")
105
 
106
  # --- Examples ---
107
  examples = [
@@ -133,7 +140,7 @@ button[aria-label="Download"] {
133
  with gr.Column(scale=2):
134
  output_image = gr.Image(label="Imazhi i Gjeneruar", interactive=False, show_download_button=True)
135
  with gr.Column(scale=1):
136
- prompt = gr.Text(label="Përshkrimi (Shqip ose Anglisht)", placeholder="Shkruani se çfarë doni të krijoni...", lines=3)
137
  generate_btn = gr.Button("🎨 Gjenero")
138
  aspect_ratio = gr.Radio(label="Raporti i Imazhit", choices=list(ASPECT_RATIOS.keys()), value="16:9")
139
  randomize_seed = gr.Checkbox(label="Përdor numër të rastësishëm", value=True)
 
8
  from diffusers import DiffusionPipeline, AutoencoderTiny
9
  from custom_pipeline import FluxWithCFGPipeline
10
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
11
+ from langdetect import detect
12
 
13
  # --- Torch Optimizations ---
14
  torch.backends.cuda.matmul.allow_tf32 = True
 
34
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
35
  pipe.to(device)
36
 
37
+ # --- Load Multilingual Translator (M2M100) ---
38
  try:
39
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
40
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
41
  print("✅ Loaded M2M100 Albanian-English translator")
42
  except Exception as e:
43
+ print(f"❌ Failed to load M2M100: {e}")
44
  tokenizer = None
45
  model = None
46
 
 
49
  return text
50
  try:
51
  tokenizer.src_lang = "sq"
52
+ encoded = tokenizer(text, return_tensors="pt").to(device)
53
+ generated = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id("en"))
54
+ translated = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
55
  return translated
56
  except Exception as e:
57
  print(f"❌ Translation failed: {e}")
58
  return text
59
 
60
+ def is_albanian(text):
61
+ try:
62
+ lang = detect(text)
63
+ return lang == "sq"
64
+ except Exception as e:
65
+ print(f"⚠️ Language detection failed: {e}")
66
+ return False
67
+
68
+ # --- Inference Function ---
69
  @spaces.GPU
70
  def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", randomize_seed: bool = False):
71
  if pipe is None:
72
+ raise gr.Error("Pipeline nuk u ngarkua.")
73
 
74
  if not prompt or prompt.strip() == "":
75
  return None, seed, "Gabim: Plotësoni përshkrimin."
76
 
77
+ # Enforce Albanian only input
78
+ if not is_albanian(prompt.strip()):
79
+ return None, seed, "Ju lutemi shkruani vetëm në gjuhën shqipe."
80
+
81
  if randomize_seed:
82
  seed = random.randint(0, MAX_SEED)
83
 
84
  width, height = ASPECT_RATIOS.get(aspect_ratio, (DEFAULT_WIDTH, DEFAULT_HEIGHT))
85
 
86
+ # Translate Albanian prompt to English
87
+ prompt_final = translate_sq_to_en(prompt.strip())
88
+ print(f"🌐 Translated prompt: {prompt_final}")
 
 
 
 
 
 
89
 
90
  prompt_final += ", ultra realistic, sharp, 8k resolution"
91
 
 
108
  except Exception as e:
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
111
+ raise gr.Error(f"Gabim gjatë gjenerimit: {e}")
112
 
113
  # --- Examples ---
114
  examples = [
 
140
  with gr.Column(scale=2):
141
  output_image = gr.Image(label="Imazhi i Gjeneruar", interactive=False, show_download_button=True)
142
  with gr.Column(scale=1):
143
+ prompt = gr.Text(label="Përshkrimi (vetëm shqip)", placeholder="Shkruani vetëm gjuhën shqipe...", lines=3)
144
  generate_btn = gr.Button("🎨 Gjenero")
145
  aspect_ratio = gr.Radio(label="Raporti i Imazhit", choices=list(ASPECT_RATIOS.keys()), value="16:9")
146
  randomize_seed = gr.Checkbox(label="Përdor numër të rastësishëm", value=True)