589c5b06fe
- Auto-fixed 70 lint errors (import sorting, aliased errors, datetime.UTC)
- Fixed 85 remaining errors manually:
- E501: wrapped long lines in queen_profiles, catalog, routes_credentials
- F821: added missing TYPE_CHECKING imports for AgentHost, ToolRegistry,
HookContext, HookResult; added runtime imports where needed
- F811: removed duplicate method definitions in queen_lifecycle_tools
- F841/B007: removed unused variables in discovery.py
- W291: removed trailing whitespace in queen nodes
- E402: moved import to top of queen_memory_v2.py
- Fixed AgentRuntime -> AgentHost in example template type annotations
- Reformatted 343 files with ruff format
362 lines
11 KiB
Python
362 lines
11 KiB
Python
"""
|
|
Token lifecycle management for OAuth2 credentials.
|
|
|
|
This module provides the TokenLifecycleManager which coordinates
|
|
automatic token refresh with the credential store.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import TYPE_CHECKING
|
|
|
|
from pydantic import SecretStr
|
|
|
|
from ..models import CredentialKey, CredentialObject, CredentialType
|
|
from .base_provider import BaseOAuth2Provider
|
|
from .provider import OAuth2Token
|
|
|
|
if TYPE_CHECKING:
|
|
from ..store import CredentialStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TokenRefreshResult:
|
|
"""Result of a token refresh operation."""
|
|
|
|
success: bool
|
|
token: OAuth2Token | None = None
|
|
error: str | None = None
|
|
needs_reauthorization: bool = False
|
|
|
|
|
|
class TokenLifecycleManager:
|
|
"""
|
|
Manages the complete lifecycle of OAuth2 tokens.
|
|
|
|
Responsibilities:
|
|
- Coordinate with CredentialStore for persistence
|
|
- Automatically refresh expired tokens
|
|
- Handle refresh failures gracefully
|
|
- Provide callbacks for monitoring
|
|
|
|
This class is useful when you need more control over token management
|
|
than the basic auto-refresh in CredentialStore provides.
|
|
|
|
Usage:
|
|
manager = TokenLifecycleManager(
|
|
provider=github_provider,
|
|
credential_id="github_oauth",
|
|
store=credential_store,
|
|
)
|
|
|
|
# Get valid token (auto-refreshes if needed)
|
|
token = await manager.get_valid_token()
|
|
|
|
# Use token
|
|
headers = provider.format_for_request(token)
|
|
|
|
Synchronous usage:
|
|
# For synchronous code, use sync_ methods
|
|
token = manager.sync_get_valid_token()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: BaseOAuth2Provider,
|
|
credential_id: str,
|
|
store: CredentialStore,
|
|
refresh_buffer_minutes: int = 5,
|
|
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
|
|
on_refresh_failed: Callable[[str], None] | None = None,
|
|
):
|
|
"""
|
|
Initialize the lifecycle manager.
|
|
|
|
Args:
|
|
provider: OAuth2 provider for token operations
|
|
credential_id: ID of the credential in the store
|
|
store: Credential store for persistence
|
|
refresh_buffer_minutes: Minutes before expiry to trigger refresh
|
|
on_token_refreshed: Callback when token is refreshed
|
|
on_refresh_failed: Callback when refresh fails
|
|
"""
|
|
self.provider = provider
|
|
self.credential_id = credential_id
|
|
self.store = store
|
|
self.refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
|
|
self.on_token_refreshed = on_token_refreshed
|
|
self.on_refresh_failed = on_refresh_failed
|
|
|
|
# In-memory cache for performance
|
|
self._cached_token: OAuth2Token | None = None
|
|
self._cache_time: datetime | None = None
|
|
|
|
# --- Async Token Access ---
|
|
|
|
async def get_valid_token(self) -> OAuth2Token | None:
|
|
"""
|
|
Get a valid access token, refreshing if necessary.
|
|
|
|
This is the main entry point for async code.
|
|
|
|
Returns:
|
|
Valid OAuth2Token or None if unavailable
|
|
"""
|
|
# Check cache first
|
|
if self._cached_token and not self._needs_refresh(self._cached_token):
|
|
return self._cached_token
|
|
|
|
# Load from store
|
|
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
|
if credential is None:
|
|
return None
|
|
|
|
# Convert to OAuth2Token
|
|
token = self._credential_to_token(credential)
|
|
if token is None:
|
|
return None
|
|
|
|
# Refresh if needed
|
|
if self._needs_refresh(token):
|
|
result = await self._async_refresh_token(credential)
|
|
if result.success and result.token:
|
|
token = result.token
|
|
elif result.needs_reauthorization:
|
|
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
|
return None
|
|
else:
|
|
# Use existing token if still technically valid
|
|
if token.is_expired:
|
|
return None
|
|
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
|
|
|
|
self._cached_token = token
|
|
self._cache_time = datetime.now(UTC)
|
|
return token
|
|
|
|
async def acquire_token_client_credentials(
|
|
self,
|
|
scopes: list[str] | None = None,
|
|
) -> OAuth2Token:
|
|
"""
|
|
Acquire a new token using client credentials flow.
|
|
|
|
For service-to-service authentication.
|
|
|
|
Args:
|
|
scopes: Scopes to request
|
|
|
|
Returns:
|
|
New OAuth2Token
|
|
"""
|
|
# Run in executor to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
token = await loop.run_in_executor(None, lambda: self.provider.client_credentials_grant(scopes=scopes))
|
|
|
|
self._save_token_to_store(token)
|
|
self._cached_token = token
|
|
return token
|
|
|
|
async def revoke(self) -> bool:
|
|
"""
|
|
Revoke tokens and clear from store.
|
|
|
|
Returns:
|
|
True if revocation succeeded
|
|
"""
|
|
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
|
if credential:
|
|
self.provider.revoke(credential)
|
|
|
|
self.store.delete_credential(self.credential_id)
|
|
self._cached_token = None
|
|
return True
|
|
|
|
# --- Synchronous Token Access ---
|
|
|
|
def sync_get_valid_token(self) -> OAuth2Token | None:
|
|
"""
|
|
Synchronous version of get_valid_token().
|
|
|
|
For use in synchronous code.
|
|
"""
|
|
# Check cache
|
|
if self._cached_token and not self._needs_refresh(self._cached_token):
|
|
return self._cached_token
|
|
|
|
# Load from store
|
|
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
|
|
if credential is None:
|
|
return None
|
|
|
|
token = self._credential_to_token(credential)
|
|
if token is None:
|
|
return None
|
|
|
|
# Refresh if needed
|
|
if self._needs_refresh(token):
|
|
result = self._sync_refresh_token(credential)
|
|
if result.success and result.token:
|
|
token = result.token
|
|
elif result.needs_reauthorization:
|
|
logger.warning(f"Token for {self.credential_id} needs reauthorization")
|
|
return None
|
|
else:
|
|
if token.is_expired:
|
|
return None
|
|
|
|
self._cached_token = token
|
|
self._cache_time = datetime.now(UTC)
|
|
return token
|
|
|
|
def sync_acquire_token_client_credentials(
|
|
self,
|
|
scopes: list[str] | None = None,
|
|
) -> OAuth2Token:
|
|
"""Synchronous version of acquire_token_client_credentials()."""
|
|
token = self.provider.client_credentials_grant(scopes=scopes)
|
|
self._save_token_to_store(token)
|
|
self._cached_token = token
|
|
return token
|
|
|
|
# --- Helper Methods ---
|
|
|
|
def _needs_refresh(self, token: OAuth2Token) -> bool:
|
|
"""Check if token needs refresh."""
|
|
if token.expires_at is None:
|
|
return False
|
|
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
|
|
|
|
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
|
|
"""Convert credential to OAuth2Token."""
|
|
access_token = credential.get_key("access_token")
|
|
if not access_token:
|
|
return None
|
|
|
|
expires_at = None
|
|
access_key = credential.keys.get("access_token")
|
|
if access_key:
|
|
expires_at = access_key.expires_at
|
|
|
|
return OAuth2Token(
|
|
access_token=access_token,
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
refresh_token=credential.get_key("refresh_token"),
|
|
scope=credential.get_key("scope"),
|
|
)
|
|
|
|
def _save_token_to_store(self, token: OAuth2Token) -> None:
|
|
"""Save token to credential store."""
|
|
credential = CredentialObject(
|
|
id=self.credential_id,
|
|
credential_type=CredentialType.OAUTH2,
|
|
keys={
|
|
"access_token": CredentialKey(
|
|
name="access_token",
|
|
value=SecretStr(token.access_token),
|
|
expires_at=token.expires_at,
|
|
),
|
|
},
|
|
provider_id=self.provider.provider_id,
|
|
auto_refresh=True,
|
|
)
|
|
|
|
if token.refresh_token:
|
|
credential.keys["refresh_token"] = CredentialKey(
|
|
name="refresh_token",
|
|
value=SecretStr(token.refresh_token),
|
|
)
|
|
|
|
if token.scope:
|
|
credential.keys["scope"] = CredentialKey(
|
|
name="scope",
|
|
value=SecretStr(token.scope),
|
|
)
|
|
|
|
self.store.save_credential(credential)
|
|
|
|
async def _async_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
|
"""Async wrapper for token refresh."""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, lambda: self._sync_refresh_token(credential))
|
|
|
|
def _sync_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
|
|
"""Synchronously refresh token."""
|
|
refresh_token = credential.get_key("refresh_token")
|
|
if not refresh_token:
|
|
return TokenRefreshResult(
|
|
success=False,
|
|
error="No refresh token available",
|
|
needs_reauthorization=True,
|
|
)
|
|
|
|
try:
|
|
new_token = self.provider.refresh_access_token(refresh_token)
|
|
|
|
# Save to store
|
|
self._save_token_to_store(new_token)
|
|
|
|
# Notify callback
|
|
if self.on_token_refreshed:
|
|
self.on_token_refreshed(new_token)
|
|
|
|
logger.info(f"Token refreshed for {self.credential_id}")
|
|
return TokenRefreshResult(success=True, token=new_token)
|
|
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
|
|
# Check for refresh token revocation
|
|
if "invalid_grant" in error_msg.lower():
|
|
return TokenRefreshResult(
|
|
success=False,
|
|
error=error_msg,
|
|
needs_reauthorization=True,
|
|
)
|
|
|
|
if self.on_refresh_failed:
|
|
self.on_refresh_failed(error_msg)
|
|
|
|
logger.error(f"Token refresh failed for {self.credential_id}: {e}")
|
|
return TokenRefreshResult(success=False, error=error_msg)
|
|
|
|
def invalidate_cache(self) -> None:
|
|
"""Clear cached token."""
|
|
self._cached_token = None
|
|
self._cache_time = None
|
|
|
|
# --- Convenience Methods ---
|
|
|
|
def get_request_headers(self) -> dict[str, str]:
|
|
"""
|
|
Get headers for HTTP request with current token.
|
|
|
|
Returns empty dict if no valid token.
|
|
"""
|
|
token = self.sync_get_valid_token()
|
|
if token is None:
|
|
return {}
|
|
|
|
result = self.provider.format_for_request(token)
|
|
return result.get("headers", {})
|
|
|
|
def get_request_kwargs(self) -> dict:
|
|
"""
|
|
Get kwargs for HTTP request (headers, params, etc.).
|
|
|
|
Returns empty dict if no valid token.
|
|
"""
|
|
token = self.sync_get_valid_token()
|
|
if token is None:
|
|
return {}
|
|
|
|
return self.provider.format_for_request(token)
|