Update model.py
Browse files
model.py
CHANGED
|
@@ -3,6 +3,7 @@ import tensorflow as tf
|
|
| 3 |
from transformers import TFPreTrainedModel
|
| 4 |
from tensorflow_examples.models.pix2pix import pix2pix
|
| 5 |
from transformers import TFPreTrainedModel
|
|
|
|
| 6 |
|
| 7 |
class CycleGANConfig(PretrainedConfig):
|
| 8 |
model_type = "cyclegan"
|
|
@@ -143,7 +144,7 @@ class TFCycleGANModel(TFPreTrainedModel):
|
|
| 143 |
gen_g_loss = self.generator_loss(disc_fake_y)
|
| 144 |
gen_f_loss = self.generator_loss(disc_fake_x)
|
| 145 |
|
| 146 |
-
total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
|
| 147 |
|
| 148 |
# # Total generator loss = adversarial loss + cycle loss
|
| 149 |
total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y)
|
|
@@ -157,27 +158,27 @@ class TFCycleGANModel(TFPreTrainedModel):
|
|
| 157 |
|
| 158 |
# Calculate the gradients for generator and discriminator
|
| 159 |
generator_g_gradients = tape.gradient(total_gen_g_loss,
|
| 160 |
-
generator_g.trainable_variables)
|
| 161 |
generator_f_gradients = tape.gradient(total_gen_f_loss,
|
| 162 |
-
generator_f.trainable_variables)
|
| 163 |
|
| 164 |
discriminator_x_gradients = tape.gradient(disc_x_loss,
|
| 165 |
-
discriminator_x.trainable_variables)
|
| 166 |
discriminator_y_gradients = tape.gradient(disc_y_loss,
|
| 167 |
-
discriminator_y.trainable_variables)
|
| 168 |
|
| 169 |
# Apply the gradients to the optimizer
|
| 170 |
-
generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
|
| 171 |
-
generator_g.trainable_variables))
|
| 172 |
|
| 173 |
-
generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
|
| 174 |
-
generator_f.trainable_variables))
|
| 175 |
|
| 176 |
-
discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
|
| 177 |
-
discriminator_x.trainable_variables))
|
| 178 |
|
| 179 |
-
discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
|
| 180 |
-
discriminator_y.trainable_variables))
|
| 181 |
|
| 182 |
|
| 183 |
|
|
|
|
| 3 |
from transformers import TFPreTrainedModel
|
| 4 |
from tensorflow_examples.models.pix2pix import pix2pix
|
| 5 |
from transformers import TFPreTrainedModel
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
|
| 8 |
class CycleGANConfig(PretrainedConfig):
|
| 9 |
model_type = "cyclegan"
|
|
|
|
| 144 |
gen_g_loss = self.generator_loss(disc_fake_y)
|
| 145 |
gen_f_loss = self.generator_loss(disc_fake_x)
|
| 146 |
|
| 147 |
+
total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + self.calc_cycle_loss(real_y, cycled_y)
|
| 148 |
|
| 149 |
# # Total generator loss = adversarial loss + cycle loss
|
| 150 |
total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y)
|
|
|
|
| 158 |
|
| 159 |
# Calculate the gradients for generator and discriminator
|
| 160 |
generator_g_gradients = tape.gradient(total_gen_g_loss,
|
| 161 |
+
self.generator_g.trainable_variables)
|
| 162 |
generator_f_gradients = tape.gradient(total_gen_f_loss,
|
| 163 |
+
self.generator_f.trainable_variables)
|
| 164 |
|
| 165 |
discriminator_x_gradients = tape.gradient(disc_x_loss,
|
| 166 |
+
self.discriminator_x.trainable_variables)
|
| 167 |
discriminator_y_gradients = tape.gradient(disc_y_loss,
|
| 168 |
+
self.discriminator_y.trainable_variables)
|
| 169 |
|
| 170 |
# Apply the gradients to the optimizer
|
| 171 |
+
self.generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
|
| 172 |
+
self.generator_g.trainable_variables))
|
| 173 |
|
| 174 |
+
self.generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
|
| 175 |
+
self.generator_f.trainable_variables))
|
| 176 |
|
| 177 |
+
self.discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
|
| 178 |
+
self.discriminator_x.trainable_variables))
|
| 179 |
|
| 180 |
+
self.discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
|
| 181 |
+
self.discriminator_y.trainable_variables))
|
| 182 |
|
| 183 |
|
| 184 |
|