Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5824d8e
Squashed commit of the following:
mayureshagashe2105 Jun 5, 2024
3d8ff35
Remove auto cache deletion
mayureshagashe2105 Jun 5, 2024
4c5a31b
Rename _to_dict --> _get_update_fields
mayureshagashe2105 Jun 5, 2024
b6d440f
Fix tests
mayureshagashe2105 Jun 5, 2024
e53dbab
Set 'CachedContent' as a public property
mayureshagashe2105 Jun 5, 2024
f53a7c6
blacken
mayureshagashe2105 Jun 5, 2024
02cba55
set 'role=user' when content is passed as a str (#4)
mayureshagashe2105 Jun 5, 2024
4c495ef
Handle ttl and expire_time separately
mayureshagashe2105 Jun 6, 2024
0f5f8eb
Remove name param
mayureshagashe2105 Jun 6, 2024
cef3fc7
Update caching_types.py
MarkDaoust Jun 6, 2024
f03a765
Update caching.py
MarkDaoust Jun 6, 2024
42b1e35
Update docstrs and error messages
mayureshagashe2105 Jun 7, 2024
e4648a7
Update model name to gemini-1.5-pro for caching tests
mayureshagashe2105 Jun 7, 2024
f2b495f
Merge branch 'magashe-caching-patch-1' of https://github.com/mayuresh…
mayureshagashe2105 Jun 7, 2024
f715ecb
Remove dafault ttl assignment
mayureshagashe2105 Jun 7, 2024
a576166
blacken
mayureshagashe2105 Jun 7, 2024
5e9b14b
Remove client arg
mayureshagashe2105 Jun 10, 2024
6ccee3e
Add 'usage_metadata' param to CachedContent class
mayureshagashe2105 Jun 11, 2024
3de6909
Add 'display_name' to CachedContent class
mayureshagashe2105 Jun 11, 2024
7fccb32
update generativelanguage version, fix tests
MarkDaoust Jun 11, 2024
2fabe67
format
MarkDaoust Jun 11, 2024
7d14bb1
fewer automatic 'role' insertions
MarkDaoust Jun 11, 2024
3982b48
cleanup
MarkDaoust Jun 11, 2024
940834a
Wrap the proto
MarkDaoust Jun 12, 2024
4a5229e
Apply suggestions from code review
MarkDaoust Jun 12, 2024
c039644
fix
MarkDaoust Jun 12, 2024
fc767a1
format
MarkDaoust Jun 12, 2024
75cc224
cleanup
MarkDaoust Jun 12, 2024
19f0384
update version
MarkDaoust Jun 12, 2024
d438860
fix
MarkDaoust Jun 12, 2024
aa12c3d
typing
MarkDaoust Jun 12, 2024
cc54a87
Merge branch 'main' into magashe-caching-patch-1
MarkDaoust Jun 13, 2024
a6f4355
Simplify update method
mayureshagashe2105 Jun 13, 2024
1c77da4
Add repr to CachedContent
mayureshagashe2105 Jun 13, 2024
0bac36e
cleanup
mayureshagashe2105 Jun 13, 2024
25f4d10
blacken
mayureshagashe2105 Jun 13, 2024
9b48863
Apply suggestions from code review
mayureshagashe2105 Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
from google.generativeai.types.model_types import idecode_time
from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai import string_utils
from google.generativeai.utils import flatten_update_paths
from google.generativeai.client import get_default_cache_client

from google.protobuf import field_mask_pb2
import google.ai.generativelanguage as glm


@string_utils.prettyprint
@dataclasses.dataclass
class CachedContent:
"""Cached content resource."""
Expand All @@ -39,29 +41,19 @@ class CachedContent:
update_time: datetime.datetime
expire_time: datetime.datetime

# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
# Adding basic support for now.
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_tb):
self.delete()

def _to_dict(self) -> protos.CachedContent:
def _get_update_fields(self, **input_only_update_fields) -> protos.CachedContent:
proto_paths = {
"name": self.name,
"model": self.model,
}
proto_paths.update(input_only_update_fields)
return protos.CachedContent(**proto_paths)

def _apply_update(self, path, value):
parts = path.split(".")
for part in parts[:-1]:
self = getattr(self, part)
if parts[-1] == "ttl":
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
parts[-1] = "expire_time"
setattr(self, parts[-1], value)
if path[-1] != "ttl":
setattr(self, parts[-1], value)

