API: User-Provided Functions
Apart from the classes to define models, adapters, and trainers, users can create the following custom functions as part of their experiments.
Create Model Function
Mandatory user-provided function to create HuggingFace model and tokenizer objects based on the
model type(s) and name(s) given in the RFModelConfig
and multi-config specification.
Also read the LoRA and Model Configs page.
A model can be imported from the Hugging Face model hub or read from a local checkpoint file.
It is passed to run_fit()
directly. Also read the Experiment page.
This function is invoked when a trainer object is created for each run.
- create_model_fn(model_config: Dict[str, Any]) Tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]
- Parameters:
model_config (Dict[str, Any]) – Dictionary injected by RapidFire AI into this user-defined function with all key-value pairs for one model config output by the config-group generator.
- Returns:
Tuple containing the initialized Hugging Face model (e.g.,
AutoModelForCausalLM
,AutoModelForSequenceClassification
) and tokenizer (e.g.,AutoTokenizer
,PreTrainedTokenizer
) objects- Return type:
Tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]
Example:
# From the SFT tutorial notebook
def sample_create_model(model_config):
"""Function to create model object for any given config; must return tuple of (model, tokenizer)"""
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM
model_name = model_config["model_name"]
model_type = model_config["model_type"]
model_kwargs = model_config["model_kwargs"]
if model_type == "causal_lm":
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
elif model_type == "seq2seq_lm":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
elif model_type == "masked_lm":
model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)
elif model_type == "custom":
# Handle custom model loading logic, e.g., loading your own checkpoints
# model = ...
pass
else:
# Default to causal LM
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return (model,tokenizer)
Compute Metrics Function
Optional user-provided function specifying custom evaluation metrics based on the generated outputs and ground truth.
It is passed to the compute_metrics
argument of RFModelConfig
.
Also read: the LoRA and Model Configs page.
You can create multiple variants of these functions and pass them all as a single
List
to your RFModelConfig
to create a multi-config specification.
This function is invoked by the underlying HF trainer at a cadence controlled by the
eval_strategy
and eval_steps
arguments.
Also read: the Trainer Configs page.
Example:
# From the SFT tutorial notebook
def sample_compute_metrics(eval_preds):
"""Optional function to compute eval metrics based on predictions and labels"""
predictions, labels = eval_preds
# Standard text-based eval metrics: Rouge and BLEU
import evaluate
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
rouge_output = rouge.compute(predictions=predictions, references=labels, use_stemmer=True)
rouge_l = rouge_output["rougeL"]
bleu_output = bleu.compute(predictions=predictions, references=labels)
bleu_score = bleu_output["bleu"]
return {"rougeL": round(rouge_l, 4), "bleu": round(bleu_score, 4)}
Formatting Function
Optional user-provided function to format each example (row) of the dataset to construct the prompt and completion with relevant roles and system prompt as expected by your model. Apart from adding the system prompt, for conversational data it should format the user instruction and assistant responses as separate message dictionary entries.
It is passed to the formatting_func
argument of RFModelConfig
.
Also read: the LoRA and Model Configs page.
You can create multiple variants of these functions and pass them all as a single
List
to your RFModelConfig
to create a multi-config specification.
This function is invoked by the underlying HF trainer on all examples of the train dataset and (if given) eval dataset on the fly.
Example:
# From the SFT tutorial notebook
def sample_formatting_function(row):
"""Function to preprocess each row from dataset"""
# Special tokens for formatting
SYSTEM_PROMPT = "You are a helpful and friendly customer support assistant. Please answer the user's query to the best of your ability."
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["instruction"]},
],
"completion": [
{"role": "assistant", "content": row["response"]}
]
}
Reward Functions
User-provided reward function(s) needed for GRPO. You can create as many reward functions as you like with custom names.
A list of such functions is passed to the reward_funcs
argument of RFModelConfig
.
Also read: the LoRA and Model Configs page.
You can create multiple variants of this list with different subsets of functions and pass them
all as a single List
to your RFModelConfig
to create a multi-config specification.
These functions are invoked by the underlying HF trainer on the generated outputs on the fly.
- reward_function(prompts, completions, completions_ids, trainer_state, **kwargs) List[float]
- Parameters:
prompts (List[str] | List[List[Dict[str, str]]]) – List of input prompts that produced the completions.
completions (List[str] | List[List[Dict[str, str]]]) – List of generated completions corresponding to above prompts.
completions_ids (List[List[int]]) – List of tokenized completions (token IDs) corresponding to each completion.
trainer_state (transformers.TrainerState) – Current state of the trainer. Useful for implementing dynamic reward functions like curriculum learning where rewards adjust based on training progress.
kwargs (Any) – Additional keyword arguments containing all dataset columns (except “prompt”). For example, if the dataset contains a “ground_truth” column, it will be passed as a keyword argument.
- Returns:
List of reward scores, one per single completion.
- Return type:
List[float] | None
Examples:
# From the GRPO tutorial notebook
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
# x('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
import re
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
Notes:
Note that TRL injects into a reward function lists of prompts, completions, completion IDs, and trainer
state as keyword arguments. You can use only a subset of these in your reward function signature as
long as you include **kwargs
, as shown in the second example above.
Depending on the dataset format, prompts
and completions
will be either lists of
strings (standard format) or lists of message dictionaries (conversational format).
Standard format is usually common for text completion tasks, simple Q&A, code generation, and
mathematical reasoning.
Conversational format is needed for multi-turn conversations, chat models with system prompts,
role-playing scenarios, and complex dialogue systems.
Make sure your reward function can handle both cases if you dataset includes both types.
The return type of every reward function must be a list of floats, one per completion.
It can also return None
for examples when the reward function is not applicable,
which is useful for multi-task training.