Spaces:
Runtime error
Runtime error
| import traceback | |
| from queue import Queue | |
| from threading import Thread | |
| import collections.abc | |
| import torch | |
| from transformers import StoppingCriteria | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops=[], encounters=[]): | |
| super().__init__() | |
| assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match" | |
| self.encounters = encounters | |
| self.stops = [stop.to("cuda") for stop in stops] | |
| self.num_stops = [0] * len(stops) | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| for stopi, stop in enumerate(self.stops): | |
| if torch.all((stop == input_ids[0][-len(stop):])).item(): | |
| self.num_stops[stopi] += 1 | |
| if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: | |
| return True | |
| # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) | |
| # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) | |
| return False | |
| class Stream(StoppingCriteria): | |
| """ | |
| This class can be used to callback during generation. Keep | |
| in mind for decoder-only type of transformers, this will include the initial prompted tokens. | |
| Args: | |
| func (`callable`): | |
| A callable function to apply on first input in list every iteration of generation | |
| """ | |
| def __init__(self, func=None): | |
| self.func = func | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| if self.func is not None: | |
| # only consume first of multiple responses | |
| self.func(input_ids[0]) | |
| return False | |
| class CallbackToGenerator(collections.abc.Generator): | |
| """ | |
| A generator wrapper for a function that invokes a callback multiple times. | |
| Calling `send` on the generator emits a value from one callback, and returns | |
| the next. | |
| Note this starts a background thread | |
| """ | |
| def __init__(self, func, *args, callback=None, **kwargs): | |
| self.func = func | |
| self.args = args | |
| self.kwargs = kwargs | |
| self.callback = callback | |
| self._ready_queue = Queue(1) | |
| self._done_queue = Queue(1) | |
| self._done_holder = [False] | |
| # local to avoid reference cycles | |
| ready_queue = self._ready_queue | |
| done_queue = self._done_queue | |
| done_holder = self._done_holder | |
| def val_callback(value): | |
| done_queue.put((False, value)) | |
| cmd, val = ready_queue.get() | |
| if cmd == 'send': | |
| return val | |
| elif cmd == 'throw': | |
| raise val | |
| else: | |
| assert False # pragma: no cover | |
| def thread_func(): | |
| while True: | |
| cmd, val = ready_queue.get() | |
| if cmd == 'send' and val is not None: | |
| done_queue.put((True, TypeError("can't send non-None value to a just-started generator"))) | |
| continue | |
| break | |
| try: | |
| if cmd == 'throw': | |
| raise val | |
| ret = func(callback=val_callback, **self.kwargs) | |
| raise StopIteration(ret) if ret is not None else StopIteration | |
| except BaseException as e: | |
| done_holder[0] = True | |
| done_queue.put((True, e)) | |
| self._thread = Thread(target=thread_func) | |
| self._thread.start() | |
| def _put(self, *args): | |
| if self._done_holder[0]: | |
| raise StopIteration | |
| self._ready_queue.put(args) | |
| is_exception, val = self._done_queue.get() | |
| if is_exception: | |
| try: | |
| raise val | |
| finally: | |
| # prevent val's traceback containing a reference cycle | |
| del val | |
| else: | |
| return val | |
| def send(self, value): | |
| return self._put('send', value) | |
| def throw(self, exc): | |
| return self._put('throw', exc) | |
| def close(self): | |
| try: | |
| self.throw(GeneratorExit) | |
| except StopIteration: | |
| self._thread.join() | |
| except GeneratorExit: | |
| self._thread.join() | |
| except BaseException: | |
| self._thread.join() | |
| raise | |
| else: | |
| # yielded again, can't clean up the thread | |
| raise RuntimeError('Task with callback ignored GeneratorExit') | |
| def __del__(self): | |
| self.close() | |