Files
hive/core/antigravity_auth.py
Hundao 589c5b06fe fix: resolve all ruff lint and format errors across codebase (#7058)
- Auto-fixed 70 lint errors (import sorting, aliased errors, datetime.UTC)
- Fixed 85 remaining errors manually:
  - E501: wrapped long lines in queen_profiles, catalog, routes_credentials
  - F821: added missing TYPE_CHECKING imports for AgentHost, ToolRegistry,
    HookContext, HookResult; added runtime imports where needed
  - F811: removed duplicate method definitions in queen_lifecycle_tools
  - F841/B007: removed unused variables in discovery.py
  - W291: removed trailing whitespace in queen nodes
  - E402: moved import to top of queen_memory_v2.py
  - Fixed AgentRuntime -> AgentHost in example template type annotations
- Reformatted 343 files with ruff format
2026-04-16 19:30:01 +08:00

570 lines
19 KiB
Python

#!/usr/bin/env python3
"""Antigravity authentication CLI.
Implements OAuth2 flow for Google's Antigravity Code Assist gateway.
Credentials are stored in ~/.hive/antigravity-accounts.json.
Usage:
python -m antigravity_auth auth account add
python -m antigravity_auth auth account list
python -m antigravity_auth auth account remove <email>
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import secrets
import socket
import sys
import time
import urllib.parse
import urllib.request
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any
logging.basicConfig(level=logging.INFO, format="%(message)s")
logger = logging.getLogger(__name__)
# OAuth endpoints
_OAUTH_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_OAUTH_TOKEN_URL = "https://oauth2.googleapis.com/token"
# Scopes for Antigravity/Cloud Code Assist
_OAUTH_SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
]
# Credentials file path in ~/.hive/
_ACCOUNTS_FILE = Path.home() / ".hive" / "antigravity-accounts.json"
# Default project ID
_DEFAULT_PROJECT_ID = "rising-fact-p41fc"
_DEFAULT_REDIRECT_PORT = 51121
# OAuth credentials fetched from the opencode-antigravity-auth project.
# This project reverse-engineered and published the public OAuth credentials
# for Google's Antigravity/Cloud Code Assist API.
# Source: https://github.com/NoeFabris/opencode-antigravity-auth
_CREDENTIALS_URL = "https://raw.githubusercontent.com/NoeFabris/opencode-antigravity-auth/dev/src/constants.ts"
# Cached credentials fetched from public source
_cached_client_id: str | None = None
_cached_client_secret: str | None = None
def _fetch_credentials_from_public_source() -> tuple[str | None, str | None]:
"""Fetch OAuth client ID and secret from the public npm package source on GitHub."""
global _cached_client_id, _cached_client_secret
if _cached_client_id and _cached_client_secret:
return _cached_client_id, _cached_client_secret
try:
req = urllib.request.Request(_CREDENTIALS_URL, headers={"User-Agent": "Hive-Antigravity-Auth/1.0"})
with urllib.request.urlopen(req, timeout=10) as resp:
content = resp.read().decode("utf-8")
import re
id_match = re.search(r'ANTIGRAVITY_CLIENT_ID\s*=\s*"([^"]+)"', content)
secret_match = re.search(r'ANTIGRAVITY_CLIENT_SECRET\s*=\s*"([^"]+)"', content)
if id_match:
_cached_client_id = id_match.group(1)
if secret_match:
_cached_client_secret = secret_match.group(1)
return _cached_client_id, _cached_client_secret
except Exception as e:
logger.debug(f"Failed to fetch credentials from public source: {e}")
return None, None
def get_client_id() -> str:
"""Get OAuth client ID from env, config, or public source."""
env_id = os.environ.get("ANTIGRAVITY_CLIENT_ID")
if env_id:
return env_id
# Try hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
cfg_id = cfg.get("llm", {}).get("antigravity_client_id")
if cfg_id:
return cfg_id
except Exception:
pass
# Fetch from public source
client_id, _ = _fetch_credentials_from_public_source()
if client_id:
return client_id
raise RuntimeError("Could not obtain Antigravity OAuth client ID")
def get_client_secret() -> str | None:
"""Get OAuth client secret from env, config, or public source."""
secret = os.environ.get("ANTIGRAVITY_CLIENT_SECRET")
if secret:
return secret
# Try to read from hive config
hive_cfg = Path.home() / ".hive" / "configuration.json"
if hive_cfg.exists():
try:
with open(hive_cfg) as f:
cfg = json.load(f)
secret = cfg.get("llm", {}).get("antigravity_client_secret")
if secret:
return secret
except Exception:
pass
# Fetch from public source (npm package on GitHub)
_, secret = _fetch_credentials_from_public_source()
return secret
def find_free_port() -> int:
"""Find an available local port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
return s.getsockname()[1]
class OAuthCallbackHandler(BaseHTTPRequestHandler):
"""Handle OAuth callback from browser."""
auth_code: str | None = None
state: str | None = None
error: str | None = None
def log_message(self, format: str, *args: Any) -> None:
pass # Suppress default logging
def do_GET(self) -> None:
parsed = urllib.parse.urlparse(self.path)
if parsed.path == "/oauth-callback":
query = urllib.parse.parse_qs(parsed.query)
if "error" in query:
self.error = query["error"][0]
self._send_response("Authentication failed. You can close this window.")
return
if "code" in query and "state" in query:
OAuthCallbackHandler.auth_code = query["code"][0]
OAuthCallbackHandler.state = query["state"][0]
self._send_response("Authentication successful! You can close this window and return to the terminal.")
return
self._send_response("Waiting for authentication...")
def _send_response(self, message: str) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
html = f"""<!DOCTYPE html>
<html>
<head><title>Antigravity Auth</title></head>
<body style="font-family: system-ui; display: flex; align-items: center;
justify-content: center; height: 100vh; margin: 0; background: #1a1a2e;
color: #eee;">
<div style="text-align: center;">
<h2>{message}</h2>
</div>
</body>
</html>"""
self.wfile.write(html.encode())
def wait_for_callback(port: int, timeout: int = 300) -> tuple[str | None, str | None, str | None]:
"""Start local server and wait for OAuth callback."""
server = HTTPServer(("localhost", port), OAuthCallbackHandler)
server.timeout = 1
start = time.time()
while time.time() - start < timeout:
if OAuthCallbackHandler.auth_code:
return (
OAuthCallbackHandler.auth_code,
OAuthCallbackHandler.state,
OAuthCallbackHandler.error,
)
server.handle_request()
return None, None, "timeout"
def exchange_code_for_tokens(
code: str, redirect_uri: str, client_id: str, client_secret: str | None
) -> dict[str, Any] | None:
"""Exchange authorization code for tokens."""
data = {
"code": code,
"client_id": client_id,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.error(f"Token exchange failed: {e}")
return None
def get_user_email(access_token: str) -> str | None:
"""Get user email from Google API."""
req = urllib.request.Request(
"https://www.googleapis.com/oauth2/v2/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
)
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("email")
except Exception:
return None
def load_accounts() -> dict[str, Any]:
"""Load existing accounts from file."""
if not _ACCOUNTS_FILE.exists():
return {"schemaVersion": 4, "accounts": []}
try:
with open(_ACCOUNTS_FILE) as f:
return json.load(f)
except Exception:
return {"schemaVersion": 4, "accounts": []}
def save_accounts(data: dict[str, Any]) -> None:
"""Save accounts to file."""
_ACCOUNTS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(_ACCOUNTS_FILE, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Saved credentials to {_ACCOUNTS_FILE}")
def validate_credentials(access_token: str, project_id: str = _DEFAULT_PROJECT_ID) -> bool:
"""Test if credentials work by making a simple API call to Antigravity.
Returns True if credentials are valid, False otherwise.
"""
endpoint = "https://daily-cloudcode-pa.sandbox.googleapis.com"
body = {
"project": project_id,
"model": "gemini-3-flash",
"request": {
"contents": [{"role": "user", "parts": [{"text": "hi"}]}],
"generationConfig": {"maxOutputTokens": 10},
},
"requestType": "agent",
"userAgent": "antigravity",
"requestId": "validation-test",
}
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Antigravity/1.18.3"
),
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
}
try:
req = urllib.request.Request(
f"{endpoint}/v1internal:generateContent",
data=json.dumps(body).encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(req, timeout=30) as resp:
json.loads(resp.read())
return True
except Exception:
return False
def refresh_access_token(refresh_token: str, client_id: str, client_secret: str | None) -> dict | None:
"""Refresh the access token using the refresh token."""
data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client_id,
}
if client_secret:
data["client_secret"] = client_secret
body = urllib.parse.urlencode(data).encode()
req = urllib.request.Request(
_OAUTH_TOKEN_URL,
data=body,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read())
except Exception as e:
logger.debug(f"Token refresh failed: {e}")
return None
def cmd_account_add(args: argparse.Namespace) -> int:
"""Add a new Antigravity account via OAuth2.
First checks if valid credentials already exist. If so, validates them
and skips OAuth if they work. Otherwise, proceeds with OAuth flow.
"""
client_id = get_client_id()
client_secret = get_client_secret()
# Check if credentials already exist
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
if accounts:
account = next((a for a in accounts if a.get("enabled", True) is not False), accounts[0])
access_token = account.get("access")
refresh_token_str = account.get("refresh", "")
refresh_token = refresh_token_str.split("|")[0] if refresh_token_str else None
project_id = refresh_token_str.split("|")[1] if "|" in refresh_token_str else _DEFAULT_PROJECT_ID
email = account.get("email", "unknown")
expires_ms = account.get("expires", 0)
expires_at = expires_ms / 1000.0 if expires_ms else 0.0
# Check if token is expired or near expiry
if access_token and expires_at and time.time() < expires_at - 60:
# Token still valid, test it
logger.info(f"Found existing credentials for: {email}")
logger.info("Validating existing credentials...")
if validate_credentials(access_token, project_id):
logger.info("✓ Credentials valid! Skipping OAuth.")
return 0
else:
logger.info("Credentials failed validation, refreshing...")
elif refresh_token:
logger.info(f"Found expired credentials for: {email}")
logger.info("Attempting token refresh...")
tokens = refresh_access_token(refresh_token, client_id, client_secret)
if tokens:
new_access = tokens.get("access_token")
expires_in = tokens.get("expires_in", 3600)
if new_access:
# Update the account
account["access"] = new_access
account["expires"] = int((time.time() + expires_in) * 1000)
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
save_accounts(accounts_data)
# Validate the refreshed token
logger.info("Validating refreshed credentials...")
if validate_credentials(new_access, project_id):
logger.info("✓ Credentials refreshed and validated!")
return 0
else:
logger.info("Refreshed token failed validation, proceeding with OAuth...")
else:
logger.info("Token refresh failed, proceeding with OAuth...")
# No valid credentials, proceed with OAuth
if not client_secret:
logger.warning(
"No client secret configured. Token refresh may fail.\n"
"Set ANTIGRAVITY_CLIENT_SECRET env var or add "
"'antigravity_client_secret' to ~/.hive/configuration.json"
)
# Use fixed port and path matching Google's expected OAuth redirect URI
port = _DEFAULT_REDIRECT_PORT
redirect_uri = f"http://localhost:{port}/oauth-callback"
# Generate state for CSRF protection
state = secrets.token_urlsafe(16)
# Build authorization URL
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(_OAUTH_SCOPES),
"state": state,
"access_type": "offline",
"prompt": "consent",
}
auth_url = f"{_OAUTH_AUTH_URL}?{urllib.parse.urlencode(params)}"
logger.info("Opening browser for authentication...")
logger.info(f"If the browser doesn't open, visit: {auth_url}\n")
# Open browser
webbrowser.open(auth_url)
# Wait for callback
logger.info(f"Listening for callback on port {port}...")
code, received_state, error = wait_for_callback(port)
if error:
logger.error(f"Authentication failed: {error}")
return 1
if not code:
logger.error("No authorization code received")
return 1
if received_state != state:
logger.error("State mismatch - possible CSRF attack")
return 1
# Exchange code for tokens
logger.info("Exchanging authorization code for tokens...")
tokens = exchange_code_for_tokens(code, redirect_uri, client_id, client_secret)
if not tokens:
return 1
access_token = tokens.get("access_token")
refresh_token = tokens.get("refresh_token")
expires_in = tokens.get("expires_in", 3600)
if not access_token:
logger.error("No access token in response")
return 1
# Get user email
email = get_user_email(access_token)
if email:
logger.info(f"Authenticated as: {email}")
# Load existing accounts and add/update
accounts_data = load_accounts()
accounts = accounts_data.get("accounts", [])
# Build new account entry (V4 schema)
expires_ms = int((time.time() + expires_in) * 1000)
refresh_entry = f"{refresh_token}|{_DEFAULT_PROJECT_ID}"
new_account = {
"access": access_token,
"refresh": refresh_entry,
"expires": expires_ms,
"email": email,
"enabled": True,
}
# Update existing account or add new one
existing_idx = next((i for i, a in enumerate(accounts) if a.get("email") == email), None)
if existing_idx is not None:
accounts[existing_idx] = new_account
logger.info(f"Updated existing account: {email}")
else:
accounts.append(new_account)
logger.info(f"Added new account: {email}")
accounts_data["accounts"] = accounts
accounts_data["schemaVersion"] = 4
accounts_data["last_refresh"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
save_accounts(accounts_data)
logger.info("\n✓ Authentication complete!")
return 0
def cmd_account_list(args: argparse.Namespace) -> int:
"""List all stored accounts."""
data = load_accounts()
accounts = data.get("accounts", [])
if not accounts:
logger.info("No accounts configured.")
logger.info("Run 'antigravity auth account add' to add one.")
return 0
logger.info("Configured accounts:\n")
for i, account in enumerate(accounts, 1):
email = account.get("email", "unknown")
enabled = "enabled" if account.get("enabled", True) else "disabled"
logger.info(f" {i}. {email} ({enabled})")
return 0
def cmd_account_remove(args: argparse.Namespace) -> int:
"""Remove an account by email."""
email = args.email
data = load_accounts()
accounts = data.get("accounts", [])
original_len = len(accounts)
accounts = [a for a in accounts if a.get("email") != email]
if len(accounts) == original_len:
logger.error(f"No account found with email: {email}")
return 1
data["accounts"] = accounts
save_accounts(data)
logger.info(f"Removed account: {email}")
return 0
def main() -> int:
parser = argparse.ArgumentParser(
description="Antigravity authentication CLI",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="command", help="Commands")
# auth account add
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command")
account_parser = auth_subparsers.add_parser("account", help="Account management")
account_subparsers = account_parser.add_subparsers(dest="account_command")
add_parser = account_subparsers.add_parser("add", help="Add a new account via OAuth2")
add_parser.set_defaults(func=cmd_account_add)
list_parser = account_subparsers.add_parser("list", help="List configured accounts")
list_parser.set_defaults(func=cmd_account_list)
remove_parser = account_subparsers.add_parser("remove", help="Remove an account")
remove_parser.add_argument("email", help="Email of account to remove")
remove_parser.set_defaults(func=cmd_account_remove)
args = parser.parse_args()
if hasattr(args, "func"):
return args.func(args)
parser.print_help()
return 0
if __name__ == "__main__":
sys.exit(main())