Update handler.py
Browse files- handler.py +17 -36
    	
        handler.py
    CHANGED
    
    | @@ -1,57 +1,38 @@ | |
|  | |
| 1 | 
             
            from typing import Dict, Any
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 4 | 
             
            from accelerate import init_empty_weights, load_checkpoint_and_dispatch
         | 
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
             
            class EndpointHandler:
         | 
| 8 | 
            -
                def __init__(self, model_dir: str, ** | 
| 9 | 
            -
                    self.tokenizer = AutoTokenizer.from_pretrained(
         | 
| 10 | 
            -
                        model_dir, trust_remote_code=True
         | 
| 11 | 
            -
                    )
         | 
| 12 |  | 
| 13 | 
            -
                    # ①  | 
| 14 | 
             
                    with init_empty_weights():
         | 
| 15 | 
             
                        base = AutoModelForCausalLM.from_pretrained(
         | 
| 16 | 
            -
                            model_dir,
         | 
| 17 | 
            -
                            torch_dtype=torch.float16,
         | 
| 18 | 
            -
                            trust_remote_code=True,
         | 
| 19 | 
             
                        )
         | 
| 20 |  | 
| 21 | 
            -
                    # ②  | 
| 22 | 
             
                    self.model = load_checkpoint_and_dispatch(
         | 
| 23 | 
            -
                        base,
         | 
| 24 | 
            -
                        checkpoint=model_dir,
         | 
| 25 | 
            -
                        device_map="auto",
         | 
| 26 | 
            -
                        dtype=torch.float16,
         | 
| 27 | 
             
                    ).eval()
         | 
| 28 |  | 
| 29 | 
            -
                    # ③  | 
| 30 | 
            -
                    self. | 
| 31 | 
            -
                    torch.cuda.set_device(self. | 
| 32 | 
            -
             | 
| 33 | 
            -
                    # ④ 生成参数
         | 
| 34 | 
            -
                    self.generation_kwargs = dict(
         | 
| 35 | 
            -
                        max_new_tokens=512,     # 🛈 2 k token 占显存极高,先压到 512 再逐步调
         | 
| 36 | 
            -
                        do_sample=True,
         | 
| 37 | 
            -
                        temperature=0.7,
         | 
| 38 | 
            -
                        top_p=0.9,
         | 
| 39 | 
            -
                    )
         | 
| 40 |  | 
| 41 | 
            -
                    #  | 
| 42 | 
            -
                     | 
| 43 |  | 
| 44 | 
             
                def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
         | 
| 45 | 
             
                    prompt = data["inputs"]
         | 
| 46 |  | 
| 47 | 
            -
                    # 把 *所有* 输入张量放到  | 
| 48 | 
            -
                    inputs = self.tokenizer(prompt, return_tensors="pt").to(self. | 
| 49 | 
            -
             | 
| 50 | 
             
                    with torch.inference_mode():
         | 
| 51 | 
            -
                         | 
| 52 |  | 
| 53 | 
            -
                    return {
         | 
| 54 | 
            -
                        "generated_text": self.tokenizer.decode(
         | 
| 55 | 
            -
                            output_ids[0], skip_special_tokens=True
         | 
| 56 | 
            -
                        )
         | 
| 57 | 
            -
                    }
         | 
|  | |
| 1 | 
            +
            # handler.py
         | 
| 2 | 
             
            from typing import Dict, Any
         | 
| 3 | 
             
            import torch
         | 
| 4 | 
             
            from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 5 | 
             
            from accelerate import init_empty_weights, load_checkpoint_and_dispatch
         | 
| 6 |  | 
|  | |
| 7 | 
             
            class EndpointHandler:
         | 
| 8 | 
            +
                def __init__(self, model_dir: str, **kw):
         | 
| 9 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
         | 
|  | |
|  | |
| 10 |  | 
| 11 | 
            +
                    # ① 空壳模型
         | 
| 12 | 
             
                    with init_empty_weights():
         | 
| 13 | 
             
                        base = AutoModelForCausalLM.from_pretrained(
         | 
| 14 | 
            +
                            model_dir, torch_dtype=torch.float16, trust_remote_code=True
         | 
|  | |
|  | |
| 15 | 
             
                        )
         | 
| 16 |  | 
| 17 | 
            +
                    # ② 分片加载
         | 
| 18 | 
             
                    self.model = load_checkpoint_and_dispatch(
         | 
| 19 | 
            +
                        base, checkpoint=model_dir, device_map="auto", dtype=torch.float16
         | 
|  | |
|  | |
|  | |
| 20 | 
             
                    ).eval()
         | 
| 21 |  | 
| 22 | 
            +
                    # ③ 锁定“默认 GPU”= 词嵌入所在 GPU
         | 
| 23 | 
            +
                    self.embed_device = self.model.get_input_embeddings().weight.device
         | 
| 24 | 
            +
                    torch.cuda.set_device(self.embed_device)     # ← 关键 1
         | 
| 25 | 
            +
                    print(">>> embedding on", self.embed_device)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 26 |  | 
| 27 | 
            +
                    # 生成参数
         | 
| 28 | 
            +
                    self.gen_kwargs = dict(max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True)
         | 
| 29 |  | 
| 30 | 
             
                def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
         | 
| 31 | 
             
                    prompt = data["inputs"]
         | 
| 32 |  | 
| 33 | 
            +
                    # 把 *所有* 输入张量放到 embed_device
         | 
| 34 | 
            +
                    inputs = self.tokenizer(prompt, return_tensors="pt").to(self.embed_device)  # ← 关键 2
         | 
|  | |
| 35 | 
             
                    with torch.inference_mode():
         | 
| 36 | 
            +
                        out_ids = self.model.generate(**inputs, **self.gen_kwargs)
         | 
| 37 |  | 
| 38 | 
            +
                    return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)}
         | 
|  | |
|  | |
|  | |
|  | 

