Spaces:
Paused
Paused
Commit
·
5f366e1
1
Parent(s):
35e0c5e
frontend vertex key support
Browse files- index.html +7 -7
- main.py +58 -19
index.html
CHANGED
|
@@ -173,8 +173,8 @@
|
|
| 173 |
<h1>🎨 Gemini Image Generator</h1>
|
| 174 |
|
| 175 |
<div class="input-group">
|
| 176 |
-
<label for="
|
| 177 |
-
<input type="text" id="
|
| 178 |
</div>
|
| 179 |
|
| 180 |
<div class="input-group">
|
|
@@ -299,11 +299,11 @@
|
|
| 299 |
}
|
| 300 |
|
| 301 |
async function generateImage() {
|
| 302 |
-
const
|
| 303 |
const prompt = document.getElementById('prompt').value.trim();
|
| 304 |
|
| 305 |
-
if (!
|
| 306 |
-
showError('Please enter your
|
| 307 |
return;
|
| 308 |
}
|
| 309 |
|
|
@@ -346,11 +346,11 @@
|
|
| 346 |
return;
|
| 347 |
}
|
| 348 |
|
| 349 |
-
const response = await fetch('/v1beta/models/gemini-2.5-flash-image-preview:generateContent', {
|
| 350 |
method: 'POST',
|
| 351 |
headers: {
|
| 352 |
'Content-Type': 'application/json',
|
| 353 |
-
'x-
|
| 354 |
},
|
| 355 |
body: JSON.stringify({
|
| 356 |
contents: [{
|
|
|
|
| 173 |
<h1>🎨 Gemini Image Generator</h1>
|
| 174 |
|
| 175 |
<div class="input-group">
|
| 176 |
+
<label for="vertexKey">Vertex Express Key:</label>
|
| 177 |
+
<input type="text" id="vertexKey" placeholder="Enter your Vertex Express key">
|
| 178 |
</div>
|
| 179 |
|
| 180 |
<div class="input-group">
|
|
|
|
| 299 |
}
|
| 300 |
|
| 301 |
async function generateImage() {
|
| 302 |
+
const vertexKey = document.getElementById('vertexKey').value.trim();
|
| 303 |
const prompt = document.getElementById('prompt').value.trim();
|
| 304 |
|
| 305 |
+
if (!vertexKey) {
|
| 306 |
+
showError('Please enter your Vertex Express key');
|
| 307 |
return;
|
| 308 |
}
|
| 309 |
|
|
|
|
| 346 |
return;
|
| 347 |
}
|
| 348 |
|
| 349 |
+
const response = await fetch('/frontend/v1beta/models/gemini-2.5-flash-image-preview:generateContent', {
|
| 350 |
method: 'POST',
|
| 351 |
headers: {
|
| 352 |
'Content-Type': 'application/json',
|
| 353 |
+
'x-vertex-express-key': vertexKey
|
| 354 |
},
|
| 355 |
body: JSON.stringify({
|
| 356 |
contents: [{
|
main.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import httpx
|
| 4 |
-
from fastapi import FastAPI, Request, HTTPException, Security
|
| 5 |
from fastapi.responses import StreamingResponse, Response, FileResponse
|
| 6 |
from fastapi.security import APIKeyHeader, APIKeyQuery
|
|
|
|
| 7 |
from fastapi.staticfiles import StaticFiles
|
| 8 |
from itertools import cycle
|
| 9 |
import asyncio
|
|
@@ -23,6 +24,7 @@ app = FastAPI()
|
|
| 23 |
project_id_cache = {}
|
| 24 |
key_rotator = cycle(VERTEX_EXPRESS_KEYS)
|
| 25 |
key_lock = asyncio.Lock()
|
|
|
|
| 26 |
|
| 27 |
# --- API Key Security ---
|
| 28 |
api_key_query = APIKeyQuery(name="key", auto_error=False)
|
|
@@ -82,33 +84,35 @@ async def gif_worker():
|
|
| 82 |
worker_path = current_dir / "gif.worker.js"
|
| 83 |
return FileResponse(worker_path, media_type="application/javascript")
|
| 84 |
|
| 85 |
-
# ---
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
project_id = await get_project_id(express_key)
|
| 92 |
-
|
| 93 |
raw_request_body = await request.body()
|
| 94 |
request_body_to_send = raw_request_body
|
| 95 |
|
| 96 |
try:
|
| 97 |
request_json = json.loads(raw_request_body)
|
| 98 |
if "gemini-2.0-flash-exp-image-generation" in model_path:
|
| 99 |
-
model_path =
|
| 100 |
|
| 101 |
if "generationConfig" not in request_json:
|
| 102 |
request_json["generationConfig"] = {}
|
| 103 |
|
| 104 |
# Model-specific request body modification
|
| 105 |
if "gemini-2.5-flash-image-preview" in model_path:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
request_json["safetySettings"] = [
|
| 114 |
{
|
|
@@ -146,7 +150,7 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
|
|
| 146 |
]
|
| 147 |
request_body_to_send = json.dumps(request_json).encode('utf-8')
|
| 148 |
except json.JSONDecodeError:
|
| 149 |
-
|
| 150 |
|
| 151 |
target_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers/google/models/{model_path}?key={express_key}"
|
| 152 |
|
|
@@ -154,7 +158,7 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
|
|
| 154 |
|
| 155 |
headers_to_proxy = {
|
| 156 |
k: v for k, v in request.headers.items()
|
| 157 |
-
if k.lower() not in ['host', 'authorization', 'x-goog-api-key', 'content-length']
|
| 158 |
}
|
| 159 |
|
| 160 |
print(request_body_to_send)
|
|
@@ -208,12 +212,47 @@ async def proxy(request: Request, model_path: str, api_key: str = Security(get_a
|
|
| 208 |
return Response(
|
| 209 |
content=modified_response_data,
|
| 210 |
status_code=response.status_code,
|
| 211 |
-
headers={"content-type":response.headers.get("content-type")},
|
| 212 |
)
|
| 213 |
finally:
|
| 214 |
await response.aclose()
|
| 215 |
await client.aclose()
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
if __name__ == "__main__":
|
| 218 |
import uvicorn
|
| 219 |
# Hugging Face Spaces run on port 7860
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import httpx
|
| 4 |
+
from fastapi import FastAPI, Request, HTTPException, Security, Header
|
| 5 |
from fastapi.responses import StreamingResponse, Response, FileResponse
|
| 6 |
from fastapi.security import APIKeyHeader, APIKeyQuery
|
| 7 |
+
import logging
|
| 8 |
from fastapi.staticfiles import StaticFiles
|
| 9 |
from itertools import cycle
|
| 10 |
import asyncio
|
|
|
|
| 24 |
project_id_cache = {}
|
| 25 |
key_rotator = cycle(VERTEX_EXPRESS_KEYS)
|
| 26 |
key_lock = asyncio.Lock()
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
# --- API Key Security ---
|
| 30 |
api_key_query = APIKeyQuery(name="key", auto_error=False)
|
|
|
|
| 84 |
worker_path = current_dir / "gif.worker.js"
|
| 85 |
return FileResponse(worker_path, media_type="application/javascript")
|
| 86 |
|
| 87 |
+
# --- Shared Model Calling Logic ---
|
| 88 |
+
async def call_model(request: Request, model_path: str, express_key: str, project_id: str):
|
| 89 |
+
"""
|
| 90 |
+
Shared function to handle model calling logic for both proxy endpoints.
|
| 91 |
+
"""
|
|
|
|
|
|
|
|
|
|
| 92 |
raw_request_body = await request.body()
|
| 93 |
request_body_to_send = raw_request_body
|
| 94 |
|
| 95 |
try:
|
| 96 |
request_json = json.loads(raw_request_body)
|
| 97 |
if "gemini-2.0-flash-exp-image-generation" in model_path:
|
| 98 |
+
model_path = model_path.replace("gemini-2.0-flash-exp-image-generation", "gemini-2.5-flash-image-preview")
|
| 99 |
|
| 100 |
if "generationConfig" not in request_json:
|
| 101 |
request_json["generationConfig"] = {}
|
| 102 |
|
| 103 |
# Model-specific request body modification
|
| 104 |
if "gemini-2.5-flash-image-preview" in model_path:
|
| 105 |
+
if "generationConfig" in request_json and "thinkingConfig" in request_json.get("generationConfig", {}):
|
| 106 |
+
del request_json["generationConfig"]["thinkingConfig"]
|
| 107 |
+
if "generationConfig" in request_json and "responseMimeType" in request_json.get("generationConfig", {}):
|
| 108 |
+
del request_json["generationConfig"]["responseMimeType"]
|
| 109 |
+
request_json["generationConfig"]["responseModalities"] = ["TEXT", "IMAGE"]
|
| 110 |
+
|
| 111 |
+
# Ensure contents have role field
|
| 112 |
+
if "contents" in request_json:
|
| 113 |
+
for content in request_json["contents"]:
|
| 114 |
+
if "role" not in content:
|
| 115 |
+
content["role"] = "user"
|
| 116 |
|
| 117 |
request_json["safetySettings"] = [
|
| 118 |
{
|
|
|
|
| 150 |
]
|
| 151 |
request_body_to_send = json.dumps(request_json).encode('utf-8')
|
| 152 |
except json.JSONDecodeError:
|
| 153 |
+
pass # Not a json body, proxy as is
|
| 154 |
|
| 155 |
target_url = f"https://aiplatform.googleapis.com/v1/projects/{project_id}/locations/global/publishers/google/models/{model_path}?key={express_key}"
|
| 156 |
|
|
|
|
| 158 |
|
| 159 |
headers_to_proxy = {
|
| 160 |
k: v for k, v in request.headers.items()
|
| 161 |
+
if k.lower() not in ['host', 'authorization', 'x-goog-api-key', 'x-vertex-express-key', 'content-length']
|
| 162 |
}
|
| 163 |
|
| 164 |
print(request_body_to_send)
|
|
|
|
| 212 |
return Response(
|
| 213 |
content=modified_response_data,
|
| 214 |
status_code=response.status_code,
|
| 215 |
+
headers={"content-type": response.headers.get("content-type")},
|
| 216 |
)
|
| 217 |
finally:
|
| 218 |
await response.aclose()
|
| 219 |
await client.aclose()
|
| 220 |
|
| 221 |
+
# --- Frontend-specific endpoint (no authentication required) ---
|
| 222 |
+
@app.post("/frontend/v1beta/models/{model_name}:{function_name}")
|
| 223 |
+
async def frontend_proxy(
|
| 224 |
+
model_name: str,
|
| 225 |
+
function_name: str,
|
| 226 |
+
request: Request,
|
| 227 |
+
vertex_express_key: str = Header(..., alias="x-vertex-express-key")
|
| 228 |
+
):
|
| 229 |
+
"""
|
| 230 |
+
Frontend-specific proxy endpoint that only requires a Vertex Express key.
|
| 231 |
+
No proxy authentication needed.
|
| 232 |
+
"""
|
| 233 |
+
try:
|
| 234 |
+
# Get or extract project ID for this key
|
| 235 |
+
project_id = await get_project_id(vertex_express_key)
|
| 236 |
+
|
| 237 |
+
# Use shared model calling logic
|
| 238 |
+
model_path = f"{model_name}:{function_name}"
|
| 239 |
+
return await call_model(request, model_path, vertex_express_key, project_id)
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
logger.error(f"Frontend proxy error: {str(e)}")
|
| 243 |
+
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}")
|
| 244 |
+
|
| 245 |
+
# --- Proxy Endpoint ---
|
| 246 |
+
@app.post("/v1beta/models/{model_path:path}")
|
| 247 |
+
async def proxy(request: Request, model_path: str, _: str = Security(get_api_key)):
|
| 248 |
+
async with key_lock:
|
| 249 |
+
express_key = next(key_rotator)
|
| 250 |
+
|
| 251 |
+
project_id = await get_project_id(express_key)
|
| 252 |
+
|
| 253 |
+
# Use shared model calling logic
|
| 254 |
+
return await call_model(request, model_path, express_key, project_id)
|
| 255 |
+
|
| 256 |
if __name__ == "__main__":
|
| 257 |
import uvicorn
|
| 258 |
# Hugging Face Spaces run on port 7860
|