| |
| |
| |
| |
| |
| import itertools |
| import sys |
| import time |
|
|
|
|
|
|
|
|
| from typing import Any, Dict, List |
|
|
|
|
|
|
|
|
| import torch |
| from torch import nn |
|
|
|
|
|
|
| from omegaconf import DictConfig |
| from PIL import Image |
|
|
| from torchtune import config, utils |
| from torchtune.utils._generation import sample |
| from torchtune.models import convert_weights |
| from torchtune.data import Message |
|
|
| from models.tokenizer import START_IMAGE, END_IMAGE, START_AUDIO, END_AUDIO, START_VIDEO, END_VIDEO |
| from imagebind.models.imagebind_model import ModalityType |
| from diffusers import DiffusionPipeline |
|
|
| from models import add_proj_convert_weights, _BASE_TRAINABLE |
| import os |
|
|
| log = utils.get_logger("DEBUG") |
| add_proj_convert_weights() |
|
|
|
|
| class InferenceRecipe: |
| """ |
| Recipe for generating tokens from a dense Transformer-based LLM. |
| |
| Currently this recipe supports single-GPU generation only. Speculative |
| decoding is not supported. |
| |
| For more details on how to use this recipe for generation, please see our |
| tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation |
| |
| For using this recipe with a quantized model, please the following section of |
| the above tutorial: |
| https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization |
| """ |
|
|
| def __init__(self, cfg: DictConfig) -> None: |
| self._device = utils.get_device(device=cfg.device) |
| self._dtype = utils.get_dtype(dtype=cfg.dtype) |
| self._quantizer = config.instantiate(cfg.inference.quantizer) |
| self._quantization_mode = utils.get_quantizer_mode(self._quantizer) |
| self.prompt_template = cfg.inference.prompt_template |
| perception_tokens = cfg.model.perception_tokens |
| self._perception_tokens = ("0 " * perception_tokens)[:perception_tokens] |
| utils.set_seed(seed=cfg.seed) |
|
|
| def setup(self, cfg: DictConfig) -> None: |
| checkpointer = config.instantiate(cfg.checkpointer) |
| if self._quantization_mode is None: |
| ckpt_dict = checkpointer.load_checkpoint() |
| else: |
| |
| |
| |
| ckpt_dict = checkpointer.load_checkpoint(weights_only=False) |
|
|
| self._model = self._setup_model( |
| model_cfg=cfg.model, |
| model_state_dict=ckpt_dict[utils.MODEL_KEY], |
| ) |
| with self._device: |
| self._model.setup_caches(max_batch_size=cfg.batch_size, dtype=self._dtype) |
|
|
| self._tokenizer = config.instantiate(cfg.tokenizer) |
| self._mm_ids_start = self._tokenizer.encode(START_IMAGE + START_AUDIO + START_VIDEO, add_eos=False, add_bos=False) |
| self._mm_ids_end = self._tokenizer.encode(END_IMAGE + END_AUDIO + END_VIDEO, add_eos=False, add_bos=False) |
| self.use_clip = cfg.model.use_clip |
| if self.use_clip: |
| self._clip_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=self._dtype).to(self._device) |
|
|
| def _setup_model( |
| self, |
| model_cfg: DictConfig, |
| model_state_dict: Dict[str, Any], |
| ) -> nn.Module: |
| with utils.set_default_dtype(self._dtype), self._device: |
| model = config.instantiate(model_cfg) |
|
|
| if self._quantization_mode is not None: |
| model = self._quantizer.quantize(model) |
| model = model.to(device=self._device, dtype=self._dtype) |
|
|
| model.load_state_dict(model_state_dict) |
|
|
| |
| utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) |
| log.debug(f"Model is initialized with precision {self._dtype}.") |
|
|
| return model |
|
|
| def mm_process_prompt(self, prompt): |
| return ( |
| prompt |
| .replace("{image}", f"{START_IMAGE}{self._perception_tokens}{END_IMAGE}") |
| .replace("{audio}", f"{START_AUDIO}{self._perception_tokens}{END_AUDIO}") |
| .replace("{video}", f"{START_VIDEO}{self._perception_tokens}{END_VIDEO}") |
| ) |
|
|
| def extract_mm_context(self, video_ib_embed, tokens): |
| context = {} |
| in_mm_embed = False |
| for idx, tok in enumerate(tokens): |
| in_mm_embed = in_mm_embed and not tok in self._mm_ids_end |
| if in_mm_embed: |
| |
| context[idx] = { |
| "ib_embed": video_ib_embed.to(dtype=self._dtype, device=self._device), |
| } |
| in_mm_embed = in_mm_embed or tok in self._mm_ids_start |
| return context |
|
|
| @torch.no_grad() |
| def generate(self, cfg: DictConfig, video_ib_embed: List[float]): |
| messages = [ |
| Message( |
| role="user", |
| content=self.mm_process_prompt(self.prompt_template), |
| ), |
| Message( |
| role="assistant", |
| content="", |
| ) |
| ] |
| tokens, mask = self._tokenizer.tokenize_messages(messages) |
| tokens = tokens[:-2] |
| mm_context = [self.extract_mm_context(video_ib_embed, tokens)] |
| prompt = torch.tensor(tokens, dtype=torch.int, device=self._device) |
|
|
| self._model.tok_embeddings.set_context(mm_context) |
| self._model.output.set_context(mm_context) |
|
|
| bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
| allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
| disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
| |
|
|
| def custom_generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
| model.tok_embeddings.set_context([]) |
| model.output.set_context([]) |
| |
| |
| logits = model(x, input_pos=input_pos) |
| |
| |
| |
| logits = logits[0, -1] |
| |
| |
| token = sample(logits, temperature, top_k) |
| if token in disallowed_tokens: |
| return torch.tensor([self._tokenizer.eos_id]).to(x) |
| return token |
|
|
| |
| |
| if self._quantization_mode is not None: |
| log.info("Starting compilation to improve generation performance ...") |
| custom_generate_next_token = torch.compile( |
| custom_generate_next_token, mode="max-autotune", fullgraph=True |
| ) |
| t0 = time.perf_counter() |
| _ = utils.generate( |
| model=self._model, |
| prompt=prompt, |
| max_generated_tokens=2, |
| temperature=cfg.temperature, |
| top_k=cfg.top_k, |
| eos_id=self._tokenizer.eos_id, |
| custom_generate_next_token=custom_generate_next_token, |
| ) |
| t = time.perf_counter() - t0 |
| log.info(f"Warmup run for quantized model takes: {t:.02f} sec") |
|
|
| t0 = time.perf_counter() |
| generated_tokens = utils.generate( |
| model=self._model, |
| prompt=prompt, |
| max_generated_tokens=cfg.max_new_tokens, |
| temperature=cfg.temperature, |
| top_k=cfg.top_k, |
| eos_id=self._tokenizer.eos_id, |
| custom_generate_next_token=custom_generate_next_token, |
| ) |
| t = time.perf_counter() - t0 |
|
|
| cleaned_tokens = [t for t in generated_tokens[len(prompt):] if t not in disallowed_tokens + allowed_id] |
| caption = self._tokenizer.decode(cleaned_tokens) |
|
|
| |
|
|
| return caption |
|
|
|
|
| @torch.no_grad() |
| def generate_batch(self, cfg: DictConfig, video_ib_embed: torch.Tensor): |
| log.info(f"inside generate_batch, video_ib_embed shape: {video_ib_embed.shape}") |
| batch_dim = video_ib_embed.size(0) |
| messages = [ |
| Message( |
| role="user", |
| content=self.mm_process_prompt(self.prompt_template), |
| ), |
| Message(role="assistant", content="") |
| ] |
| tokens, mask = self._tokenizer.tokenize_messages(messages) |
| tokens = tokens[:-2] |
| mm_context = [self.extract_mm_context(e, tokens) for e in video_ib_embed] |
| prompt = torch.tensor(tokens, dtype=torch.int, device=self._device).expand(batch_dim, -1).clone() |
| prompt_length = prompt.size(1) |
|
|
| self._model.tok_embeddings.set_context(mm_context) |
| self._model.output.set_context(mm_context) |
|
|
| bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
| allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
| disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
|
|
| def generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
| |
| |
| |
| logits = model(x, input_pos=input_pos)[:, -1] |
| tokens = sample(logits, temperature, top_k) |
| return torch.tensor([ |
| [self._tokenizer.eos_id if t in disallowed_tokens else t for t in toks] |
| for toks in tokens |
| ]).to(x.device) |
|
|
| generated_tokens = prompt.clone() |
| |
| stop_token_reached = torch.zeros(batch_dim, dtype=torch.bool, device=prompt.device) |
|
|
| |
| tokens = generate_next_token( |
| self._model, |
| input_pos=torch.arange(0, prompt_length, device=prompt.device), |
| x=prompt, |
| temperature=cfg.temperature, |
| top_k=cfg.top_k, |
| ) |
| eot_reached_b = tokens == self._tokenizer.eot_id |
| generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
|
|
| self._model.tok_embeddings.set_context([]) |
| self._model.output.set_context([]) |
|
|
| input_pos = torch.tensor([prompt_length], device=prompt.device) |
| for _ in range(cfg.max_new_tokens - 1): |
| tokens = generate_next_token( |
| self._model, input_pos=input_pos, x=tokens, temperature=cfg.temperature, top_k=cfg.top_k |
| ) |
| eot_reached_b |= tokens == self._tokenizer.eot_id |
| tokens *= ~eot_reached_b |
| generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
| if eot_reached_b.all(): |
| print('eot_reached_b.all()') |
| break |
| input_pos += 1 |
|
|
| captions = [] |
| for caption_tokens in generated_tokens.tolist(): |
| captions.append(self._tokenizer.decode(caption_tokens[prompt.size(1):])) |
| return captions |
|
|
|
|
| @config.parse |
| def main(cfg: DictConfig) -> None: |
| config.log_config(recipe_name="InferenceRecipe", cfg=cfg) |
| cfg.model = DictConfig({ |
| "_component_": "models.mmllama3_8b", |
| "use_clip": False, |
| "perception_tokens": cfg.model.perception_tokens, |
| }) |
| cfg.batch_size = 4 |
| cfg.checkpointer.checkpoint_dir = os.path.dirname("/home/salman/tezuesh/omegalabs-anytoany-bittensor/sandboxing/cache/xzistance_omega-a2a-hotkey/meta_model_0.pth") |
| |
| cfg.checkpointer.checkpoint_files = ["models/meta_model_0.pt"] |
| cfg.inference.max_new_tokens = 300 |
| cfg.tokenizer.path = "./models/tokenizer.model" |
| inference_recipe = InferenceRecipe(cfg) |
| inference_recipe.setup(cfg=cfg) |
| captions = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=torch.randn(4,1024)) |
| print(captions) |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |