Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 221 additions & 7 deletions codex-rs/login/src/login_with_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

The script should exit with a non-zero code if the user fails to navigate the
auth flow.

To test this script locally without overwriting your existing auth.json file:

```
rm -rf /tmp/codex_home && mkdir /tmp/codex_home
CODEX_HOME=/tmp/codex_home python3 codex-rs/login/src/login_with_chatgpt.py
```
"""

from __future__ import annotations
Expand All @@ -23,10 +30,12 @@
import secrets
import sys
import threading
import time
import urllib.parse
import urllib.request
import webbrowser
from dataclasses import dataclass
from typing import Any, Dict # for type hints

# Required port for OAuth client.
REQUIRED_PORT = 1455
Expand Down Expand Up @@ -244,12 +253,8 @@ def _exchange_code_for_api_key(self, code: str) -> tuple[AuthBundle, str]:
if len(access_token_parts) != 3:
raise ValueError("Invalid access token")

id_token_claims = json.loads(
base64.urlsafe_b64decode(id_token_parts[1] + "==").decode("utf-8")
)
access_token_claims = json.loads(
base64.urlsafe_b64decode(access_token_parts[1] + "==").decode("utf-8")
)
id_token_claims = _decode_jwt_segment(id_token_parts[1])
access_token_claims = _decode_jwt_segment(access_token_parts[1])

token_claims = id_token_claims.get("https://api.openai.com/auth", {})
access_claims = access_token_claims.get("https://api.openai.com/auth", {})
Expand Down Expand Up @@ -313,7 +318,20 @@ def _exchange_code_for_api_key(self, code: str) -> tuple[AuthBundle, str]:
}
success_url = f"{URL_BASE}/success?{urllib.parse.urlencode(success_url_query)}"

# TODO(mbolin): Port maybeRedeemCredits() to Python and call it here.
# Attempt to redeem complimentary API credits for eligible ChatGPT
# Plus / Pro subscribers. Any errors are logged but do not interrupt
# the login flow.

try:
maybe_redeem_credits(
issuer=self.server.issuer,
client_id=self.server.client_id,
id_token=token_data.id_token,
refresh_token=token_data.refresh_token,
codex_home=self.server.codex_home,
)
except Exception as exc: # pragma: no cover – best-effort only
eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}")

# Persist refresh_token/id_token for future use (redeem credits etc.)
last_refresh_str = (
Expand Down Expand Up @@ -417,6 +435,163 @@ def auth_url(self) -> str:
return f"{self.issuer}/oauth/authorize?" + urllib.parse.urlencode(params)


def maybe_redeem_credits(
*,
issuer: str,
client_id: str,
id_token: str | None,
refresh_token: str,
codex_home: str,
) -> None:
"""Attempt to redeem complimentary API credits for ChatGPT subscribers.

