Merge pull request #1532 from TimothyZhang7/main

chore: fix lint issues
This commit is contained in:
Timothy @aden
2026-01-27 17:01:05 -08:00
committed by GitHub
11 changed files with 270 additions and 229 deletions
+25 -24
View File
@@ -7,16 +7,16 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
from __future__ import annotations
from datetime import datetime, timezone
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, SecretStr
def _utc_now() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(timezone.utc)
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, SecretStr
return datetime.now(UTC)
class CredentialType(str, Enum):
@@ -53,8 +53,8 @@ class CredentialKey(BaseModel):
name: str
value: SecretStr
expires_at: Optional[datetime] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
expires_at: datetime | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
model_config = {"extra": "allow"}
@@ -63,7 +63,7 @@ class CredentialKey(BaseModel):
"""Check if this key has expired."""
if self.expires_at is None:
return False
return datetime.now(timezone.utc) >= self.expires_at
return datetime.now(UTC) >= self.expires_at
def get_secret_value(self) -> str:
"""Get the actual secret value (use sparingly)."""
@@ -98,28 +98,29 @@ class CredentialObject(BaseModel):
id: str = Field(description="Unique identifier (e.g., 'brave_search', 'github_oauth')")
credential_type: CredentialType = CredentialType.API_KEY
keys: Dict[str, CredentialKey] = Field(default_factory=dict)
keys: dict[str, CredentialKey] = Field(default_factory=dict)
# Lifecycle management
provider_id: Optional[str] = Field(
default=None, description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')"
provider_id: str | None = Field(
default=None,
description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')",
)
last_refreshed: Optional[datetime] = None
last_refreshed: datetime | None = None
auto_refresh: bool = False
# Usage tracking
last_used: Optional[datetime] = None
last_used: datetime | None = None
use_count: int = 0
# Metadata
description: str = ""
tags: List[str] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=_utc_now)
updated_at: datetime = Field(default_factory=_utc_now)
model_config = {"extra": "allow"}
def get_key(self, key_name: str) -> Optional[str]:
def get_key(self, key_name: str) -> str | None:
"""
Get a specific key's value.
@@ -138,8 +139,8 @@ class CredentialObject(BaseModel):
self,
key_name: str,
value: str,
expires_at: Optional[datetime] = None,
metadata: Optional[Dict[str, Any]] = None,
expires_at: datetime | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
Set or update a key.
@@ -156,7 +157,7 @@ class CredentialObject(BaseModel):
expires_at=expires_at,
metadata=metadata or {},
)
self.updated_at = datetime.now(timezone.utc)
self.updated_at = datetime.now(UTC)
def has_key(self, key_name: str) -> bool:
"""Check if a key exists."""
@@ -179,10 +180,10 @@ class CredentialObject(BaseModel):
def record_usage(self) -> None:
"""Record that this credential was used."""
self.last_used = datetime.now(timezone.utc)
self.last_used = datetime.now(UTC)
self.use_count += 1
def get_default_key(self) -> Optional[str]:
def get_default_key(self) -> str | None:
"""
Get the default key value.
@@ -232,18 +233,18 @@ class CredentialUsageSpec(BaseModel):
"""
credential_id: str = Field(description="ID of credential to use (e.g., 'brave_search')")
required_keys: List[str] = Field(default_factory=list, description="Keys that must be present")
required_keys: list[str] = Field(default_factory=list, description="Keys that must be present")
# Injection templates (bipartisan model)
headers: Dict[str, str] = Field(
headers: dict[str, str] = Field(
default_factory=dict,
description="Header templates (e.g., {'Authorization': 'Bearer {{access_token}}'})",
)
query_params: Dict[str, str] = Field(
query_params: dict[str, str] = Field(
default_factory=dict,
description="Query param templates (e.g., {'api_key': '{{api_key}}'})",
)
body_fields: Dict[str, str] = Field(
body_fields: dict[str, str] = Field(
default_factory=dict,
description="Request body field templates",
)
@@ -8,13 +8,18 @@ OAuth2 servers. OSS users can extend this class for custom providers.
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from ..models import CredentialObject, CredentialRefreshError, CredentialType
from ..provider import CredentialProvider
from .provider import OAuth2Config, OAuth2Error, OAuth2Token, RefreshTokenInvalidError, TokenPlacement
from .provider import (
OAuth2Config,
OAuth2Error,
OAuth2Token,
TokenPlacement,
)
logger = logging.getLogger(__name__)
@@ -72,14 +77,14 @@ class BaseOAuth2Provider(CredentialProvider):
"""
self.config = config
self._provider_id = provider_id
self._client: Optional[Any] = None
self._client: Any | None = None
@property
def provider_id(self) -> str:
return self._provider_id
@property
def supported_types(self) -> List[CredentialType]:
def supported_types(self) -> list[CredentialType]:
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
def _get_client(self) -> Any:
@@ -90,7 +95,9 @@ class BaseOAuth2Provider(CredentialProvider):
self._client = httpx.Client(timeout=self.config.request_timeout)
except ImportError as e:
raise ImportError("OAuth2 provider requires 'httpx'. Install with: pip install httpx") from e
raise ImportError(
"OAuth2 provider requires 'httpx'. Install with: pip install httpx"
) from e
return self._client
def _close_client(self) -> None:
@@ -109,7 +116,7 @@ class BaseOAuth2Provider(CredentialProvider):
self,
state: str,
redirect_uri: str,
scopes: Optional[List[str]] = None,
scopes: list[str] | None = None,
**kwargs: Any,
) -> str:
"""
@@ -175,7 +182,7 @@ class BaseOAuth2Provider(CredentialProvider):
def client_credentials_grant(
self,
scopes: Optional[List[str]] = None,
scopes: list[str] | None = None,
**kwargs: Any,
) -> OAuth2Token:
"""
@@ -209,7 +216,7 @@ class BaseOAuth2Provider(CredentialProvider):
def refresh_access_token(
self,
refresh_token: str,
scopes: Optional[List[str]] = None,
scopes: list[str] | None = None,
**kwargs: Any,
) -> OAuth2Token:
"""
@@ -316,7 +323,7 @@ class BaseOAuth2Provider(CredentialProvider):
if new_token.refresh_token and new_token.refresh_token != refresh_tok:
credential.set_key("refresh_token", new_token.refresh_token)
credential.last_refreshed = datetime.now(timezone.utc)
credential.last_refreshed = datetime.now(UTC)
logger.info(f"Refreshed OAuth2 credential '{credential.id}'")
return credential
@@ -350,7 +357,7 @@ class BaseOAuth2Provider(CredentialProvider):
return False
buffer = timedelta(minutes=5)
return datetime.now(timezone.utc) >= (access_key.expires_at - buffer)
return datetime.now(UTC) >= (access_key.expires_at - buffer)
def revoke(self, credential: CredentialObject) -> bool:
"""
@@ -380,7 +387,7 @@ class BaseOAuth2Provider(CredentialProvider):
# --- Token Request Helpers ---
def _token_request(self, data: Dict[str, Any]) -> OAuth2Token:
def _token_request(self, data: dict[str, Any]) -> OAuth2Token:
"""
Make a token request to the OAuth2 server.
@@ -415,11 +422,13 @@ class BaseOAuth2Provider(CredentialProvider):
if response.status_code != 200 or "error" in response_data:
error = response_data.get("error", "unknown_error")
description = response_data.get("error_description", response.text)
raise OAuth2Error(error=error, description=description, status_code=response.status_code)
raise OAuth2Error(
error=error, description=description, status_code=response.status_code
)
return OAuth2Token.from_token_response(response_data)
def _parse_form_response(self, text: str) -> Dict[str, str]:
def _parse_form_response(self, text: str) -> dict[str, str]:
"""Parse form-encoded response (some providers use this instead of JSON)."""
from urllib.parse import parse_qs
@@ -428,7 +437,7 @@ class BaseOAuth2Provider(CredentialProvider):
# --- Token Formatting for Requests ---
def format_for_request(self, token: OAuth2Token) -> Dict[str, Any]:
def format_for_request(self, token: OAuth2Token) -> dict[str, Any]:
"""
Format token for use in HTTP requests (bipartisan model).
@@ -455,7 +464,7 @@ class BaseOAuth2Provider(CredentialProvider):
return {}
def format_credential_for_request(self, credential: CredentialObject) -> Dict[str, Any]:
def format_credential_for_request(self, credential: CredentialObject) -> dict[str, Any]:
"""
Format a credential for use in HTTP requests.
+22 -19
View File
@@ -9,13 +9,14 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Callable, Optional
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING
from pydantic import SecretStr
from ..models import CredentialKey, CredentialObject, CredentialRefreshError, CredentialType
from ..models import CredentialKey, CredentialObject, CredentialType
from .base_provider import BaseOAuth2Provider
from .provider import OAuth2Token
@@ -30,8 +31,8 @@ class TokenRefreshResult:
"""Result of a token refresh operation."""
success: bool
token: Optional[OAuth2Token] = None
error: Optional[str] = None
token: OAuth2Token | None = None
error: str | None = None
needs_reauthorization: bool = False
@@ -70,10 +71,10 @@ class TokenLifecycleManager:
self,
provider: BaseOAuth2Provider,
credential_id: str,
store: "CredentialStore",
store: CredentialStore,
refresh_buffer_minutes: int = 5,
on_token_refreshed: Optional[Callable[[OAuth2Token], None]] = None,
on_refresh_failed: Optional[Callable[[str], None]] = None,
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
on_refresh_failed: Callable[[str], None] | None = None,
):
"""
Initialize the lifecycle manager.
@@ -94,12 +95,12 @@ class TokenLifecycleManager:
self.on_refresh_failed = on_refresh_failed
# In-memory cache for performance
self._cached_token: Optional[OAuth2Token] = None
self._cache_time: Optional[datetime] = None
self._cached_token: OAuth2Token | None = None
self._cache_time: datetime | None = None
# --- Async Token Access ---
async def get_valid_token(self) -> Optional[OAuth2Token]:
async def get_valid_token(self) -> OAuth2Token | None:
"""
Get a valid access token, refreshing if necessary.
@@ -137,12 +138,12 @@ class TokenLifecycleManager:
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
self._cached_token = token
self._cache_time = datetime.now(timezone.utc)
self._cache_time = datetime.now(UTC)
return token
async def acquire_token_client_credentials(
self,
scopes: Optional[list[str]] = None,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""
Acquire a new token using client credentials flow.
@@ -157,7 +158,9 @@ class TokenLifecycleManager:
"""
# Run in executor to avoid blocking
loop = asyncio.get_event_loop()
token = await loop.run_in_executor(None, lambda: self.provider.client_credentials_grant(scopes=scopes))
token = await loop.run_in_executor(
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
)
self._save_token_to_store(token)
self._cached_token = token
@@ -180,7 +183,7 @@ class TokenLifecycleManager:
# --- Synchronous Token Access ---
def sync_get_valid_token(self) -> Optional[OAuth2Token]:
def sync_get_valid_token(self) -> OAuth2Token | None:
"""
Synchronous version of get_valid_token().
@@ -212,12 +215,12 @@ class TokenLifecycleManager:
return None
self._cached_token = token
self._cache_time = datetime.now(timezone.utc)
self._cache_time = datetime.now(UTC)
return token
def sync_acquire_token_client_credentials(
self,
scopes: Optional[list[str]] = None,
scopes: list[str] | None = None,
) -> OAuth2Token:
"""Synchronous version of acquire_token_client_credentials()."""
token = self.provider.client_credentials_grant(scopes=scopes)
@@ -231,9 +234,9 @@ class TokenLifecycleManager:
"""Check if token needs refresh."""
if token.expires_at is None:
return False
return datetime.now(timezone.utc) >= (token.expires_at - self.refresh_buffer)
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
def _credential_to_token(self, credential: CredentialObject) -> Optional[OAuth2Token]:
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
"""Convert credential to OAuth2Token."""
access_token = credential.get_key("access_token")
if not access_token:
+18 -18
View File
@@ -10,9 +10,9 @@ This module defines the core OAuth2 data structures:
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any
class TokenPlacement(str, Enum):
@@ -47,10 +47,10 @@ class OAuth2Token:
access_token: str
token_type: str = "Bearer"
expires_at: Optional[datetime] = None
refresh_token: Optional[str] = None
scope: Optional[str] = None
raw_response: Dict[str, Any] = field(default_factory=dict)
expires_at: datetime | None = None
refresh_token: str | None = None
scope: str | None = None
raw_response: dict[str, Any] = field(default_factory=dict)
@property
def is_expired(self) -> bool:
@@ -63,7 +63,7 @@ class OAuth2Token:
if self.expires_at is None:
return False
buffer = timedelta(minutes=5)
return datetime.now(timezone.utc) >= (self.expires_at - buffer)
return datetime.now(UTC) >= (self.expires_at - buffer)
@property
def can_refresh(self) -> bool:
@@ -71,15 +71,15 @@ class OAuth2Token:
return self.refresh_token is not None and self.refresh_token.strip() != ""
@property
def expires_in_seconds(self) -> Optional[int]:
def expires_in_seconds(self) -> int | None:
"""Get seconds until expiration, or None if no expiration."""
if self.expires_at is None:
return None
delta = self.expires_at - datetime.now(timezone.utc)
delta = self.expires_at - datetime.now(UTC)
return max(0, int(delta.total_seconds()))
@classmethod
def from_token_response(cls, data: Dict[str, Any]) -> "OAuth2Token":
def from_token_response(cls, data: dict[str, Any]) -> OAuth2Token:
"""
Create OAuth2Token from an OAuth2 token endpoint response.
@@ -91,7 +91,7 @@ class OAuth2Token:
"""
expires_at = None
if "expires_in" in data:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=data["expires_in"])
expires_at = datetime.now(UTC) + timedelta(seconds=data["expires_in"])
return cls(
access_token=data["access_token"],
@@ -137,28 +137,28 @@ class OAuth2Config:
# Endpoints (only token_url is strictly required)
token_url: str
authorization_url: Optional[str] = None
revocation_url: Optional[str] = None
introspection_url: Optional[str] = None
authorization_url: str | None = None
revocation_url: str | None = None
introspection_url: str | None = None
# Client credentials
client_id: str = ""
client_secret: str = ""
# Scopes
default_scopes: List[str] = field(default_factory=list)
default_scopes: list[str] = field(default_factory=list)
# Token placement for API calls (bipartisan model)
token_placement: TokenPlacement = TokenPlacement.HEADER_BEARER
custom_header_name: Optional[str] = None
custom_header_name: str | None = None
query_param_name: str = "access_token"
# Request configuration
extra_token_params: Dict[str, str] = field(default_factory=dict)
extra_token_params: dict[str, str] = field(default_factory=dict)
request_timeout: float = 30.0
# Additional headers for token requests
extra_headers: Dict[str, str] = field(default_factory=dict)
extra_headers: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate configuration."""
+6 -7
View File
@@ -13,8 +13,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from typing import List
from datetime import UTC, datetime, timedelta
from .models import CredentialObject, CredentialRefreshError, CredentialType
@@ -64,7 +63,7 @@ class CredentialProvider(ABC):
@property
@abstractmethod
def supported_types(self) -> List[CredentialType]:
def supported_types(self) -> list[CredentialType]:
"""
Credential types this provider can manage.
@@ -127,7 +126,7 @@ class CredentialProvider(ABC):
True if credential should be refreshed
"""
buffer = timedelta(minutes=5)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
for key in credential.keys.values():
if key.expires_at is not None:
@@ -181,7 +180,7 @@ class StaticProvider(CredentialProvider):
return "static"
@property
def supported_types(self) -> List[CredentialType]:
def supported_types(self) -> list[CredentialType]:
return [CredentialType.API_KEY, CredentialType.BASIC_AUTH, CredentialType.CUSTOM]
def refresh(self, credential: CredentialObject) -> CredentialObject:
@@ -236,7 +235,7 @@ class BearerTokenProvider(CredentialProvider):
return "bearer_token"
@property
def supported_types(self) -> List[CredentialType]:
def supported_types(self) -> list[CredentialType]:
return [CredentialType.BEARER_TOKEN]
def refresh(self, credential: CredentialObject) -> CredentialObject:
@@ -273,7 +272,7 @@ class BearerTokenProvider(CredentialProvider):
credential needs attention.
"""
buffer = timedelta(minutes=5)
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
for key_name in ["access_token", "token"]:
key = credential.keys.get(key_name)
+32 -27
View File
@@ -14,9 +14,9 @@ import json
import logging
import os
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import SecretStr
@@ -44,7 +44,7 @@ class CredentialStorage(ABC):
pass
@abstractmethod
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""
Load a credential from storage.
@@ -70,7 +70,7 @@ class CredentialStorage(ABC):
pass
@abstractmethod
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""
List all credential IDs in storage.
@@ -119,7 +119,7 @@ class EncryptedFileStorage(CredentialStorage):
def __init__(
self,
base_path: str | Path,
encryption_key: Optional[bytes] = None,
encryption_key: bytes | None = None,
key_env_var: str = "HIVE_CREDENTIAL_KEY",
):
"""
@@ -187,7 +187,7 @@ class EncryptedFileStorage(CredentialStorage):
self._update_index(credential.id, "save", credential.credential_type.value)
logger.debug(f"Saved encrypted credential '{credential.id}'")
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""Load and decrypt credential."""
cred_path = self._cred_path(credential_id)
if not cred_path.exists():
@@ -202,7 +202,9 @@ class EncryptedFileStorage(CredentialStorage):
json_bytes = self._fernet.decrypt(encrypted)
data = json.loads(json_bytes.decode())
except Exception as e:
raise CredentialDecryptionError(f"Failed to decrypt credential '{credential_id}': {e}") from e
raise CredentialDecryptionError(
f"Failed to decrypt credential '{credential_id}': {e}"
) from e
# Deserialize
return self._deserialize_credential(data)
@@ -217,7 +219,7 @@ class EncryptedFileStorage(CredentialStorage):
return True
return False
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""List all credential IDs."""
index_path = self.base_path / "metadata" / "index.json"
if not index_path.exists():
@@ -230,7 +232,7 @@ class EncryptedFileStorage(CredentialStorage):
"""Check if credential exists."""
return self._cred_path(credential_id).exists()
def _serialize_credential(self, credential: CredentialObject) -> Dict[str, Any]:
def _serialize_credential(self, credential: CredentialObject) -> dict[str, Any]:
"""Convert credential to JSON-serializable dict, extracting secret values."""
data = credential.model_dump(mode="json")
@@ -244,7 +246,7 @@ class EncryptedFileStorage(CredentialStorage):
return data
def _deserialize_credential(self, data: Dict[str, Any]) -> CredentialObject:
def _deserialize_credential(self, data: dict[str, Any]) -> CredentialObject:
"""Reconstruct credential from dict, wrapping values in SecretStr."""
# Convert plain values back to SecretStr
for key_data in data.get("keys", {}).values():
@@ -257,7 +259,7 @@ class EncryptedFileStorage(CredentialStorage):
self,
credential_id: str,
operation: str,
credential_type: Optional[str] = None,
credential_type: str | None = None,
) -> None:
"""Update the metadata index."""
index_path = self.base_path / "metadata" / "index.json"
@@ -270,13 +272,13 @@ class EncryptedFileStorage(CredentialStorage):
if operation == "save":
index["credentials"][credential_id] = {
"updated_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"type": credential_type,
}
elif operation == "delete":
index["credentials"].pop(credential_id, None)
index["last_modified"] = datetime.now(timezone.utc).isoformat()
index["last_modified"] = datetime.now(UTC).isoformat()
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
@@ -301,8 +303,8 @@ class EnvVarStorage(CredentialStorage):
def __init__(
self,
env_mapping: Optional[Dict[str, str]] = None,
dotenv_path: Optional[Path] = None,
env_mapping: dict[str, str] | None = None,
dotenv_path: Path | None = None,
):
"""
Initialize env var storage.
@@ -323,7 +325,7 @@ class EnvVarStorage(CredentialStorage):
# Default pattern: CREDENTIAL_ID_API_KEY
return f"{credential_id.upper().replace('-', '_')}_API_KEY"
def _read_env_value(self, env_var: str) -> Optional[str]:
def _read_env_value(self, env_var: str) -> str | None:
"""Read value from env var or .env file."""
# Check os.environ first (takes precedence)
value = os.environ.get(env_var)
@@ -346,10 +348,11 @@ class EnvVarStorage(CredentialStorage):
def save(self, credential: CredentialObject) -> None:
"""Cannot save to environment variables at runtime."""
raise NotImplementedError(
"EnvVarStorage is read-only. Set environment variables externally or use EncryptedFileStorage."
"EnvVarStorage is read-only. Set environment variables "
"externally or use EncryptedFileStorage."
)
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from environment variable."""
env_var = self._get_env_var_name(credential_id)
value = self._read_env_value(env_var)
@@ -366,9 +369,11 @@ class EnvVarStorage(CredentialStorage):
def delete(self, credential_id: str) -> bool:
"""Cannot delete environment variables at runtime."""
raise NotImplementedError("EnvVarStorage is read-only. Unset environment variables externally.")
raise NotImplementedError(
"EnvVarStorage is read-only. Unset environment variables externally."
)
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""List credentials that are available in environment."""
available = []
@@ -407,20 +412,20 @@ class InMemoryStorage(CredentialStorage):
credential = storage.load("test_cred")
"""
def __init__(self, initial_data: Optional[Dict[str, CredentialObject]] = None):
def __init__(self, initial_data: dict[str, CredentialObject] | None = None):
"""
Initialize in-memory storage.
Args:
initial_data: Optional dict of credential_id -> CredentialObject
"""
self._data: Dict[str, CredentialObject] = initial_data or {}
self._data: dict[str, CredentialObject] = initial_data or {}
def save(self, credential: CredentialObject) -> None:
"""Save credential to memory."""
self._data[credential.id] = credential
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from memory."""
return self._data.get(credential_id)
@@ -431,7 +436,7 @@ class InMemoryStorage(CredentialStorage):
return True
return False
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""List all credential IDs."""
return list(self._data.keys())
@@ -462,7 +467,7 @@ class CompositeStorage(CredentialStorage):
def __init__(
self,
primary: CredentialStorage,
fallbacks: Optional[List[CredentialStorage]] = None,
fallbacks: list[CredentialStorage] | None = None,
):
"""
Initialize composite storage.
@@ -478,7 +483,7 @@ class CompositeStorage(CredentialStorage):
"""Save to primary storage."""
self._primary.save(credential)
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""Load from primary, then fallbacks."""
# Try primary first
credential = self._primary.load(credential_id)
@@ -497,7 +502,7 @@ class CompositeStorage(CredentialStorage):
"""Delete from primary storage only."""
return self._primary.delete(credential_id)
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""List credentials from all storages."""
all_ids = set(self._primary.list_all())
for fallback in self._fallbacks:
+35 -35
View File
@@ -13,17 +13,15 @@ from __future__ import annotations
import logging
import threading
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from datetime import UTC, datetime
from typing import Any
from pydantic import SecretStr
from .models import (
CredentialError,
CredentialKey,
CredentialObject,
CredentialRefreshError,
CredentialType,
CredentialUsageSpec,
)
from .provider import CredentialProvider, StaticProvider
@@ -69,8 +67,8 @@ class CredentialStore:
def __init__(
self,
storage: Optional[CredentialStorage] = None,
providers: Optional[List[CredentialProvider]] = None,
storage: CredentialStorage | None = None,
providers: list[CredentialProvider] | None = None,
cache_ttl_seconds: int = 300,
auto_refresh: bool = True,
):
@@ -84,11 +82,11 @@ class CredentialStore:
auto_refresh: Whether to auto-refresh expired credentials on access.
"""
self._storage = storage or EnvVarStorage()
self._providers: Dict[str, CredentialProvider] = {}
self._usage_specs: Dict[str, CredentialUsageSpec] = {}
self._providers: dict[str, CredentialProvider] = {}
self._usage_specs: dict[str, CredentialUsageSpec] = {}
# Cache: credential_id -> (CredentialObject, cached_at)
self._cache: Dict[str, tuple[CredentialObject, datetime]] = {}
self._cache: dict[str, tuple[CredentialObject, datetime]] = {}
self._cache_ttl = cache_ttl_seconds
self._lock = threading.RLock()
@@ -113,7 +111,7 @@ class CredentialStore:
self._providers[provider.provider_id] = provider
logger.debug(f"Registered credential provider: {provider.provider_id}")
def get_provider(self, provider_id: str) -> Optional[CredentialProvider]:
def get_provider(self, provider_id: str) -> CredentialProvider | None:
"""
Get a provider by ID.
@@ -125,7 +123,9 @@ class CredentialStore:
"""
return self._providers.get(provider_id)
def get_provider_for_credential(self, credential: CredentialObject) -> Optional[CredentialProvider]:
def get_provider_for_credential(
self, credential: CredentialObject
) -> CredentialProvider | None:
"""
Get the appropriate provider for a credential.
@@ -159,7 +159,7 @@ class CredentialStore:
"""
self._usage_specs[spec.credential_id] = spec
def get_usage_spec(self, credential_id: str) -> Optional[CredentialUsageSpec]:
def get_usage_spec(self, credential_id: str) -> CredentialUsageSpec | None:
"""
Get the usage spec for a credential.
@@ -177,7 +177,7 @@ class CredentialStore:
self,
credential_id: str,
refresh_if_needed: bool = True,
) -> Optional[CredentialObject]:
) -> CredentialObject | None:
"""
Get a credential by ID.
@@ -210,7 +210,7 @@ class CredentialStore:
return credential
def get_key(self, credential_id: str, key_name: str) -> Optional[str]:
def get_key(self, credential_id: str, key_name: str) -> str | None:
"""
Convenience method to get a specific key value.
@@ -226,7 +226,7 @@ class CredentialStore:
return None
return credential.get_key(key_name)
def get(self, credential_id: str) -> Optional[str]:
def get(self, credential_id: str) -> str | None:
"""
Legacy compatibility: get the primary key value.
@@ -262,7 +262,7 @@ class CredentialStore:
"""
return self._resolver.resolve(template)
def resolve_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""
Resolve credential templates in headers dictionary.
@@ -280,7 +280,7 @@ class CredentialStore:
"""
return self._resolver.resolve_headers(headers)
def resolve_params(self, params: Dict[str, str]) -> Dict[str, str]:
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
"""
Resolve credential templates in query parameters dictionary.
@@ -292,7 +292,7 @@ class CredentialStore:
"""
return self._resolver.resolve_params(params)
def resolve_for_usage(self, credential_id: str) -> Dict[str, Any]:
def resolve_for_usage(self, credential_id: str) -> dict[str, Any]:
"""
Get resolved request kwargs for a registered usage spec.
@@ -309,7 +309,7 @@ class CredentialStore:
if spec is None:
raise ValueError(f"No usage spec registered for '{credential_id}'")
result: Dict[str, Any] = {}
result: dict[str, Any] = {}
if spec.headers:
result["headers"] = self.resolve_headers(spec.headers)
@@ -353,7 +353,7 @@ class CredentialStore:
logger.info(f"Deleted credential '{credential_id}'")
return result
def list_credentials(self) -> List[str]:
def list_credentials(self) -> list[str]:
"""
List all available credential IDs.
@@ -376,7 +376,7 @@ class CredentialStore:
# --- Validation ---
def validate_for_usage(self, credential_id: str) -> List[str]:
def validate_for_usage(self, credential_id: str) -> list[str]:
"""
Validate that a credential meets its usage spec requirements.
@@ -401,7 +401,7 @@ class CredentialStore:
return errors
def validate_all(self) -> Dict[str, List[str]]:
def validate_all(self) -> dict[str, list[str]]:
"""
Validate all registered usage specs.
@@ -462,7 +462,7 @@ class CredentialStore:
try:
refreshed = provider.refresh(credential)
refreshed.last_refreshed = datetime.now(timezone.utc)
refreshed.last_refreshed = datetime.now(UTC)
# Persist the refreshed credential
self._storage.save(refreshed)
@@ -475,7 +475,7 @@ class CredentialStore:
logger.error(f"Failed to refresh credential '{credential.id}': {e}")
return credential
def refresh_credential(self, credential_id: str) -> Optional[CredentialObject]:
def refresh_credential(self, credential_id: str) -> CredentialObject | None:
"""
Manually refresh a credential.
@@ -496,13 +496,13 @@ class CredentialStore:
# --- Caching ---
def _get_from_cache(self, credential_id: str) -> Optional[CredentialObject]:
def _get_from_cache(self, credential_id: str) -> CredentialObject | None:
"""Get credential from cache if not expired."""
if credential_id not in self._cache:
return None
credential, cached_at = self._cache[credential_id]
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
age = (datetime.now(UTC) - cached_at).total_seconds()
if age > self._cache_ttl:
del self._cache[credential_id]
@@ -512,7 +512,7 @@ class CredentialStore:
def _add_to_cache(self, credential: CredentialObject) -> None:
"""Add credential to cache."""
self._cache[credential.id] = (credential, datetime.now(timezone.utc))
self._cache[credential.id] = (credential, datetime.now(UTC))
def _remove_from_cache(self, credential_id: str) -> None:
"""Remove credential from cache."""
@@ -528,8 +528,8 @@ class CredentialStore:
@classmethod
def for_testing(
cls,
credentials: Dict[str, Dict[str, str]],
) -> "CredentialStore":
credentials: dict[str, dict[str, str]],
) -> CredentialStore:
"""
Create a credential store for testing with mock credentials.
@@ -550,7 +550,7 @@ class CredentialStore:
})
"""
# Convert test data to CredentialObjects
cred_objects: Dict[str, CredentialObject] = {}
cred_objects: dict[str, CredentialObject] = {}
for cred_id, keys in credentials.items():
cred_objects[cred_id] = CredentialObject(
@@ -567,9 +567,9 @@ class CredentialStore:
def with_encrypted_storage(
cls,
base_path: str,
providers: Optional[List[CredentialProvider]] = None,
providers: list[CredentialProvider] | None = None,
**kwargs: Any,
) -> "CredentialStore":
) -> CredentialStore:
"""
Create a credential store with encrypted file storage.
@@ -592,10 +592,10 @@ class CredentialStore:
@classmethod
def with_env_storage(
cls,
env_mapping: Optional[Dict[str, str]] = None,
providers: Optional[List[CredentialProvider]] = None,
env_mapping: dict[str, str] | None = None,
providers: list[CredentialProvider] | None = None,
**kwargs: Any,
) -> "CredentialStore":
) -> CredentialStore:
"""
Create a credential store with environment variable storage.
+15 -11
View File
@@ -17,7 +17,7 @@ Examples:
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING
from .models import CredentialKeyNotFoundError, CredentialNotFoundError
@@ -45,7 +45,7 @@ class TemplateResolver:
# Matches {{credential_id}} or {{credential_id.key_name}}
TEMPLATE_PATTERN = re.compile(r"\{\{([a-zA-Z0-9_-]+)(?:\.([a-zA-Z0-9_-]+))?\}\}")
def __init__(self, credential_store: "CredentialStore"):
def __init__(self, credential_store: CredentialStore):
"""
Initialize the template resolver.
@@ -88,7 +88,9 @@ class TemplateResolver:
if key_name:
value = credential.get_key(key_name)
if value is None:
raise CredentialKeyNotFoundError(f"Key '{key_name}' not found in credential '{cred_id}'")
raise CredentialKeyNotFoundError(
f"Key '{key_name}' not found in credential '{cred_id}'"
)
else:
# Use default key
value = credential.get_default_key()
@@ -104,9 +106,9 @@ class TemplateResolver:
def resolve_headers(
self,
header_templates: Dict[str, str],
header_templates: dict[str, str],
fail_on_missing: bool = True,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Resolve templates in a headers dictionary.
@@ -124,13 +126,15 @@ class TemplateResolver:
... })
{"Authorization": "Bearer ghp_xxx", "X-API-Key": "BSAKxxx"}
"""
return {key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()}
return {
key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()
}
def resolve_params(
self,
param_templates: Dict[str, str],
param_templates: dict[str, str],
fail_on_missing: bool = True,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Resolve templates in a query parameters dictionary.
@@ -155,7 +159,7 @@ class TemplateResolver:
"""
return bool(self.TEMPLATE_PATTERN.search(text))
def extract_references(self, text: str) -> List[Tuple[str, Optional[str]]]:
def extract_references(self, text: str) -> list[tuple[str, str | None]]:
"""
Extract all credential references from text.
@@ -172,7 +176,7 @@ class TemplateResolver:
"""
return [(match.group(1), match.group(2)) for match in self.TEMPLATE_PATTERN.finditer(text)]
def validate_references(self, text: str) -> List[str]:
def validate_references(self, text: str) -> list[str]:
"""
Validate all credential references in text without resolving.
@@ -201,7 +205,7 @@ class TemplateResolver:
return errors
def get_required_credentials(self, text: str) -> List[str]:
def get_required_credentials(self, text: str) -> list[str]:
"""
Get list of credential IDs required by a template string.
@@ -12,24 +12,17 @@ Tests cover:
import os
import tempfile
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from pydantic import SecretStr
from core.framework.credentials import (
BearerTokenProvider,
CompositeStorage,
CredentialError,
CredentialKey,
CredentialKeyNotFoundError,
CredentialNotFoundError,
CredentialObject,
CredentialProvider,
CredentialRefreshError,
CredentialStorage,
CredentialStore,
CredentialType,
CredentialUsageSpec,
@@ -39,6 +32,7 @@ from core.framework.credentials import (
StaticProvider,
TemplateResolver,
)
from pydantic import SecretStr
class TestCredentialKey:
@@ -54,13 +48,13 @@ class TestCredentialKey:
def test_key_with_expiration(self):
"""Test key with expiration time."""
future = datetime.now(timezone.utc) + timedelta(hours=1)
future = datetime.now(UTC) + timedelta(hours=1)
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=future)
assert not key.is_expired
def test_expired_key(self):
"""Test that expired key is detected."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
past = datetime.now(UTC) - timedelta(hours=1)
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)
assert key.is_expired
@@ -111,13 +105,13 @@ class TestCredentialObject:
def test_set_key_with_expiration(self):
"""Test setting a key with expiration."""
cred = CredentialObject(id="test", keys={})
expires = datetime.now(timezone.utc) + timedelta(hours=1)
expires = datetime.now(UTC) + timedelta(hours=1)
cred.set_key("token", "xxx", expires_at=expires)
assert cred.keys["token"].expires_at == expires
def test_needs_refresh(self):
"""Test needs_refresh property."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
past = datetime.now(UTC) - timedelta(hours=1)
cred = CredentialObject(
id="test",
keys={"token": CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)},
@@ -136,7 +130,9 @@ class TestCredentialObject:
# With access_token
cred2 = CredentialObject(
id="test",
keys={"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))},
keys={
"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))
},
)
assert cred2.get_default_key() == "token-value"
@@ -301,7 +297,9 @@ class TestEncryptedFileStorage:
key = Fernet.generate_key().decode()
with patch.dict(os.environ, {"HIVE_CREDENTIAL_KEY": key}):
storage = EncryptedFileStorage(temp_dir)
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
storage.save(cred)
# Create new storage instance with same key
@@ -332,10 +330,18 @@ class TestCompositeStorage:
def test_read_from_primary(self):
"""Test reading from primary storage."""
primary = InMemoryStorage()
primary.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}))
primary.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}
)
)
fallback = InMemoryStorage()
fallback.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}))
fallback.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
)
)
storage = CompositeStorage(primary, [fallback])
cred = storage.load("test")
@@ -347,7 +353,11 @@ class TestCompositeStorage:
"""Test fallback when credential not in primary."""
primary = InMemoryStorage()
fallback = InMemoryStorage()
fallback.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}))
fallback.save(
CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
)
)
storage = CompositeStorage(primary, [fallback])
cred = storage.load("test")
@@ -383,7 +393,9 @@ class TestStaticProvider:
def test_refresh_returns_unchanged(self):
"""Test that refresh returns credential unchanged."""
provider = StaticProvider()
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
refreshed = provider.refresh(cred)
assert refreshed.get_key("k") == "v"
@@ -391,7 +403,9 @@ class TestStaticProvider:
def test_validate_with_keys(self):
"""Test validation with keys present."""
provider = StaticProvider()
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
cred = CredentialObject(
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
)
assert provider.validate(cred)
@@ -592,10 +606,12 @@ class TestCredentialStore:
storage = InMemoryStorage()
store = CredentialStore(storage=storage, cache_ttl_seconds=60)
storage.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}))
storage.save(
CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
)
# First load
cred1 = store.get_credential("test")
store.get_credential("test")
# Delete from storage
storage.delete("test")
@@ -646,12 +662,12 @@ class TestOAuth2Module:
from core.framework.credentials.oauth2 import OAuth2Token
# Not expired
future = datetime.now(timezone.utc) + timedelta(hours=1)
future = datetime.now(UTC) + timedelta(hours=1)
token = OAuth2Token(access_token="xxx", expires_at=future)
assert not token.is_expired
# Expired
past = datetime.now(timezone.utc) - timedelta(hours=1)
past = datetime.now(UTC) - timedelta(hours=1)
expired_token = OAuth2Token(access_token="xxx", expires_at=past)
assert expired_token.is_expired
@@ -670,7 +686,9 @@ class TestOAuth2Module:
from core.framework.credentials.oauth2 import OAuth2Config, TokenPlacement
# Valid config
config = OAuth2Config(token_url="https://example.com/token", client_id="id", client_secret="secret")
config = OAuth2Config(
token_url="https://example.com/token", client_id="id", client_secret="secret"
)
assert config.token_url == "https://example.com/token"
# Missing token_url
+17 -15
View File
@@ -10,7 +10,7 @@ from __future__ import annotations
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import SecretStr
@@ -72,10 +72,10 @@ class HashiCorpVaultStorage(CredentialStorage):
def __init__(
self,
url: str,
token: Optional[str] = None,
token: str | None = None,
mount_point: str = "secret",
path_prefix: str = "hive/credentials",
namespace: Optional[str] = None,
namespace: str | None = None,
verify_ssl: bool = True,
):
"""
@@ -107,7 +107,9 @@ class HashiCorpVaultStorage(CredentialStorage):
self._namespace = namespace
if not self._token:
raise ValueError("Vault token required. Set VAULT_TOKEN env var or pass token parameter.")
raise ValueError(
"Vault token required. Set VAULT_TOKEN env var or pass token parameter."
)
self._client = hvac.Client(
url=url,
@@ -143,7 +145,7 @@ class HashiCorpVaultStorage(CredentialStorage):
logger.error(f"Failed to save credential '{credential.id}' to Vault: {e}")
raise
def load(self, credential_id: str) -> Optional[CredentialObject]:
def load(self, credential_id: str) -> CredentialObject | None:
"""Load credential from Vault."""
path = self._path(credential_id)
@@ -181,7 +183,7 @@ class HashiCorpVaultStorage(CredentialStorage):
logger.error(f"Failed to delete credential '{credential_id}' from Vault: {e}")
raise
def list_all(self) -> List[str]:
def list_all(self) -> list[str]:
"""List all credentials under the prefix."""
try:
response = self._client.secrets.kv.v2.list_secrets(
@@ -210,9 +212,9 @@ class HashiCorpVaultStorage(CredentialStorage):
except Exception:
return False
def _serialize_for_vault(self, credential: CredentialObject) -> Dict[str, Any]:
def _serialize_for_vault(self, credential: CredentialObject) -> dict[str, Any]:
"""Convert credential to Vault secret format."""
data: Dict[str, Any] = {
data: dict[str, Any] = {
"_type": credential.credential_type.value,
}
@@ -237,7 +239,7 @@ class HashiCorpVaultStorage(CredentialStorage):
return data
def _deserialize_from_vault(self, credential_id: str, data: Dict[str, Any]) -> CredentialObject:
def _deserialize_from_vault(self, credential_id: str, data: dict[str, Any]) -> CredentialObject:
"""Reconstruct credential from Vault secret."""
# Extract metadata fields
cred_type = CredentialType(data.pop("_type", "api_key"))
@@ -246,7 +248,7 @@ class HashiCorpVaultStorage(CredentialStorage):
auto_refresh = data.pop("_auto_refresh", "") == "true"
# Build keys dict
keys: Dict[str, CredentialKey] = {}
keys: dict[str, CredentialKey] = {}
# Find all non-metadata keys
key_names = [k for k in data.keys() if not k.startswith("_")]
@@ -264,7 +266,7 @@ class HashiCorpVaultStorage(CredentialStorage):
pass
# Check for metadata
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata_key = f"_metadata_{key_name}"
if metadata_key in data:
try:
@@ -292,7 +294,7 @@ class HashiCorpVaultStorage(CredentialStorage):
# --- Vault-Specific Operations ---
def get_secret_metadata(self, credential_id: str) -> Optional[Dict[str, Any]]:
def get_secret_metadata(self, credential_id: str) -> dict[str, Any] | None:
"""
Get Vault metadata for a secret (version info, timestamps, etc.).
@@ -313,7 +315,7 @@ class HashiCorpVaultStorage(CredentialStorage):
except Exception:
return None
def soft_delete(self, credential_id: str, versions: Optional[List[int]] = None) -> bool:
def soft_delete(self, credential_id: str, versions: list[int] | None = None) -> bool:
"""
Soft delete specific versions (can be recovered).
@@ -343,7 +345,7 @@ class HashiCorpVaultStorage(CredentialStorage):
logger.error(f"Soft delete failed for '{credential_id}': {e}")
return False
def undelete(self, credential_id: str, versions: List[int]) -> bool:
def undelete(self, credential_id: str, versions: list[int]) -> bool:
"""
Recover soft-deleted versions.
@@ -367,7 +369,7 @@ class HashiCorpVaultStorage(CredentialStorage):
logger.error(f"Undelete failed for '{credential_id}': {e}")
return False
def load_version(self, credential_id: str, version: int) -> Optional[CredentialObject]:
def load_version(self, credential_id: str, version: int) -> CredentialObject | None:
"""
Load a specific version of a credential.
@@ -26,7 +26,7 @@ Usage:
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING
from .base import CredentialError, CredentialSpec
@@ -55,8 +55,8 @@ class CredentialStoreAdapter:
def __init__(
self,
store: "CredentialStore",
specs: Optional[Dict[str, CredentialSpec]] = None,
store: CredentialStore,
specs: dict[str, CredentialSpec] | None = None,
):
"""
Initialize the adapter.
@@ -74,8 +74,8 @@ class CredentialStoreAdapter:
self._specs = specs
# Build reverse mappings for validation
self._tool_to_cred: Dict[str, str] = {}
self._node_type_to_cred: Dict[str, str] = {}
self._tool_to_cred: dict[str, str] = {}
self._node_type_to_cred: dict[str, str] = {}
for cred_name, spec in self._specs.items():
for tool_name in spec.tools:
@@ -85,7 +85,7 @@ class CredentialStoreAdapter:
# --- Existing CredentialManager API ---
def get(self, name: str) -> Optional[str]:
def get(self, name: str) -> str | None:
"""
Get a credential value by logical name.
@@ -117,7 +117,7 @@ class CredentialStoreAdapter:
value = self._store.get(name)
return value is not None and value != ""
def get_credential_for_tool(self, tool_name: str) -> Optional[str]:
def get_credential_for_tool(self, tool_name: str) -> str | None:
"""
Get the credential name required by a tool.
@@ -129,7 +129,7 @@ class CredentialStoreAdapter:
"""
return self._tool_to_cred.get(tool_name)
def get_missing_for_tools(self, tool_names: List[str]) -> List[Tuple[str, CredentialSpec]]:
def get_missing_for_tools(self, tool_names: list[str]) -> list[tuple[str, CredentialSpec]]:
"""
Get list of missing credentials for the given tools.
@@ -139,7 +139,7 @@ class CredentialStoreAdapter:
Returns:
List of (credential_name, spec) tuples for missing credentials
"""
missing: List[Tuple[str, CredentialSpec]] = []
missing: list[tuple[str, CredentialSpec]] = []
checked: set[str] = set()
for tool_name in tool_names:
@@ -156,7 +156,7 @@ class CredentialStoreAdapter:
return missing
def validate_for_tools(self, tool_names: List[str]) -> None:
def validate_for_tools(self, tool_names: list[str]) -> None:
"""
Validate that all credentials required by the given tools are available.
@@ -170,9 +170,9 @@ class CredentialStoreAdapter:
if missing:
raise CredentialError(self._format_missing_error(missing, tool_names))
def get_missing_for_node_types(self, node_types: List[str]) -> List[Tuple[str, CredentialSpec]]:
def get_missing_for_node_types(self, node_types: list[str]) -> list[tuple[str, CredentialSpec]]:
"""Get list of missing credentials for the given node types."""
missing: List[Tuple[str, CredentialSpec]] = []
missing: list[tuple[str, CredentialSpec]] = []
checked: set[str] = set()
for node_type in node_types:
@@ -189,7 +189,7 @@ class CredentialStoreAdapter:
return missing
def validate_for_node_types(self, node_types: List[str]) -> None:
def validate_for_node_types(self, node_types: list[str]) -> None:
"""
Validate that all credentials required by the given node types are available.
@@ -210,7 +210,7 @@ class CredentialStoreAdapter:
Raises:
CredentialError: If any startup-required credentials are missing
"""
missing: List[Tuple[str, CredentialSpec]] = []
missing: list[tuple[str, CredentialSpec]] = []
for cred_name, spec in self._specs.items():
if spec.startup_required and not self.is_available(cred_name):
@@ -221,7 +221,7 @@ class CredentialStoreAdapter:
# --- New CredentialStore Features ---
def get_key(self, credential_id: str, key_name: str) -> Optional[str]:
def get_key(self, credential_id: str, key_name: str) -> str | None:
"""
Get a specific key from a multi-key credential.
@@ -250,7 +250,7 @@ class CredentialStoreAdapter:
"""
return self._store.resolve(template)
def resolve_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
"""
Resolve credential templates in headers dictionary.
@@ -268,12 +268,12 @@ class CredentialStoreAdapter:
"""
return self._store.resolve_headers(headers)
def resolve_params(self, params: Dict[str, str]) -> Dict[str, str]:
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
"""Resolve credential templates in query parameters."""
return self._store.resolve_params(params)
@property
def store(self) -> "CredentialStore":
def store(self) -> CredentialStore:
"""Access the underlying credential store for advanced operations."""
return self._store
@@ -281,14 +281,14 @@ class CredentialStoreAdapter:
def _format_missing_error(
self,
missing: List[Tuple[str, CredentialSpec]],
tool_names: List[str],
missing: list[tuple[str, CredentialSpec]],
tool_names: list[str],
) -> str:
"""Format a clear, actionable error message for missing credentials."""
lines = ["Cannot run agent: Missing credentials\n"]
lines.append("The following tools require credentials that are not set:\n")
for cred_name, spec in missing:
for _cred_name, spec in missing:
affected_tools = [t for t in tool_names if t in spec.tools]
tools_str = ", ".join(affected_tools)
@@ -305,14 +305,14 @@ class CredentialStoreAdapter:
def _format_missing_node_type_error(
self,
missing: List[Tuple[str, CredentialSpec]],
node_types: List[str],
missing: list[tuple[str, CredentialSpec]],
node_types: list[str],
) -> str:
"""Format a clear, actionable error message for missing node type credentials."""
lines = ["Cannot run agent: Missing credentials\n"]
lines.append("The following node types require credentials that are not set:\n")
for cred_name, spec in missing:
for _cred_name, spec in missing:
affected_types = [t for t in node_types if t in spec.node_types]
types_str = ", ".join(affected_types)
@@ -329,12 +329,12 @@ class CredentialStoreAdapter:
def _format_startup_error(
self,
missing: List[Tuple[str, CredentialSpec]],
missing: list[tuple[str, CredentialSpec]],
) -> str:
"""Format a clear, actionable error message for missing startup credentials."""
lines = ["Server startup failed: Missing required credentials\n"]
for cred_name, spec in missing:
for _cred_name, spec in missing:
lines.append(f" {spec.env_var}")
if spec.description:
lines.append(f" {spec.description}")
@@ -351,9 +351,9 @@ class CredentialStoreAdapter:
@classmethod
def for_testing(
cls,
overrides: Dict[str, str],
specs: Optional[Dict[str, CredentialSpec]] = None,
) -> "CredentialStoreAdapter":
overrides: dict[str, str],
specs: dict[str, CredentialSpec] | None = None,
) -> CredentialStoreAdapter:
"""
Create a CredentialStoreAdapter for testing with mock credentials.
@@ -380,9 +380,9 @@ class CredentialStoreAdapter:
@classmethod
def with_env_storage(
cls,
env_mapping: Optional[Dict[str, str]] = None,
specs: Optional[Dict[str, CredentialSpec]] = None,
) -> "CredentialStoreAdapter":
env_mapping: dict[str, str] | None = None,
specs: dict[str, CredentialSpec] | None = None,
) -> CredentialStoreAdapter:
"""
Create adapter with environment variable storage (current behavior).