from transformers import PretrainedConfig import tensorflow as tf from transformers import TFPreTrainedModel from tensorflow_examples.models.pix2pix import pix2pix from transformers import TFPreTrainedModel from transformers import PretrainedConfig class CycleGANConfig(PretrainedConfig): model_type = "cyclegan" def __init__( self, output_channels=3, norm_type='instancenorm', generator_type='unet', discriminator_target=False, lambda = 10, learning_rate = 2e-4, beta_1 = 0.5, epochs = 50, training_checkpoint = "./rgb2thermal_checkpoints/train", **kwargs ): super().__init__(**kwargs) self.output_channels = output_channels self.norm_type = norm_type self.generator_type = generator_type self.discriminator_target = discriminator_target class TFCycleGANModel(TFPreTrainedModel): config_class = CycleGANConfig def __init__(self, config): super().__init__(config) self.generator_g = pix2pix.unet_generator(config.output_channels, norm_type=config.norm_type) self.generator_f = pix2pix.unet_generator(config.output_channels, norm_type=config.norm_type) self.discriminator_x = pix2pix.discriminator(norm_type=config.norm_type, target=config.discriminator_target) self.discriminator_y = pix2pix.discriminator(norm_type=config.norm_type, target=config.discriminator_target) self.generator_g_optimizer = tf.keras.optimizers.Adam(config.learning_rate, beta_1=config.beta_1) self.generator_f_optimizer = tf.keras.optimizers.Adam(config.learning_rate, beta_1=config.beta_1) self.discriminator_x_optimizer = tf.keras.optimizers.Adam(config.learning_rate, beta_1=config.beta_1) self.discriminator_y_optimizer = tf.keras.optimizers.Adam(config.learning_rate, beta_1=config.beta_1) self.LAMBDA = config.lambda self.loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True) self.training_checkpoint = config.training_checkpoint self.EPOCHS = config.epochs def call(self, inputs): return self.generator_g(inputs) def generate(self, inputs): return self.generator_g(inputs) def random_crop(self, image): cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3]) return cropped_image # normalizing the images to [-1, 1] def normalize(self, image): image = tf.cast(image, tf.float32) image = (image / 127.5) - 1 return image # Enhance data augmentation def more_augment(self, image): #Random Flip Left Right image = tf.image.random_flip_left_right(image) #Random brightness image = tf.image.random_brightness(image, max_delta=0.1) #random Contrast image = tf.image.random_contrast(image, lower=0.9, upper=1.1) return image def random_jitter(self, image): # resizing to 286 x 286 x 3 image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) # randomly cropping to 256 x 256 x 3 image = self.random_crop(image) # random mirroring image = tf.image.random_flip_left_right(image) return image def preprocess_image_train(self, image, label): image = self.random_jitter(image) image = self.more_augment(image) image = self.normalize(image) return image def preprocess_image_test(self, image, label): image = self.normalize(image) return image def discriminator_loss(self, real, generated): real_loss = self.loss_obj(tf.ones_like(real), real) generated_loss = self.loss_obj(tf.zeros_like(generated), generated) total_disc_loss = real_loss + generated_loss return total_disc_loss * 0.5 def generator_loss(self, generated): return loss_obj(tf.ones_like(generated), generated) def calc_cycle_loss(self, real_image, cycled_image): loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image)) return self.LAMBDA * loss1 def identity_loss(self, real_image, same_image): loss = tf.reduce_mean(tf.abs(real_image - same_image)) return self.LAMBDA * 0.5 * loss @tf.function def train_step(self, real_x, real_y): # persistent is set to True because the tape is used more than # once to calculate the gradients. with tf.GradientTape(persistent=True) as tape: # Generator G translates X -> Y # Generator F translates Y -> X. fake_y = self.generator_g(real_x, training=True) cycled_x = self.generator_f(fake_y, training=True) fake_x = self.generator_f(real_y, training=True) cycled_y = self.generator_g(fake_x, training=True) # same_x and same_y are used for identity loss. same_x = self.generator_f(real_x, training=True) same_y = self.generator_g(real_y, training=True) disc_real_x = self.discriminator_x(real_x, training=True) disc_real_y = self.discriminator_y(real_y, training=True) disc_fake_x = self.discriminator_x(fake_x, training=True) disc_fake_y = self.discriminator_y(fake_y, training=True) # calculate the loss gen_g_loss = self.generator_loss(disc_fake_y) gen_f_loss = self.generator_loss(disc_fake_x) total_cycle_loss = self.calc_cycle_loss(real_x, cycled_x) + self.calc_cycle_loss(real_y, cycled_y) # # Total generator loss = adversarial loss + cycle loss total_gen_g_loss = gen_g_loss + total_cycle_loss + self.identity_loss(real_y, same_y) total_gen_f_loss = gen_f_loss + total_cycle_loss + self.identity_loss(real_x, same_x) # total_gen_g_loss = gen_g_loss + 10 * total_cycle_loss + 5 * self.identity_loss(real_y, same_y) # total_gen_f_loss = gen_f_loss + 10 * total_cycle_loss + 5 * self.identity_loss(real_x, same_x) disc_x_loss = self.discriminator_loss(disc_real_x, disc_fake_x) disc_y_loss = self.discriminator_loss(disc_real_y, disc_fake_y) # Calculate the gradients for generator and discriminator generator_g_gradients = tape.gradient(total_gen_g_loss, self.generator_g.trainable_variables) generator_f_gradients = tape.gradient(total_gen_f_loss, self.generator_f.trainable_variables) discriminator_x_gradients = tape.gradient(disc_x_loss, self.discriminator_x.trainable_variables) discriminator_y_gradients = tape.gradient(disc_y_loss, self.discriminator_y.trainable_variables) # Apply the gradients to the optimizer self.generator_g_optimizer.apply_gradients(zip(generator_g_gradients, self.generator_g.trainable_variables)) self.generator_f_optimizer.apply_gradients(zip(generator_f_gradients, self.generator_f.trainable_variables)) self.discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, self.discriminator_x.trainable_variables)) self.discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, self.discriminator_y.trainable_variables)) class TFCycleGANModel(TFPreTrainedModel): @classmethod def from_pretrained(cls, pretrained_model_name_or_path = "./rgb2thermal_checkpoints/train", *model_args, **kwargs): config = kwargs.pop("config", None) if not isinstance(config, CycleGANConfig): config = CycleGANConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) kwargs["config"] = config model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # Load the checkpoint checkpoint_path = pretrained_model_name_or_path checkpoint = tf.train.Checkpoint(generator_g=model.generator_g, generator_f=model.generator_f, discriminator_x=model.discriminator_x, discriminator_y=model.discriminator_y) checkpoint.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial() print("Model restored from checkpoint") return model