@classmethod
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
Expand Down Expand Up @@ -112,7 +104,7 @@ def _prepare_create_request(
contents = content_types.to_contents(contents)

if ttl:
ttl = caching_types.to_ttl(ttl)
ttl = caching_types.to_expiration(ttl)

cached_content = protos.CachedContent(
name=name,
Expand Down Expand Up @@ -236,25 +228,35 @@ def update(
if client is None:
client = get_default_cache_client()

if "ttl" in updates and "expire_time" in updates:
raise ValueError(
"`expiration` is a _oneof field. Please provide either `ttl` or `expire_time`."
)

field_mask = field_mask_pb2.FieldMask()

updates = flatten_update_paths(updates)
for update_path in updates:
if update_path == "ttl":
if update_path == "ttl" or update_path == "expire_time":
updates = updates.copy()
update_path_val = updates.get(update_path)
updates[update_path] = caching_types.to_ttl(update_path_val)
updates[update_path] = caching_types.to_expiration(update_path_val)
else:
raise ValueError(
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
)
field_mask = field_mask_pb2.FieldMask()

for path in updates.keys():
field_mask.paths.append(path)
field_mask.paths.append(update_path)

for path, value in updates.items():
self._apply_update(path, value)

request = protos.UpdateCachedContentRequest(
cached_content=self._to_dict(), update_mask=field_mask
cached_content=self._get_update_fields(**updates), update_mask=field_mask
)
client.update_cached_content(request)
updated_cc = client.update_cached_content(request)
updated_cc = self._decode_cached_content(updated_cc)
for path, value in dataclasses.asdict(updated_cc).items():
self._apply_update(path, value)

return self
28 changes: 10 additions & 18 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,9 @@ def __init__(
self._client = None
self._async_client = None

def __new__(cls, *args, **kwargs) -> GenerativeModel:
self = super().__new__(cls)

if cached_instance := kwargs.pop("cached_content", None):
setattr(self, "_cached_content", cached_instance.name)
setattr(cls, "cached_content", property(fget=lambda self: self._cached_content))

return self
@property
def cached_content(self) -> str:
return getattr(self, "_cached_content", None)

@property
def model_name(self):
Expand All @@ -123,7 +118,7 @@ def maybe_text(content):
safety_settings={self._safety_settings},
tools={self._tools},
system_instruction={maybe_text(self._system_instruction)},
cached_content={getattr(self, "cached_content", None)}
cached_content={self.cached_content}
)"""
)

Expand All @@ -139,13 +134,11 @@ def _prepare_request(
tool_config: content_types.ToolConfigType | None,
) -> protos.GenerateContentRequest:
"""Creates a `protos.GenerateContentRequest` from raw inputs."""
if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]):
if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]):
raise ValueError(
"`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context."
)

cached_content = getattr(self, "cached_content", None)

tools_lib = self._get_tools_lib(tools)
if tools_lib is not None:
tools_lib = tools_lib.to_proto()
Expand Down Expand Up @@ -174,7 +167,7 @@ def _prepare_request(
tools=tools_lib,
tool_config=tool_config,
system_instruction=self._system_instruction,
cached_content=cached_content,
cached_content=self.cached_content,
)

def _get_tools_lib(
Expand Down Expand Up @@ -221,17 +214,16 @@ def from_cached_content(
if isinstance(cached_content, str):
cached_content = caching.CachedContent.get(name=cached_content)

# call __new__ with the cached_content to set the model's context. This is done to avoid
# the exposing `cached_content` as a public attribute.
self = cls.__new__(cls, cached_content=cached_content)

# call __init__ to set the model's `generation_config`, `safety_settings`.
# `model_name` will be the name of the model for which the `cached_content` was created.
self.__init__(
self = GenerativeModel(
model_name=cached_content.model,
generation_config=generation_config,
safety_settings=safety_settings,
)

# set the model's context.
setattr(self, "_cached_content", cached_content.name)
return self

def generate_content(
Expand Down
33 changes: 26 additions & 7 deletions google/generativeai/types/caching_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_extensions import TypedDict
import re

__all__ = ["TTL"]
__all__ = ["ExpirationTypes", "ExpireTime", "TTL"]


_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$"
Expand All @@ -33,19 +33,38 @@ def valid_cached_content_name(name: str) -> bool:


class TTL(TypedDict):
# Represents datetime.datetime.now() + desired ttl
seconds: int
nanos: int


ExpirationTypes = Union[TTL, int, datetime.timedelta]
class ExpireTime(TypedDict):
# Represents seconds of UTC time since Unix epoch
seconds: int
nanos: int


ExpirationTypes = Union[TTL, ExpireTime, int, datetime.timedelta, datetime.datetime]


def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL:
if isinstance(expiration, datetime.timedelta):
return {"seconds": int(expiration.total_seconds())}
def to_expiration(expiration: Optional[ExpirationTypes]) -> TTL | ExpireTime:
if isinstance(expiration, datetime.timedelta): # consider `ttl`
return {
"seconds": int(expiration.total_seconds()),
"nanos": int(expiration.microseconds * 1000),
}
elif isinstance(expiration, datetime.datetime): # consider `expire_time`
timestamp = expiration.timestamp()
seconds = int(timestamp)
nanos = int((seconds % 1) * 1000)
return {
"seconds": seconds,
"nanos": nanos,
}
elif isinstance(expiration, dict):
return expiration
elif isinstance(expiration, int):
return {"seconds": expiration}
elif isinstance(expiration, int): # consider `ttl`
return {"seconds": expiration, "nanos": 0}
else:
raise TypeError(
f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n",
Expand Down
26 changes: 10 additions & 16 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,16 @@ def test_update_cached_content_invalid_update_paths(self):
with self.assertRaises(ValueError):
cc.update(updates=update_masks)

def test_update_cached_content_valid_update_paths(self):
update_masks = dict(
ttl=datetime.timedelta(hours=2),
)
@parameterized.named_parameters(
[
dict(testcase_name="ttl", update_masks=dict(ttl=datetime.timedelta(hours=2))),
dict(
testcase_name="expire_time",
update_masks=dict(expire_time=datetime.datetime(2024, 6, 5, 12, 12, 12, 23)),
),
]
)
def test_update_cached_content_valid_update_paths(self, update_masks):

cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
cc = cc.update(updates=update_masks)
Expand All @@ -229,18 +235,6 @@ def test_delete_cached_content(self):
cc.delete()
self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest)

def test_auto_delete_cached_content_with_context_manager(self):
with caching.CachedContent.create(
name="test-cached-content",
model="models/gemini-1.0-pro-001",
contents=["Add 5 and 6"],
system_instruction="Always add 10 to the result.",
ttl=datetime.timedelta(minutes=30),
) as cc:
... # some logic

self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest)


if __name__ == "__main__":
absltest.main()