Skip to content

Commit fa3bf8c

Browse files
authored
RAFT Enhancements: Improved robustness, logging, checkpointing, threading, Llama support, Azure auth and eval (#604)
This pull request introduces a comprehensive set of updates and improvements to the RAFT project, enhancing robustness, logging, progress monitoring, checkpointing, multi-threading, Llama support, Azure authentication, and evaluation processes. **Note**: Those updates where developed for the most part to prepare the MS Build 2024 talk [Practicalities of Fine-Tuning Llama 2 with AI Studio](https://aka.ms/build24-ft-practical) with @ShishirPatil and Bala Venkataraman. Key updates include: ### RAFT Script Improvements: This PR introduces significant updates to the `raft.py` script, expanding its functionality, improving its configurability, and removing deprecated options. Below is a summary of the key changes: - **Logging Enhancements:** Improved logging configuration, including more granular logging for various operations. - **Checkpointing Overhaul:** Significant refactoring of checkpointing logic in `raft.py`, including the introduction of multi-threading, better directory handling, and optimization of chunk processing. The `--fast` mode, which deactivated checkpointing, was removed in favor of a more efficient implementation that allows checkpointing to remain activated at all times. - **Multi-Worker Support:** Added a `--workers` parameter to enable parallel processing, improving efficiency and reliability during various operations. - **Llama Instruction Support:** Added support for Llama instructions in addition to GPT instructions, enhancing the versatility of the script for different model types. - **Dataset Processing:** Added more robust handling and filtering of datasets, including support for customized field names, empty row filtering, and threshold-based early stopping. - **Authentication Updates:** Added support for Azure OpenAI Keyless and Managed Identity authentication, along with related environment variable handling. - **Content Safety Handling:** Updated the content generation process to skip chunks that fail content safety compliance checks, allowing the process to continue without interruption. - **Progress Logging Enhancements:** Improved progress logging with `tqdm`, including enhanced stats support in `client_utils.py`, providing better insights into the process flow. - **Bug Fixes and Cleanup:** Fixed various bugs across the project, cleaned up help messages, and removed outdated or redundant components. #### New Features and Options 1. **Output Format Expansion:** - Added a new output format option: `eval`. This format is intended for evaluation purposes, providing an additional way to format datasets. 2. **Enhanced Output Configuration:** - Introduced `--output-completion-prompt-column` and `--output-completion-completion-column` options to allow users to specify custom column names for prompts and completions when using the `completion` format. 3. **System Prompt Customization:** - Added the `--system-prompt-key` option to allow users to select between different system prompt keys (`gpt` or `llama`) based on the model they intend to use for dataset generation. 4. **Worker Thread Management:** - Introduced the `--workers` option to allow parallel processing by specifying the number of worker threads, improving the script’s efficiency in handling large datasets. 5. **Checkpoint Management:** - Added the `--auto-clean-checkpoints` option, giving users the ability to automatically clean up checkpoints after dataset generation, reducing the need for manual intervention. 6. **Question/Answer Sample Threshold:** - Introduced the `--qa-threshold` option, which allows users to specify a threshold for the number of Question/Answer samples to generate before stopping. This provides more control over the dataset generation process, particularly in large-scale operations. #### Removed Options 1. **`--fast`:** - The `--fast` option has been removed. This option was previously used to run the script in a fast mode with no recovery implemented. The script has been optimized to improve performance without the need for a separate fast mode, rendering this option obsolete. #### Default Value Updates - Several options now have default values set, including `--output-type`, `--output-format`, `--doctype`, `--embedding_model`, `--completion_model`, `--workers`, and more. These defaults aim to make the script more user-friendly by reducing the need for extensive configuration. --- ### Evaluation Script Improvements: - **Stop Keyword:** Added a stop keyword functionality to allow controlled early termination of evaluation processes when specific conditions are met. - **Retry Mechanism:** Introduced a retry mechanism for failed tasks, improving reliability during evaluations. - **Improved Robustness:** Enhanced the script’s robustness, particularly in handling errors and edge cases, ensuring a smoother evaluation process. - **Logging Retry Statistics:** Implemented logging for retry attempts, providing detailed insights and transparency into the evaluation process. - **Main Thread Exception Handling:** Fixed an issue where exceptions in the main thread could cause silent failures, ensuring that all errors are properly reported and handled. - **Support for Chat and Completion Models:** Extended the script to support both chat and completion models, increasing its versatility across different use cases. - **Environment Prefix Handling:** Enabled the script to accept an environment prefix as a parameter, enhancing its adaptability to different deployment environments. - **Progress Monitoring:** Integrated progress monitoring with `tqdm`, allowing for real-time tracking of the evaluation process. - **Configurable Workers:** Made the number of workers configurable using the `--workers` option, allowing for fine-tuned parallel processing during evaluations. Here's the PR message formatted in Markdown: #### Enhanced CLI Options for `eval.py` This PR introduces several new command-line options to the `eval.py` script, providing enhanced functionality and flexibility for model evaluation. The following changes have been made: - **`--model MODEL`**: Added support for specifying the model to be evaluated. - **`--mode MODE`**: Introduced a new option to select the API mode, either 'chat' or 'completion'. The default mode is set to 'chat'. - **`--input-prompt-key INPUT_PROMPT_KEY`**: Added the ability to define which column in the dataset should be used as the input prompt. - **`--output-answer-key OUTPUT_ANSWER_KEY`**: Added the ability to define which column in the dataset should be used as the output answer. - **`--workers WORKERS`**: Introduced multi-threading support, allowing users to specify the number of worker threads for evaluating the dataset, improving processing efficiency. - **`--env-prefix ENV_PREFIX`**: Added an option to customize the prefix for environment variables used for API keys and base URLs. The default prefix is set to `EVAL`. These enhancements provide greater control over the evaluation process, allowing for more customized and efficient use of the `eval.py` script. ## Testing ``` pytest ```
1 parent 41fdee6 commit fa3bf8c

File tree

11 files changed

+906
-284
lines changed

11 files changed

+906
-284
lines changed

raft/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
.venv/
2+
output/

raft/README.md

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ pip install -r requirements.txt
2121
```
2222

2323
Arguments:
24-
- `--datapath` - the path at which the document is located
24+
- `--datapath` - if a file, the path at which the document is located. If a folder, the path at which to load all documents
2525
- `--output` - the path at which to save the dataset
26-
- `--output-format` - the format of the output dataset. Defaults to `hf` for HuggingFace. Can be one of `hf`, `completion`, `chat`.
26+
- `--output-format` - the format of the output dataset. Defaults to `hf` for HuggingFace. Can be one of `hf`, `completion`, `chat`, `eval`.
2727
- `--output-type` - the type of the output dataset file. Defaults to `jsonl`. Can be one of `jsonl`, `parquet`.
2828
- `--output-chat-system-prompt` - The system prompt to use when the output format is `chat`. Optional.
29+
- `--output-completion-prompt-column` - The column (json field name) for the `prompt` / `instruction` when using the `completion` output format. Defaults to `prompt`.
30+
- `--output-completion-completion-column` - The column (json field name) for the `completion` when using the `completion` output format. Defaults to `completion`.
2931
- `--distractors` - the number of distractor documents to include per data point / triplet
3032
- `--doctype` - the type of the document, must be one of the accepted doctypes
3133
- currently accepted doctypes: `pdf`, `txt`, `json`, `api`
@@ -37,8 +39,11 @@ Arguments:
3739
- `--openai_key` - your OpenAI key used to make queries to GPT-3.5 or GPT-4
3840
- `--embedding-model` - The embedding model to use to encode documents chunks. Defaults to `text-embedding-ada-002`.
3941
- `--completion-model` - The model to use to generate questions and answers. Defaults to `gpt-4`.
40-
- `--fast` - Fast mode flag. By default, this flag is not included and the script runs in safe mode, where it saves checkpoint datasets, allowing the script to recover and continue where it left off in the case of an interruption. Include this flag to run RAFT without recovery.
42+
- `--system-prompt-key` - The system prompt key to use to generate the dataset. Defaults to `gpt`. Can by one of `gpt`, `llama`.
43+
- `--workers` - The number of worker threads to use to generate the dataset. Defaults to 2.
44+
- `--auto-clean-checkpoints` - Whether to auto clean the checkpoints after the dataset is generated. Defaults to `false`.
4145

46+
*Note*: The `--fast` mode flag has been removed, checkpointing is now always active.
4247

4348
## Usage
4449

@@ -219,6 +224,27 @@ python3 format.py --input output/data-00000-of-00001.arrow --output output.compl
219224

220225
```
221226
python3 format.py --help
227+
228+
usage: format.py [-h] --input INPUT [--input-type {arrow,jsonl}] --output OUTPUT --output-format {hf,completion,chat,eval} [--output-type {parquet,jsonl}] [--output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT] [--output-completion-prompt-column OUTPUT_COMPLETION_PROMPT_COLUMN] [--output-completion-completion-column OUTPUT_COMPLETION_COMPLETION_COLUMN] [--output-completion-stop OUTPUT_COMPLETION_STOP]
229+
230+
options:
231+
-h, --help show this help message and exit
232+
--input INPUT Input HuggingFace dataset file (default: None)
233+
--input-type {arrow,jsonl}
234+
Format of the input dataset. Defaults to arrow. (default: arrow)
235+
--output OUTPUT Output file (default: None)
236+
--output-format {hf,completion,chat,eval}
237+
Format to convert the dataset to (default: None)
238+
--output-type {parquet,jsonl}
239+
Type to export the dataset to. Defaults to jsonl. (default: jsonl)
240+
--output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT
241+
The system prompt to use when the output format is chat (default: None)
242+
--output-completion-prompt-column OUTPUT_COMPLETION_PROMPT_COLUMN
243+
The prompt column name to use for the completion format (default: prompt)
244+
--output-completion-completion-column OUTPUT_COMPLETION_COMPLETION_COLUMN
245+
The completion column name to use for the completion format (default: completion)
246+
--output-completion-stop OUTPUT_COMPLETION_STOP
247+
The stop keyword to use for the completion format (default: <STOP>)
222248
```
223249

224250
**Note**: If fine tuning a chat model, then you need to use `--output-format chat` and optionally add the `--output-chat-system-prompt` parameter to configure the system prompt included in the dataset.

raft/checkpointing.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
from typing import List
4+
from datasets import Dataset, concatenate_datasets
5+
import logging
6+
import shutil
7+
8+
logger = logging.getLogger("raft")
9+
10+
@dataclass
11+
class Checkpoint:
12+
path: Path
13+
num: int
14+
15+
def load(self) -> Dataset:
16+
return Dataset.load_from_disk(self.path)
17+
18+
def __lt__(self, other: 'Checkpoint') -> bool:
19+
return self.num < other.num
20+
21+
def __eq__(self, other: 'Checkpoint') -> bool:
22+
return self.num == other.num
23+
24+
def __hash__(self) -> int:
25+
return hash(self.num)
26+
27+
class Checkpointing:
28+
29+
def __init__(self, checkpoints_dir: Path) -> None:
30+
self.checkpoints_dir = checkpoints_dir
31+
32+
def missing_checkpoints(self, num) -> List[int]:
33+
return [n for n in range(0, num) if not (self.checkpoints_dir / f"checkpoint-{n}").exists()]
34+
35+
def save_checkpoint(self, ds: Dataset, num: int):
36+
checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
37+
ds.save_to_disk(checkpoint_path)
38+
39+
def load_checkpoint(self, num: int):
40+
checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
41+
if checkpoint_path.exists():
42+
return Dataset.load_from_disk(checkpoint_path)
43+
return None
44+
45+
def get_checkpoints(self) -> List[Checkpoint]:
46+
checkpoints = []
47+
if not self.checkpoints_dir.exists():
48+
return checkpoints
49+
for dir_path in self.checkpoints_dir.iterdir():
50+
if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"):
51+
num = int(dir_path.name.split("-")[1])
52+
checkpoints.append(Checkpoint(dir_path, num))
53+
return checkpoints
54+
55+
def has_checkpoints(self) -> bool:
56+
return len(self.get_checkpoints()) > 0
57+
58+
def collect_checkpoints(self) -> Dataset:
59+
ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()])
60+
ds = concatenate_datasets(ds_list)
61+
return ds
62+
63+
def delete_checkpoints(self):
64+
shutil.rmtree(self.checkpoints_dir)
65+
66+
def checkpointed(checkpointing: Checkpointing):
67+
def wrapped(func):
68+
def wrapper(chunk_id, *args, **kwargs):
69+
ds = checkpointing.load_checkpoint(chunk_id)
70+
if ds:
71+
return ds
72+
ds = func(chunk_id=chunk_id, *args, **kwargs)
73+
if ds.num_rows > 0:
74+
checkpointing.save_checkpoint(ds, chunk_id)
75+
return ds
76+
return wrapper
77+
return wrapped

raft/client_utils.py

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
1+
from abc import ABC
12
from typing import Any
2-
from dotenv import load_dotenv
33
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
44
from openai import AzureOpenAI, OpenAI
55
import logging
66
from env_config import read_env_config, set_env
7-
from os import environ
7+
from os import environ, getenv
8+
import time
9+
from threading import Lock
10+
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
11+
from azure.identity import get_bearer_token_provider
812

9-
logger = logging.getLogger("client_utils")
1013

11-
load_dotenv() # take environment variables from .env.
14+
logger = logging.getLogger("client_utils")
1215

13-
def build_openai_client(**kwargs: Any) -> OpenAI:
16+
def build_openai_client(env_prefix : str = "COMPLETION", **kwargs: Any) -> OpenAI:
1417
"""
1518
Build OpenAI client based on the environment variables.
1619
"""
1720

18-
env = read_env_config("COMPLETION")
21+
kwargs = _remove_empty_values(kwargs)
22+
env = read_env_config(env_prefix)
1923
with set_env(**env):
2024
if is_azure():
21-
client = AzureOpenAI(**kwargs)
25+
auth_args = _get_azure_auth_client_args()
26+
client = AzureOpenAI(**auth_args, **kwargs)
2227
else:
2328
client = OpenAI(**kwargs)
2429
return client
@@ -28,19 +33,124 @@ def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings:
2833
Build OpenAI embeddings client based on the environment variables.
2934
"""
3035

36+
kwargs = _remove_empty_values(kwargs)
3137
env = read_env_config("EMBEDDING")
32-
3338
with set_env(**env):
3439
if is_azure():
35-
client = AzureOpenAIEmbeddings(**kwargs)
40+
auth_args = _get_azure_auth_client_args()
41+
client = AzureOpenAIEmbeddings(**auth_args, **kwargs)
3642
else:
3743
client = OpenAIEmbeddings(**kwargs)
3844
return client
3945

46+
def _remove_empty_values(d: dict) -> dict:
47+
return {k: v for k, v in d.items() if v is not None}
48+
49+
def _get_azure_auth_client_args() -> dict:
50+
"""Handle Azure OpenAI Keyless, Managed Identity and Key based authentication
51+
https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
52+
"""
53+
client_args = {}
54+
if getenv("AZURE_OPENAI_KEY"):
55+
logger.info("Using Azure OpenAI Key based authentication")
56+
client_args["api_key"] = getenv("AZURE_OPENAI_KEY")
57+
else:
58+
if client_id := getenv("AZURE_OPENAI_CLIENT_ID"):
59+
# Authenticate using a user-assigned managed identity on Azure
60+
logger.info("Using Azure OpenAI Managed Identity Keyless authentication")
61+
azure_credential = ManagedIdentityCredential(client_id=client_id)
62+
else:
63+
# Authenticate using the default Azure credential chain
64+
logger.info("Using Azure OpenAI Default Azure Credential Keyless authentication")
65+
azure_credential = DefaultAzureCredential()
66+
67+
client_args["azure_ad_token_provider"] = get_bearer_token_provider(
68+
azure_credential, "https://cognitiveservices.azure.com/.default")
69+
client_args["api_version"] = getenv("AZURE_OPENAI_API_VERSION") or "2024-02-15-preview"
70+
client_args["azure_endpoint"] = getenv("AZURE_OPENAI_ENDPOINT")
71+
client_args["azure_deployment"] = getenv("AZURE_OPENAI_DEPLOYMENT")
72+
return client_args
73+
4074
def is_azure():
4175
azure = "AZURE_OPENAI_ENDPOINT" in environ or "AZURE_OPENAI_KEY" in environ or "AZURE_OPENAI_AD_TOKEN" in environ
4276
if azure:
4377
logger.debug("Using Azure OpenAI environment variables")
4478
else:
4579
logger.debug("Using OpenAI environment variables")
4680
return azure
81+
82+
def safe_min(a: Any, b: Any) -> Any:
83+
if a is None:
84+
return b
85+
if b is None:
86+
return a
87+
return min(a, b)
88+
89+
def safe_max(a: Any, b: Any) -> Any:
90+
if a is None:
91+
return b
92+
if b is None:
93+
return a
94+
return max(a, b)
95+
96+
class UsageStats:
97+
def __init__(self) -> None:
98+
self.start = time.time()
99+
self.completion_tokens = 0
100+
self.prompt_tokens = 0
101+
self.total_tokens = 0
102+
self.end = None
103+
self.duration = 0
104+
self.calls = 0
105+
106+
def __add__(self, other: 'UsageStats') -> 'UsageStats':
107+
stats = UsageStats()
108+
stats.start = safe_min(self.start, other.start)
109+
stats.end = safe_max(self.end, other.end)
110+
stats.completion_tokens = self.completion_tokens + other.completion_tokens
111+
stats.prompt_tokens = self.prompt_tokens + other.prompt_tokens
112+
stats.total_tokens = self.total_tokens + other.total_tokens
113+
stats.duration = self.duration + other.duration
114+
stats.calls = self.calls + other.calls
115+
return stats
116+
117+
class StatsCompleter(ABC):
118+
def __init__(self, create_func):
119+
self.create_func = create_func
120+
self.stats = None
121+
self.lock = Lock()
122+
123+
def __call__(self, *args: Any, **kwds: Any) -> Any:
124+
response = self.create_func(*args, **kwds)
125+
self.lock.acquire()
126+
try:
127+
if not self.stats:
128+
self.stats = UsageStats()
129+
self.stats.completion_tokens += response.usage.completion_tokens
130+
self.stats.prompt_tokens += response.usage.prompt_tokens
131+
self.stats.total_tokens += response.usage.total_tokens
132+
self.stats.calls += 1
133+
return response
134+
finally:
135+
self.lock.release()
136+
137+
def get_stats_and_reset(self) -> UsageStats:
138+
self.lock.acquire()
139+
try:
140+
end = time.time()
141+
stats = self.stats
142+
if stats:
143+
stats.end = end
144+
stats.duration = end - self.stats.start
145+
self.stats = None
146+
return stats
147+
finally:
148+
self.lock.release()
149+
150+
class ChatCompleter(StatsCompleter):
151+
def __init__(self, client):
152+
super().__init__(client.chat.completions.create)
153+
154+
class CompletionsCompleter(StatsCompleter):
155+
def __init__(self, client):
156+
super().__init__(client.completions.create)

raft/env_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,30 @@
11
import contextlib
22
import os
3+
import logging
4+
5+
logger = logging.getLogger("env_config")
36

47
# List of environment variables prefixes that are allowed to be used for configuration.
58
env_prefix_whitelist = [
69
'OPENAI',
710
'AZURE_OPENAI'
811
]
912

13+
def _obfuscate(secret):
14+
l = len(secret)
15+
return '.' * (l - 4) + secret[-4:]
16+
17+
def _log_env(use_prefix: str, env: dict):
18+
"""
19+
Logs each name value pair of the given environment. If the name indicates that it might store a secret such as an API key, then obfuscate the value.
20+
"""
21+
log_prefix = f"'{use_prefix}'" if use_prefix else "no"
22+
logger.info(f"Resolved OpenAI env vars with {log_prefix} prefix:")
23+
for key, value in env.items():
24+
if any(prefix in key for prefix in ['KEY', 'SECRET', 'TOKEN']):
25+
value = _obfuscate(value)
26+
logger.info(f" - {key}={value}")
27+
1028
def read_env_config(use_prefix: str, env: dict = os.environ) -> str:
1129
"""
1230
Read whitelisted environment variables and return them in a dictionary.
@@ -15,6 +33,7 @@ def read_env_config(use_prefix: str, env: dict = os.environ) -> str:
1533
config = {}
1634
for prefix in [None, use_prefix]:
1735
read_env_config_prefixed(prefix, config, env)
36+
_log_env(use_prefix, config)
1837
return config
1938

2039
def read_env_config_prefixed(use_prefix: str, config: dict, env: dict = os.environ) -> str:

0 commit comments

Comments
 (0)