The key is to re-format the data from a traditional machine translation dataset that splits the source and target text and piece them up together in a format that the model expects.
For the Mistral 7B model specifically, it usually expects:
- each row of data would be encapsulated between
<s> and
where
- the input source sentence would be embedded between the
[INST] ... [/INST]
- the output target sentence would be after the
[/INST]
symbol
- any pre-data prompts before the
[INST] ... [/INST]
E.g. if we want to use a translation prompt as such "Translate English to German:",
valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, split="dev")
def preprocess_func(row):
return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}
valid_dataset = valid_data.map(preprocess_func)
valid_dataset[42]
[out]:
{'id': 43,
'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
'domain': 'wikinews',
'topic': 'disaster',
'has_image': 0,
'has_hyperlink': 0,
'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.',
'text': 'Translate from English to German: <s>[INST] The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say. [INST] Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht. </s>'}
Then the normal fine-tuning Mistral-7b scripts could just read the text
key in the dataset, e.g.
Requires
!pip install -U transformers sentencepiece datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U bitsandbytes
!pip install -U peft
!pip install -U trl
And if you are in a Jupyter environment, you'll need to reset the kernel after installing accelerate, so:
import os
os._exit(00)
Then:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "mistral_7b_flores_dev_en_de"
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token
valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False,
split="dev")
test_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False,
split="devtest")
def preprocess_func(row):
return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}
valid_dataset = valid_data.map(preprocess_func)
test_dataset = test_data.map(preprocess_func)
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)
training_arguments = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=25,
logging_steps=25,
learning_rate=2e-4,
weight_decay=0.001,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to=None
)
trainer = SFTTrainer(
model=model,
train_dataset=valid_dataset,
peft_config=peft_config,
max_seq_length= None,
dataset_text_field="text",
tokenizer=tokenizer,
args=training_arguments,
packing= False,
)
trainer.train()