Skip to content

Latest commit

 

History

History
80 lines (61 loc) · 6.1 KB

File metadata and controls

80 lines (61 loc) · 6.1 KB

Example transformer models (decoder-only LLMs)

Here we provide a list of popular decoder-only LLMs composed via the transformer building blocks from this library. The main purpose is to demonstrate how to construct a new PyTorch LLM model from scratch using the AI Edge Torch Generative API, and convert it to TFLite format for on-device inference.

Gemma

Gemma is Google's open-source LLM. The model has both a 2B and 7B versions. See the model's HuggingFace page. The example we provide is Gemma 2B, and the checkpoint for the model can be found here.

TinyLlama

TinyLlama is a popular OSS smaller version of Meta's Llama2 model, with only 1.1B parameters. HuggingFace checkpoint.

Microsoft Phi-2

Microsoft Phi-2 is also a decoder-only LLM with 2.7B parameters, see details on HuggingFace.

Overall workflow

To support a new LLM with the Edge Generative API, we need to go through the process of model (re)authoring, checkpoint mapping/loading, model quantization (via PT2E), model conversion to flatbuffer schema, model quality evaluation, benchmarking and on-device inference pipeline authoring.

Model (re)authoring

Model (re)authoring refers to the process of a few things:

  1. Understanding the overall model architecture (encoder-decoder, decoder-only etc).
  2. Compose the model using ai_edge_torch provided transformer building blocks. For each of the example models, we have a model definition file (e.g. tiny_llama/tiny_llama.py) where a nn.Module is defined, with its layers and a forward function. There is also a get_model_config function which returns a ModelConfig instance with hyper-parameters such as embedding size, layer count etc. Finally, there is a define_and_run function which builds the model instance, and runs the forward pass with a few sample inputs.

Here we use TinyLlama as an example to walk you through the authoring steps.

Define model's structure

class TinyLLamma(nn.Module):
def __init__(self, config: cfg.ModelConfig):
super().__init__()
self.config = config
# Construct model layers.
self.lm_head = nn.Linear(
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
)
self.tok_embedding = nn.Embedding(
config.vocab_size, config.embedding_dim, padding_idx=0
)
self.transformer_blocks = nn.ModuleList(
TransformerBlock(config) for _ in range(config.num_layers)
)
self.final_norm = builder.build_norm(
config.embedding_dim,
config.final_norm_config,
)
self.rope_cache = attn_utils.build_rope_cache(
size=config.kv_cache_max,
dim=int(config.attn_config.rotary_percentage * config.head_dim),
base=10_000,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
)
self.mask_cache = attn_utils.build_causal_mask_cache(
size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
)
self.config = config

Define model's forward function

# The model's forward function takes in additional k/v cache tensors
# and returns the updated k/v cache tensors to the caller.
# This can be eliminated if we handle k/v cache updates inside the model itself.
@torch.inference_mode
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
B, T = idx.size()
assert (
self.config.max_seq_len >= T
), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
cos, sin = self.rope_cache
cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
mask = mask[:, :, :, : self.config.kv_cache_max]
# forward the model itself
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
for i, block in enumerate(self.transformer_blocks):
x = block(x, (cos, sin), mask, input_pos)
x = self.final_norm(x)
res = self.lm_head(x) # (b, t, vocab_size)
return res

Now, you will have an nn.Module named TinyLlama, the next step is to restore the weights from orginal checkpoint into the new model.

Checkpoint mapping/loading

After the model is defined, we need to load the original trained weights to the new model. This is needed because the state_dict of the new model will be different from the original model's state_dict. There are helper functions in place to simplify the state_dict mapping process (utilities/loader.py). The user needs to provide a layer name tempelate (TensorNames) for the source model. This tempelate is then used to create an updated state_dict that works with the mapped model. The tensor map includes the following fields:

@dataclass
class TensorNames:
attn_query_proj: str
attn_key_proj: str
attn_value_proj: str
attn_output_proj: str
ff_up_proj: str
ff_down_proj: str
ff_gate_proj: str = None
pre_attn_norm: str = None
pre_ff_norm: str = None
embedding: str = None
final_norm: str = None
lm_head: str = None

The fields that have a default value of None are optional and should only be populated if they are relevant to the model architecture. For TinyLlama, we will define the following TENSOR_NAMES:

TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
ff_gate_proj="model.layers.{}.mlp.gate_proj",
attn_query_proj="model.layers.{}.self_attn.q_proj",
attn_key_proj="model.layers.{}.self_attn.k_proj",
attn_value_proj="model.layers.{}.self_attn.v_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
pre_ff_norm="model.layers.{}.post_attention_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
lm_head="lm_head",
)

With the TensorNames defined, a user can simply use the loading utils to load an instance of the mapped model. For instance:

model = MappedModel(config)
loader = loading_utils.ModelLoader("path_to_checkpoint", TENSOR_NAMES)
loader.load(model)

Currently, ModelLoader supports PyTorch state dictionary and SafeTensors checkpoints. We recommend testing the mapped model against your reference implementation using a few input samples before proceeding to the conversion step.

Model conversion

In this step, we use the ai_edge_torch's standard multi-signature conversion API to convert PyTorch nn.Module to a single TFLite flatbuffer for on-device execution. For example, in tiny_llama/convert_to_tflite.py, we use this python code to convert the TinyLLama model to a multi-signature TFLite model:

def convert_tiny_llama_to_tflite(
checkpoint_path: str,
prefill_seq_len: int = 512,
kv_cache_max_len: int = 1024,
quantize: bool = True,
):
"""An example method for converting TinyLlama model to multi-signature
tflite model.
Args:
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
Defaults to 512.
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
including both prefill and decode. Defaults to 1024.
quantize (bool, optional): Whether the model should be quanized.
Defaults to True.
"""
pytorch_model = tiny_llama.build_model(
checkpoint_path, kv_cache_max_len=kv_cache_max_len
)
# Tensors used to trace the model graph during conversion.
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
prefill_input_pos = torch.arange(0, prefill_seq_len)
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)
quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
)
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
.convert(quant_config=quant_config)
)
edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

Once converted, you will get a .tflite model which will be ready for on-device execution. Note that the .tflite model generated uses static shapes. Inside the generated .tflite model, there will be two signatures defined (two entrypoints to the model):

  1. prefill: taking 2 tensor inputs prefill_tokens, prefill_input_pos. With shape (BATCH_SIZE, PREFILL_SEQ_LEN) and (PREFILL_SEQ_LEN).
  2. decode: taking 2 tensor inputs decode_token, decode_input_pos. With shape (1, 1) and (1). To learn more about TFLite signatures, please refer to this article.

Model quantization

To apply quantization, we need to create a configuration that fully expresses how the model should be quantized. This configuration is then passed into conversion, generating a quantized model.

quantize/quant_recipes.py contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from generative/examples/quantize/example.py.

quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
    model, (tokens, input_pos), quant_config=quant_config
)

Once converted, you will get a quantized .tflite model which will be ready for on-device execution.