Skip to content

Commit b98882a

Browse files
authored
feat: generate response length based on a histogram when max_tokens is defined in the request (#169)
* when request contains max tokens, calculate length on the response based on a histogram - intial implementation Signed-off-by: Maya Barnea <[email protected]> * - in case max_tokens is defined in the request, finish reason will not be randomly selected, instead it will be stop when response length is maxTokens, otherwise - stop - fix utils_tests Signed-off-by: Maya Barnea <[email protected]> * rename variables + fix problem by PR comment Signed-off-by: Maya Barnea <[email protected]> * add more explanations to getResponseLengthByHistogram Signed-off-by: Maya Barnea <[email protected]> * fix misspelling Signed-off-by: Maya Barnea <[email protected]> * fixes in comments Signed-off-by: Maya Barnea <[email protected]> --------- Signed-off-by: Maya Barnea <[email protected]>
1 parent 938be82 commit b98882a

File tree

2 files changed

+85
-6
lines changed

2 files changed

+85
-6
lines changed

pkg/common/utils.go

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ const (
3939
RemoteDecodeFinishReason = "remote_decode"
4040
)
4141

42+
// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response
43+
var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
44+
var cumulativeBucketsProbabilities []float64
45+
4246
// list of responses to use in random mode for comepltion requests
4347
var chatCompletionFakeResponses = []string{
4448
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
@@ -54,6 +58,16 @@ var chatCompletionFakeResponses = []string{
5458
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
5559
}
5660

61+
func init() {
62+
cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
63+
sum := 0.0
64+
65+
for i, val := range respLenBucketsProbabilities {
66+
sum += val
67+
cumulativeBucketsProbabilities[i] = sum
68+
}
69+
}
70+
5771
// returns the max tokens or error if incorrect
5872
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
5973
var typeToken string
@@ -154,14 +168,67 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
154168
if maxCompletionTokens == nil {
155169
numOfTokens = GetRandomResponseLen()
156170
} else {
157-
numOfTokens = int(*maxCompletionTokens)
158-
finishReason = GetRandomFinishReason()
171+
maxTokens := int(*maxCompletionTokens)
172+
// max tokens is defined - generate real length of the response based on it
173+
numOfTokens = getResponseLengthByHistogram(maxTokens)
174+
if numOfTokens == maxTokens {
175+
// if response should be create with maximum number of tokens - finish reason will be 'length'
176+
finishReason = LengthFinishReason
177+
}
159178
}
160179

161180
text := GetRandomText(numOfTokens)
162181
return text, finishReason
163182
}
164183

184+
// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
185+
// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities.
186+
// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value.
187+
// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens.
188+
// Other values define probabilities for the equally sized buckets.
189+
// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens]
190+
func getResponseLengthByHistogram(maxTokens int) int {
191+
if maxTokens <= 1 {
192+
return maxTokens
193+
}
194+
// maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens]
195+
if maxTokens <= len(cumulativeBucketsProbabilities) {
196+
res := RandomInt(1, maxTokens)
197+
return res
198+
}
199+
200+
r := RandomFloat(0, 1)
201+
202+
// check if r is in the last bucket, then maxTokens should be returned
203+
if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] {
204+
return maxTokens
205+
}
206+
207+
// determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use
208+
// initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1
209+
bucketIndex := len(cumulativeBucketsProbabilities) - 1
210+
for i, c := range cumulativeBucketsProbabilities {
211+
if r <= c {
212+
bucketIndex = i
213+
break
214+
}
215+
}
216+
217+
// calculate the size of all of the buckets (except the special last bucket)
218+
bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1)
219+
// start is the minimum number in the required bucket
220+
start := int(bucketSize*float64(bucketIndex)) + 1
221+
// end is the maximum number in the required bucket
222+
end := int(bucketSize * float64(bucketIndex+1))
223+
// sometimes end could be maxTokens because of rounding, change the value to maxToken-1
224+
if end >= maxTokens {
225+
end = maxTokens - 1
226+
}
227+
228+
// pick uniformly within the bucket’s range
229+
return RandomInt(start, end)
230+
}
231+
165232
// GetResponseText returns response text, from a given text
166233
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
167234
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {

pkg/common/utils_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() {
3838
It("should return short text", func() {
3939
maxCompletionTokens := int64(2)
4040
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
41-
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
42-
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
41+
tokensCnt := int64(len(Tokenize(text)))
42+
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
43+
if tokensCnt == maxCompletionTokens {
44+
Expect(finishReason).To(Equal(LengthFinishReason))
45+
} else {
46+
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
47+
Expect(finishReason).To(Equal(StopFinishReason))
48+
}
4349
})
4450
It("should return long text", func() {
4551
// return required number of tokens although it is higher than ResponseLenMax
4652
maxCompletionTokens := int64(ResponseLenMax * 5)
4753
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
48-
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
54+
tokensCnt := int64(len(Tokenize(text)))
55+
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
4956
Expect(IsValidText(text)).To(BeTrue())
50-
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
57+
if tokensCnt == maxCompletionTokens {
58+
Expect(finishReason).To(Equal(LengthFinishReason))
59+
} else {
60+
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
61+
Expect(finishReason).To(Equal(StopFinishReason))
62+
}
5163
})
5264
})
5365

0 commit comments

Comments
 (0)