Fine-tuning Mistral 7B for Function Calling (Advanced)¶

by Grayson Adkins, updated February 6, 2024

This notebook demonstrates fine-tuning of an open-source model (Mistral 7B). It leverages the transformers and PEFT libraries from Hugging Face for quantization, LoRA, and training, and a custom-built data set for function calling.

Open In Colab

Overview¶

This notebook builds on the basic fine-tuning example by introducing the following innovations:

  • A prompt loss mask to focus the model's attention and encourage structured responses
  • A stop sequence after responses to encourage conciseness
  • A small, but high-quality function-calling data set to fine-tune the model for responding with functions and query parametes
  • A chat template template customized for fine-tuning

Notes

  • The example data set used in this notebook is for function calling but these techniques work for any Q&A data set.
  • The system prompts have been omitted, but you can add them back if you wish to fine-tune for a certain system message.
  • While you can run this notebook on an NVIDIA T4 GPU (free on Google Colab), I recommend using an A6000 or A100 to get better results. These larger machines are available in Google Colab Pro or at RunPod, Lambda Labs, et al.)

Recommended Reading

  • Tokenization
  • Low-Rank Adaption (LoRA)

Attribution

  • Some functions in this notebook were adapted from Trelis examples with modifications.
  • A related, but simpler training notebook by Hugging Face is available here.

Why should you read this notebook?¶

You want to learn how to:

  • Fine-tune an open-source model for structured and concise responses
  • Fine-tune using just a single GPU
  • Learn how to use prompt loss-masks for controlling model attention

Source Code¶

The source code for this notebook is available in the ai-cookbook repo on my GitHub.

Key Concepts¶

Typically, a model is graded on its prediction of the next token in both the question and answer. However, our primary goal is for the model to give thoughtful attention to the question, while its performance should be graded based soley on how it predicts the answer; this is achieved by attention and loss masks, respectively.

Attention mask¶

Attention is a mechanism used during training to instruct the model on what parts of the input text (e.g., a question or a context) it should pay attention to. It helps the model focus on the relevant information and ignore irrelevant portions of the input. An attention mask is simply a sequence of 1s and 0s that is multiplied by the input sequence IDs—resulting in a new input sequence where irrelevant tokens are zeroed out (i.e. masked).

{'input_ids': tensor([[9204, 18, 3763, 456, 222, 13563, 22580, 584]]),
 'attention_mask': tensor([[1, 1, 1, 0, 1, 1, 0, 1]])}
{'result': tensor([[9204, 18, 3763, 0, 222, 13563, 0, 584]])}

As an example, we usually want to make sure that PAD tokens are masked.

Loss mask¶

A loss mask is used to calculate the loss or error during training. It specifies which parts of the model's output should be considered when computing the loss. When training a model, we take the losses and multiply them by the loss mask.

To improve model performance, in this notebook we mask the losses associated with prompt to ensure the model focuses on answering the question, not predicting the next sequence of tokens in the question.

Stop sequence¶

Have you every noticed how verbose some models are? By fine-tuning with stop sequence, such as USER:, we can teach the model to be more concise:

{
  prompt: "Where is the stock price of Apple?\n\nBOT:",
  completion: "Apple stock price is $188.04.\n\nUSER: ",
},
...

Mistral 7B Instruct¶

Mistral 7B Instruct is an instruction fine-tuned version of Mistral 7B available on Hugging Face.

Per the HF model card:

Instruction format¶

The template used to build a prompt for the Instruct model is defined as follows:

<s> [INST] Instruction [/INST] Model answer</s> [INST] Follow-up instruction [/INST]

Model architecture¶

This instruction model is based on Mistral-7B-v0.1, a transformer model with the following architecture choices:

  • Grouped-Query Attention
  • Sliding-Window Attention
  • Byte-fallback BPE tokenizer

Setup¶

  • You can run QLoRa training on a free Google Colab Notebook for 7B models.
  • To configure a GPU on Google Colab, navigate to Connect to a new runtime and select T4 High-RAM.
  • In the code below, be sure to comment out flash attention when loading the model since flash is only supported on newer Ampere GPUs (A6000, A100, H100, etc.) and not in T4s.
  • (Optional) Uncomment the code to mount Google Drive to download the model to your Google Drive. This will reduce total start time.
  • If you don't already have one, create a Hugging Face account and create an Access Token called "Notebooks" or similar with write permissions.
In [ ]:
# # Print GPU info
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
#   print('Not connected to a GPU')
# else:
#   print(gpu_info)
In [ ]:
# # Print VRAM
# from psutil import virtual_memory
# ram_gb = virtual_memory().total / 1e9
# print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

# if ram_gb < 20:
#   print('Not using a high-RAM runtime')
# else:
#   print('You are using a high-RAM runtime!')

Install¶

In [50]:
# Authenticate to Hugging Face to pull and push models
!pip install huggingface_hub -q
from huggingface_hub import notebook_login

notebook_login()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
In [3]:
# (Optional) Configure Weights & Biases (wandb) to track training runs
!pip install wandb -q -U
import wandb
wandb.login()
wandb: Currently logged in as: gadkins. Use `wandb login --relogin` to force relogin
Out[3]:
True
In [4]:
# base_model = "./Mistral-7B-Instruct-v0.1-function-calling-v2"
# base_model = "meta-llama/Llama-2-7b-hf"
# base_model = "meta-llama/Llama-2-7b-chat-hf"
# base_model = "meta-llama/Llama-2-13b-chat-hf"
# base_model = "codellama/CodeLlama-34b-Instruct-hf"
# base_model = "meta-llama/Llama-2-70b-chat-hf"
base_model = "mistralai/Mistral-7B-Instruct-v0.1"
# base_model = "deepseek-ai/deepseek-coder-1.3b-instruct"
# base_model = "deepseek-ai/deepseek-coder-6.7b-instruct"
# base_model = "deepseek-ai/deepseek-coder-33b-instruct"
# base_model = "larryvrh/Yi-34B-200K-Llamafied"
# base_model = "./Yi-34B-200K-Llamafied-chat-SFT"
# base_model = "openchat/openchat_3.5"
# base_model = "SUSTech/SUS-Chat-34B"
# base_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# base_model = "microsoft/phi-2"

cache_dir = '' # Initialise the cache_dir to null.
# (Optionally, you can set Google Drive as the cache_dir below)
In [5]:
# stable versions

!python -m pip install --upgrade pip
!pip install -U -q transformers
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U accelerate
!pip install -q datasets
!pip install -q -U scipy
!pip install -q -U trl
!pip install -U flash-attn -q
Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (23.1.2)
Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 31.6 MB/s eta 0:00:00
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.4/129.4 kB 3.4 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.4/8.4 MB 95.7 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 105.0/105.0 MB 19.9 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 183.4/183.4 kB 6.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 270.9/270.9 kB 13.9 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 507.1/507.1 kB 17.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 10.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 12.7 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.4/60.4 kB 2.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.4/38.4 MB 55.6 MB/s eta 0:00:00
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
lida 0.0.10 requires fastapi, which is not installed.
lida 0.0.10 requires kaleido, which is not installed.
lida 0.0.10 requires python-multipart, which is not installed.
lida 0.0.10 requires uvicorn, which is not installed.
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.9/150.9 kB 6.6 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.8/79.8 kB 7.3 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 31.6 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 3.8 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 307.2/307.2 kB 22.6 MB/s eta 0:00:00
  Building wheel for flash-attn (setup.py) ... done
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, AutoConfig
import transformers
import torch
from torch.utils.data import DataLoader, Dataset

