Spaces:
Build error
Build error
update GPU duration
Browse files
app.py
CHANGED
|
@@ -68,7 +68,7 @@ def load_models():
|
|
| 68 |
|
| 69 |
pipe_sd35, pipe_sdxl = load_models()
|
| 70 |
|
| 71 |
-
@spaces.GPU
|
| 72 |
def generate_image(
|
| 73 |
model_name,
|
| 74 |
seed,
|
|
@@ -97,7 +97,6 @@ def generate_image(
|
|
| 97 |
|
| 98 |
pipe.to(device)
|
| 99 |
pipe.enable_model_cpu_offload()
|
| 100 |
-
# os.makedirs('./weights', exist_ok=True)
|
| 101 |
os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
|
| 102 |
# TODO: load noise model
|
| 103 |
if method == 'core' or method == 'z-core':
|
|
@@ -105,7 +104,7 @@ def generate_image(
|
|
| 105 |
from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
|
| 106 |
|
| 107 |
if model_name == 'sd35':
|
| 108 |
-
refine_model = PromptSD35Net()
|
| 109 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
| 110 |
lora_true(refine_model, lora_idx=0)
|
| 111 |
checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
|
|
@@ -117,10 +116,9 @@ def generate_image(
|
|
| 117 |
checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
|
| 118 |
refine_model.load_state_dict(checkpoint)
|
| 119 |
|
| 120 |
-
print("Load Lora Success")
|
| 121 |
-
refine_model = refine_model.to(device)
|
| 122 |
refine_model = refine_model.to(torch.bfloat16)
|
| 123 |
-
|
|
|
|
| 124 |
# 根据模型类型设置形状
|
| 125 |
if model_name == 'sdxl':
|
| 126 |
shape = (1, 4, size // 8, size // 8)
|
|
|
|
| 68 |
|
| 69 |
pipe_sd35, pipe_sdxl = load_models()
|
| 70 |
|
| 71 |
+
@spaces.GPU(duration=360)
|
| 72 |
def generate_image(
|
| 73 |
model_name,
|
| 74 |
seed,
|
|
|
|
| 97 |
|
| 98 |
pipe.to(device)
|
| 99 |
pipe.enable_model_cpu_offload()
|
|
|
|
| 100 |
os.system('huggingface-cli download sst12345/CoRe2 weights/sd35_noise_model.pth weights/sdxl_noise_model.pth --local-dir ./weights')
|
| 101 |
# TODO: load noise model
|
| 102 |
if method == 'core' or method == 'z-core':
|
|
|
|
| 104 |
from diffusion_pipeline.lora import replace_linear_with_lora, lora_true
|
| 105 |
|
| 106 |
if model_name == 'sd35':
|
| 107 |
+
refine_model = PromptSD35Net()
|
| 108 |
replace_linear_with_lora(refine_model, rank=64, alpha=1.0, number_of_lora=28)
|
| 109 |
lora_true(refine_model, lora_idx=0)
|
| 110 |
checkpoint = torch.load('./weights/weights/sd35_noise_model.pth', map_location='cpu')
|
|
|
|
| 116 |
checkpoint = torch.load('./weights/weights/sdxl_noise_model.pth', map_location='cpu')
|
| 117 |
refine_model.load_state_dict(checkpoint)
|
| 118 |
|
|
|
|
|
|
|
| 119 |
refine_model = refine_model.to(torch.bfloat16)
|
| 120 |
+
refine_model = refine_model.to(device)
|
| 121 |
+
print("Load Lora Success")
|
| 122 |
# 根据模型类型设置形状
|
| 123 |
if model_name == 'sdxl':
|
| 124 |
shape = (1, 4, size // 8, size // 8)
|