The operation is best-effort: any error results in a warning being printed
and the function returning early without raising.
"""
id_claims: Dict[str, Any] | None = parse_id_token_claims(id_token or "")

# Refresh expired ID token, if possible
token_expired = True
if id_claims and isinstance(id_claims.get("exp"), int):
token_expired = _current_timestamp_ms() >= int(id_claims["exp"]) * 1000

if token_expired:
eprint("Refreshing credentials...")
new_refresh_token: str | None = None
new_id_token: str | None = None

try:
payload = json.dumps(
{
"client_id": client_id,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": "openid profile email",
}
).encode()

req = urllib.request.Request(
url="https://auth.openai.com/oauth/token",
data=payload,
method="POST",
headers={"Content-Type": "application/json"},
)

with urllib.request.urlopen(req) as resp:
refresh_data = json.loads(resp.read().decode())
new_id_token = refresh_data.get("id_token")
new_id_claims = parse_id_token_claims(new_id_token or "")
new_refresh_token = refresh_data.get("refresh_token")
except Exception as err:
eprint("Unable to refresh ID token via token-exchange:", err)
return

if not new_id_token or not new_refresh_token:
return

# Update auth.json with new tokens.
try:
auth_dir = codex_home
auth_path = os.path.join(auth_dir, "auth.json")
with open(auth_path, "r", encoding="utf-8") as fp:
existing = json.load(fp)

tokens = existing.setdefault("tokens", {})
tokens["id_token"] = new_id_token
# Note this does not touch the access_token?
tokens["refresh_token"] = new_refresh_token
tokens["last_refresh"] = (
datetime.datetime.now(datetime.timezone.utc)
.isoformat()
.replace("+00:00", "Z")
)

with open(auth_path, "w", encoding="utf-8") as fp:
if hasattr(os, "fchmod"):
os.fchmod(fp.fileno(), 0o600)
json.dump(existing, fp, indent=2)
except Exception as err:
eprint("Unable to update refresh token in auth file:", err)

if not new_id_claims:
# Still couldn't parse claims.
return

id_token = new_id_token
id_claims = new_id_claims

# Done refreshing credentials: now try to redeem credits.
if not id_token:
eprint("No ID token available, cannot redeem credits.")
return

auth_claims = id_claims.get("https://api.openai.com/auth", {})

# Subscription eligibility check (Plus or Pro, >7 days active)
sub_start_str = auth_claims.get("chatgpt_subscription_active_start")
if isinstance(sub_start_str, str):
try:
sub_start_ts = datetime.datetime.fromisoformat(sub_start_str.rstrip("Z"))
if datetime.datetime.now(
datetime.timezone.utc
) - sub_start_ts < datetime.timedelta(days=7):
eprint(
"Sorry, your subscription must be active for more than 7 days to redeem credits."
)
return
except ValueError:
# Malformed; ignore
pass

completed_onboarding = bool(auth_claims.get("completed_platform_onboarding"))
is_org_owner = bool(auth_claims.get("is_org_owner"))
needs_setup = not completed_onboarding and is_org_owner
plan_type = auth_claims.get("chatgpt_plan_type")

if needs_setup or plan_type not in {"plus", "pro"}:
eprint("Only users with Plus or Pro subscriptions can redeem free API credits.")
return

api_host = (
"https://api.openai.com"
if issuer == "https://auth.openai.com"
else "https://api.openai.org"
)

try:
redeem_payload = json.dumps({"id_token": id_token}).encode()
req = urllib.request.Request(
url=f"{api_host}/v1/billing/redeem_credits",
data=redeem_payload,
method="POST",
headers={"Content-Type": "application/json"},
)

with urllib.request.urlopen(req) as resp:
redeem_data = json.loads(resp.read().decode())

granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0)
if granted and granted > 0:
eprint(
f"""Thanks for being a ChatGPT {'Plus' if plan_type=='plus' else 'Pro'} subscriber!
If you haven't already redeemed, you should receive {'$5' if plan_type=='plus' else '$50'} in API credits.

Credits: https://platform.openai.com/settings/organization/billing/credit-grants
More info: https://help.openai.com/en/articles/11381614""",
)
else:
eprint(
f"""It looks like no credits were granted:

{json.dumps(redeem_data, indent=2)}

Credits: https://platform.openai.com/settings/organization/billing/credit-grants
More info: https://help.openai.com/en/articles/11381614"""
)
except Exception as err:
eprint("Credit redemption request failed:", err)


def _generate_pkce() -> PkceCodes:
"""Generate PKCE *code_verifier* and *code_challenge* (S256)."""
code_verifier = secrets.token_hex(64)
Expand All @@ -429,6 +604,45 @@ def eprint(*args, **kwargs) -> None:
print(*args, file=sys.stderr, **kwargs)


# Parse ID-token claims (if provided)
#
# interface IDTokenClaims {
# "exp": number; // specifically, an int
# "https://api.openai.com/auth": {
# organization_id: string;
# project_id: string;
# completed_platform_onboarding: boolean;
# is_org_owner: boolean;
# chatgpt_subscription_active_start: string;
# chatgpt_subscription_active_until: string;
# chatgpt_plan_type: string;
# };
# }
def parse_id_token_claims(id_token: str) -> Dict[str, Any] | None:
if id_token:
parts = id_token.split(".")
if len(parts) == 3:
return _decode_jwt_segment(parts[1])
return None


def _decode_jwt_segment(segment: str) -> Dict[str, Any]:
"""Return the decoded JSON payload from a JWT segment.

Adds required padding for urlsafe_b64decode.
"""
padded = segment + "=" * (-len(segment) % 4)
try:
data = base64.urlsafe_b64decode(padded.encode())
return json.loads(data.decode())
except Exception:
return {}


def _current_timestamp_ms() -> int:
return int(time.time() * 1000)


LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
<html lang="en">
<head>
Expand Down