| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import os |
| import os.path as osp |
| import warnings |
| from dataclasses import asdict |
| from typing import Any, Dict, List, Optional, Sequence, Tuple |
|
|
| import torch |
| import transformers |
| from huggingface_hub import file_exists, repo_exists |
| from huggingface_hub.utils import HFValidationError |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| PretrainedConfig, |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| ) |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
|
| |
| from .conversation import SeparatorStyle, default_conversation |
|
|
| SENTINEL_TOKEN = "<vila/sentinel>" |
| MEDIA_TOKENS = { |
| "image": "<image>", |
| "video": "<vila/video>", |
| } |
|
|
| |
| |
| |
|
|
| DUMMY_CONVERSATION = [ |
| {"from": "human", "value": "question"}, |
| {"from": "gpt", "value": "answer"}, |
| ] * 10 |
|
|
|
|
| def tokenizer_image_token(prompt, tokenizer, return_tensors=None): |
| return tokenizer(prompt, return_tensors=return_tensors).input_ids[0] |
|
|
|
|
| def has_tokenizer(repo_id_or_path: str) -> bool: |
| |
| if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")): |
| return True |
|
|
| |
| try: |
| return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json") |
| except HFValidationError: |
| return False |
|
|
|
|
| def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: |
| if not hasattr(tokenizer, "sentinel_token"): |
| tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) |
| tokenizer.sentinel_token = SENTINEL_TOKEN |
| tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) |
|
|
|
|
| def tokenize_conversation_legacy( |
| messages: Sequence[Dict[str, str]], |
| tokenizer: transformers.PreTrainedTokenizer, |
| add_generation_prompt: bool = False, |
| overrides: Optional[Dict[str, str]] = None, |
| no_system_prompt: bool = False, |
| ) -> torch.Tensor: |
| conv = default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
| if no_system_prompt: |
| conv.system = "" |
|
|
| |
| if messages[0]["from"] != "human": |
| messages = messages[1:] |
|
|
| |
| if add_generation_prompt: |
| messages.append({"from": "gpt", "value": None}) |
|
|
| conv.messages = [] |
| for turn, message in enumerate(messages): |
| role = roles[message["from"]] |
| assert role == conv.roles[turn % 2] |
| if overrides is not None and message["from"] in overrides: |
| conv.append_message(role, overrides[message["from"]]) |
| else: |
| conv.append_message(role, message["value"]) |
|
|
| return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") |
|
|
|
|
| def tokenize_conversation( |
| messages: Sequence[Dict[str, str]], |
| tokenizer: transformers.PreTrainedTokenizer, |
| add_generation_prompt: bool = False, |
| overrides: Optional[Dict[str, str]] = None, |
| no_system_prompt: bool = False, |
| ) -> torch.Tensor: |
| |
| for message in messages: |
| message["value"] = message["value"].strip() |
|
|
| if default_conversation.sep_style != SeparatorStyle.AUTO: |
| return tokenize_conversation_legacy( |
| messages, |
| tokenizer, |
| add_generation_prompt=add_generation_prompt, |
| overrides=overrides, |
| no_system_prompt=no_system_prompt, |
| ) |
|
|
| conversation = [] |
| for m in messages: |
| message = {} |
| if m["from"] == "human": |
| message["role"] = "user" |
| elif m["from"] == "gpt": |
| message["role"] = "assistant" |
| else: |
| raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") |
|
|
| message["content"] = m["value"] |
| if overrides is not None and m["from"] in overrides: |
| message["content"] = overrides[m["from"]] |
| conversation.append(message) |
|
|
| if no_system_prompt: |
| conversation = [{"role": "system", "content": ""}] + conversation |
|
|
| text = tokenizer.apply_chat_template( |
| conversation, |
| add_generation_prompt=add_generation_prompt, |
| tokenize=False, |
| ) |
| return tokenizer_image_token(text, tokenizer, return_tensors="pt") |
|
|
|
|
| def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: |
| _maybe_add_sentinel_token(tokenizer) |
| template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) |
|
|
| stop_tokens = {tokenizer.eos_token} |
| for k in range(template.size(0) - 1): |
| if template[k] == tokenizer.sentinel_token_id: |
| stop_token = tokenizer.decode(template[k + 1]) |
| stop_tokens.add(stop_token) |
| return list(stop_tokens) |
|
|
|
|
| def context_length_extension(config): |
| orig_ctx_len = getattr(config, "max_position_embeddings", None) |
| model_max_length = getattr(config, "model_max_length", None) |
| if orig_ctx_len and model_max_length > orig_ctx_len: |
| print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}") |
| scaling_factor = float(math.ceil(model_max_length / orig_ctx_len)) |
| config.rope_scaling = {"type": "linear", "factor": scaling_factor} |
| return config |
|
|
|
|
| def build_llm_and_tokenizer( |
| model_name_or_path: str, |
| config: PretrainedConfig, |
| attn_implementation=None, |
| model_max_length=None, |
| *args, |
| **kwargs, |
| ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: |
| |
| llm_cfg = AutoConfig.from_pretrained(model_name_or_path) |
| llm_cfg._attn_implementation = attn_implementation |
| llm_cfg.model_max_length = model_max_length |
| if model_max_length is not None: |
| context_length_extension(llm_cfg) |
|
|
| |
| quantization_restore_from_checkpoint = False |
|
|
| if quantization_restore_from_checkpoint: |
| fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None) |
|
|
| llm = AutoModelForCausalLM.from_pretrained( |
| fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| ) |
| else: |
| if is_deepspeed_zero3_enabled(): |
| |
| kwargs.pop("device_map") |
| llm = AutoModelForCausalLM.from_pretrained( |
| model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
| ) |
| |
|
|
| |
| llm_path = model_name_or_path |
| if not has_tokenizer(llm_path): |
| llm_path = osp.join(llm_path, "llm") |
| if not has_tokenizer(llm_path): |
| raise ValueError(f"Cannot find tokenizer in {llm_path}.") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False) |
| if model_max_length is not None: |
| tokenizer.model_max_length = model_max_length |
|
|
| |
| if getattr(config, "chat_template", None) is not None: |
| print(f"Using chat template: {config.chat_template}") |
| fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja") |
| if not os.path.exists(fpath): |
| fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja") |
| with open(fpath) as fd: |
| chat_template = fd.read() |
| tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "") |
|
|
| |
| tokenizer.stop_tokens = infer_stop_tokens(tokenizer) |
| tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens) |
|
|
| |
| tokenizer.media_tokens = MEDIA_TOKENS |
| tokenizer.media_token_ids = {} |
| for name, token in MEDIA_TOKENS.items(): |
| tokenizer.add_tokens([token], special_tokens=True) |
| tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token) |
|
|
| config.hidden_size = llm.config.hidden_size |
| return llm, tokenizer |
|
|