sywise98 commited on
Commit
fbc1d53
·
verified ·
1 Parent(s): 1e6638d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -13
model.py CHANGED
@@ -3,6 +3,7 @@ import tensorflow as tf
3
  from transformers import TFPreTrainedModel
4
  from tensorflow_examples.models.pix2pix import pix2pix
5
  from transformers import TFPreTrainedModel
 
6
 
7
  class CycleGANConfig(PretrainedConfig):
8
  model_type = "cyclegan"
@@ -143,7 +144,7 @@ class TFCycleGANModel(TFPreTrainedModel):
143
  gen_g_loss = self.generator_loss(disc_fake_y)
144
  gen_f_loss = self.generator_loss(disc_fake_x)
145
 
146
- total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
147
 
148
  # # Total generator loss = adversarial loss + cycle loss
149
  total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y)
@@ -157,27 +158,27 @@ class TFCycleGANModel(TFPreTrainedModel):
157
 
158
  # Calculate the gradients for generator and discriminator
159
  generator_g_gradients = tape.gradient(total_gen_g_loss,
160
- generator_g.trainable_variables)
161
  generator_f_gradients = tape.gradient(total_gen_f_loss,
162
- generator_f.trainable_variables)
163
 
164
  discriminator_x_gradients = tape.gradient(disc_x_loss,
165
- discriminator_x.trainable_variables)
166
  discriminator_y_gradients = tape.gradient(disc_y_loss,
167
- discriminator_y.trainable_variables)
168
 
169
  # Apply the gradients to the optimizer
170
- generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
171
- generator_g.trainable_variables))
172
 
173
- generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
174
- generator_f.trainable_variables))
175
 
176
- discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
177
- discriminator_x.trainable_variables))
178
 
179
- discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
180
- discriminator_y.trainable_variables))
181
 
182
 
183
 
 
3
  from transformers import TFPreTrainedModel
4
  from tensorflow_examples.models.pix2pix import pix2pix
5
  from transformers import TFPreTrainedModel
6
+ from transformers import PretrainedConfig
7
 
8
  class CycleGANConfig(PretrainedConfig):
9
  model_type = "cyclegan"
 
144
  gen_g_loss = self.generator_loss(disc_fake_y)
145
  gen_f_loss = self.generator_loss(disc_fake_x)
146
 
147
+ total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + self.calc_cycle_loss(real_y, cycled_y)
148
 
149
  # # Total generator loss = adversarial loss + cycle loss
150
  total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y)
 
158
 
159
  # Calculate the gradients for generator and discriminator
160
  generator_g_gradients = tape.gradient(total_gen_g_loss,
161
+ self.generator_g.trainable_variables)
162
  generator_f_gradients = tape.gradient(total_gen_f_loss,
163
+ self.generator_f.trainable_variables)
164
 
165
  discriminator_x_gradients = tape.gradient(disc_x_loss,
166
+ self.discriminator_x.trainable_variables)
167
  discriminator_y_gradients = tape.gradient(disc_y_loss,
168
+ self.discriminator_y.trainable_variables)
169
 
170
  # Apply the gradients to the optimizer
171
+ self.generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
172
+ self.generator_g.trainable_variables))
173
 
174
+ self.generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
175
+ self.generator_f.trainable_variables))
176
 
177
+ self.discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
178
+ self.discriminator_x.trainable_variables))
179
 
180
+ self.discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
181
+ self.discriminator_y.trainable_variables))
182
 
183
 
184