Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +7 -6
inference.py
CHANGED
|
@@ -37,15 +37,16 @@ checkpoint_cc12m = torch.load(cc12m_model_path, map_location=torch.device(device
|
|
| 37 |
|
| 38 |
# Create a new Generator model and initialize it with the pre-trained weights
|
| 39 |
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
|
| 43 |
# Function to generate images from text
|
| 44 |
def generate_image_from_text(caption, model, batch_size=4):
|
| 45 |
if model == "CUB":
|
| 46 |
-
generator =
|
| 47 |
else:
|
| 48 |
-
generator =
|
| 49 |
|
| 50 |
# Create the noise tensor
|
| 51 |
noise = torch.randn((batch_size, 100)).to(device)
|
|
@@ -82,9 +83,9 @@ def generate_image_from_text(caption, model, batch_size=4):
|
|
| 82 |
# Function to generate images from text
|
| 83 |
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
|
| 84 |
if model == "CUB":
|
| 85 |
-
generator =
|
| 86 |
else:
|
| 87 |
-
generator =
|
| 88 |
|
| 89 |
# Create the noise tensor
|
| 90 |
noise = torch.randn((batch_size, 100)).to(device)
|
|
|
|
| 37 |
|
| 38 |
# Create a new Generator model and initialize it with the pre-trained weights
|
| 39 |
netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
| 40 |
+
netG1 = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)
|
| 41 |
+
cub = load_model_weights(netG, checkpoint_cub['model']['netG'], multi_gpus=False)
|
| 42 |
+
cc12m = load_model_weights(netG1, checkpoint_cc12m['model']['netG'], multi_gpus=False)
|
| 43 |
|
| 44 |
# Function to generate images from text
|
| 45 |
def generate_image_from_text(caption, model, batch_size=4):
|
| 46 |
if model == "CUB":
|
| 47 |
+
generator = cub
|
| 48 |
else:
|
| 49 |
+
generator = cc12m
|
| 50 |
|
| 51 |
# Create the noise tensor
|
| 52 |
noise = torch.randn((batch_size, 100)).to(device)
|
|
|
|
| 83 |
# Function to generate images from text
|
| 84 |
def generate_image_from_text_with_persistent_storage(caption, model, batch_size=4):
|
| 85 |
if model == "CUB":
|
| 86 |
+
generator = cub
|
| 87 |
else:
|
| 88 |
+
generator = cc12m
|
| 89 |
|
| 90 |
# Create the noise tensor
|
| 91 |
noise = torch.randn((batch_size, 100)).to(device)
|