bibibi12345 commited on
Commit
5f366e1
·
1 Parent(s): 35e0c5e

frontend vertex key support

Browse files
Files changed (2) hide show
  1. index.html +7 -7
  2. main.py +58 -19
index.html CHANGED
@@ -173,8 +173,8 @@
173
  <h1>🎨 Gemini Image Generator</h1>
174
 
175
  <div class="input-group">
176
- <label for="apiKey">API Key:</label>
177
- <input type="text" id="apiKey" placeholder="Enter your API key">
178
  </div>
179
 
180
  <div class="input-group">
@@ -299,11 +299,11 @@
299
  }
300
 
301
  async function generateImage() {
302
- const apiKey = document.getElementById('apiKey').value.trim();
303
  const prompt = document.getElementById('prompt').value.trim();
304
 
305
- if (!apiKey) {
306
- showError('Please enter your API key');
307
  return;
308
  }
309
 
@@ -346,11 +346,11 @@
346
  return;
347
  }
348
 
349
- const response = await fetch('/v1beta/models/gemini-2.5-flash-image-preview:generateContent', {
350
  method: 'POST',
351
  headers: {
352
  'Content-Type': 'application/json',
353
- 'x-goog-api-key': apiKey
354
  },
355
  body: JSON.stringify({
356
  contents: [{
 
173
  <h1>🎨 Gemini Image Generator</h1>
174
 
175
  <div class="input-group">
176
+ <label for="vertexKey">Vertex Express Key:</label>
177
+ <input type="text" id="vertexKey" placeholder="Enter your Vertex Express key">
178
  </div>
179
 
180
  <div class="input-group">
 
299
  }
300
 
301
  async function generateImage() {
302
+ const vertexKey = document.getElementById('vertexKey').value.trim();
303
  const prompt = document.getElementById('prompt').value.trim();
304
 
305
+ if (!vertexKey) {
306
+ showError('Please enter your Vertex Express key');
307
  return;
308
  }
309
 
 
346
  return;
347
  }
348
 
349
+ const response = await fetch('/frontend/v1beta/models/gemini-2.5-flash-image-preview:generateContent', {
350
  method: 'POST',
351
  headers: {
352
  'Content-Type': 'application/json',
353
+ 'x-vertex-express-key': vertexKey
354
  },
355
  body: JSON.stringify({
356
  contents: [{
main.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import re
3
  import httpx
4
- from fastapi import FastAPI, Request, HTTPException, Security
5
  from fastapi.responses import StreamingResponse, Response, FileResponse
6
  from fastapi.security import APIKeyHeader, APIKeyQuery
 
7
  from fastapi.staticfiles import StaticFiles
8
  from itertools import cycle
9
  import asyncio
@@ -23,6 +24,7 @@ app = FastAPI()
23
  project_id_cache = {}
24
  key_rotator = cycle(VERTEX_EXPRESS_KEYS)
25
  key_lock = asyncio.Lock()
 
26
 
27
  # --- API Key Security ---
28
  api_key_query = APIKeyQuery(name="key", auto_error=False)
@@ -82,33 +84,35 @@ async def gif_worker():
82
  worker_path = current_dir / "gif.worker.js"
83
  return FileResponse(worker_path, media_type="application/javascript")
84
 
85
- # --- Proxy Endpoint ---
86
- @app.post("/v1beta/models/{model_path:path}")
87
- async def proxy(request: Request, model_path: str, api_key: str = Security(get_api_key)):
88
- async with key_lock:
89
- express_key = next(key_rotator)
90
-
91
- project_id = await get_project_id(express_key)
92
-
93
  raw_request_body = await request.body()
94
  request_body_to_send = raw_request_body
95
 
96
  try:
97
  request_json = json.loads(raw_request_body)
98
  if "gemini-2.0-flash-exp-image-generation" in model_path:
99
- model_path = model_path.replace("gemini-2.0-flash-exp-image-generation", "gemini-2.5-flash-image-preview")
100
 
101
  if "generationConfig" not in request_json:
102
  request_json["generationConfig"] = {}
103
 
104
  # Model-specific request body modification
105
  if "gemini-2.5-flash-image-preview" in model_path:
106
- if "generationConfig" in request_json and "thinkingConfig" in request_json.get("generationConfig", {}):
107
- del request_json["generationConfig"]["thinkingConfig"]
108
- if "generationConfig" in request_json and "responseMimeType" in request_json.get("generationConfig", {}):
109
- del request_json["generationConfig"]["responseMimeType"]
110
- request_json["generationConfig"]
111
- request_json["generationConfig"]["responseModalities"] = ["TEXT", "IMAGE"]
 
 
 
 
 
112
 
113
  request_json["safetySettings"] = [
114
  {
@@ -146,7 +150,7 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
146
  ]
147
  request_body_to_send = json.dumps(request_json).encode('utf-8')
148
  except json.JSONDecodeError:
149
- pass # Not a json body, proxy as is
150
 
151
  target_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers/google/models/{model_path}?key={express_key}"
152
 
@@ -154,7 +158,7 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
154
 
155
  headers_to_proxy = {
156
  k: v for k, v in request.headers.items()
157
- if k.lower() not in ['host', 'authorization', 'x-goog-api-key', 'content-length']
158
  }
159
 
160
  print(request_body_to_send)
@@ -208,12 +212,47 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
208
  return Response(
209
  content=modified_response_data,
210
  status_code=response.status_code,
211
- headers={"content-type":response.headers.get("content-type")},
212
  )
213
  finally:
214
  await response.aclose()
215
  await client.aclose()
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  if __name__ == "__main__":
218
  import uvicorn
219
  # Hugging Face Spaces run on port 7860
 
1
  import os
2
  import re
3
  import httpx
4
+ from fastapi import FastAPI, Request, HTTPException, Security, Header
5
  from fastapi.responses import StreamingResponse, Response, FileResponse
6
  from fastapi.security import APIKeyHeader, APIKeyQuery
7
+ import logging
8
  from fastapi.staticfiles import StaticFiles
9
  from itertools import cycle
10
  import asyncio
 
24
  project_id_cache = {}
25
  key_rotator = cycle(VERTEX_EXPRESS_KEYS)
26
  key_lock = asyncio.Lock()
27
+ logger = logging.getLogger(__name__)
28
 
29
  # --- API Key Security ---
30
  api_key_query = APIKeyQuery(name="key", auto_error=False)
 
84
  worker_path = current_dir / "gif.worker.js"
85
  return FileResponse(worker_path, media_type="application/javascript")
86
 
87
+ # --- Shared Model Calling Logic ---
88
+ async def call_model(request: Request, model_path: str, express_key: str, project_id: str):
89
+ """
90
+ Shared function to handle model calling logic for both proxy endpoints.
91
+ """
 
 
 
92
  raw_request_body = await request.body()
93
  request_body_to_send = raw_request_body
94
 
95
  try:
96
  request_json = json.loads(raw_request_body)
97
  if "gemini-2.0-flash-exp-image-generation" in model_path:
98
+ model_path = model_path.replace("gemini-2.0-flash-exp-image-generation", "gemini-2.5-flash-image-preview")
99
 
100
  if "generationConfig" not in request_json:
101
  request_json["generationConfig"] = {}
102
 
103
  # Model-specific request body modification
104
  if "gemini-2.5-flash-image-preview" in model_path:
105
+ if "generationConfig" in request_json and "thinkingConfig" in request_json.get("generationConfig", {}):
106
+ del request_json["generationConfig"]["thinkingConfig"]
107
+ if "generationConfig" in request_json and "responseMimeType" in request_json.get("generationConfig", {}):
108
+ del request_json["generationConfig"]["responseMimeType"]
109
+ request_json["generationConfig"]["responseModalities"] = ["TEXT", "IMAGE"]
110
+
111
+ # Ensure contents have role field
112
+ if "contents" in request_json:
113
+ for content in request_json["contents"]:
114
+ if "role" not in content:
115
+ content["role"] = "user"
116
 
117
  request_json["safetySettings"] = [
118
  {
 
150
  ]
151
  request_body_to_send = json.dumps(request_json).encode('utf-8')
152
  except json.JSONDecodeError:
153
+ pass # Not a json body, proxy as is
154
 
155
  target_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers/google/models/{model_path}?key={express_key}"
156
 
 
158
 
159
  headers_to_proxy = {
160
  k: v for k, v in request.headers.items()
161
+ if k.lower() not in ['host', 'authorization', 'x-goog-api-key', 'x-vertex-express-key', 'content-length']
162
  }
163
 
164
  print(request_body_to_send)
 
212
  return Response(
213
  content=modified_response_data,
214
  status_code=response.status_code,
215
+ headers={"content-type": response.headers.get("content-type")},
216
  )
217
  finally:
218
  await response.aclose()
219
  await client.aclose()
220
 
221
+ # --- Frontend-specific endpoint (no authentication required) ---
222
+ @app.post("/frontend/v1beta/models/{model_name}:{function_name}")
223
+ async def frontend_proxy(
224
+ model_name: str,
225
+ function_name: str,
226
+ request: Request,
227
+ vertex_express_key: str = Header(..., alias="x-vertex-express-key")
228
+ ):
229
+ """
230
+ Frontend-specific proxy endpoint that only requires a Vertex Express key.
231
+ No proxy authentication needed.
232
+ """
233
+ try:
234
+ # Get or extract project ID for this key
235
+ project_id = await get_project_id(vertex_express_key)
236
+
237
+ # Use shared model calling logic
238
+ model_path = f"{model_name}:{function_name}"
239
+ return await call_model(request, model_path, vertex_express_key, project_id)
240
+
241
+ except Exception as e:
242
+ logger.error(f"Frontend proxy error: {str(e)}")
243
+ raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}")
244
+
245
+ # --- Proxy Endpoint ---
246
+ @app.post("/v1beta/models/{model_path:path}")
247
+ async def proxy(request: Request, model_path: str, _: str = Security(get_api_key)):
248
+ async with key_lock:
249
+ express_key = next(key_rotator)
250
+
251
+ project_id = await get_project_id(express_key)
252
+
253
+ # Use shared model calling logic
254
+ return await call_model(request, model_path, express_key, project_id)
255
+
256
  if __name__ == "__main__":
257
  import uvicorn
258
  # Hugging Face Spaces run on port 7860