Skip to content

Commit 3e0dc79

Browse files
authored
Add version and more URLs to index document (#2029)
plus supporting `python -m cog.http.server --version` for improved runtime inspectability.
1 parent eb04c7b commit 3e0dc79

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

python/cog/server/http.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
)
3939
from ..types import PYDANTIC_V2, CogConfig
4040

41+
try:
42+
from .._version import __version__
43+
except ImportError:
44+
__version__ = "dev"
45+
4146
if PYDANTIC_V2:
4247
from .helpers import (
4348
unwrap_pydantic_serialization_iterators,
@@ -187,6 +192,17 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa
187192

188193
return wrapped
189194

195+
index_document = {
196+
"cog_version": __version__,
197+
"docs_url": "/docs",
198+
"openapi_url": "/openapi.json",
199+
"shutdown_url": "/shutdown",
200+
"healthcheck_url": "/health-check",
201+
"predictions_url": "/predictions",
202+
"predictions_idempotent_url": "/predictions/{prediction_id}",
203+
"predictions_cancel_url": "/predictions/{prediction_id}/cancel",
204+
}
205+
190206
if "train" in config:
191207
try:
192208
trainer_ref = get_predictor_ref(config, "train")
@@ -281,6 +297,14 @@ def cancel_training(
281297
) -> Any:
282298
return cancel(training_id)
283299

300+
index_document.update(
301+
{
302+
"trainings_url": "/trainings",
303+
"trainings_idempotent_url": "/trainings/{training_id}",
304+
"trainings_cancel_url": "/trainings/{training_id}/cancel",
305+
}
306+
)
307+
284308
except Exception as e: # pylint: disable=broad-exception-caught
285309
if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build:
286310
pass # ignore missing train.py for backward compatibility with existing "bad" models in use
@@ -310,11 +334,7 @@ def shutdown() -> None:
310334

311335
@app.get("/")
312336
async def root() -> Any:
313-
return {
314-
# "cog_version": "", # TODO
315-
"docs_url": "/docs",
316-
"openapi_url": "/openapi.json",
317-
}
337+
return index_document
318338

319339
@app.get("/health-check")
320340
async def healthcheck() -> Any:
@@ -570,6 +590,9 @@ def _cpu_count() -> int:
570590

571591
if __name__ == "__main__":
572592
parser = argparse.ArgumentParser(description="Cog HTTP server")
593+
parser.add_argument(
594+
"-v", "--version", action="store_true", help="Show version and exit"
595+
)
573596
parser.add_argument(
574597
"--host",
575598
dest="host",
@@ -608,6 +631,10 @@ def _cpu_count() -> int:
608631
)
609632
args = parser.parse_args()
610633

634+
if args.version:
635+
print(f"cog.server.http {__version__}")
636+
sys.exit(0)
637+
611638
# log level is configurable so we can make it quiet or verbose for `cog predict`
612639
# cog predict --debug # -> debug
613640
# cog predict # -> warning

python/tests/server/test_http.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@
1818
)
1919

2020

21+
def test_index_document():
22+
client = make_client(fixture_name="slow_setup")
23+
resp = client.get("/")
24+
data = resp.json()
25+
for field in (
26+
"cog_version",
27+
"docs_url",
28+
"openapi_url",
29+
"shutdown_url",
30+
"healthcheck_url",
31+
"predictions_url",
32+
"predictions_idempotent_url",
33+
"predictions_cancel_url",
34+
):
35+
assert field in data
36+
assert data[field] is not None
37+
38+
2139
def test_setup_healthcheck():
2240
client = make_client(fixture_name="slow_setup")
2341
resp = client.get("/health-check")

0 commit comments

Comments
 (0)