|
38 | 38 | )
|
39 | 39 | from ..types import PYDANTIC_V2, CogConfig
|
40 | 40 |
|
| 41 | +try: |
| 42 | + from .._version import __version__ |
| 43 | +except ImportError: |
| 44 | + __version__ = "dev" |
| 45 | + |
41 | 46 | if PYDANTIC_V2:
|
42 | 47 | from .helpers import (
|
43 | 48 | unwrap_pydantic_serialization_iterators,
|
@@ -187,6 +192,17 @@ async def wrapped(*args: "P.args", **kwargs: "P.kwargs") -> "T": # pylint: disa
|
187 | 192 |
|
188 | 193 | return wrapped
|
189 | 194 |
|
| 195 | + index_document = { |
| 196 | + "cog_version": __version__, |
| 197 | + "docs_url": "/docs", |
| 198 | + "openapi_url": "/openapi.json", |
| 199 | + "shutdown_url": "/shutdown", |
| 200 | + "healthcheck_url": "/health-check", |
| 201 | + "predictions_url": "/predictions", |
| 202 | + "predictions_idempotent_url": "/predictions/{prediction_id}", |
| 203 | + "predictions_cancel_url": "/predictions/{prediction_id}/cancel", |
| 204 | + } |
| 205 | + |
190 | 206 | if "train" in config:
|
191 | 207 | try:
|
192 | 208 | trainer_ref = get_predictor_ref(config, "train")
|
@@ -281,6 +297,14 @@ def cancel_training(
|
281 | 297 | ) -> Any:
|
282 | 298 | return cancel(training_id)
|
283 | 299 |
|
| 300 | + index_document.update( |
| 301 | + { |
| 302 | + "trainings_url": "/trainings", |
| 303 | + "trainings_idempotent_url": "/trainings/{training_id}", |
| 304 | + "trainings_cancel_url": "/trainings/{training_id}/cancel", |
| 305 | + } |
| 306 | + ) |
| 307 | + |
284 | 308 | except Exception as e: # pylint: disable=broad-exception-caught
|
285 | 309 | if isinstance(e, (PredictorNotSet, FileNotFoundError)) and not is_build:
|
286 | 310 | pass # ignore missing train.py for backward compatibility with existing "bad" models in use
|
@@ -310,11 +334,7 @@ def shutdown() -> None:
|
310 | 334 |
|
311 | 335 | @app.get("/")
|
312 | 336 | async def root() -> Any:
|
313 |
| - return { |
314 |
| - # "cog_version": "", # TODO |
315 |
| - "docs_url": "/docs", |
316 |
| - "openapi_url": "/openapi.json", |
317 |
| - } |
| 337 | + return index_document |
318 | 338 |
|
319 | 339 | @app.get("/health-check")
|
320 | 340 | async def healthcheck() -> Any:
|
@@ -570,6 +590,9 @@ def _cpu_count() -> int:
|
570 | 590 |
|
571 | 591 | if __name__ == "__main__":
|
572 | 592 | parser = argparse.ArgumentParser(description="Cog HTTP server")
|
| 593 | + parser.add_argument( |
| 594 | + "-v", "--version", action="store_true", help="Show version and exit" |
| 595 | + ) |
573 | 596 | parser.add_argument(
|
574 | 597 | "--host",
|
575 | 598 | dest="host",
|
@@ -608,6 +631,10 @@ def _cpu_count() -> int:
|
608 | 631 | )
|
609 | 632 | args = parser.parse_args()
|
610 | 633 |
|
| 634 | + if args.version: |
| 635 | + print(f"cog.server.http {__version__}") |
| 636 | + sys.exit(0) |
| 637 | + |
611 | 638 | # log level is configurable so we can make it quiet or verbose for `cog predict`
|
612 | 639 | # cog predict --debug # -> debug
|
613 | 640 | # cog predict # -> warning
|
|
0 commit comments