Merge pull request #1532 from TimothyZhang7/main
chore: fix lint issues
This commit is contained in:
@@ -7,16 +7,16 @@ containing one or more keys (e.g., api_key, access_token, refresh_token).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(timezone.utc)
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class CredentialType(str, Enum):
|
||||
@@ -53,8 +53,8 @@ class CredentialKey(BaseModel):
|
||||
|
||||
name: str
|
||||
value: SecretStr
|
||||
expires_at: Optional[datetime] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
expires_at: datetime | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
@@ -63,7 +63,7 @@ class CredentialKey(BaseModel):
|
||||
"""Check if this key has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
"""Get the actual secret value (use sparingly)."""
|
||||
@@ -98,28 +98,29 @@ class CredentialObject(BaseModel):
|
||||
|
||||
id: str = Field(description="Unique identifier (e.g., 'brave_search', 'github_oauth')")
|
||||
credential_type: CredentialType = CredentialType.API_KEY
|
||||
keys: Dict[str, CredentialKey] = Field(default_factory=dict)
|
||||
keys: dict[str, CredentialKey] = Field(default_factory=dict)
|
||||
|
||||
# Lifecycle management
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')"
|
||||
provider_id: str | None = Field(
|
||||
default=None,
|
||||
description="ID of provider responsible for lifecycle (e.g., 'oauth2', 'static')",
|
||||
)
|
||||
last_refreshed: Optional[datetime] = None
|
||||
last_refreshed: datetime | None = None
|
||||
auto_refresh: bool = False
|
||||
|
||||
# Usage tracking
|
||||
last_used: Optional[datetime] = None
|
||||
last_used: datetime | None = None
|
||||
use_count: int = 0
|
||||
|
||||
# Metadata
|
||||
description: str = ""
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
updated_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
def get_key(self, key_name: str) -> Optional[str]:
|
||||
def get_key(self, key_name: str) -> str | None:
|
||||
"""
|
||||
Get a specific key's value.
|
||||
|
||||
@@ -138,8 +139,8 @@ class CredentialObject(BaseModel):
|
||||
self,
|
||||
key_name: str,
|
||||
value: str,
|
||||
expires_at: Optional[datetime] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
expires_at: datetime | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set or update a key.
|
||||
@@ -156,7 +157,7 @@ class CredentialObject(BaseModel):
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
def has_key(self, key_name: str) -> bool:
|
||||
"""Check if a key exists."""
|
||||
@@ -179,10 +180,10 @@ class CredentialObject(BaseModel):
|
||||
|
||||
def record_usage(self) -> None:
|
||||
"""Record that this credential was used."""
|
||||
self.last_used = datetime.now(timezone.utc)
|
||||
self.last_used = datetime.now(UTC)
|
||||
self.use_count += 1
|
||||
|
||||
def get_default_key(self) -> Optional[str]:
|
||||
def get_default_key(self) -> str | None:
|
||||
"""
|
||||
Get the default key value.
|
||||
|
||||
@@ -232,18 +233,18 @@ class CredentialUsageSpec(BaseModel):
|
||||
"""
|
||||
|
||||
credential_id: str = Field(description="ID of credential to use (e.g., 'brave_search')")
|
||||
required_keys: List[str] = Field(default_factory=list, description="Keys that must be present")
|
||||
required_keys: list[str] = Field(default_factory=list, description="Keys that must be present")
|
||||
|
||||
# Injection templates (bipartisan model)
|
||||
headers: Dict[str, str] = Field(
|
||||
headers: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Header templates (e.g., {'Authorization': 'Bearer {{access_token}}'})",
|
||||
)
|
||||
query_params: Dict[str, str] = Field(
|
||||
query_params: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Query param templates (e.g., {'api_key': '{{api_key}}'})",
|
||||
)
|
||||
body_fields: Dict[str, str] = Field(
|
||||
body_fields: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Request body field templates",
|
||||
)
|
||||
|
||||
@@ -8,13 +8,18 @@ OAuth2 servers. OSS users can extend this class for custom providers.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ..models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..provider import CredentialProvider
|
||||
from .provider import OAuth2Config, OAuth2Error, OAuth2Token, RefreshTokenInvalidError, TokenPlacement
|
||||
from .provider import (
|
||||
OAuth2Config,
|
||||
OAuth2Error,
|
||||
OAuth2Token,
|
||||
TokenPlacement,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,14 +77,14 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
"""
|
||||
self.config = config
|
||||
self._provider_id = provider_id
|
||||
self._client: Optional[Any] = None
|
||||
self._client: Any | None = None
|
||||
|
||||
@property
|
||||
def provider_id(self) -> str:
|
||||
return self._provider_id
|
||||
|
||||
@property
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.OAUTH2, CredentialType.BEARER_TOKEN]
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
@@ -90,7 +95,9 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
|
||||
self._client = httpx.Client(timeout=self.config.request_timeout)
|
||||
except ImportError as e:
|
||||
raise ImportError("OAuth2 provider requires 'httpx'. Install with: pip install httpx") from e
|
||||
raise ImportError(
|
||||
"OAuth2 provider requires 'httpx'. Install with: pip install httpx"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _close_client(self) -> None:
|
||||
@@ -109,7 +116,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
self,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -175,7 +182,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
|
||||
def client_credentials_grant(
|
||||
self,
|
||||
scopes: Optional[List[str]] = None,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
@@ -209,7 +216,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
def refresh_access_token(
|
||||
self,
|
||||
refresh_token: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
@@ -316,7 +323,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
if new_token.refresh_token and new_token.refresh_token != refresh_tok:
|
||||
credential.set_key("refresh_token", new_token.refresh_token)
|
||||
|
||||
credential.last_refreshed = datetime.now(timezone.utc)
|
||||
credential.last_refreshed = datetime.now(UTC)
|
||||
logger.info(f"Refreshed OAuth2 credential '{credential.id}'")
|
||||
|
||||
return credential
|
||||
@@ -350,7 +357,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
return False
|
||||
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(timezone.utc) >= (access_key.expires_at - buffer)
|
||||
return datetime.now(UTC) >= (access_key.expires_at - buffer)
|
||||
|
||||
def revoke(self, credential: CredentialObject) -> bool:
|
||||
"""
|
||||
@@ -380,7 +387,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
|
||||
# --- Token Request Helpers ---
|
||||
|
||||
def _token_request(self, data: Dict[str, Any]) -> OAuth2Token:
|
||||
def _token_request(self, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Make a token request to the OAuth2 server.
|
||||
|
||||
@@ -415,11 +422,13 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
if response.status_code != 200 or "error" in response_data:
|
||||
error = response_data.get("error", "unknown_error")
|
||||
description = response_data.get("error_description", response.text)
|
||||
raise OAuth2Error(error=error, description=description, status_code=response.status_code)
|
||||
raise OAuth2Error(
|
||||
error=error, description=description, status_code=response.status_code
|
||||
)
|
||||
|
||||
return OAuth2Token.from_token_response(response_data)
|
||||
|
||||
def _parse_form_response(self, text: str) -> Dict[str, str]:
|
||||
def _parse_form_response(self, text: str) -> dict[str, str]:
|
||||
"""Parse form-encoded response (some providers use this instead of JSON)."""
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -428,7 +437,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
|
||||
# --- Token Formatting for Requests ---
|
||||
|
||||
def format_for_request(self, token: OAuth2Token) -> Dict[str, Any]:
|
||||
def format_for_request(self, token: OAuth2Token) -> dict[str, Any]:
|
||||
"""
|
||||
Format token for use in HTTP requests (bipartisan model).
|
||||
|
||||
@@ -455,7 +464,7 @@ class BaseOAuth2Provider(CredentialProvider):
|
||||
|
||||
return {}
|
||||
|
||||
def format_credential_for_request(self, credential: CredentialObject) -> Dict[str, Any]:
|
||||
def format_credential_for_request(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""
|
||||
Format a credential for use in HTTP requests.
|
||||
|
||||
|
||||
@@ -9,13 +9,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from ..models import CredentialKey, CredentialObject, CredentialRefreshError, CredentialType
|
||||
from ..models import CredentialKey, CredentialObject, CredentialType
|
||||
from .base_provider import BaseOAuth2Provider
|
||||
from .provider import OAuth2Token
|
||||
|
||||
@@ -30,8 +31,8 @@ class TokenRefreshResult:
|
||||
"""Result of a token refresh operation."""
|
||||
|
||||
success: bool
|
||||
token: Optional[OAuth2Token] = None
|
||||
error: Optional[str] = None
|
||||
token: OAuth2Token | None = None
|
||||
error: str | None = None
|
||||
needs_reauthorization: bool = False
|
||||
|
||||
|
||||
@@ -70,10 +71,10 @@ class TokenLifecycleManager:
|
||||
self,
|
||||
provider: BaseOAuth2Provider,
|
||||
credential_id: str,
|
||||
store: "CredentialStore",
|
||||
store: CredentialStore,
|
||||
refresh_buffer_minutes: int = 5,
|
||||
on_token_refreshed: Optional[Callable[[OAuth2Token], None]] = None,
|
||||
on_refresh_failed: Optional[Callable[[str], None]] = None,
|
||||
on_token_refreshed: Callable[[OAuth2Token], None] | None = None,
|
||||
on_refresh_failed: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
@@ -94,12 +95,12 @@ class TokenLifecycleManager:
|
||||
self.on_refresh_failed = on_refresh_failed
|
||||
|
||||
# In-memory cache for performance
|
||||
self._cached_token: Optional[OAuth2Token] = None
|
||||
self._cache_time: Optional[datetime] = None
|
||||
self._cached_token: OAuth2Token | None = None
|
||||
self._cache_time: datetime | None = None
|
||||
|
||||
# --- Async Token Access ---
|
||||
|
||||
async def get_valid_token(self) -> Optional[OAuth2Token]:
|
||||
async def get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Get a valid access token, refreshing if necessary.
|
||||
|
||||
@@ -137,12 +138,12 @@ class TokenLifecycleManager:
|
||||
logger.warning(f"Refresh failed for {self.credential_id}, using existing token")
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(timezone.utc)
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
async def acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: Optional[list[str]] = None,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""
|
||||
Acquire a new token using client credentials flow.
|
||||
@@ -157,7 +158,9 @@ class TokenLifecycleManager:
|
||||
"""
|
||||
# Run in executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
token = await loop.run_in_executor(None, lambda: self.provider.client_credentials_grant(scopes=scopes))
|
||||
token = await loop.run_in_executor(
|
||||
None, lambda: self.provider.client_credentials_grant(scopes=scopes)
|
||||
)
|
||||
|
||||
self._save_token_to_store(token)
|
||||
self._cached_token = token
|
||||
@@ -180,7 +183,7 @@ class TokenLifecycleManager:
|
||||
|
||||
# --- Synchronous Token Access ---
|
||||
|
||||
def sync_get_valid_token(self) -> Optional[OAuth2Token]:
|
||||
def sync_get_valid_token(self) -> OAuth2Token | None:
|
||||
"""
|
||||
Synchronous version of get_valid_token().
|
||||
|
||||
@@ -212,12 +215,12 @@ class TokenLifecycleManager:
|
||||
return None
|
||||
|
||||
self._cached_token = token
|
||||
self._cache_time = datetime.now(timezone.utc)
|
||||
self._cache_time = datetime.now(UTC)
|
||||
return token
|
||||
|
||||
def sync_acquire_token_client_credentials(
|
||||
self,
|
||||
scopes: Optional[list[str]] = None,
|
||||
scopes: list[str] | None = None,
|
||||
) -> OAuth2Token:
|
||||
"""Synchronous version of acquire_token_client_credentials()."""
|
||||
token = self.provider.client_credentials_grant(scopes=scopes)
|
||||
@@ -231,9 +234,9 @@ class TokenLifecycleManager:
|
||||
"""Check if token needs refresh."""
|
||||
if token.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= (token.expires_at - self.refresh_buffer)
|
||||
return datetime.now(UTC) >= (token.expires_at - self.refresh_buffer)
|
||||
|
||||
def _credential_to_token(self, credential: CredentialObject) -> Optional[OAuth2Token]:
|
||||
def _credential_to_token(self, credential: CredentialObject) -> OAuth2Token | None:
|
||||
"""Convert credential to OAuth2Token."""
|
||||
access_token = credential.get_key("access_token")
|
||||
if not access_token:
|
||||
|
||||
@@ -10,9 +10,9 @@ This module defines the core OAuth2 data structures:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TokenPlacement(str, Enum):
|
||||
@@ -47,10 +47,10 @@ class OAuth2Token:
|
||||
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_at: Optional[datetime] = None
|
||||
refresh_token: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict)
|
||||
expires_at: datetime | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
@@ -63,7 +63,7 @@ class OAuth2Token:
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
buffer = timedelta(minutes=5)
|
||||
return datetime.now(timezone.utc) >= (self.expires_at - buffer)
|
||||
return datetime.now(UTC) >= (self.expires_at - buffer)
|
||||
|
||||
@property
|
||||
def can_refresh(self) -> bool:
|
||||
@@ -71,15 +71,15 @@ class OAuth2Token:
|
||||
return self.refresh_token is not None and self.refresh_token.strip() != ""
|
||||
|
||||
@property
|
||||
def expires_in_seconds(self) -> Optional[int]:
|
||||
def expires_in_seconds(self) -> int | None:
|
||||
"""Get seconds until expiration, or None if no expiration."""
|
||||
if self.expires_at is None:
|
||||
return None
|
||||
delta = self.expires_at - datetime.now(timezone.utc)
|
||||
delta = self.expires_at - datetime.now(UTC)
|
||||
return max(0, int(delta.total_seconds()))
|
||||
|
||||
@classmethod
|
||||
def from_token_response(cls, data: Dict[str, Any]) -> "OAuth2Token":
|
||||
def from_token_response(cls, data: dict[str, Any]) -> OAuth2Token:
|
||||
"""
|
||||
Create OAuth2Token from an OAuth2 token endpoint response.
|
||||
|
||||
@@ -91,7 +91,7 @@ class OAuth2Token:
|
||||
"""
|
||||
expires_at = None
|
||||
if "expires_in" in data:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=data["expires_in"])
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=data["expires_in"])
|
||||
|
||||
return cls(
|
||||
access_token=data["access_token"],
|
||||
@@ -137,28 +137,28 @@ class OAuth2Config:
|
||||
|
||||
# Endpoints (only token_url is strictly required)
|
||||
token_url: str
|
||||
authorization_url: Optional[str] = None
|
||||
revocation_url: Optional[str] = None
|
||||
introspection_url: Optional[str] = None
|
||||
authorization_url: str | None = None
|
||||
revocation_url: str | None = None
|
||||
introspection_url: str | None = None
|
||||
|
||||
# Client credentials
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
|
||||
# Scopes
|
||||
default_scopes: List[str] = field(default_factory=list)
|
||||
default_scopes: list[str] = field(default_factory=list)
|
||||
|
||||
# Token placement for API calls (bipartisan model)
|
||||
token_placement: TokenPlacement = TokenPlacement.HEADER_BEARER
|
||||
custom_header_name: Optional[str] = None
|
||||
custom_header_name: str | None = None
|
||||
query_param_name: str = "access_token"
|
||||
|
||||
# Request configuration
|
||||
extra_token_params: Dict[str, str] = field(default_factory=dict)
|
||||
extra_token_params: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout: float = 30.0
|
||||
|
||||
# Additional headers for token requests
|
||||
extra_headers: Dict[str, str] = field(default_factory=dict)
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate configuration."""
|
||||
|
||||
@@ -13,8 +13,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from .models import CredentialObject, CredentialRefreshError, CredentialType
|
||||
|
||||
@@ -64,7 +63,7 @@ class CredentialProvider(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
"""
|
||||
Credential types this provider can manage.
|
||||
|
||||
@@ -127,7 +126,7 @@ class CredentialProvider(ABC):
|
||||
True if credential should be refreshed
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key in credential.keys.values():
|
||||
if key.expires_at is not None:
|
||||
@@ -181,7 +180,7 @@ class StaticProvider(CredentialProvider):
|
||||
return "static"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.API_KEY, CredentialType.BASIC_AUTH, CredentialType.CUSTOM]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
@@ -236,7 +235,7 @@ class BearerTokenProvider(CredentialProvider):
|
||||
return "bearer_token"
|
||||
|
||||
@property
|
||||
def supported_types(self) -> List[CredentialType]:
|
||||
def supported_types(self) -> list[CredentialType]:
|
||||
return [CredentialType.BEARER_TOKEN]
|
||||
|
||||
def refresh(self, credential: CredentialObject) -> CredentialObject:
|
||||
@@ -273,7 +272,7 @@ class BearerTokenProvider(CredentialProvider):
|
||||
credential needs attention.
|
||||
"""
|
||||
buffer = timedelta(minutes=5)
|
||||
now = datetime.now(timezone.utc)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
for key_name in ["access_token", "token"]:
|
||||
key = credential.keys.get(key_name)
|
||||
|
||||
@@ -14,9 +14,9 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -44,7 +44,7 @@ class CredentialStorage(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Load a credential from storage.
|
||||
|
||||
@@ -70,7 +70,7 @@ class CredentialStorage(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""
|
||||
List all credential IDs in storage.
|
||||
|
||||
@@ -119,7 +119,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
def __init__(
|
||||
self,
|
||||
base_path: str | Path,
|
||||
encryption_key: Optional[bytes] = None,
|
||||
encryption_key: bytes | None = None,
|
||||
key_env_var: str = "HIVE_CREDENTIAL_KEY",
|
||||
):
|
||||
"""
|
||||
@@ -187,7 +187,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
self._update_index(credential.id, "save", credential.credential_type.value)
|
||||
logger.debug(f"Saved encrypted credential '{credential.id}'")
|
||||
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load and decrypt credential."""
|
||||
cred_path = self._cred_path(credential_id)
|
||||
if not cred_path.exists():
|
||||
@@ -202,7 +202,9 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
json_bytes = self._fernet.decrypt(encrypted)
|
||||
data = json.loads(json_bytes.decode())
|
||||
except Exception as e:
|
||||
raise CredentialDecryptionError(f"Failed to decrypt credential '{credential_id}': {e}") from e
|
||||
raise CredentialDecryptionError(
|
||||
f"Failed to decrypt credential '{credential_id}': {e}"
|
||||
) from e
|
||||
|
||||
# Deserialize
|
||||
return self._deserialize_credential(data)
|
||||
@@ -217,7 +219,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
if not index_path.exists():
|
||||
@@ -230,7 +232,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
"""Check if credential exists."""
|
||||
return self._cred_path(credential_id).exists()
|
||||
|
||||
def _serialize_credential(self, credential: CredentialObject) -> Dict[str, Any]:
|
||||
def _serialize_credential(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to JSON-serializable dict, extracting secret values."""
|
||||
data = credential.model_dump(mode="json")
|
||||
|
||||
@@ -244,7 +246,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_credential(self, data: Dict[str, Any]) -> CredentialObject:
|
||||
def _deserialize_credential(self, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from dict, wrapping values in SecretStr."""
|
||||
# Convert plain values back to SecretStr
|
||||
for key_data in data.get("keys", {}).values():
|
||||
@@ -257,7 +259,7 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
self,
|
||||
credential_id: str,
|
||||
operation: str,
|
||||
credential_type: Optional[str] = None,
|
||||
credential_type: str | None = None,
|
||||
) -> None:
|
||||
"""Update the metadata index."""
|
||||
index_path = self.base_path / "metadata" / "index.json"
|
||||
@@ -270,13 +272,13 @@ class EncryptedFileStorage(CredentialStorage):
|
||||
|
||||
if operation == "save":
|
||||
index["credentials"][credential_id] = {
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(UTC).isoformat(),
|
||||
"type": credential_type,
|
||||
}
|
||||
elif operation == "delete":
|
||||
index["credentials"].pop(credential_id, None)
|
||||
|
||||
index["last_modified"] = datetime.now(timezone.utc).isoformat()
|
||||
index["last_modified"] = datetime.now(UTC).isoformat()
|
||||
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
@@ -301,8 +303,8 @@ class EnvVarStorage(CredentialStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_mapping: Optional[Dict[str, str]] = None,
|
||||
dotenv_path: Optional[Path] = None,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
dotenv_path: Path | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize env var storage.
|
||||
@@ -323,7 +325,7 @@ class EnvVarStorage(CredentialStorage):
|
||||
# Default pattern: CREDENTIAL_ID_API_KEY
|
||||
return f"{credential_id.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
def _read_env_value(self, env_var: str) -> Optional[str]:
|
||||
def _read_env_value(self, env_var: str) -> str | None:
|
||||
"""Read value from env var or .env file."""
|
||||
# Check os.environ first (takes precedence)
|
||||
value = os.environ.get(env_var)
|
||||
@@ -346,10 +348,11 @@ class EnvVarStorage(CredentialStorage):
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Cannot save to environment variables at runtime."""
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Set environment variables externally or use EncryptedFileStorage."
|
||||
"EnvVarStorage is read-only. Set environment variables "
|
||||
"externally or use EncryptedFileStorage."
|
||||
)
|
||||
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from environment variable."""
|
||||
env_var = self._get_env_var_name(credential_id)
|
||||
value = self._read_env_value(env_var)
|
||||
@@ -366,9 +369,11 @@ class EnvVarStorage(CredentialStorage):
|
||||
|
||||
def delete(self, credential_id: str) -> bool:
|
||||
"""Cannot delete environment variables at runtime."""
|
||||
raise NotImplementedError("EnvVarStorage is read-only. Unset environment variables externally.")
|
||||
raise NotImplementedError(
|
||||
"EnvVarStorage is read-only. Unset environment variables externally."
|
||||
)
|
||||
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials that are available in environment."""
|
||||
available = []
|
||||
|
||||
@@ -407,20 +412,20 @@ class InMemoryStorage(CredentialStorage):
|
||||
credential = storage.load("test_cred")
|
||||
"""
|
||||
|
||||
def __init__(self, initial_data: Optional[Dict[str, CredentialObject]] = None):
|
||||
def __init__(self, initial_data: dict[str, CredentialObject] | None = None):
|
||||
"""
|
||||
Initialize in-memory storage.
|
||||
|
||||
Args:
|
||||
initial_data: Optional dict of credential_id -> CredentialObject
|
||||
"""
|
||||
self._data: Dict[str, CredentialObject] = initial_data or {}
|
||||
self._data: dict[str, CredentialObject] = initial_data or {}
|
||||
|
||||
def save(self, credential: CredentialObject) -> None:
|
||||
"""Save credential to memory."""
|
||||
self._data[credential.id] = credential
|
||||
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from memory."""
|
||||
return self._data.get(credential_id)
|
||||
|
||||
@@ -431,7 +436,7 @@ class InMemoryStorage(CredentialStorage):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credential IDs."""
|
||||
return list(self._data.keys())
|
||||
|
||||
@@ -462,7 +467,7 @@ class CompositeStorage(CredentialStorage):
|
||||
def __init__(
|
||||
self,
|
||||
primary: CredentialStorage,
|
||||
fallbacks: Optional[List[CredentialStorage]] = None,
|
||||
fallbacks: list[CredentialStorage] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize composite storage.
|
||||
@@ -478,7 +483,7 @@ class CompositeStorage(CredentialStorage):
|
||||
"""Save to primary storage."""
|
||||
self._primary.save(credential)
|
||||
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load from primary, then fallbacks."""
|
||||
# Try primary first
|
||||
credential = self._primary.load(credential_id)
|
||||
@@ -497,7 +502,7 @@ class CompositeStorage(CredentialStorage):
|
||||
"""Delete from primary storage only."""
|
||||
return self._primary.delete(credential_id)
|
||||
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""List credentials from all storages."""
|
||||
all_ids = set(self._primary.list_all())
|
||||
for fallback in self._fallbacks:
|
||||
|
||||
@@ -13,17 +13,15 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .models import (
|
||||
CredentialError,
|
||||
CredentialKey,
|
||||
CredentialObject,
|
||||
CredentialRefreshError,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
)
|
||||
from .provider import CredentialProvider, StaticProvider
|
||||
@@ -69,8 +67,8 @@ class CredentialStore:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[CredentialStorage] = None,
|
||||
providers: Optional[List[CredentialProvider]] = None,
|
||||
storage: CredentialStorage | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
auto_refresh: bool = True,
|
||||
):
|
||||
@@ -84,11 +82,11 @@ class CredentialStore:
|
||||
auto_refresh: Whether to auto-refresh expired credentials on access.
|
||||
"""
|
||||
self._storage = storage or EnvVarStorage()
|
||||
self._providers: Dict[str, CredentialProvider] = {}
|
||||
self._usage_specs: Dict[str, CredentialUsageSpec] = {}
|
||||
self._providers: dict[str, CredentialProvider] = {}
|
||||
self._usage_specs: dict[str, CredentialUsageSpec] = {}
|
||||
|
||||
# Cache: credential_id -> (CredentialObject, cached_at)
|
||||
self._cache: Dict[str, tuple[CredentialObject, datetime]] = {}
|
||||
self._cache: dict[str, tuple[CredentialObject, datetime]] = {}
|
||||
self._cache_ttl = cache_ttl_seconds
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@@ -113,7 +111,7 @@ class CredentialStore:
|
||||
self._providers[provider.provider_id] = provider
|
||||
logger.debug(f"Registered credential provider: {provider.provider_id}")
|
||||
|
||||
def get_provider(self, provider_id: str) -> Optional[CredentialProvider]:
|
||||
def get_provider(self, provider_id: str) -> CredentialProvider | None:
|
||||
"""
|
||||
Get a provider by ID.
|
||||
|
||||
@@ -125,7 +123,9 @@ class CredentialStore:
|
||||
"""
|
||||
return self._providers.get(provider_id)
|
||||
|
||||
def get_provider_for_credential(self, credential: CredentialObject) -> Optional[CredentialProvider]:
|
||||
def get_provider_for_credential(
|
||||
self, credential: CredentialObject
|
||||
) -> CredentialProvider | None:
|
||||
"""
|
||||
Get the appropriate provider for a credential.
|
||||
|
||||
@@ -159,7 +159,7 @@ class CredentialStore:
|
||||
"""
|
||||
self._usage_specs[spec.credential_id] = spec
|
||||
|
||||
def get_usage_spec(self, credential_id: str) -> Optional[CredentialUsageSpec]:
|
||||
def get_usage_spec(self, credential_id: str) -> CredentialUsageSpec | None:
|
||||
"""
|
||||
Get the usage spec for a credential.
|
||||
|
||||
@@ -177,7 +177,7 @@ class CredentialStore:
|
||||
self,
|
||||
credential_id: str,
|
||||
refresh_if_needed: bool = True,
|
||||
) -> Optional[CredentialObject]:
|
||||
) -> CredentialObject | None:
|
||||
"""
|
||||
Get a credential by ID.
|
||||
|
||||
@@ -210,7 +210,7 @@ class CredentialStore:
|
||||
|
||||
return credential
|
||||
|
||||
def get_key(self, credential_id: str, key_name: str) -> Optional[str]:
|
||||
def get_key(self, credential_id: str, key_name: str) -> str | None:
|
||||
"""
|
||||
Convenience method to get a specific key value.
|
||||
|
||||
@@ -226,7 +226,7 @@ class CredentialStore:
|
||||
return None
|
||||
return credential.get_key(key_name)
|
||||
|
||||
def get(self, credential_id: str) -> Optional[str]:
|
||||
def get(self, credential_id: str) -> str | None:
|
||||
"""
|
||||
Legacy compatibility: get the primary key value.
|
||||
|
||||
@@ -262,7 +262,7 @@ class CredentialStore:
|
||||
"""
|
||||
return self._resolver.resolve(template)
|
||||
|
||||
def resolve_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in headers dictionary.
|
||||
|
||||
@@ -280,7 +280,7 @@ class CredentialStore:
|
||||
"""
|
||||
return self._resolver.resolve_headers(headers)
|
||||
|
||||
def resolve_params(self, params: Dict[str, str]) -> Dict[str, str]:
|
||||
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in query parameters dictionary.
|
||||
|
||||
@@ -292,7 +292,7 @@ class CredentialStore:
|
||||
"""
|
||||
return self._resolver.resolve_params(params)
|
||||
|
||||
def resolve_for_usage(self, credential_id: str) -> Dict[str, Any]:
|
||||
def resolve_for_usage(self, credential_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get resolved request kwargs for a registered usage spec.
|
||||
|
||||
@@ -309,7 +309,7 @@ class CredentialStore:
|
||||
if spec is None:
|
||||
raise ValueError(f"No usage spec registered for '{credential_id}'")
|
||||
|
||||
result: Dict[str, Any] = {}
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
if spec.headers:
|
||||
result["headers"] = self.resolve_headers(spec.headers)
|
||||
@@ -353,7 +353,7 @@ class CredentialStore:
|
||||
logger.info(f"Deleted credential '{credential_id}'")
|
||||
return result
|
||||
|
||||
def list_credentials(self) -> List[str]:
|
||||
def list_credentials(self) -> list[str]:
|
||||
"""
|
||||
List all available credential IDs.
|
||||
|
||||
@@ -376,7 +376,7 @@ class CredentialStore:
|
||||
|
||||
# --- Validation ---
|
||||
|
||||
def validate_for_usage(self, credential_id: str) -> List[str]:
|
||||
def validate_for_usage(self, credential_id: str) -> list[str]:
|
||||
"""
|
||||
Validate that a credential meets its usage spec requirements.
|
||||
|
||||
@@ -401,7 +401,7 @@ class CredentialStore:
|
||||
|
||||
return errors
|
||||
|
||||
def validate_all(self) -> Dict[str, List[str]]:
|
||||
def validate_all(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Validate all registered usage specs.
|
||||
|
||||
@@ -462,7 +462,7 @@ class CredentialStore:
|
||||
|
||||
try:
|
||||
refreshed = provider.refresh(credential)
|
||||
refreshed.last_refreshed = datetime.now(timezone.utc)
|
||||
refreshed.last_refreshed = datetime.now(UTC)
|
||||
|
||||
# Persist the refreshed credential
|
||||
self._storage.save(refreshed)
|
||||
@@ -475,7 +475,7 @@ class CredentialStore:
|
||||
logger.error(f"Failed to refresh credential '{credential.id}': {e}")
|
||||
return credential
|
||||
|
||||
def refresh_credential(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def refresh_credential(self, credential_id: str) -> CredentialObject | None:
|
||||
"""
|
||||
Manually refresh a credential.
|
||||
|
||||
@@ -496,13 +496,13 @@ class CredentialStore:
|
||||
|
||||
# --- Caching ---
|
||||
|
||||
def _get_from_cache(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def _get_from_cache(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Get credential from cache if not expired."""
|
||||
if credential_id not in self._cache:
|
||||
return None
|
||||
|
||||
credential, cached_at = self._cache[credential_id]
|
||||
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
|
||||
age = (datetime.now(UTC) - cached_at).total_seconds()
|
||||
|
||||
if age > self._cache_ttl:
|
||||
del self._cache[credential_id]
|
||||
@@ -512,7 +512,7 @@ class CredentialStore:
|
||||
|
||||
def _add_to_cache(self, credential: CredentialObject) -> None:
|
||||
"""Add credential to cache."""
|
||||
self._cache[credential.id] = (credential, datetime.now(timezone.utc))
|
||||
self._cache[credential.id] = (credential, datetime.now(UTC))
|
||||
|
||||
def _remove_from_cache(self, credential_id: str) -> None:
|
||||
"""Remove credential from cache."""
|
||||
@@ -528,8 +528,8 @@ class CredentialStore:
|
||||
@classmethod
|
||||
def for_testing(
|
||||
cls,
|
||||
credentials: Dict[str, Dict[str, str]],
|
||||
) -> "CredentialStore":
|
||||
credentials: dict[str, dict[str, str]],
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store for testing with mock credentials.
|
||||
|
||||
@@ -550,7 +550,7 @@ class CredentialStore:
|
||||
})
|
||||
"""
|
||||
# Convert test data to CredentialObjects
|
||||
cred_objects: Dict[str, CredentialObject] = {}
|
||||
cred_objects: dict[str, CredentialObject] = {}
|
||||
|
||||
for cred_id, keys in credentials.items():
|
||||
cred_objects[cred_id] = CredentialObject(
|
||||
@@ -567,9 +567,9 @@ class CredentialStore:
|
||||
def with_encrypted_storage(
|
||||
cls,
|
||||
base_path: str,
|
||||
providers: Optional[List[CredentialProvider]] = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "CredentialStore":
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with encrypted file storage.
|
||||
|
||||
@@ -592,10 +592,10 @@ class CredentialStore:
|
||||
@classmethod
|
||||
def with_env_storage(
|
||||
cls,
|
||||
env_mapping: Optional[Dict[str, str]] = None,
|
||||
providers: Optional[List[CredentialProvider]] = None,
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
providers: list[CredentialProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "CredentialStore":
|
||||
) -> CredentialStore:
|
||||
"""
|
||||
Create a credential store with environment variable storage.
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ Examples:
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import CredentialKeyNotFoundError, CredentialNotFoundError
|
||||
|
||||
@@ -45,7 +45,7 @@ class TemplateResolver:
|
||||
# Matches {{credential_id}} or {{credential_id.key_name}}
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{([a-zA-Z0-9_-]+)(?:\.([a-zA-Z0-9_-]+))?\}\}")
|
||||
|
||||
def __init__(self, credential_store: "CredentialStore"):
|
||||
def __init__(self, credential_store: CredentialStore):
|
||||
"""
|
||||
Initialize the template resolver.
|
||||
|
||||
@@ -88,7 +88,9 @@ class TemplateResolver:
|
||||
if key_name:
|
||||
value = credential.get_key(key_name)
|
||||
if value is None:
|
||||
raise CredentialKeyNotFoundError(f"Key '{key_name}' not found in credential '{cred_id}'")
|
||||
raise CredentialKeyNotFoundError(
|
||||
f"Key '{key_name}' not found in credential '{cred_id}'"
|
||||
)
|
||||
else:
|
||||
# Use default key
|
||||
value = credential.get_default_key()
|
||||
@@ -104,9 +106,9 @@ class TemplateResolver:
|
||||
|
||||
def resolve_headers(
|
||||
self,
|
||||
header_templates: Dict[str, str],
|
||||
header_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a headers dictionary.
|
||||
|
||||
@@ -124,13 +126,15 @@ class TemplateResolver:
|
||||
... })
|
||||
{"Authorization": "Bearer ghp_xxx", "X-API-Key": "BSAKxxx"}
|
||||
"""
|
||||
return {key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()}
|
||||
return {
|
||||
key: self.resolve(value, fail_on_missing) for key, value in header_templates.items()
|
||||
}
|
||||
|
||||
def resolve_params(
|
||||
self,
|
||||
param_templates: Dict[str, str],
|
||||
param_templates: dict[str, str],
|
||||
fail_on_missing: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Resolve templates in a query parameters dictionary.
|
||||
|
||||
@@ -155,7 +159,7 @@ class TemplateResolver:
|
||||
"""
|
||||
return bool(self.TEMPLATE_PATTERN.search(text))
|
||||
|
||||
def extract_references(self, text: str) -> List[Tuple[str, Optional[str]]]:
|
||||
def extract_references(self, text: str) -> list[tuple[str, str | None]]:
|
||||
"""
|
||||
Extract all credential references from text.
|
||||
|
||||
@@ -172,7 +176,7 @@ class TemplateResolver:
|
||||
"""
|
||||
return [(match.group(1), match.group(2)) for match in self.TEMPLATE_PATTERN.finditer(text)]
|
||||
|
||||
def validate_references(self, text: str) -> List[str]:
|
||||
def validate_references(self, text: str) -> list[str]:
|
||||
"""
|
||||
Validate all credential references in text without resolving.
|
||||
|
||||
@@ -201,7 +205,7 @@ class TemplateResolver:
|
||||
|
||||
return errors
|
||||
|
||||
def get_required_credentials(self, text: str) -> List[str]:
|
||||
def get_required_credentials(self, text: str) -> list[str]:
|
||||
"""
|
||||
Get list of credential IDs required by a template string.
|
||||
|
||||
|
||||
@@ -12,24 +12,17 @@ Tests cover:
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from core.framework.credentials import (
|
||||
BearerTokenProvider,
|
||||
CompositeStorage,
|
||||
CredentialError,
|
||||
CredentialKey,
|
||||
CredentialKeyNotFoundError,
|
||||
CredentialNotFoundError,
|
||||
CredentialObject,
|
||||
CredentialProvider,
|
||||
CredentialRefreshError,
|
||||
CredentialStorage,
|
||||
CredentialStore,
|
||||
CredentialType,
|
||||
CredentialUsageSpec,
|
||||
@@ -39,6 +32,7 @@ from core.framework.credentials import (
|
||||
StaticProvider,
|
||||
TemplateResolver,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
|
||||
class TestCredentialKey:
|
||||
@@ -54,13 +48,13 @@ class TestCredentialKey:
|
||||
|
||||
def test_key_with_expiration(self):
|
||||
"""Test key with expiration time."""
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=future)
|
||||
assert not key.is_expired
|
||||
|
||||
def test_expired_key(self):
|
||||
"""Test that expired key is detected."""
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
key = CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)
|
||||
assert key.is_expired
|
||||
|
||||
@@ -111,13 +105,13 @@ class TestCredentialObject:
|
||||
def test_set_key_with_expiration(self):
|
||||
"""Test setting a key with expiration."""
|
||||
cred = CredentialObject(id="test", keys={})
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
expires = datetime.now(UTC) + timedelta(hours=1)
|
||||
cred.set_key("token", "xxx", expires_at=expires)
|
||||
assert cred.keys["token"].expires_at == expires
|
||||
|
||||
def test_needs_refresh(self):
|
||||
"""Test needs_refresh property."""
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
cred = CredentialObject(
|
||||
id="test",
|
||||
keys={"token": CredentialKey(name="token", value=SecretStr("xxx"), expires_at=past)},
|
||||
@@ -136,7 +130,9 @@ class TestCredentialObject:
|
||||
# With access_token
|
||||
cred2 = CredentialObject(
|
||||
id="test",
|
||||
keys={"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))},
|
||||
keys={
|
||||
"access_token": CredentialKey(name="access_token", value=SecretStr("token-value"))
|
||||
},
|
||||
)
|
||||
assert cred2.get_default_key() == "token-value"
|
||||
|
||||
@@ -301,7 +297,9 @@ class TestEncryptedFileStorage:
|
||||
key = Fernet.generate_key().decode()
|
||||
with patch.dict(os.environ, {"HIVE_CREDENTIAL_KEY": key}):
|
||||
storage = EncryptedFileStorage(temp_dir)
|
||||
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
storage.save(cred)
|
||||
|
||||
# Create new storage instance with same key
|
||||
@@ -332,10 +330,18 @@ class TestCompositeStorage:
|
||||
def test_read_from_primary(self):
|
||||
"""Test reading from primary storage."""
|
||||
primary = InMemoryStorage()
|
||||
primary.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}))
|
||||
primary.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("primary"))}
|
||||
)
|
||||
)
|
||||
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}))
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
@@ -347,7 +353,11 @@ class TestCompositeStorage:
|
||||
"""Test fallback when credential not in primary."""
|
||||
primary = InMemoryStorage()
|
||||
fallback = InMemoryStorage()
|
||||
fallback.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}))
|
||||
fallback.save(
|
||||
CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("fallback"))}
|
||||
)
|
||||
)
|
||||
|
||||
storage = CompositeStorage(primary, [fallback])
|
||||
cred = storage.load("test")
|
||||
@@ -383,7 +393,9 @@ class TestStaticProvider:
|
||||
def test_refresh_returns_unchanged(self):
|
||||
"""Test that refresh returns credential unchanged."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
refreshed = provider.refresh(cred)
|
||||
assert refreshed.get_key("k") == "v"
|
||||
@@ -391,7 +403,9 @@ class TestStaticProvider:
|
||||
def test_validate_with_keys(self):
|
||||
"""Test validation with keys present."""
|
||||
provider = StaticProvider()
|
||||
cred = CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
cred = CredentialObject(
|
||||
id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}
|
||||
)
|
||||
|
||||
assert provider.validate(cred)
|
||||
|
||||
@@ -592,10 +606,12 @@ class TestCredentialStore:
|
||||
storage = InMemoryStorage()
|
||||
store = CredentialStore(storage=storage, cache_ttl_seconds=60)
|
||||
|
||||
storage.save(CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))}))
|
||||
storage.save(
|
||||
CredentialObject(id="test", keys={"k": CredentialKey(name="k", value=SecretStr("v"))})
|
||||
)
|
||||
|
||||
# First load
|
||||
cred1 = store.get_credential("test")
|
||||
store.get_credential("test")
|
||||
|
||||
# Delete from storage
|
||||
storage.delete("test")
|
||||
@@ -646,12 +662,12 @@ class TestOAuth2Module:
|
||||
from core.framework.credentials.oauth2 import OAuth2Token
|
||||
|
||||
# Not expired
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
future = datetime.now(UTC) + timedelta(hours=1)
|
||||
token = OAuth2Token(access_token="xxx", expires_at=future)
|
||||
assert not token.is_expired
|
||||
|
||||
# Expired
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past = datetime.now(UTC) - timedelta(hours=1)
|
||||
expired_token = OAuth2Token(access_token="xxx", expires_at=past)
|
||||
assert expired_token.is_expired
|
||||
|
||||
@@ -670,7 +686,9 @@ class TestOAuth2Module:
|
||||
from core.framework.credentials.oauth2 import OAuth2Config, TokenPlacement
|
||||
|
||||
# Valid config
|
||||
config = OAuth2Config(token_url="https://example.com/token", client_id="id", client_secret="secret")
|
||||
config = OAuth2Config(
|
||||
token_url="https://example.com/token", client_id="id", client_secret="secret"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
# Missing token_url
|
||||
|
||||
@@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -72,10 +72,10 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
token: Optional[str] = None,
|
||||
token: str | None = None,
|
||||
mount_point: str = "secret",
|
||||
path_prefix: str = "hive/credentials",
|
||||
namespace: Optional[str] = None,
|
||||
namespace: str | None = None,
|
||||
verify_ssl: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -107,7 +107,9 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
self._namespace = namespace
|
||||
|
||||
if not self._token:
|
||||
raise ValueError("Vault token required. Set VAULT_TOKEN env var or pass token parameter.")
|
||||
raise ValueError(
|
||||
"Vault token required. Set VAULT_TOKEN env var or pass token parameter."
|
||||
)
|
||||
|
||||
self._client = hvac.Client(
|
||||
url=url,
|
||||
@@ -143,7 +145,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
logger.error(f"Failed to save credential '{credential.id}' to Vault: {e}")
|
||||
raise
|
||||
|
||||
def load(self, credential_id: str) -> Optional[CredentialObject]:
|
||||
def load(self, credential_id: str) -> CredentialObject | None:
|
||||
"""Load credential from Vault."""
|
||||
path = self._path(credential_id)
|
||||
|
||||
@@ -181,7 +183,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
logger.error(f"Failed to delete credential '{credential_id}' from Vault: {e}")
|
||||
raise
|
||||
|
||||
def list_all(self) -> List[str]:
|
||||
def list_all(self) -> list[str]:
|
||||
"""List all credentials under the prefix."""
|
||||
try:
|
||||
response = self._client.secrets.kv.v2.list_secrets(
|
||||
@@ -210,9 +212,9 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _serialize_for_vault(self, credential: CredentialObject) -> Dict[str, Any]:
|
||||
def _serialize_for_vault(self, credential: CredentialObject) -> dict[str, Any]:
|
||||
"""Convert credential to Vault secret format."""
|
||||
data: Dict[str, Any] = {
|
||||
data: dict[str, Any] = {
|
||||
"_type": credential.credential_type.value,
|
||||
}
|
||||
|
||||
@@ -237,7 +239,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
|
||||
return data
|
||||
|
||||
def _deserialize_from_vault(self, credential_id: str, data: Dict[str, Any]) -> CredentialObject:
|
||||
def _deserialize_from_vault(self, credential_id: str, data: dict[str, Any]) -> CredentialObject:
|
||||
"""Reconstruct credential from Vault secret."""
|
||||
# Extract metadata fields
|
||||
cred_type = CredentialType(data.pop("_type", "api_key"))
|
||||
@@ -246,7 +248,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
auto_refresh = data.pop("_auto_refresh", "") == "true"
|
||||
|
||||
# Build keys dict
|
||||
keys: Dict[str, CredentialKey] = {}
|
||||
keys: dict[str, CredentialKey] = {}
|
||||
|
||||
# Find all non-metadata keys
|
||||
key_names = [k for k in data.keys() if not k.startswith("_")]
|
||||
@@ -264,7 +266,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
pass
|
||||
|
||||
# Check for metadata
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata_key = f"_metadata_{key_name}"
|
||||
if metadata_key in data:
|
||||
try:
|
||||
@@ -292,7 +294,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
|
||||
# --- Vault-Specific Operations ---
|
||||
|
||||
def get_secret_metadata(self, credential_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_secret_metadata(self, credential_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get Vault metadata for a secret (version info, timestamps, etc.).
|
||||
|
||||
@@ -313,7 +315,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def soft_delete(self, credential_id: str, versions: Optional[List[int]] = None) -> bool:
|
||||
def soft_delete(self, credential_id: str, versions: list[int] | None = None) -> bool:
|
||||
"""
|
||||
Soft delete specific versions (can be recovered).
|
||||
|
||||
@@ -343,7 +345,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
logger.error(f"Soft delete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def undelete(self, credential_id: str, versions: List[int]) -> bool:
|
||||
def undelete(self, credential_id: str, versions: list[int]) -> bool:
|
||||
"""
|
||||
Recover soft-deleted versions.
|
||||
|
||||
@@ -367,7 +369,7 @@ class HashiCorpVaultStorage(CredentialStorage):
|
||||
logger.error(f"Undelete failed for '{credential_id}': {e}")
|
||||
return False
|
||||
|
||||
def load_version(self, credential_id: str, version: int) -> Optional[CredentialObject]:
|
||||
def load_version(self, credential_id: str, version: int) -> CredentialObject | None:
|
||||
"""
|
||||
Load a specific version of a credential.
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import CredentialError, CredentialSpec
|
||||
|
||||
@@ -55,8 +55,8 @@ class CredentialStoreAdapter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: "CredentialStore",
|
||||
specs: Optional[Dict[str, CredentialSpec]] = None,
|
||||
store: CredentialStore,
|
||||
specs: dict[str, CredentialSpec] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the adapter.
|
||||
@@ -74,8 +74,8 @@ class CredentialStoreAdapter:
|
||||
self._specs = specs
|
||||
|
||||
# Build reverse mappings for validation
|
||||
self._tool_to_cred: Dict[str, str] = {}
|
||||
self._node_type_to_cred: Dict[str, str] = {}
|
||||
self._tool_to_cred: dict[str, str] = {}
|
||||
self._node_type_to_cred: dict[str, str] = {}
|
||||
|
||||
for cred_name, spec in self._specs.items():
|
||||
for tool_name in spec.tools:
|
||||
@@ -85,7 +85,7 @@ class CredentialStoreAdapter:
|
||||
|
||||
# --- Existing CredentialManager API ---
|
||||
|
||||
def get(self, name: str) -> Optional[str]:
|
||||
def get(self, name: str) -> str | None:
|
||||
"""
|
||||
Get a credential value by logical name.
|
||||
|
||||
@@ -117,7 +117,7 @@ class CredentialStoreAdapter:
|
||||
value = self._store.get(name)
|
||||
return value is not None and value != ""
|
||||
|
||||
def get_credential_for_tool(self, tool_name: str) -> Optional[str]:
|
||||
def get_credential_for_tool(self, tool_name: str) -> str | None:
|
||||
"""
|
||||
Get the credential name required by a tool.
|
||||
|
||||
@@ -129,7 +129,7 @@ class CredentialStoreAdapter:
|
||||
"""
|
||||
return self._tool_to_cred.get(tool_name)
|
||||
|
||||
def get_missing_for_tools(self, tool_names: List[str]) -> List[Tuple[str, CredentialSpec]]:
|
||||
def get_missing_for_tools(self, tool_names: list[str]) -> list[tuple[str, CredentialSpec]]:
|
||||
"""
|
||||
Get list of missing credentials for the given tools.
|
||||
|
||||
@@ -139,7 +139,7 @@ class CredentialStoreAdapter:
|
||||
Returns:
|
||||
List of (credential_name, spec) tuples for missing credentials
|
||||
"""
|
||||
missing: List[Tuple[str, CredentialSpec]] = []
|
||||
missing: list[tuple[str, CredentialSpec]] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
for tool_name in tool_names:
|
||||
@@ -156,7 +156,7 @@ class CredentialStoreAdapter:
|
||||
|
||||
return missing
|
||||
|
||||
def validate_for_tools(self, tool_names: List[str]) -> None:
|
||||
def validate_for_tools(self, tool_names: list[str]) -> None:
|
||||
"""
|
||||
Validate that all credentials required by the given tools are available.
|
||||
|
||||
@@ -170,9 +170,9 @@ class CredentialStoreAdapter:
|
||||
if missing:
|
||||
raise CredentialError(self._format_missing_error(missing, tool_names))
|
||||
|
||||
def get_missing_for_node_types(self, node_types: List[str]) -> List[Tuple[str, CredentialSpec]]:
|
||||
def get_missing_for_node_types(self, node_types: list[str]) -> list[tuple[str, CredentialSpec]]:
|
||||
"""Get list of missing credentials for the given node types."""
|
||||
missing: List[Tuple[str, CredentialSpec]] = []
|
||||
missing: list[tuple[str, CredentialSpec]] = []
|
||||
checked: set[str] = set()
|
||||
|
||||
for node_type in node_types:
|
||||
@@ -189,7 +189,7 @@ class CredentialStoreAdapter:
|
||||
|
||||
return missing
|
||||
|
||||
def validate_for_node_types(self, node_types: List[str]) -> None:
|
||||
def validate_for_node_types(self, node_types: list[str]) -> None:
|
||||
"""
|
||||
Validate that all credentials required by the given node types are available.
|
||||
|
||||
@@ -210,7 +210,7 @@ class CredentialStoreAdapter:
|
||||
Raises:
|
||||
CredentialError: If any startup-required credentials are missing
|
||||
"""
|
||||
missing: List[Tuple[str, CredentialSpec]] = []
|
||||
missing: list[tuple[str, CredentialSpec]] = []
|
||||
|
||||
for cred_name, spec in self._specs.items():
|
||||
if spec.startup_required and not self.is_available(cred_name):
|
||||
@@ -221,7 +221,7 @@ class CredentialStoreAdapter:
|
||||
|
||||
# --- New CredentialStore Features ---
|
||||
|
||||
def get_key(self, credential_id: str, key_name: str) -> Optional[str]:
|
||||
def get_key(self, credential_id: str, key_name: str) -> str | None:
|
||||
"""
|
||||
Get a specific key from a multi-key credential.
|
||||
|
||||
@@ -250,7 +250,7 @@ class CredentialStoreAdapter:
|
||||
"""
|
||||
return self._store.resolve(template)
|
||||
|
||||
def resolve_headers(self, headers: Dict[str, str]) -> Dict[str, str]:
|
||||
def resolve_headers(self, headers: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Resolve credential templates in headers dictionary.
|
||||
|
||||
@@ -268,12 +268,12 @@ class CredentialStoreAdapter:
|
||||
"""
|
||||
return self._store.resolve_headers(headers)
|
||||
|
||||
def resolve_params(self, params: Dict[str, str]) -> Dict[str, str]:
|
||||
def resolve_params(self, params: dict[str, str]) -> dict[str, str]:
|
||||
"""Resolve credential templates in query parameters."""
|
||||
return self._store.resolve_params(params)
|
||||
|
||||
@property
|
||||
def store(self) -> "CredentialStore":
|
||||
def store(self) -> CredentialStore:
|
||||
"""Access the underlying credential store for advanced operations."""
|
||||
return self._store
|
||||
|
||||
@@ -281,14 +281,14 @@ class CredentialStoreAdapter:
|
||||
|
||||
def _format_missing_error(
|
||||
self,
|
||||
missing: List[Tuple[str, CredentialSpec]],
|
||||
tool_names: List[str],
|
||||
missing: list[tuple[str, CredentialSpec]],
|
||||
tool_names: list[str],
|
||||
) -> str:
|
||||
"""Format a clear, actionable error message for missing credentials."""
|
||||
lines = ["Cannot run agent: Missing credentials\n"]
|
||||
lines.append("The following tools require credentials that are not set:\n")
|
||||
|
||||
for cred_name, spec in missing:
|
||||
for _cred_name, spec in missing:
|
||||
affected_tools = [t for t in tool_names if t in spec.tools]
|
||||
tools_str = ", ".join(affected_tools)
|
||||
|
||||
@@ -305,14 +305,14 @@ class CredentialStoreAdapter:
|
||||
|
||||
def _format_missing_node_type_error(
|
||||
self,
|
||||
missing: List[Tuple[str, CredentialSpec]],
|
||||
node_types: List[str],
|
||||
missing: list[tuple[str, CredentialSpec]],
|
||||
node_types: list[str],
|
||||
) -> str:
|
||||
"""Format a clear, actionable error message for missing node type credentials."""
|
||||
lines = ["Cannot run agent: Missing credentials\n"]
|
||||
lines.append("The following node types require credentials that are not set:\n")
|
||||
|
||||
for cred_name, spec in missing:
|
||||
for _cred_name, spec in missing:
|
||||
affected_types = [t for t in node_types if t in spec.node_types]
|
||||
types_str = ", ".join(affected_types)
|
||||
|
||||
@@ -329,12 +329,12 @@ class CredentialStoreAdapter:
|
||||
|
||||
def _format_startup_error(
|
||||
self,
|
||||
missing: List[Tuple[str, CredentialSpec]],
|
||||
missing: list[tuple[str, CredentialSpec]],
|
||||
) -> str:
|
||||
"""Format a clear, actionable error message for missing startup credentials."""
|
||||
lines = ["Server startup failed: Missing required credentials\n"]
|
||||
|
||||
for cred_name, spec in missing:
|
||||
for _cred_name, spec in missing:
|
||||
lines.append(f" {spec.env_var}")
|
||||
if spec.description:
|
||||
lines.append(f" {spec.description}")
|
||||
@@ -351,9 +351,9 @@ class CredentialStoreAdapter:
|
||||
@classmethod
|
||||
def for_testing(
|
||||
cls,
|
||||
overrides: Dict[str, str],
|
||||
specs: Optional[Dict[str, CredentialSpec]] = None,
|
||||
) -> "CredentialStoreAdapter":
|
||||
overrides: dict[str, str],
|
||||
specs: dict[str, CredentialSpec] | None = None,
|
||||
) -> CredentialStoreAdapter:
|
||||
"""
|
||||
Create a CredentialStoreAdapter for testing with mock credentials.
|
||||
|
||||
@@ -380,9 +380,9 @@ class CredentialStoreAdapter:
|
||||
@classmethod
|
||||
def with_env_storage(
|
||||
cls,
|
||||
env_mapping: Optional[Dict[str, str]] = None,
|
||||
specs: Optional[Dict[str, CredentialSpec]] = None,
|
||||
) -> "CredentialStoreAdapter":
|
||||
env_mapping: dict[str, str] | None = None,
|
||||
specs: dict[str, CredentialSpec] | None = None,
|
||||
) -> CredentialStoreAdapter:
|
||||
"""
|
||||
Create adapter with environment variable storage (current behavior).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user