If using Google Colab + Google Drive¶

In [7]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
In [8]:
import os
cache_dir = "/content/drive/My Drive/huggingface_cache"
os.makedirs(cache_dir, exist_ok=True) # Ensure the directory exists
In [9]:
# https://stackoverflow.com/questions/56081324/why-are-google-colab-shell-commands-not-working
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

Load model¶

Note about quantization:

In this section, we have the option to load a quantized version of the model (see the QLoRA notebook for quantization details) to reduce the computation requirements such that it will fit on a free T4 GPU in Google Colab. If cost is most important to you, then I recommend this option—just uncomment the quantization_config option below.

However, I've observed slightly better performance in function-calling fine-tunes when using models at full precision. Note that if you use full precision, you'll need a larger GPU such as an A100. If you're using Google Colab, you'll need to upgrade to Pro or use another service like RunPod or Lambda Labs (which are a bit cheaper).

In [10]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Instantiate model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    # quantization_config=bnb_config, # Uncomment to use quantized version
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    #attn_implementation="flash_attention_2", # Supported in Ampere GPUs or newer
    cache_dir=cache_dir
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Tokenization¶

In [11]:
# # Required for certain tokenizers like Yi
# !pip install sentencepiece -q -U
In [12]:
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=cache_dir, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir=cache_dir)
In [13]:
print("EOS token:", tokenizer.eos_token)
print("EOS token id:", tokenizer.eos_token_id)
EOS token: </s>
EOS token id: 2
In [14]:
# If pad token is None, we'll need to set one in the next section
print("Pad token: ", tokenizer.pad_token)
print("Pad token ID: ", tokenizer.pad_token_id)
Pad token:  None
Pad token ID:  None
In [15]:
# Padding to the right (i.e. after) the prompt and response has better results
tokenizer.padding_side='right'
print(tokenizer)
LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-Instruct-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

Set pad token if none exists¶

Some models already have a pad token set. You can see whether they do or don't from the tokenizer print statement above. If that's the case, then you don't need to do anything further.

If no pad token exists, then you have three options:

Options

  1. Use an existing token in the vocab as the pad token, instead of introducing a new one. This is to avoid having to create a whole new instance of the tokenizer with a new pad token. For this option, we use the existing <unk> token (i.e. "unknown") to pad—note that this assumes the <unk> token exists in the vocab.
  2. The next option is to use the EOS token.
  3. The last option is to add a pad token. This expands the size of the model embeddings so that's it's no longer a factor of 16, which can slow down inference. So this is the last option.
In [16]:
## (Recommended) OPTION 1
# If <unk> is in the tokenizer, set the pad token to <unk>
# Else, set pad token to EOS token
if '<unk>' in tokenizer.get_vocab():
    print('Found \'<unk>\' token in tokenizer. Using \'<unk>\' for pad.')
    # Set the pad token
    tokenizer.pad_token = '<unk>'
else:
    print(f'Using EOS token, \'{tokenizer.eos_token}\', for padding')
    tokenizer.pad_token = tokenizer.eos_token

## OPTION 2
# # Check if the pad token is already in the tokenizer vocabulary
# if '<pad>' not in tokenizer.get_vocab():
#     print('pad token not in the tokenizer')

#     # Add the pad token
#     tokenizer.add_tokens(['<pad>'])

# # Set the pad token
# tokenizer.pad_token = '<pad>'

# # Resize token embeddings
# model.resize_token_embeddings(len(tokenizer))
Found '<unk>' token in tokenizer. Using '<unk>' for pad.
In [17]:
# Update pad token id in model and its config
model.pad_token_id = tokenizer.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id

# Check if they are equal
assert model.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID \
does not match the tokenizer's pad token ID!"

# Print the pad token ids
print('Tokenizer pad token ID:', tokenizer.pad_token_id)
print('Model pad token ID:', model.pad_token_id)
print('Model config pad token ID:', model.config.pad_token_id)
print('Number of tokens now in tokenizer:', len(tokenizer))
Tokenizer pad token ID: 0
Model pad token ID: 0
Model config pad token ID: 0
Number of tokens now in tokenizer: 32000
In [18]:
# Print model configuration
print(model.config)
MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pad_token_id": 0,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.37.2",
  "use_cache": true,
  "vocab_size": 32000
}

In [19]:
# Sample string
# sample_string = ['hello [/INST]', 'my good friend</s>']
sample_string = ['Caio!']

# Tokenize the stringified JSON object
encoded_sample = tokenizer(sample_string, truncation=True, padding=True, max_length=1024, return_tensors='pt', add_special_tokens=True)

BOS_token_id = tokenizer.bos_token_id
EOS_token_id = tokenizer.eos_token_id
BOS_token = tokenizer.decode([BOS_token_id])
EOS_token = tokenizer.decode([EOS_token_id])

print(f"Beginning of the sequence: {sample_string[0]} (BOS token: {BOS_token}, id: {BOS_token_id})")
print(f"End of the sequence: {sample_string[-1]} (EOS token: {EOS_token}, id: {EOS_token_id})")

token_count = len(encoded_sample)

print(f"Tokens in the string: {token_count}")
print(f"Token IDs: {encoded_sample}")

# Decode the input_ids
decoded_sample = tokenizer.decode(encoded_sample['input_ids'][0], skip_special_tokens=False)

# Print the decoded string
print(f"Decoded string: {decoded_sample}")

