Spaces:
Sleeping
Sleeping
| import threading | |
| class Callbacks: | |
| """" | |
| Handles all registered callbacks for YOLOv5 Hooks | |
| """ | |
| def __init__(self): | |
| # Define the available callbacks | |
| self._callbacks = { | |
| 'on_pretrain_routine_start': [], | |
| 'on_pretrain_routine_end': [], | |
| 'on_train_start': [], | |
| 'on_train_epoch_start': [], | |
| 'on_train_batch_start': [], | |
| 'optimizer_step': [], | |
| 'on_before_zero_grad': [], | |
| 'on_train_batch_end': [], | |
| 'on_train_epoch_end': [], | |
| 'on_val_start': [], | |
| 'on_val_batch_start': [], | |
| 'on_val_image_end': [], | |
| 'on_val_batch_end': [], | |
| 'on_val_end': [], | |
| 'on_fit_epoch_end': [], # fit = train + val | |
| 'on_model_save': [], | |
| 'on_train_end': [], | |
| 'on_params_update': [], | |
| 'teardown': [],} | |
| self.stop_training = False # set True to interrupt training | |
| def register_action(self, hook, name='', callback=None): | |
| """ | |
| Register a new action to a callback hook | |
| Args: | |
| hook: The callback hook name to register the action to | |
| name: The name of the action for later reference | |
| callback: The callback to fire | |
| """ | |
| assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
| assert callable(callback), f"callback '{callback}' is not callable" | |
| self._callbacks[hook].append({'name': name, 'callback': callback}) | |
| def get_registered_actions(self, hook=None): | |
| """" | |
| Returns all the registered actions by callback hook | |
| Args: | |
| hook: The name of the hook to check, defaults to all | |
| """ | |
| return self._callbacks[hook] if hook else self._callbacks | |
| def run(self, hook, *args, thread=False, **kwargs): | |
| """ | |
| Loop through the registered actions and fire all callbacks on main thread | |
| Args: | |
| hook: The name of the hook to check, defaults to all | |
| args: Arguments to receive from YOLOv5 | |
| thread: (boolean) Run callbacks in daemon thread | |
| kwargs: Keyword Arguments to receive from YOLOv5 | |
| """ | |
| assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
| for logger in self._callbacks[hook]: | |
| if thread: | |
| threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start() | |
| else: | |
| logger['callback'](*args, **kwargs) | |