HAL1993 commited on
Commit
a762308
·
verified ·
1 Parent(s): d5dea30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -16,7 +16,6 @@ torch.backends.cudnn.benchmark = True
16
 
17
  # --- Constants ---
18
  MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
  DEFAULT_WIDTH = 1024
21
  DEFAULT_HEIGHT = 576
22
  ASPECT_RATIOS = {
@@ -30,15 +29,18 @@ INFERENCE_STEPS = 8
30
  dtype = torch.float16
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
 
33
  pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
34
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
35
  pipe.to(device)
 
36
 
37
- # --- Load Multilingual Translator (M2M100) ---
38
  try:
 
39
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
40
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
41
- print("✅ Loaded M2M100 Albanian-English translator")
42
  except Exception as e:
43
  print(f"❌ Failed to load M2M100: {e}")
44
  tokenizer = None
@@ -46,12 +48,14 @@ except Exception as e:
46
 
47
  def translate_sq_to_en(text):
48
  if not tokenizer or not model:
 
49
  return text
50
  try:
51
  tokenizer.src_lang = "sq"
52
  encoded = tokenizer(text, return_tensors="pt").to(device)
53
  generated = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id("en"))
54
  translated = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
 
55
  return translated
56
  except Exception as e:
57
  print(f"❌ Translation failed: {e}")
@@ -60,22 +64,22 @@ def translate_sq_to_en(text):
60
  def is_albanian(text):
61
  try:
62
  lang = detect(text)
 
63
  return lang == "sq"
64
  except Exception as e:
65
  print(f"⚠️ Language detection failed: {e}")
66
  return False
67
 
68
- # --- Inference Function ---
69
  @spaces.GPU
70
  def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", randomize_seed: bool = False):
71
  if pipe is None:
72
  raise gr.Error("Pipeline nuk u ngarkua.")
73
 
74
- if not prompt or prompt.strip() == "":
 
75
  return None, seed, "Gabim: Plotësoni përshkrimin."
76
 
77
- # Enforce Albanian only input
78
- if not is_albanian(prompt.strip()):
79
  return None, seed, "Ju lutemi shkruani vetëm në gjuhën shqipe."
80
 
81
  if randomize_seed:
@@ -84,16 +88,18 @@ def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", rand
84
  width, height = ASPECT_RATIOS.get(aspect_ratio, (DEFAULT_WIDTH, DEFAULT_HEIGHT))
85
 
86
  # Translate Albanian prompt to English
87
- prompt_final = translate_sq_to_en(prompt.strip())
88
- print(f"🌐 Translated prompt: {prompt_final}")
89
 
 
90
  prompt_final += ", ultra realistic, sharp, 8k resolution"
91
 
 
 
92
  try:
93
- generator = torch.Generator(device=device).manual_seed(int(float(seed)))
94
  start_time = time.time()
95
  with torch.inference_mode():
96
- image = pipe(
97
  prompt=prompt_final,
98
  width=width,
99
  height=height,
@@ -101,7 +107,8 @@ def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", rand
101
  generator=generator,
102
  output_type="pil",
103
  return_dict=False
104
- )[0][0]
 
105
  latency = time.time() - start_time
106
  status = f"Koha e përpunimit: {latency:.2f} sekonda"
107
  return image, seed, status