# Print the attention mask
print(f"Attention mask: {encoded_sample['attention_mask']}")
Beginning of the sequence: Caio! (BOS token: <s>, id: 1)
End of the sequence: Caio! (EOS token: </s>, id: 2)
Tokens in the string: 2
Token IDs: {'input_ids': tensor([[    1, 11013,   691, 28808]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
Decoded string: <s> Caio!
Attention mask: tensor([[1, 1, 1, 1]])

Set up LoRa¶

In [20]:
# # If loading with adapters
# # Note: Instead, it's often faster to download base model then add adapters
# from peft import PeftModel

# # adapter_model = f'{base_model}' + '-function-calling-adapters' # replace

# # Load peft model with adapters
# model = PeftModel.from_pretrained(
#     model,
#     adapter_model,
# )
In [21]:
# To reduce VRAM usage (supported by most models)
model.gradient_checkpointing_enable()

# If using quantized model
# from peft import prepare_model_for_kbit_training
# model = prepare_model_for_kbit_training(model)
In [22]:
# Print list of modules
print(model.state_dict().keys())
odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.28.self_attn.q_proj.weight', 'model.layers.28.self_attn.k_proj.weight', 'model.layers.28.self_attn.v_proj.weight', 'model.layers.28.self_attn.o_proj.weight', 'model.layers.28.mlp.gate_proj.weight', 'model.layers.28.mlp.up_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.29.self_attn.q_proj.weight', 'model.layers.29.self_attn.k_proj.weight', 'model.layers.29.self_attn.v_proj.weight', 'model.layers.29.self_attn.o_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.29.mlp.up_proj.weight', 'model.layers.29.mlp.down_proj.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.30.self_attn.q_proj.weight', 'model.layers.30.self_attn.k_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.30.self_attn.o_proj.weight', 'model.layers.30.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.31.self_attn.q_proj.weight', 'model.layers.31.self_attn.k_proj.weight', 'model.layers.31.self_attn.v_proj.weight', 'model.layers.31.self_attn.o_proj.weight', 'model.layers.31.mlp.gate_proj.weight', 'model.layers.31.mlp.up_proj.weight', 'model.layers.31.mlp.down_proj.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.norm.weight', 'lm_head.weight'])
In [23]:
print(model)
MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
In [24]:
# # If extending model context
# def set_added_trainable_params(model):
#     """
#     Sets the parameters with names containing "embed" or "norm" as trainable.
#     """
#     trainable_params_dict = {}

#     for name, param in model.named_parameters():
#         if "embed" in name or "norm" in name: #for most models
#         # if "ln" in name or "embd" in name: #for Phi-2
#             param.requires_grad_()
#             trainable_params_dict[name] = param

#     return trainable_params_dict

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable %: {100 * trainable_params / all_param}"
    )

from peft import LoraConfig, get_peft_model

# Initialize LoRA configuration
config = LoraConfig(
    # Lower rank results in smaller update matrices with fewer trainable params
    r=8, # Use 8 for models >=7B or larger, else 128
    lora_alpha=32,
    target_modules=[
    #     "Wqkv", #for Phi-2
    #     "fc1", #for Phi-2
    #     "fc2" #for Phi-2
      "self_attn.q_proj",
      "self_attn.k_proj",
      "self_attn.v_proj",
      "self_attn.o_proj",
      # "self_attn.rotary_emb.inv_freq",
      "mlp.gate_proj",
      "mlp.up_proj",
      "mlp.down_proj",
      # "input_layernorm.weight",
      # "post_attention_layernorm.weight",
      # "model.norm.weight",
      # "lm_head.weight"
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA to the model
model = get_peft_model(model, config)

# # Set added parameters with names containing "embed" or "norm" as trainable.
# # Recommended if you are extending an LLM's context window.
# set_added_trainable_params(model)

# Print out the number of trainable parameters
print_trainable_parameters(model)
trainable params: 20971520 || all params: 7262703616 || trainable %: 0.2887563792882719

Prepare data¶

Each function in the data set is stored as JSON in its own file. All functions follow OpenAI's metadata format.

JSON data format¶

{
    "type": "function",
    "function": {
        "name": "function_name",
        "description": "function description",
        "parameters": {
            "type": "object",
            "properties": {
                "property_1": {
                    "type": "property_type", //#e.g. string
                    "description": "property description"
                },
                "property_2": {
                    "type": "property_type", //#e.g. string
                    "description": "property description"
                }
            },
            "required": ["property_1","property_2"]
        }
    },
    "samplePromptResponsePairs": [
        {
            "prompt": "sample_prompt",
            "response": {
                "name": "generate_password",
                "arguments": {
                    "property_1": "property_value",
                    "property_2": "property_value"
                }
            }
        },
        ...
    ]
}
In [25]:
!pip install -q -U datasets
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [26]:
from datasets import load_dataset

# From Hugging Face Hub
data = load_dataset(
    "Trelis/function_calling_v3"
    )
Downloading readme:   0%|          | 0.00/8.93k [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/104k [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/7.83k [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/32.3k [00:00<?, ?B/s]
Downloading data:   0%|          | 0.00/11.6k [00:00<?, ?B/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating validation split: 0 examples [00:00, ? examples/s]
Generating test split: 0 examples [00:00, ? examples/s]
In [27]:
print(data)
DatasetDict({
    train: Dataset({
        features: ['functionList', 'userPrompt', 'assistantResponse'],
        num_rows: 66
    })
    validation: Dataset({
        features: ['functionList', 'userPrompt', 'assistantResponse'],
        num_rows: 19
    })
    test: Dataset({
        features: ['functionList', 'userPrompt', 'assistantResponse'],
        num_rows: 7
    })
})
In [28]:
class TextDataset(Dataset):
    def __init__(self, encodings, response_lengths, input_lengths):
        self.encodings = encodings
        self.response_lengths = response_lengths
        self.input_lengths = input_lengths

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}

        # Set labels to be the same as input_ids
        item["labels"] = item["input_ids"].clone()

        # Calculate the start and end positions of the response
        response_start_position = self.input_lengths[idx]
        response_end_position = self.input_lengths[idx] + self.response_lengths[idx]

        # Create a loss mask that covers only the response tokens
        item["loss_mask"] = torch.zeros_like(item["input_ids"])
        item["loss_mask"][response_start_position:response_end_position] = 1

        # Shift the loss mask to the left by one position
        shifted_loss_mask = torch.cat([item["loss_mask"][1:], torch.tensor([0])])
        item["loss_mask"] = shifted_loss_mask

        # Shift the labels to the left by one position
        item["labels"][:-1] = item["input_ids"][1:]

        # Replace the token after the response with an EOS token
        item["labels"][response_end_position - 1] = tokenizer.eos_token_id

        # Replace the token after the response with an 1 in the loss mask
        item["loss_mask"][response_end_position - 1] = 1

        return item

    def __len__(self):
        return len(self.encodings["input_ids"])
In [29]:
# Define the function start and end strings
# \n\n is added at the end during training to avoid different tokenizations of
# the E_INST string with whatever follows.
B_FUNC, E_FUNC = "You have access to the following functions. Use them if required:\n\n", "\n\n"

# Define the user prompt start and end strings
# B_INST, E_INST = "GPT4 Correct User: ", "<|end_of_turn|>GPT4 Correct Assistant:" # OpenChat style
B_INST, E_INST = "[INST] ", " [/INST]" # Llama 2 or Mistral style
# B_INST, E_INST = "Instruct:", "\nOutput:" # Phi 2
# B_INST, E_INST = "\n### Instruction:\n", "\n### Response:\n" # DeepSeek Coder style
# B_INST, E_INST = "Human: ", " Assistant:" # Yi style for function calling, no training space
# B_INST, E_INST = "### Human: ", "\n\n### Assistant: " # SUSChat
In [30]:
def prepare_dataset(dataset, tokenizer):
    # Create the formatted text with the correct roles for each part of the dialogue
    formatted_dataset = dataset.map(
        lambda x: {
            "input_text": "".join([
                f"{B_INST}{B_FUNC}{x['functionList'].strip()}{E_FUNC}",
                f"{x['userPrompt'].strip()}{E_INST}\n\n",
                f"{x['assistantResponse'].strip()}",  # append EOS token in TextData...
            ]),
            "response_text": "".join([
                f"{x['assistantResponse'].strip()}",  # append EOS token in TextData...
            ]),
        }
    )

    # Tokenize the datasets
    encodings = tokenizer([dialogue["input_text"] for dialogue in \
                           formatted_dataset], truncation=True, padding=True, \
                          max_length=1024, return_tensors='pt', \
                          add_special_tokens=True)

    # Tokenize the response one by one without padding and special tokens for
    # the purpose of calculating length
    response_lengths = [len(tokenizer.encode(dialogue["response_text"], \
                                             truncation=True, max_length=1024, \
                                             padding=False, \
                                             add_special_tokens=False)) \
                        for dialogue in formatted_dataset]

    # Tokenize the input one by one without padding and with the initial
    # special token for the purpose of calculating length
    total_lengths = [len(tokenizer.encode(dialogue["input_text"], \
                                          truncation=True, max_length=1024, \
                                          padding=False, \
                                          add_special_tokens=True)) \
                     for dialogue in formatted_dataset]
    input_lengths = [total_length - response_length \
                     for total_length, response_length in \
                     zip(total_lengths, response_lengths)]

    # Create TextDataset
    text_dataset = TextDataset(encodings, response_lengths, input_lengths)

    return text_dataset
In [31]:
# Apply function to your datasets
train_dataset = prepare_dataset(data['train'], tokenizer)
test_dataset = prepare_dataset(data['test'], tokenizer)
validation_dataset = prepare_dataset(data['validation'], tokenizer)
Map:   0%|          | 0/66 [00:00<?, ? examples/s]
Map:   0%|          | 0/7 [00:00<?, ? examples/s]
Map:   0%|          | 0/19 [00:00<?, ? examples/s]

Examine the datasets¶

In [32]:
# Print the number of items in the dataset
print(f"Number of samples in the dataset: {len(train_dataset)}")

# Get a sample item
sample_item = train_dataset[1]  # replace with the index of any sample

# Print the dimensions of the sample item
print(f"Dimensions of input_ids: {sample_item['input_ids'].shape}")
print(f"Dimensions of attention_mask: {sample_item['attention_mask'].shape}")
print(f"Dimensions of loss_mask: {sample_item['loss_mask'].shape}")
print(f"Dimensions of labels: {sample_item['labels'].shape}")

# Print some tokens from the start and end of the sample
num_tokens_to_print = 336  # replace with the number of tokens you want to print

print("\nTokens at the start of the sample:")
print(sample_item['input_ids'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][:num_tokens_to_print].tolist()))

print("\nLabels at the start of the sample:")
print(sample_item['labels'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][:num_tokens_to_print].tolist()))

print("Attention mask at the start of the sample:")
print(sample_item['attention_mask'][:num_tokens_to_print].tolist())

print("Loss mask at the start of the sample:")
print(sample_item['loss_mask'][:num_tokens_to_print].tolist())

print("\nTokens at the end of the sample:")
print(sample_item['input_ids'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][-num_tokens_to_print:].tolist()))

print("\nLabels at the end of the sample:")
print(sample_item['labels'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][-num_tokens_to_print:].tolist()))

print("Attention mask at the end of the sample:")
print(sample_item['attention_mask'][-num_tokens_to_print:].tolist())

print("Loss mask at the end of the sample:")
print(sample_item['loss_mask'][-num_tokens_to_print:].tolist())
Number of samples in the dataset: 66
Dimensions of input_ids: torch.Size([677])
Dimensions of attention_mask: torch.Size([677])
Dimensions of loss_mask: torch.Size([677])
Dimensions of labels: torch.Size([677])

Tokens at the start of the sample:
[1, 733, 16289, 28793, 995, 506, 2735, 298, 272, 2296, 5572, 28723, 5938, 706, 513, 3030, 28747, 13, 13, 28792, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 2360, 28730, 283, 28744, 449, 548, 13, 17422, 345, 6518, 1264, 345, 7009, 354, 3332, 10374, 356, 1010, 28814, 449, 28723, 6746, 938, 302, 5771, 28725, 3994, 304, 5457, 12765, 390, 7658, 298, 5175, 3471, 2373, 272, 5709, 9191, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 3385, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 3472, 5709, 1423, 28739, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 3385, 28739, 13, 1417, 28705, 4709, 13, 17422, 443, 13, 5390, 443, 13, 2287, 1630, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 527, 28730, 3022, 28730, 769, 1223, 548, 13, 17422, 345, 6518, 1264, 345, 1458, 272, 1868, 8086, 297, 264, 2078, 4723, 548, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 2733, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 2990, 304, 2939, 28725, 317, 28723, 28721, 28723, 22263, 28725, 11170, 28739, 13, 359, 2287, 1630, 13, 359, 2287, 345, 5306, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 25241, 466, 5028, 354, 272, 8086, 28723, 19641, 28747, 464, 28717, 1190, 3170, 647, 464, 28722, 18657, 12307, 21236, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 2733, 28739]
['<s>', '▁[', 'INST', ']', '▁You', '▁have', '▁access', '▁to', '▁the', '▁following', '▁functions', '.', '▁Use', '▁them', '▁if', '▁required', ':', '<0x0A>', '<0x0A>', '[', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'search', '_', 'ar', 'x', 'iv', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Search', '▁for', '▁research', '▁papers', '▁on', '▁Ar', 'X', 'iv', '.', '▁Make', '▁use', '▁of', '▁AND', ',', '▁OR', '▁and', '▁NOT', '▁operators', '▁as', '▁appropriate', '▁to', '▁join', '▁terms', '▁within', '▁the', '▁query', '.",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁search', '▁query', '▁string', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁]', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁},', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'get', '_', 'current', '_', 'we', 'ather', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Get', '▁the', '▁current', '▁weather', '▁in', '▁a', '▁given', '▁location', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁city', '▁and', '▁country', ',', '▁e', '.', 'g', '.', '▁Dublin', ',', '▁Ireland', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'unit', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Measure', 'ment', '▁unit', '▁for', '▁the', '▁weather', '.', '▁Options', ':', "▁'", 'c', 'els', 'ius', "',", "▁'", 'f', 'ahren', 'heit', '\'"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '"']

Labels at the start of the sample:
[733, 16289, 28793, 995, 506, 2735, 298, 272, 2296, 5572, 28723, 5938, 706, 513, 3030, 28747, 13, 13, 28792, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 2360, 28730, 283, 28744, 449, 548, 13, 17422, 345, 6518, 1264, 345, 7009, 354, 3332, 10374, 356, 1010, 28814, 449, 28723, 6746, 938, 302, 5771, 28725, 3994, 304, 5457, 12765, 390, 7658, 298, 5175, 3471, 2373, 272, 5709, 9191, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 3385, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 3472, 5709, 1423, 28739, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 3385, 28739, 13, 1417, 28705, 4709, 13, 17422, 443, 13, 5390, 443, 13, 2287, 1630, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 527, 28730, 3022, 28730, 769, 1223, 548, 13, 17422, 345, 6518, 1264, 345, 1458, 272, 1868, 8086, 297, 264, 2078, 4723, 548, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 2733, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 2990, 304, 2939, 28725, 317, 28723, 28721, 28723, 22263, 28725, 11170, 28739, 13, 359, 2287, 1630, 13, 359, 2287, 345, 5306, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 25241, 466, 5028, 354, 272, 8086, 28723, 19641, 28747, 464, 28717, 1190, 3170, 647, 464, 28722, 18657, 12307, 21236, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 2733, 28739, 13]
['▁[', 'INST', ']', '▁You', '▁have', '▁access', '▁to', '▁the', '▁following', '▁functions', '.', '▁Use', '▁them', '▁if', '▁required', ':', '<0x0A>', '<0x0A>', '[', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'search', '_', 'ar', 'x', 'iv', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Search', '▁for', '▁research', '▁papers', '▁on', '▁Ar', 'X', 'iv', '.', '▁Make', '▁use', '▁of', '▁AND', ',', '▁OR', '▁and', '▁NOT', '▁operators', '▁as', '▁appropriate', '▁to', '▁join', '▁terms', '▁within', '▁the', '▁query', '.",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁search', '▁query', '▁string', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁]', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁},', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'get', '_', 'current', '_', 'we', 'ather', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Get', '▁the', '▁current', '▁weather', '▁in', '▁a', '▁given', '▁location', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁city', '▁and', '▁country', ',', '▁e', '.', 'g', '.', '▁Dublin', ',', '▁Ireland', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'unit', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Measure', 'ment', '▁unit', '▁for', '▁the', '▁weather', '.', '▁Options', ':', "▁'", 'c', 'els', 'ius', "',", "▁'", 'f', 'ahren', 'heit', '\'"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '"', '<0x0A>']
Attention mask at the start of the sample:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Loss mask at the start of the sample:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Tokens at the end of the sample:
[17422, 443, 13, 5390, 443, 13, 2287, 443, 13, 28793, 13, 13, 9607, 863, 19808, 15648, 2172, 4296, 28804, 733, 28748, 16289, 28793, 13, 13, 1014, 17008, 5016, 302, 19808, 15648, 349, 521, 6206, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁}', '<0x0A>', ']', '<0x0A>', '<0x0A>', 'Where', '▁did', '▁fortune', '▁cookies', '▁orig', 'inate', '?', '▁[', '/', 'INST', ']', '<0x0A>', '<0x0A>', 'The', '▁precise', '▁origin', '▁of', '▁fortune', '▁cookies', '▁is', '▁un', 'clear', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']

Labels at the end of the sample:
[443, 13, 5390, 443, 13, 2287, 443, 13, 28793, 13, 13, 9607, 863, 19808, 15648, 2172, 4296, 28804, 733, 28748, 16289, 28793, 13, 13, 1014, 17008, 5016, 302, 19808, 15648, 349, 521, 6206, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁}', '<0x0A>', ']', '<0x0A>', '<0x0A>', 'Where', '▁did', '▁fortune', '▁cookies', '▁orig', 'inate', '?', '▁[', '/', 'INST', ']', '<0x0A>', '<0x0A>', 'The', '▁precise', '▁origin', '▁of', '▁fortune', '▁cookies', '▁is', '▁un', 'clear', '</s>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']
Attention mask at the end of the sample:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Loss mask at the end of the sample:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Generate a sample¶

In [33]:
import textwrap
wrapper = textwrap.TextWrapper(width=80)
In [34]:
import re  # import regular expressions module
In [35]:
import gc  # import Python's garbage collection module

def generate(index,data_split="test"):

    functionList = data[data_split][index]['functionList']
    user_prompt = data[data_split][index]['userPrompt']
    correct_answer = data[data_split][index]['assistantResponse']

    # model.config.use_cache = True    # Unsure this is needed

    # Format your prompt template
    prompt = f"{B_INST}{B_FUNC}{functionList.strip()}\
    {E_FUNC}{user_prompt.strip()}{E_INST}\n\n"

    print(f"Using the {data_split} data split.\n\nPrompt:")
    print(prompt)

    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    # print(f'model is on: {next(model.parameters()).device}')  # Debug
    # print(f'input_ids is on: {inputs["input_ids"].device}')  # Debug

    output = model.generate(**inputs,
                            max_new_tokens=200,
                            # do_sample=False,
                            pad_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer.eos_token_id,
                            # temperature=0.01,
                            # top_k=0
                           )

    print()

    # Subtract the length of input_ids from output to get only the model response
    output_text = tokenizer.decode(output[0, len(inputs.input_ids[0]):], \
                                   skip_special_tokens=False)
    output_text = re.sub('\n+', '\n', output_text)  # remove excessive newlines

    print("**Generated Assistant Response:**")
    print(output_text)

    print()

    print("**Correct Assistant Response:**")
    print(correct_answer)

    print()

    # Clear GPU cache and run garbage collection
    torch.cuda.empty_cache()  # Clear GPU cache
    gc.collect()  # Run garbage collection

Run validation before fine-tuning¶

Before fine-tuning the model, let's take a look at how the model responds to the validation set prompts.

Notice that the model should respond with a function name and query params, yet it tries writing code itself and includes lots of extra words.

In [41]:
# Run validation before training
for index in range(len(test_dataset)):
    print(f'---Running index {index}---')
    generate(index, "test")
---Running index 0---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Get the names of the five largest stocks by market cap [/INST]



**Generated Assistant Response:**
To get the names of the five largest stocks by market cap, we can use the `get_big_stocks` function with the following parameters:
* `number`: 5 (to get the names of the five largest stocks)
* `region`: "World" (to consider all stocks in the world)
Here's the code to do that:
```python
import requests
def get_big_stocks(number, region):
    url = f"https://api.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={region}&apikey=YOUR_API_KEY"
    response = requests.get(url)
    data = response.json()
    market_cap = {}
    for symbol in data["Time Series (Daily)"].keys():
        market_cap[symbol] = data["Time Series (

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "5"
    }
}

---Running index 1---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

Get the names of the five largest stocks in the US by market cap [/INST]



**Generated Assistant Response:**
Here's the code to get the names of the five largest stocks in the US by market cap:
```python
import requests
def get_big_stocks(params):
    url = "https://api.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0])
    response = requests.get(url)
    data = response.json()
    return data["Time Series (Daily)"]
def get_stock_price(params):
    url = "https://api.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0])
    response = requests.get(url)
    data = response

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "5",
        "region": "US"
    }
}

