Files
hive/core/framework/credentials/oauth2/lifecycle.py
T
2026-01-27 16:58:06 -08:00

364 lines
11 KiB
Python

"""
Token lifecycle management for OAuth2 credentials.
This module provides the TokenLifecycleManager which coordinates
automatic token refresh with the credential store.
"""
from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from pydantic import SecretStr
from ..models import CredentialKey, CredentialObject, CredentialType
from .base_provider import BaseOAuth2Provider
from .provider import OAuth2Token
if TYPE_CHECKING:
from ..store import CredentialStore
logger = logging.getLogger(__name__)
@dataclass
class TokenRefreshResult:
"""Result of a token refresh operation."""
success: bool
token: OAuth2Token | None = None
error: str | None = None
needs_reauthorization: bool = False
class TokenLifecycleManager:
"""
Manages the complete lifecycle of OAuth2 tokens.
Responsibilities:
- Coordinate with CredentialStore for persistence
- Automatically refresh expired tokens
- Handle refresh failures gracefully
- Provide callbacks for monitoring
This class is useful when you need more control over token management
than the basic auto-refresh in CredentialStore provides.
Usage:
manager = TokenLifecycleManager(
provider=github_provider,
credential_id="github_oauth",
store=credential_store,
)
# Get valid token (auto-refreshes if needed)
token = await manager.get_valid_token()
# Use token
headers = provider.format_for_request(token)
Synchronous usage:
# For synchronous code, use sync_ methods
token = manager.sync_get_valid_token()
"""
def __init__(
self,
provider: BaseOAuth2Provider,
credential_id: str,
store: CredentialStore,
refresh_buffer_minutes: int = 5,
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
on_refresh_failed: Callable[[str], None] | None = None,
):
"""
Initialize the lifecycle manager.
Args:
provider: OAuth2 provider for token operations
credential_id: ID of the credential in the store
store: Credential store for persistence
refresh_buffer_minutes: Minutes before expiry to trigger refresh
on_token_refreshed: Callback when token is refreshed
on_refresh_failed: Callback when refresh fails
"""
self.provider = provider
self.credential_id = credential_id
self.store = store
self.refresh_buffer = timedelta(minutes=refresh_buffer_minutes)
self.on_token_refreshed = on_token_refreshed
self.on_refresh_failed = on_refresh_failed
# In-memory cache for performance
self._cached_token: OAuth2Token | None = None
self._cache_time: datetime | None = None
# --- Async Token Access ---
async def get_valid_token(self) -> OAuth2Token | None:
"""
Get a valid access token, refreshing if necessary.
This is the main entry point for async code.
Returns:
Valid OAuth2Token or None if unavailable
"""
# Check cache first
if self._cached_token and not self._needs_refresh(self._cached_token):
return self._cached_token
# Load from store
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential is None:
return None
# Convert to OAuth2Token
token = self._credential_to_token(credential)
if token is None:
return None
# Refresh if needed
if self._needs_refresh(token):
result = await self._async_refresh_token(credential)
if result.success and result.token:
token = result.token
elif result.needs_reauthorization:
logger.warning(f"Token for {self.credential_id} needs reauthorization")
return None
else:
# Use existing token if still technically valid
if token.is_expired:
return None
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
self._cached_token = token
self._cache_time = datetime.now(UTC)
return token
async def acquire_token_client_credentials(
self,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""
Acquire a new token using client credentials flow.
For service-to-service authentication.
Args:
scopes: Scopes to request
Returns:
New OAuth2Token
"""
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
token = await loop.run_in_executor(
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
)
self._save_token_to_store(token)
self._cached_token = token
return token
async def revoke(self) -> bool:
"""
Revoke tokens and clear from store.
Returns:
True if revocation succeeded
"""
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential:
self.provider.revoke(credential)
self.store.delete_credential(self.credential_id)
self._cached_token = None
return True
# --- Synchronous Token Access ---
def sync_get_valid_token(self) -> OAuth2Token | None:
"""
Synchronous version of get_valid_token().
For use in synchronous code.
"""
# Check cache
if self._cached_token and not self._needs_refresh(self._cached_token):
return self._cached_token
# Load from store
credential = self.store.get_credential(self.credential_id, refresh_if_needed=False)
if credential is None:
return None
token = self._credential_to_token(credential)
if token is None:
return None
# Refresh if needed
if self._needs_refresh(token):
result = self._sync_refresh_token(credential)
if result.success and result.token:
token = result.token
elif result.needs_reauthorization:
logger.warning(f"Token for {self.credential_id} needs reauthorization")
return None
else:
if token.is_expired:
return None
self._cached_token = token
self._cache_time = datetime.now(UTC)
return token
def sync_acquire_token_client_credentials(
self,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""Synchronous version of acquire_token_client_credentials()."""
token = self.provider.client_credentials_grant(scopes=scopes)
self._save_token_to_store(token)
self._cached_token = token
return token
# --- Helper Methods ---
def _needs_refresh(self, token: OAuth2Token) -> bool:
"""Check if token needs refresh."""
if token.expires_at is None:
return False
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
"""Convert credential to OAuth2Token."""
access_token = credential.get_key("access_token")
if not access_token:
return None
expires_at = None
access_key = credential.keys.get("access_token")
if access_key:
expires_at = access_key.expires_at
return OAuth2Token(
access_token=access_token,
token_type="Bearer",
expires_at=expires_at,
refresh_token=credential.get_key("refresh_token"),
scope=credential.get_key("scope"),
)
def _save_token_to_store(self, token: OAuth2Token) -> None:
"""Save token to credential store."""
credential = CredentialObject(
id=self.credential_id,
credential_type=CredentialType.OAUTH2,
keys={
"access_token": CredentialKey(
name="access_token",
value=SecretStr(token.access_token),
expires_at=token.expires_at,
),
},
provider_id=self.provider.provider_id,
auto_refresh=True,
)
if token.refresh_token:
credential.keys["refresh_token"] = CredentialKey(
name="refresh_token",
value=SecretStr(token.refresh_token),
)
if token.scope:
credential.keys["scope"] = CredentialKey(
name="scope",
value=SecretStr(token.scope),
)
self.store.save_credential(credential)
async def _async_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
"""Async wrapper for token refresh."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: self._sync_refresh_token(credential))
def _sync_refresh_token(self, credential: CredentialObject) -> TokenRefreshResult:
"""Synchronously refresh token."""
refresh_token = credential.get_key("refresh_token")
if not refresh_token:
return TokenRefreshResult(
success=False,
error="No refresh token available",
needs_reauthorization=True,
)
try:
new_token = self.provider.refresh_access_token(refresh_token)
# Save to store
self._save_token_to_store(new_token)
# Notify callback
if self.on_token_refreshed:
self.on_token_refreshed(new_token)
logger.info(f"Token refreshed for {self.credential_id}")
return TokenRefreshResult(success=True, token=new_token)
except Exception as e:
error_msg = str(e)
# Check for refresh token revocation
if "invalid_grant" in error_msg.lower():
return TokenRefreshResult(
success=False,
error=error_msg,
needs_reauthorization=True,
)
if self.on_refresh_failed:
self.on_refresh_failed(error_msg)
logger.error(f"Token refresh failed for {self.credential_id}: {e}")
return TokenRefreshResult(success=False, error=error_msg)
def invalidate_cache(self) -> None:
"""Clear cached token."""
self._cached_token = None
self._cache_time = None
# --- Convenience Methods ---
def get_request_headers(self) -> dict[str, str]:
"""
Get headers for HTTP request with current token.
Returns empty dict if no valid token.
"""
token = self.sync_get_valid_token()
if token is None:
return {}
result = self.provider.format_for_request(token)
return result.get("headers", {})
def get_request_kwargs(self) -> dict:
"""
Get kwargs for HTTP request (headers, params, etc.).
Returns empty dict if no valid token.
"""
token = self.sync_get_valid_token()
if token is None:
return {}
return self.provider.format_for_request(token)