API: User-Provided Functions for Run Evals
Users can provide the following custom functions as part of their eval config to be used in run_evals().
Note that each leaf config can have its own set of functions for all of these.
Preprocess Function
Mandatory user-provided function to prepare the inputs to be given to the generator model.
It is invoked for each batch during the evaluation process before generation.
Pass it directly to the preprocess_fn key in your eval config dictionary.
The system injects into this function the batch data, as well as the RAG spec and the prompt manager of an individual leaf config.
- preprocess_fn(batch: dict[str, list], rag: RFLangChainRagSpec, prompt_manager: RFPromptManager) dict[str, list]
- Parameters:
batch (dict[str, list]) – Dictionary with a batch of examples with dataset field names as keys and lists as values
rag (RFLangChainRagSpec) – RAG specification object for document chunk retrieval and context serialization
prompt_manager (RFPromptManager) – Prompt manager object for handling instructions and few-shot examples
- Returns:
Dictionary with the preprocessed batch. It must have a reserved key
"prompts"for the fully formatted prompts for the generator. Other key-value pairs from the original batch can also be copied over if you want.- Return type:
Examples:
# Example 1 from FiQA use case: RAG-based preprocessing with document chunk retrieval
def sample_preprocess_fn(batch: dict[str, list], rag: RFLangChainRagSpec, prompt_manager: RFPromptManager) -> dict[str, list]:
"""Function to prepare the final inputs given to the generator model"""
INSTRUCTIONS = "Utilize your financial knowledge, give your answer or opinion to the input question or subject matter."
# Perform batched retrieval over all queries; returns a list of lists of k documents per query
all_context = rag.get_context(batch_queries=batch["query"], serialize=False)
# Extract the retrieved document ids from the context
retrieved_documents = [
[doc.metadata["corpus_id"] for doc in docs] for docs in all_context
]
# Serialize the retrieved documents into a single string per query using the default template
serialized_context = rag.serialize_documents(all_context)
batch["query_id"] = [int(query_id) for query_id in batch["query_id"]]
# Each batch to contain conversational prompt, retrieved documents, and original 'query_id', 'query', 'metadata'
return {
"prompts": [
[
{"role": "system", "content": INSTRUCTIONS},
{
"role": "user",
"content": f"Here is some relevant context:\n{context}. \nNow answer the following question using the context provided earlier:\n{question}",
},
]
for question, context in zip(batch["query"], serialized_context)
],
"retrieved_documents": retrieved_documents,
**batch,
}
# Example 2 from GSM8K use case: Few-shot learning preprocessing without RAG
def sample_preprocess_fn(batch: dict[str, list], rag: RFLangChainRagSpec, prompt_manager: RFPromptManager) -> dict[str, list]:
"""Function to prepare the final inputs given to the generator model"""
return {
"prompts": [
[
{"role": "system", "content": prompt_manager.get_instructions()},
{
"role": "user",
"content": f"Here are some examples: \n{examples}. \nNow answer the following question:\n{question}",
},
]
for question, examples in zip(
batch["question"],
prompt_manager.get_fewshot_examples(user_queries=batch["question"]),
)
],
**batch,
}
Postprocess Function
Optional user-provided function to postprocess a batch of examples and their respective generated outputs, which is injected into this function by the system. This is useful for, say, extracting structured information from generations, adding ground truth data, or performing any transformations needed before computing metrics.
It is invoked for each batch during the evaluation process after generation and before metric computation.
Pass it directly to the postprocess_fn key in your eval config dictionary.
Examples:
# Example 1 from FiQA use case: Adding ground truth documents for retrieval evaluation
def sample_postprocess_fn(batch: dict[str, list]) -> dict[str, list]:
"""Function to postprocess outputs produced by generator model"""
batch["ground_truth_documents"] = [
qrels[qrels["query_id"] == query_id]["corpus_id"].tolist()
for query_id in batch["query_id"]
]
return batch
# Example 2 from SciFact use case: Extracting structured answers from generated text
def extract_solution(answer):
solution = re.search(r"####\s*(SUPPORT|CONTRADICT|NOINFO)", answer, re.IGNORECASE)
if solution is None:
return "INVALID"
return solution.group(1).upper()
def sample_postprocess_fn(batch: dict[str, list]) -> dict[str, list]:
"""Function to postprocess outputs produced by generator model"""
batch["ground_truth_documents"] = [
qrels[qrels["query_id"] == query_id]["corpus_id"].tolist()
for query_id in batch["query_id"]
]
batch["answer"] = [extract_solution(answer) for answer in batch["generated_text"]]
return batch
Eval Compute Metrics Function
Mandatory user-provided function to compute eval metrics on a given batch of (postprocessed) examples, which is injected by the system. It should return metrics computed over the batch as a whole.
It is invoked for each batch during the evaluation process after generation and postprocessing (if applicable).
Pass it directly to the compute_metrics_fn key in your eval config dictionary.
- eval.compute_metrics_fn(batch: dict[str, list]) dict[str, dict[str, Any]]
- Parameters:
batch (dict[str, list]) – Dictionary containing a batch of examples, including all preprocessed fields, generated outputs, and any postprocessed fields
- Returns:
Dictionary with a metric’s name as key and a dictionary as value inside which a reserved key
"value"must exist with that corresponding metric’s value over this batch of examples.- Return type:
Example:
# Example 1 from FiQA use case: Metrics on retrieval accuracy
def sample_compute_metrics_fn(batch: Dict[str, list]) -> Dict[str, Dict[str, Any]]:
"""Function to compute all eval metrics based on retrievals and/or generations"""
true_positives, precisions, recalls, f1_scores, ndcgs, rrs = 0, [], [], [], [], []
total_queries = len(batch["query"])
for pred, gt in zip(batch["retrieved_documents"], batch["ground_truth_documents"]):
expected_set = set(gt)
retrieved_set = set(pred)
true_positives = len(expected_set.intersection(retrieved_set))
precision = true_positives / len(retrieved_set) if len(retrieved_set) > 0 else 0
recall = true_positives / len(expected_set) if len(expected_set) > 0 else 0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0
)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)
ndcgs.append(compute_ndcg_at_k(retrieved_set, expected_set, k=5))
rrs.append(compute_rr(retrieved_set, expected_set))
return {
"Total": {"value": total_queries},
"Precision": {"value": sum(precisions) / total_queries},
"Recall": {"value": sum(recalls) / total_queries},
"F1 Score": {"value": sum(f1_scores) / total_queries},
"NDCG@5": {"value": sum(ndcgs) / total_queries},
"MRR": {"value": sum(rrs) / total_queries},
}
# Example 2 from GSM8K use case: Direct answer correctness check
def sample_compute_metrics_fn(batch: dict[str, list]) -> dict[str, dict[str, Any]]:
"""Function to compute all eval metrics based on retrievals and/or generations"""
correct = sum(
1
for pred, gt in zip(batch["model_answer"], batch["ground_truth"])
if pred == gt
)
total = len(batch["model_answer"])
return {
"Correct": {"value": correct},
"Total": {"value": total},
}
Eval Accumulate Metrics Function
Optional user-provided function to aggregate algebraic eval metrics across all batches of the data.
If this function is not provided, all metrics returned by eval.compute_metrics_fn()
will be assumed to be distributive (i.e., summed across batches) by default. Use this function
when metrics require (weighted) averaging or other custom dataset-wide aggregation logic.
It is invoked once at the very end of the evaluation process after all batches have been processed.
Pass it directly to the accumulate_metrics_fn key in your eval config dictionary.
- eval.accumulate_metrics_fn(aggregated_metrics: dict[str, list[dict[str, Any]]]) dict[str, dict[str, Any]]
- Parameters:
aggregated_metrics (dict[str, list[dict[str, Any]]]) – Dictionary with a metric’s name as key and a list of per-batch metric dictionaries as values from across all data batches. Inside each value dictionary, at least the reserved key
"value"will exist that was returned by youreval_compute_metrics_fnfunction.- Returns:
Dictionary with a metric’s name as key and a dictionary as value. The value dictionary can have these keys:
"value"(required): The metric’s value over the whole dataset."is_algebraic"(optional): Boolean indicating if the metric is algebraic."is_distributive"(optional): Boolean indicating if the metric is distributive."value_range"(optional): 2-tuple of floats specifying the metric’s range min and max.
For online aggregation support, see the Online Aggregation for Evals page.
- Return type:
Example:
# Example 1 from GSM8K use case: Weighted aggregation for accuracy computation
def sample_accumulate_metrics_fn(aggregated_metrics: dict[str, list]) -> dict[str, dict[str, Any]]:
"""Function to accumulate eval metrics across all batches"""
correct = sum(m.get("value", 0) for m in aggregated_metrics.get("Correct", [{}]))
total = sum(m.get("value", 0) for m in aggregated_metrics.get("Total", [{}]))
accuracy = correct / total if total > 0 else 0
return {
"Total": {"value": total},
"Correct": {
"value": correct,
"is_distributive": True,
"value_range": (0, 1),
},
"Accuracy": {
"value": accuracy,
"is_algebraic": True,
"value_range": (0, 1),
},
}
# Example 2 from FiQA use case: Multiple algebraic metrics for retrieval accuracy
def sample_accumulate_metrics_fn(aggregated_metrics: dict[str, list]) -> dict[str, dict[str, Any]]:
"""Function to accumulate eval metrics across all batches"""
num_queries_per_batch = [m["value"] for m in aggregated_metrics["Total"]]
total_queries = sum(num_queries_per_batch)
algebraic_metrics = ["Precision", "Recall", "F1 Score", "NDCG@5", "MRR"]
return {
"Total": {"value": total_queries},
**{
metric: {
"value": sum(
m["value"] * queries
for m, queries in zip(
aggregated_metrics[metric], num_queries_per_batch
)
)
/ total_queries,
"is_algebraic": True,
"value_range": (0, 1),
}
for metric in algebraic_metrics
},
}