---Running index 2---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

Get the stock prices of the ten largest stocks in the world [/INST]



**Generated Assistant Response:**
To get the stock prices of the ten largest stocks in the world, you can use the following steps:
1. Get the names of the ten largest stocks in the world using the `get_big_stocks` function.
2. Use the `get_stock_price` function to get the stock prices of the names obtained in step 1.
Here's the code to implement this:
```python
import requests
def get_big_stocks(params):
    url = "https://api.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0])
    response = requests.get(url)
    data = response.json()
    return data["Time Series (Daily)"]
def get_stock_price(

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "10",
        "region": "World"
    }
}

---Running index 3---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Get the price of Apple's stock [/INST]



**Generated Assistant Response:**
To get the price of Apple's stock, you can use the `get_stock_price` function. Here's an example of how you can use it:
```
const stockNames = ["AAPL"];
const stockPrices = get_stock_price({ names: stockNames });
console.log(stockPrices);
```
This will output the current price of Apple's stock. Note that you will need to have an API key to use this function.</s>

**Correct Assistant Response:**
{
    "name": "get_stock_price",
    "arguments": {
        "names": [
            "Apple"
        ]
    }
}

---Running index 4---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Greetings! [/INST]



**Generated Assistant Response:**
Hello! How can I assist you today?</s>

**Correct Assistant Response:**
Greetings to you too!

---Running index 5---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

shop [/INST]



**Generated Assistant Response:**
To get the names of the largest 25 stocks by market cap in the US region, you can use the following code:
```
const bigStocks = await get_big_stocks({
    number: 25,
    region: "US"
});
console.log(bigStocks);
```
To get the stock price of an array of stocks, you can use the following code:
```
const stockPrices = await get_stock_price({
    names: ["AAPL", "GOOG", "AMZN"]
});
console.log(stockPrices);
```
Note that the `get_stock_price` function requires an array of stock names as input, and returns an object with the stock prices as keys and their corresponding values.</s>

**Correct Assistant Response:**
What would you like to know about shops?

---Running index 6---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

What are the planets in our solar system? [/INST]



**Generated Assistant Response:**
I'm sorry, but the provided functions do not seem to be related to the question about the planets in our solar system. Can you please provide more context or clarify your question?</s>

**Correct Assistant Response:**
The planets are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus and Neptune.

Training¶

In [37]:
import torch.nn as nn
In [38]:
class CustomTrainer(transformers.Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # Define number of tokens to display
        # Displays actual and predicted token info at end of each sequence
        num_tokens = 25

        labels = inputs.pop("labels")

        # # Get first hundred label IDs for each sequence in the batch
        # first_hundred_label_ids = labels[:, :200]

        # # Convert to tokens
        # first_hundred_tokens = [tokenizer.convert_ids_to_tokens(label_ids) \
        # for label_ids in first_hundred_label_ids]

        # # Print them
        # for batch_idx, tokens in enumerate(first_hundred_tokens):
        #     print(f"First 200 decoded tokens for sequence {batch_idx + 1}: {tokens}")

        loss_mask = inputs.pop("loss_mask")

        # Forward pass
        outputs = model(**inputs)

        logits = outputs.logits

        # Check for NaN in logits and labels
        if torch.isnan(logits).any():
            print("NaN detected in logits")
            print(logits)

        # Convert logits to probabilities using softmax function
        probs = nn.functional.softmax(logits, dim=-1)

        # Get the most probable tokens
        predicted_token_ids = torch.argmax(probs, dim=-1)

        # Compute the loss
        loss_fct = nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        # Reshaping the losses to have dimensions [batch_size, seq_length]
        losses = losses.view(-1, inputs['input_ids'].size(1))

        # Apply the loss mask
        masked_loss = losses * loss_mask

        # Check for NaN in losses and zero in loss_mask.sum()
        if torch.isnan(losses).any():
            print("NaN detected in losses")
            # print(losses)

        if loss_mask.sum() == 0:
            print("Sum of loss_mask is zero")
            return (torch.tensor(0).to(loss_mask.device), outputs) \
            if return_outputs else torch.tensor(0).to(loss_mask.device)  # Early return

        # Aggregate the masked losses
        # Normalize by the number of tokens considered + epsilon to prevent
        # division by zero
        loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)

        # Print formatted tokens
        batch_size, seq_length = inputs['input_ids'].size()

        # num_tokens = len(inputs['input_ids'][0])

        # # Useful for debugging training
        # # Recommend training a small number of steps
        # print("-" * 120)
        # print(f"Token analysis for last {num_tokens} tokens:")
        # header_format = "{:<10}{:<20}{:<20}{:<20}{:<20}{:<30}{:<30}".format("Index", "Input Token", "Predicted Token", "True Token", "Loss Mask", "Raw Loss", "Masked Loss")

        # for batch_idx in range(batch_size):
        #     input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][batch_idx])  # Using batch_idx
        #     predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids[batch_idx])  # Using batch_idx
        #     true_tokens = tokenizer.convert_ids_to_tokens(labels[batch_idx])  # Using batch_idx

        #     print(f"\nBatch {batch_idx + 1} of {batch_size}:")
        #     print(header_format)
        #     for i in range(-num_tokens, 0, 1):
        #         index = seq_length + i  # Correct index based on sequence length
        #         print("{:<10}{:<20}{:<20}{:<20}{:<20.1f}{:<30.6f}{:<30.6f}".format(index, input_tokens[index], predicted_tokens[index], true_tokens[index], loss_mask[batch_idx, i].item(), losses[batch_idx, i], masked_loss[batch_idx, i]))
        #     print("-" * 120)

        return (loss, outputs) if return_outputs else loss

    def get_train_dataloader(self):
      train_dataset = self.train_dataset
      data_collator = self.data_collator

      dataloader_params = {
          "batch_size": self.args.train_batch_size,
          "collate_fn": data_collator,
          "num_workers": self.args.dataloader_num_workers,
          "pin_memory": self.args.dataloader_pin_memory,
      }

      if not isinstance(train_dataset, torch.utils.data.IterableDataset):
          dataloader_params["sampler"] = self._get_train_sampler()
          dataloader_params["drop_last"] = self.args.dataloader_drop_last

      return DataLoader(train_dataset, **dataloader_params)

    def get_eval_dataloader(self, eval_dataset=None):
      eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
      if eval_dataset is None:
          raise ValueError("Trainer: evaluation requires an eval_dataset.")

      data_collator = self.data_collator

      # Parameters for the DataLoader
      dataloader_params = {
          "batch_size": self.args.eval_batch_size,
          "collate_fn": data_collator,
          "num_workers": self.args.dataloader_num_workers,
          "pin_memory": self.args.dataloader_pin_memory,
      }

      # If your dataset isn't an instance of torch's IterableDataset, you can
      # provide sampler and drop_last
      if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
          dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
          # Typically we don't drop the last batch for evaluation
          dataloader_params["drop_last"] = False

      return DataLoader(eval_dataset, **dataloader_params)
In [39]:
class CustomDataCollator: # Needed if the EOS token is included in training
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):

        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])
        loss_mask = torch.stack([item['loss_mask'] for item in batch])

        # # Debugging: print details of the first sequence in the batch
        # num_elements_to_view = 20  # Number of last elements to view

        # # Decoding the input_ids
        # decoded_input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

        # print("Debugging last", num_elements_to_view, "elements of the first sequence in the batch:")
        # print("{:<20}{:<20}{:<20}{:<20}".format("Token", "Input ID", "Label", "Loss Mask"))
        # for i in range(-num_elements_to_view, 0, 1):
        #   print("{:<20}{:<20}{:<20}{:<20}".format(decoded_input_tokens[i], input_ids[0, i].item(), labels[0, i].item(), loss_mask[0, i].item()))

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'loss_mask': loss_mask
        }

data_collator = CustomDataCollator(tokenizer)
In [40]:
trainer = CustomTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    args=transformers.TrainingArguments(
        # max_steps=1,
        num_train_epochs=1, # Larger models typically only need 1 epoch
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        evaluation_strategy="steps",
        max_grad_norm=1,
        warmup_ratio=0.1,
        eval_steps=0.2,
        learning_rate=1e-4, # 1e-4 for LoRA
        # learning_rate=1e-5, # 1e-5 for full fine-tuning
        # fp16=True, # If not using an Ampere series (i.e. not H100, A100, A6000)
        bf16=True,
        logging_steps=1,
        output_dir="outputs",
        # optim="paged_adamw_8bit", # For training in 4bit (quantized)
        optim="adamw_torch", # For training in full fp16/bf16 precision
        lr_scheduler_type='constant',
        hub_private_repo=True
    ),
    data_collator=data_collator,
    # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # Silence warnings (Set to True for inference!)
In [42]:
trainer.train()
torch.cuda.empty_cache()
Tracking run with wandb version 0.16.2
Run data is saved locally in /content/wandb/run-20240206_060238-j0g979ti
Syncing run still-capybara-2 to Weights & Biases (docs)
View project at https://wandb.ai/gadkins/huggingface
View run at https://wandb.ai/gadkins/huggingface/runs/j0g979ti
/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
[66/66 00:29, Epoch 1/1]
Step Training Loss Validation Loss
14 0.109000 0.818482
28 0.030600 0.809609
42 0.001300 0.767155
56 0.151700 0.759370

Example After Fine Tuning¶

In [43]:
model.config.use_cache = True
model.eval()
Out[43]:
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (rotary_emb): MistralRotaryEmbedding()
            )
            (mlp): MistralMLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=14336, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=14336, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): MistralRMSNorm()
            (post_attention_layernorm): MistralRMSNorm()
          )
        )
        (norm): MistralRMSNorm()
      )
      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
    )
  )
)
In [44]:
# Run validation
for index in range(len(test_dataset)):
    print(f'---Running index {index}---')
    generate(index, "test")
---Running index 0---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Get the names of the five largest stocks by market cap [/INST]



**Generated Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": 5
    }
}</s>

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "5"
    }
}

---Running index 1---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

Get the names of the five largest stocks in the US by market cap [/INST]



**Generated Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": 5
    }
}</s>

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "5",
        "region": "US"
    }
}

---Running index 2---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

Get the stock prices of the ten largest stocks in the world [/INST]



**Generated Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": 10
    }
}</s>

**Correct Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": "10",
        "region": "World"
    }
}

---Running index 3---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Get the price of Apple's stock [/INST]



**Generated Assistant Response:**
{
    "name": "get_stock_price",
    "arguments": {
        "names": [
            "AAPL"
        ]
    }
}</s>

**Correct Assistant Response:**
{
    "name": "get_stock_price",
    "arguments": {
        "names": [
            "Apple"
        ]
    }
}

---Running index 4---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

Greetings! [/INST]



**Generated Assistant Response:**
Hello! How can I assist you today?</s>

**Correct Assistant Response:**
Greetings to you too!

---Running index 5---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    }
]    

shop [/INST]



**Generated Assistant Response:**
{
    "name": "get_big_stocks",
    "arguments": {
        "number": 25
    }
}</s>

**Correct Assistant Response:**
What would you like to know about shops?

---Running index 6---
Using the test data split.

Prompt:
[INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_stock_price",
            "description": "Get the stock price of an array of stocks",
            "parameters": {
                "type": "object",
                "properties": {
                    "names": {
                        "type": "array",
                        "items": {
                            "type": "string"
                        },
                        "description": "An array of stocks"
                    }
                },
                "required": [
                    "names"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_big_stocks",
            "description": "Get the names of the largest N stocks by market cap",
            "parameters": {
                "type": "object",
                "properties": {
                    "number": {
                        "type": "integer",
                        "description": "The number of largest stocks to get the names of, e.g. 25"
                    },
                    "region": {
                        "type": "string",
                        "description": "The region to consider, can be \"US\" or \"World\"."
                    }
                },
                "required": [
                    "number"
                ]
            }
        }
    }
]    

What are the planets in our solar system? [/INST]



**Generated Assistant Response:**
The planets in our solar system are: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune</s>

**Correct Assistant Response:**
The planets are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus and Neptune.

Merge Adapters and Save Model to Hub¶

In [53]:
# Extract the last portion of the base_model
base_model_name = base_model.split("/")[-1]

adapter_model = f"gadkins/{base_model_name}-function-calling-adapters"
new_model = f"gadkins/{base_model_name}-function-calling" # Your HF account

print(f"Adapter Model: {adapter_model}\nNew Model: {new_model}")
Adapter Model: gadkins/Mistral-7B-Instruct-v0.1-function-calling-adapters
New Model: gadkins/Mistral-7B-Instruct-v0.1-function-calling
In [56]:
# (Optional) Create repo + branch for gguf and awq

from huggingface_hub import HfApi, create_branch, create_repo

# Initialize the HfApi class
api = HfApi()

create_repo(new_model, private=False)

create_branch(new_model, repo_type="model", branch="gguf")

# create_branch(new_model, repo_type="model", branch="awq")

# create_branch(new_model, repo_type="model", branch="gptq")
In [57]:
# model.config._name_or_path="gadkins/Yi-34B-200K-Llamafied-chat-SFT"
# print(model.config._name_or_path)
In [58]:
# Save the model
model.save_pretrained(adapter_model, push_to_hub=True, use_auth_token=True)
In [77]:
# Push the model to the hub
# model.push_to_hub(adapter_model, use_auth_token=True)
In [80]:
# # ## reload the base model (you might need a pro subscription for this because you may need a high RAM environment since this is loading the full original model, not quantized)
# # ## you may prefer to use auto instead of cpu if you have a gpu
# # ## if you are training in full precision (not quantized), you may not need to reload the model, you can directly merge and unload.
# # ## if you are training very large models you may need to restart the kernel and reload the base model as there may not be enough space on gpu.

# # from transformers import AutoModelForCausalLM, PretrainedConfig
# # import torch

# # model = AutoModelForCausalLM.from_pretrained(base_model, device_map='auto', trust_remote_code=True, torch_dtype=torch.float16, cache_dir=cache_dir)

# from peft import PeftModel

# # load perf model with new adapters
# model = PeftModel.from_pretrained(
#     model,
#     './gadkins/Yi-34B-200K-Llamafied-chat-SFT-function-calling-adapters-v2',
# )
In [78]:
model = model.merge_and_unload() # merge adapters with the base model.
In [79]:
# (Optional) Allows you to save the model locally to do inference without downloading
model.save_pretrained(f"gadkins/{base_model_name}-function-calling-v3")
In [81]:
model.push_to_hub(new_model, token=True, max_shard_size="10GB",safe_serialization=True)
README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]
Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]
model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]
model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]
Out[81]:
CommitInfo(commit_url='https://huggingface.co/gadkins/Mistral-7B-Instruct-v0.1-function-calling/commit/ca16aa19012184d3f5721a3a4ba7829876a385d1', commit_message='Upload MistralForCausalLM', commit_description='', oid='ca16aa19012184d3f5721a3a4ba7829876a385d1', pr_url=None, pr_revision=None, pr_num=None)

Base README.md and also tokenizer.model (needed for GGUF and GPTQ)¶

In [82]:
import os
import requests
from huggingface_hub import HfApi

def download_file_from_huggingface(model_id, filename, save_path):
    url = f"https://huggingface.co/{model_id}/resolve/main/{filename}"
    r = requests.get(url)
    if r.status_code != 200:
        print(f"Failed to download {filename}. HTTP Status Code: {r.status_code}")
        return False
    with open(os.path.join(save_path, filename), 'wb') as f:
        f.write(r.content)
    return True

def main():
    # Files to download and upload
    files_to_process = ["tokenizer.model", "README.md"]

    # Directory to save the downloaded files
    save_path = "./models"
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # Initialize HfApi class
    api = HfApi()

    # Specify the repository where you want to upload the files
    repo_id = new_model  # Assuming new_model is in the format "username/repo"

    for filename in files_to_process:
        # Download the file
        success = download_file_from_huggingface(base_model, filename, save_path)
        if success:
            print(f"Successfully downloaded {filename}")
        else:
            print(f"Failed to download {filename}")
            continue  # Skip uploading if download failed

        # File path to upload
        local_file_path = os.path.join(save_path, filename)

        # Upload the file
        api.upload_file(
            path_or_fileobj=local_file_path,
            path_in_repo=filename,  # Using filename directly, adjust as needed
            repo_id=repo_id,
            repo_type="model",  # Assuming it's a model; can be "dataset" or "space" as well
        )
        print(f"Uploaded {filename} to {repo_id}")

if __name__ == "__main__":
    main()
Successfully downloaded tokenizer.model
Uploaded tokenizer.model to gadkins/Mistral-7B-Instruct-v0.1-function-calling
Successfully downloaded README.md
Uploaded README.md to gadkins/Mistral-7B-Instruct-v0.1-function-calling

Set up chat template (advanced option)¶

This is a more advanced step that allows you to customize a chat template for function calling.

Typically you need to start by grabbing the chat_template from tokenizer_config.json of the base file and pasting that into the box below. You then need to customize that template to include function_metadata, function_response and function_call roles. You can see one example below but it won't be correct for all models.

In [64]:
print(tokenizer.chat_template)
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
In [65]:
print(tokenizer.bos_token)
print(tokenizer.eos_token)
<s>
</s>
In [66]:
import json
In [67]:
function_metadata = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "This function gets the current weather in a given city",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city, e.g., San Francisco"
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use."
                    }
                },
                "required": ["city"]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_clothes",
            "description": "This function provides a suggestion of clothes to wear based on the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "temperature": {
                        "type": "string",
                        "description": "The temperature, e.g., 15 C or 59 F"
                    },
                    "condition": {
                        "type": "string",
                        "description": "The weather condition, e.g., 'Cloudy', 'Sunny', 'Rainy'"
                    }
                },
                "required": ["temperature", "condition"]
            }
        }
    }
]
In [68]:
# Comment out later messages to test various stages of generation.

sample_messages = [
    # {
    #     "role": "system",
    #     "content": "you are a helpful assistant"
    # },
    {
        "role": "function_metadata",
        "content": "FUNCTION_METADATA"
    },
    {
        "role": "user",
        "content": "What is the current weather in London?"
    },
    # {
    #     "role": "function_call",
    #     "content": "{\n    \"name\": \"get_current_weather\",\n    \"arguments\": {\n        \"city\": \"London\"\n    }\n}</s>"
    # },
    # {
    #     "role": "function_response",
    #     "content": "{\n    \"temperature\": \"15 C\",\n    \"condition\": \"Cloudy\"\n}"
    # },
    # {
    #     "role": "assistant",
    #     "content": "The current weather in London is Cloudy with a temperature of 15 Celsius.</s>"
    # },
    # {
    #     "role": "user",
    #     "content": "That's great. Now say hello."
    # },
    # {
    #     "role": "assistant",
    #     "content": "Hello!</s>"
    # }
]
In [69]:
# Iterate through each message in the list
for message in sample_messages:
    if message['role'] == 'function_metadata':
        # Replace 'FUNCTION_METADATA' with 'function_metadata' in the content
        message['content'] = message['content'].replace('FUNCTION_METADATA', json.dumps(function_metadata, indent=4))
In [70]:
# Llama 2 templates / Mistral
tokenizer.chat_template = """{{ bos_token }} [INST] {% for message in messages %}{% if message['role'] == 'system' %}<<SYS>>\n{{ message['content'] }}\n<</SYS>>\n\n{% elif message['role'] == 'function_metadata' %}You have access to the following functions. Use them if required:\n\n{{ message['content'] }}\n\n{% elif message['role'] == 'user' %}{{ message['content'] }} [/INST]\n\n{% elif message['role'] == 'assistant' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_call' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_response' %}Here is the response to the function call. If helpful, use it to respond to my question:\n\n{{ message['content'] }} [/INST]\n\n{% endif %}{% endfor %}"""
In [71]:
print(tokenizer.chat_template)
{{ bos_token }} [INST] {% for message in messages %}{% if message['role'] == 'system' %}<<SYS>>
{{ message['content'] }}
<</SYS>>

{% elif message['role'] == 'function_metadata' %}You have access to the following functions. Use them if required:

{{ message['content'] }}

{% elif message['role'] == 'user' %}{{ message['content'] }} [/INST]

{% elif message['role'] == 'assistant' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_call' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_response' %}Here is the response to the function call. If helpful, use it to respond to my question:

{{ message['content'] }} [/INST]

{% endif %}{% endfor %}
In [72]:
# View the template applied without tokenization
prompt = tokenizer.apply_chat_template(sample_messages, tokenize=False)
print(prompt)
<s> [INST] You have access to the following functions. Use them if required:

[
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "This function gets the current weather in a given city",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city, e.g., San Francisco"
                    },
                    "format": {
                        "type": "string",
                        "enum": [
                            "celsius",
                            "fahrenheit"
                        ],
                        "description": "The temperature unit to use."
                    }
                },
                "required": [
                    "city"
                ]
            }
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_clothes",
            "description": "This function provides a suggestion of clothes to wear based on the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "temperature": {
                        "type": "string",
                        "description": "The temperature, e.g., 15 C or 59 F"
                    },
                    "condition": {
                        "type": "string",
                        "description": "The weather condition, e.g., 'Cloudy', 'Sunny', 'Rainy'"
                    }
                },
                "required": [
                    "temperature",
                    "condition"
                ]
            }
        }
    }
]

What is the current weather in London? [/INST]


In [73]:
## Test generation

inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

if "token_type_ids" in inputs:
    del inputs["token_type_ids"]

# print(f'model is on: {next(model.parameters()).device}')  # Debug line
# print(f'input_ids is on: {inputs["input_ids"].device}')  # Debug line

output = model.generate(**inputs,
                        max_new_tokens=200,
                        do_sample=False,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        # temperature=0.01,
                        # top_k=0
                       )

print()

# Subtract the length of input_ids from output to get only the model's response
output_text = tokenizer.decode(output[0, len(inputs.input_ids[0]):], skip_special_tokens=False)
print(output_text)
{
    "name": "get_current_weather",
    "arguments": {
        "city": "London"
    }
}</s>

Push Tokenizer¶

In [74]:
# optional, but allows you to save the model locally so you can immediately inference without downloading
tokenizer.save_pretrained(f"gadkins/{base_model_name}-function-calling-v3")
Out[74]:
('gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer_config.json',
 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/special_tokens_map.json',
 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer.model',
 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/added_tokens.json',
 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer.json')
In [75]:
# #Push the tokenizer
tokenizer.push_to_hub(new_model, token=True)

## RELOAD IF NEEDED (NOT RECOMMENDED IF tokenizer.chat_template was updated.
# from transformers import AutoTokenizer
# # reload the tokenizer because you don't want to have an off-size tokenizer with pad tokens.
# tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]
Out[75]:
CommitInfo(commit_url='https://huggingface.co/gadkins/Mistral-7B-Instruct-v0.1-function-calling/commit/dfe2015a6083826389d11212b55d530816a0e0c6', commit_message='Upload tokenizer', commit_description='', oid='dfe2015a6083826389d11212b55d530816a0e0c6', pr_url=None, pr_revision=None, pr_num=None)