4

There's a lot of tutorials online that uses raw text affix with arcane syntax to indicate document boundary and accessed through Huggingface datasets.Dataset object through the text key. E.g.

from datasets import load_dataset

dataset_name = "mlabonne/guanaco-llama2-1k"

dataset = load_dataset(dataset_name, split="train")
dataset["text"][42]

[out]:

<s>[INST] ¿Cuáles son los actuales presidentes de la región de Sur América? Enumérelos en una lista con su respectivo país. [/INST] A fecha del 13 de febrero de 2023, estos son los presidentes de los países de Sudamérica, según Wikipedia:
-Argentina: Alberto Fernández
-Bolivia: Luis Arce
-Brasil: Luiz Inácio Lula da Silva
-Chile: Gabriel Boric
-Colombia: Gustavo Petro
-Ecuador: Guillermo Lasso
-Paraguay: Mario Abdo Benítez
-Perú: Dina Boluarte
-Uruguay: Luis Lacalle Pou
-Venezuela: Nicolás Maduro
-Guyana: Irfaan Ali
-Surinam: Chan Santokhi
-Trinidad y Tobago: Paula-Mae Weekes </s>

But machine translation datasets are usually structured in 2 parts, source and target text with sentence_eng_Latn and sentence_deu_Latn keys, e.g.


valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")
valid_data[42]

[out]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.'}

How to fine-tune a Mistral-7b model for the machine translation task?

1 Answer 1

5

The key is to re-format the data from a traditional machine translation dataset that splits the source and target text and piece them up together in a format that the model expects.

For the Mistral 7B model specifically, it usually expects:

  • each row of data would be encapsulated between <s> and where
    • the input source sentence would be embedded between the [INST] ... [/INST]
    • the output target sentence would be after the [/INST] symbol
  • any pre-data prompts before the [INST] ... [/INST]

E.g. if we want to use a translation prompt as such "Translate English to German:",

valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, split="dev")

def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}

valid_dataset = valid_data.map(preprocess_func)

valid_dataset[42]

[out]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.',
 'text': 'Translate from English to German: <s>[INST] The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say. [INST] Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht. </s>'}

Then the normal fine-tuning Mistral-7b scripts could just read the text key in the dataset, e.g.

Requires

!pip install -U transformers sentencepiece datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U bitsandbytes

!pip install -U peft
!pip install -U trl

And if you are in a Jupyter environment, you'll need to reset the kernel after installing accelerate, so:

import os
os._exit(00)

Then:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer


base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "mistral_7b_flores_dev_en_de"


bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
)
model.config.use_cache = False 
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()



tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token



valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")

test_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="devtest")



def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}


valid_dataset = valid_data.map(preprocess_func)
test_dataset = test_data.map(preprocess_func)



model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)


training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to=None
)


trainer = SFTTrainer(
    model=model,
    train_dataset=valid_dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

trainer.train()
3
  • Shouldn't "Translate from English to German:" be inside "<s>[INST]" as well?
    – r000bin
    Commented Mar 18 at 19:51
  • Good question! Thus the bounty. I don't really know too but I've seen both versions floating around where the pre-source/target prompt is outside of the <s>...</s> and sometimes inside. *shrugs...
    – alvas
    Commented Mar 19 at 1:30
  • 1
    TL;DR it depends on your usecase - Ah alright. I'm heavily prompting with mixtral and can say that inside works better for me. I saw a lot of discussions around that topic and there is no clear answer. Someone on reddit compared different prompt format on mixtral and as mixtral is built on top of mistral I'd assume this is also true for mistral: reddit.com/r/LocalLLaMA/comments/18ljvxb/…
    – r000bin
    Commented Mar 19 at 7:47

Not the answer you're looking for? Browse other questions tagged or ask your own question.