Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| from dotenv import load_dotenv | |
| from distutils.util import strtobool | |
| from llama2_wrapper import LLAMA2_WRAPPER | |
| def main(): | |
| load_dotenv() | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| os.getenv("DEFAULT_SYSTEM_PROMPT") | |
| if os.getenv("DEFAULT_SYSTEM_PROMPT") is not None | |
| else "" | |
| ) | |
| MAX_MAX_NEW_TOKENS = ( | |
| int(os.getenv("MAX_MAX_NEW_TOKENS")) | |
| if os.getenv("DEFAULT_MAX_NEW_TOKENS") is not None | |
| else 2048 | |
| ) | |
| DEFAULT_MAX_NEW_TOKENS = ( | |
| int(os.getenv("DEFAULT_MAX_NEW_TOKENS")) | |
| if os.getenv("DEFAULT_MAX_NEW_TOKENS") is not None | |
| else 1024 | |
| ) | |
| MAX_INPUT_TOKEN_LENGTH = ( | |
| int(os.getenv("MAX_INPUT_TOKEN_LENGTH")) | |
| if os.getenv("MAX_INPUT_TOKEN_LENGTH") is not None | |
| else 4000 | |
| ) | |
| MODEL_PATH = os.getenv("MODEL_PATH") | |
| assert MODEL_PATH is not None, f"MODEL_PATH is required, got: {MODEL_PATH}" | |
| LOAD_IN_8BIT = bool(strtobool(os.getenv("LOAD_IN_8BIT", "True"))) | |
| LOAD_IN_4BIT = bool(strtobool(os.getenv("LOAD_IN_4BIT", "True"))) | |
| LLAMA_CPP = bool(strtobool(os.getenv("LLAMA_CPP", "True"))) | |
| if LLAMA_CPP: | |
| print("Running on CPU with llama.cpp.") | |
| else: | |
| import torch | |
| if torch.cuda.is_available(): | |
| print("Running on GPU with torch transformers.") | |
| else: | |
| print("CUDA not found.") | |
| config = { | |
| "model_name": MODEL_PATH, | |
| "load_in_8bit": LOAD_IN_8BIT, | |
| "load_in_4bit": LOAD_IN_4BIT, | |
| "llama_cpp": LLAMA_CPP, | |
| "MAX_INPUT_TOKEN_LENGTH": MAX_INPUT_TOKEN_LENGTH, | |
| } | |
| tic = time.perf_counter() | |
| llama2_wrapper = LLAMA2_WRAPPER(config) | |
| llama2_wrapper.init_tokenizer() | |
| llama2_wrapper.init_model() | |
| toc = time.perf_counter() | |
| print(f"Initialize the model in {toc - tic:0.4f} seconds.") | |
| example = "Can you explain briefly to me what is the Python programming language?" | |
| generator = llama2_wrapper.run( | |
| example, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 50 | |
| ) | |
| tic = time.perf_counter() | |
| try: | |
| first_response = next(generator) | |
| # history += [(example, first_response)] | |
| # print(first_response) | |
| except StopIteration: | |
| pass | |
| # history += [(example, "")] | |
| # print(history) | |
| for response in generator: | |
| # history += [(example, response)] | |
| # print(response) | |
| pass | |
| print(response) | |
| toc = time.perf_counter() | |
| output_token_length = llama2_wrapper.get_token_length(response) | |
| print(f"Generating the out in {toc - tic:0.4f} seconds.") | |
| print(f"Speed: {output_token_length / (toc - tic):0.4f} tokens/sec.") | |
| if __name__ == "__main__": | |
| main() | |