@@ -140,7 +147,7 @@ button[aria-label="Download"] {
140
  with gr.Column(scale=2):
141
  output_image = gr.Image(label="Imazhi i Gjeneruar", interactive=False, show_download_button=True)
142
  with gr.Column(scale=1):
143
- prompt = gr.Text(label="Përshkrimi (vetëm në shqip)", placeholder="Shkruani vetëm në gjuhën shqipe...", lines=3)
144
  generate_btn = gr.Button("🎨 Gjenero")
145
  aspect_ratio = gr.Radio(label="Raporti i Imazhit", choices=list(ASPECT_RATIOS.keys()), value="16:9")
146
  randomize_seed = gr.Checkbox(label="Përdor numër të rastësishëm", value=True)
@@ -150,7 +157,7 @@ button[aria-label="Download"] {
150
  gr.Examples(
151
  examples=examples,
152
  fn=generate_image,
153
- inputs=[prompt],
154
  outputs=[output_image, gr.Number(visible=False), latency],
155
  cache_examples=True,
156
  cache_mode="eager"
@@ -158,7 +165,7 @@ button[aria-label="Download"] {
158
 
159
  generate_btn.click(
160
  fn=generate_image,
161
- inputs=[prompt, gr.Number(value=42, visible=False), aspect_ratio, randomize_seed],
162
  outputs=[output_image, gr.Number(visible=False), latency],
163
  show_progress="full"
164
  )
 
16
 
17
  # --- Constants ---
18
  MAX_SEED = np.iinfo(np.int32).max
 
19
  DEFAULT_WIDTH = 1024
20
  DEFAULT_HEIGHT = 576
21
  ASPECT_RATIOS = {
 
29
  dtype = torch.float16
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ print("⏳ Loading Flux pipeline...")
33
  pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
34
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
35
  pipe.to(device)
36
+ print("✅ Flux pipeline loaded.")
37
 
38
+ # --- Load M2M100 Translator ---
39
  try:
40
+ print("⏳ Loading M2M100 tokenizer and model...")
41
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
42
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(device)
43
+ print("✅ M2M100 loaded.")
44
  except Exception as e:
45
  print(f"❌ Failed to load M2M100: {e}")
46
  tokenizer = None
 
48
 
49
  def translate_sq_to_en(text):
50
  if not tokenizer or not model:
51
+ print("⚠️ Translator not loaded, returning original text")
52
  return text
53
  try:
54
  tokenizer.src_lang = "sq"
55
  encoded = tokenizer(text, return_tensors="pt").to(device)
56
  generated = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id("en"))
57
  translated = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
58
+ print(f"🌐 Translation successful: {translated}")
59
  return translated
60
  except Exception as e:
61
  print(f"❌ Translation failed: {e}")
 
64
  def is_albanian(text):
65
  try:
66
  lang = detect(text)
67
+ print(f"🕵️ Language detected: {lang}")
68
  return lang == "sq"
69
  except Exception as e:
70
  print(f"⚠️ Language detection failed: {e}")
71
  return False
72
 
 
73
  @spaces.GPU
74
  def generate_image(prompt: str, seed: int = 42, aspect_ratio: str = "16:9", randomize_seed: bool = False):
75
  if pipe is None:
76
  raise gr.Error("Pipeline nuk u ngarkua.")
77
 
78
+ prompt_clean = prompt.strip()
79
+ if not prompt_clean:
80
  return None, seed, "Gabim: Plotësoni përshkrimin."
81
 
82
+ if not is_albanian(prompt_clean):
 
83
  return None, seed, "Ju lutemi shkruani vetëm në gjuhën shqipe."
84
 
85
  if randomize_seed:
 
88
  width, height = ASPECT_RATIOS.get(aspect_ratio, (DEFAULT_WIDTH, DEFAULT_HEIGHT))
89
 
90
  # Translate Albanian prompt to English
91
+ prompt_final = translate_sq_to_en(prompt_clean)
 
92
 
93
+ # Add quality tags for generation
94
  prompt_final += ", ultra realistic, sharp, 8k resolution"
95
 
96
+ print(f"🎯 Final prompt for generation: {prompt_final}")
97
+
98
  try:
99
+ generator = torch.Generator(device=device).manual_seed(int(seed))
100
  start_time = time.time()
101
  with torch.inference_mode():
102
+ images = pipe(
103
  prompt=prompt_final,
104
  width=width,
105
  height=height,
 
107
  generator=generator,
108
  output_type="pil",
109
  return_dict=False
110
+ )
111
+ image = images[0][0]
112
  latency = time.time() - start_time
113
  status = f"Koha e përpunimit: {latency:.2f} sekonda"
114
  return image, seed, status
 
147
  with gr.Column(scale=2):
148
  output_image = gr.Image(label="Imazhi i Gjeneruar", interactive=False, show_download_button=True)
149
  with gr.Column(scale=1):
150
+ prompt_input = gr.Text(label="Përshkrimi (vetëm në shqip)", placeholder="Shkruani vetëm në gjuhën shqipe...", lines=3)
151
  generate_btn = gr.Button("🎨 Gjenero")
152
  aspect_ratio = gr.Radio(label="Raporti i Imazhit", choices=list(ASPECT_RATIOS.keys()), value="16:9")
153
  randomize_seed = gr.Checkbox(label="Përdor numër të rastësishëm", value=True)
 
157
  gr.Examples(
158
  examples=examples,
159
  fn=generate_image,
160
+ inputs=[prompt_input],
161
  outputs=[output_image, gr.Number(visible=False), latency],
162
  cache_examples=True,
163
  cache_mode="eager"
 
165
 
166
  generate_btn.click(
167
  fn=generate_image,
168
+ inputs=[prompt_input, gr.Number(value=42, visible=False), aspect_ratio, randomize_seed],
169
  outputs=[output_image, gr.Number(visible=False), latency],
170
  show_progress="full"
171
  )