Skip to content

Simplify try_acquire method #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 42 additions & 42 deletions packages/lmi/src/lmi/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ async def rate_limit_status(self):
}
return limit_status

async def _get_resource_and_rate_limit(
self,
namespace_and_key: tuple[str, str],
rate_limit: RateLimitItem | str | None = None,
machine_id: int = 0,
) -> tuple[str, str, RateLimitItem]:
namespace, primary_key = await self.parse_namespace_and_primary_key(
namespace_and_key, machine_id=machine_id
)

rate_limit_, new_namespace = self.parse_rate_limits_and_namespace(
namespace, primary_key
)

if isinstance(rate_limit, str):
rate_limit = limit_parse(rate_limit)

rate_limit = rate_limit or rate_limit_
return new_namespace, primary_key, rate_limit

async def try_acquire(
self,
namespace_and_key: tuple[str, str],
Expand Down Expand Up @@ -338,59 +358,39 @@ async def try_acquire(
TimeoutError: if the acquire_timeout is exceeded.
ValueError: if the weight exceeds the rate limit and raise_impossible_limits is True.
"""
namespace, primary_key = await self.parse_namespace_and_primary_key(
namespace_and_key, machine_id=machine_id
)

rate_limit_, new_namespace = self.parse_rate_limits_and_namespace(
namespace, primary_key
(
new_namespace,
primary_key,
rate_limit,
) = await self._get_resource_and_rate_limit(
namespace_and_key, rate_limit, machine_id
)

if isinstance(rate_limit, str):
rate_limit = limit_parse(rate_limit)

rate_limit = rate_limit or rate_limit_

if rate_limit.amount < weight and raise_impossible_limits:
raise ValueError(
f"Weight ({weight}) > RateLimit ({rate_limit}), cannot satisfy rate"
" limit."
)
while True:
elapsed = 0.0
while (
not (
await self.rate_limiter.test(
rate_limit,
new_namespace,
primary_key,
cost=min(weight, rate_limit.amount),
)
)
and elapsed < acquire_timeout
):
await asyncio.sleep(self.WAIT_INCREMENT)
elapsed += self.WAIT_INCREMENT
if elapsed >= acquire_timeout:
raise TimeoutError(
f"Timeout ({elapsed} secs): rate limit for key: {namespace_and_key}"
)

# If the rate limit hit is False, then we're violating the limit, so we
# need to wait again. This can happen in race conditions.
if await self.rate_limiter.hit(
elapsed = 0.0
while elapsed < acquire_timeout and weight > 0:
cost = min(weight, rate_limit.amount)
could_consume = await self.rate_limiter.hit(
rate_limit,
new_namespace,
primary_key,
cost=min(weight, rate_limit.amount),
):
# we need to keep trying when we have an "impossible" limit
if rate_limit.amount < weight:
weight -= rate_limit.amount
acquire_timeout = max(acquire_timeout - elapsed, 1.0)
continue
break
acquire_timeout = max(acquire_timeout - elapsed, 1.0)
cost=cost,
)
if could_consume:
weight -= cost
else:
await asyncio.sleep(self.WAIT_INCREMENT)
elapsed += self.WAIT_INCREMENT

if weight > 0:
raise TimeoutError(
f"Timeout ({elapsed} secs): rate limit for key: {namespace_and_key}"
)


GLOBAL_LIMITER = GlobalRateLimiter()
21 changes: 21 additions & 0 deletions packages/lmi/tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from itertools import product
from typing import Any
from unittest.mock import patch

import pytest
from aviary.core import Message
Expand All @@ -10,6 +11,7 @@
from lmi.constants import CHARACTERS_PER_TOKEN_ASSUMPTION
from lmi.embeddings import LiteLLMEmbeddingModel
from lmi.llms import CommonLLMNames, LiteLLMModel
from lmi.rate_limiter import GLOBAL_LIMITER
from lmi.types import LLMResult

LLM_CONFIG_W_RATE_LIMITS = [
Expand Down Expand Up @@ -293,3 +295,22 @@ async def test_embedding_rate_limits(
)
else:
assert estimated_tokens_per_second > 0


@pytest.mark.asyncio
async def test_try_acquire():
TEST_RATE_CONFIG = {
("get", "test"): RateLimitItemPerSecond(30, 4),
}

with patch.object(GLOBAL_LIMITER, "rate_config", TEST_RATE_CONFIG):
# We can acquire 30 in less than 10 seconds
for _ in range(30):
await GLOBAL_LIMITER.try_acquire(("get", "test"))

# But if we try for one more we hit the limit
with pytest.raises(TimeoutError):
await GLOBAL_LIMITER.try_acquire(("get", "test"), acquire_timeout=2)

# Then we pass an impossible limit, but with a timeour high enough to succeed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Then we pass an impossible limit, but with a timeour high enough to succeed
# Then we pass an impossible limit, but with a timeout high enough to succeed

Re-reading this, I just noticed a typo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fun I made codespell-project/codespell#3656 for this, lol

await GLOBAL_LIMITER.try_acquire(("get", "test"), weight=40, acquire_timeout=20)