Merge branch 'main' into fix/identity-persistence
This commit is contained in:
@@ -788,31 +788,38 @@ class AgentRunner:
|
||||
extra_headers={"authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
else:
|
||||
# Fall back to environment variable
|
||||
# First check api_key_env_var from config (set by quickstart)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
||||
self.model
|
||||
)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
# Local models (e.g. Ollama) don't need an API key
|
||||
if self._is_local_model(self.model):
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model,
|
||||
api_key=os.environ[api_key_env],
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
# Fall back to credential store
|
||||
api_key = self._get_api_key_from_credential_store()
|
||||
if api_key:
|
||||
# Fall back to environment variable
|
||||
# First check api_key_env_var from config (set by quickstart)
|
||||
api_key_env = llm_config.get("api_key_env_var") or self._get_api_key_env_var(
|
||||
self.model
|
||||
)
|
||||
if api_key_env and os.environ.get(api_key_env):
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model, api_key=api_key, api_base=api_base
|
||||
model=self.model,
|
||||
api_key=os.environ[api_key_env],
|
||||
api_base=api_base,
|
||||
)
|
||||
# Set env var so downstream code (e.g. cleanup LLM in
|
||||
# node._extract_json) can also find it
|
||||
if api_key_env:
|
||||
os.environ[api_key_env] = api_key
|
||||
elif api_key_env:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
print(f"Set it with: export {api_key_env}=your-api-key")
|
||||
else:
|
||||
# Fall back to credential store
|
||||
api_key = self._get_api_key_from_credential_store()
|
||||
if api_key:
|
||||
self._llm = LiteLLMProvider(
|
||||
model=self.model, api_key=api_key, api_base=api_base
|
||||
)
|
||||
# Set env var so downstream code (e.g. cleanup LLM in
|
||||
# node._extract_json) can also find it
|
||||
if api_key_env:
|
||||
os.environ[api_key_env] = api_key
|
||||
elif api_key_env:
|
||||
print(f"Warning: {api_key_env} not set. LLM calls will fail.")
|
||||
print(f"Set it with: export {api_key_env}=your-api-key")
|
||||
|
||||
# Fail fast if the agent needs an LLM but none was configured
|
||||
if self._llm is None:
|
||||
@@ -820,6 +827,12 @@ class AgentRunner:
|
||||
if has_llm_nodes:
|
||||
from framework.credentials.models import CredentialError
|
||||
|
||||
if self._is_local_model(self.model):
|
||||
raise CredentialError(
|
||||
f"Failed to initialize LLM for local model '{self.model}'. "
|
||||
f"Ensure your local LLM server is running "
|
||||
f"(e.g. 'ollama serve' for Ollama)."
|
||||
)
|
||||
api_key_env = self._get_api_key_env_var(self.model)
|
||||
hint = (
|
||||
f"Set it with: export {api_key_env}=your-api-key"
|
||||
@@ -875,8 +888,8 @@ class AgentRunner:
|
||||
return "MISTRAL_API_KEY"
|
||||
elif model_lower.startswith("groq/"):
|
||||
return "GROQ_API_KEY"
|
||||
elif model_lower.startswith("ollama/"):
|
||||
return None # Ollama doesn't need an API key (local)
|
||||
elif self._is_local_model(model_lower):
|
||||
return None # Local models don't need an API key
|
||||
elif model_lower.startswith("azure/"):
|
||||
return "AZURE_API_KEY"
|
||||
elif model_lower.startswith("cohere/"):
|
||||
@@ -916,6 +929,22 @@ class AgentRunner:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_local_model(model: str) -> bool:
|
||||
"""Check if a model is a local model that doesn't require an API key.
|
||||
|
||||
Local providers like Ollama run on the user's machine and do not
|
||||
need any authentication credentials.
|
||||
"""
|
||||
LOCAL_PREFIXES = (
|
||||
"ollama/",
|
||||
"ollama_chat/",
|
||||
"vllm/",
|
||||
"lm_studio/",
|
||||
"llamacpp/",
|
||||
)
|
||||
return model.lower().startswith(LOCAL_PREFIXES)
|
||||
|
||||
def _setup_agent_runtime(
|
||||
self,
|
||||
tools: list,
|
||||
|
||||
@@ -153,6 +153,7 @@ class AgentRuntime:
|
||||
self._config = config or AgentRuntimeConfig()
|
||||
self._runtime_log_store = runtime_log_store
|
||||
self._checkpoint_config = checkpoint_config
|
||||
self.accounts_prompt = accounts_prompt
|
||||
|
||||
# Primary graph identity
|
||||
self._graph_id: str = graph_id or "primary"
|
||||
|
||||
@@ -826,3 +826,52 @@ class TestAsyncComplete:
|
||||
assert call_thread_ids[0] != main_thread_id, (
|
||||
"Base acomplete() should offload sync complete() to a thread pool"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AgentRunner._is_local_model — parameterized tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLocalModel:
|
||||
"""Parameterized tests for AgentRunner._is_local_model()."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"ollama/llama3",
|
||||
"ollama/mistral",
|
||||
"ollama_chat/llama3",
|
||||
"vllm/mistral",
|
||||
"lm_studio/phi3",
|
||||
"llamacpp/llama-7b",
|
||||
"Ollama/Llama3", # case-insensitive
|
||||
"VLLM/Mistral",
|
||||
],
|
||||
)
|
||||
def test_local_models_return_true(self, model):
|
||||
"""Local model prefixes should be recognized."""
|
||||
from framework.runner.runner import AgentRunner
|
||||
|
||||
assert AgentRunner._is_local_model(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-3-haiku",
|
||||
"openai/gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-haiku-20240307",
|
||||
"gemini/gemini-1.5-flash",
|
||||
"groq/llama3-70b",
|
||||
"mistral/mistral-large",
|
||||
"azure/gpt-4",
|
||||
"cohere/command-r",
|
||||
"together/llama3-70b",
|
||||
],
|
||||
)
|
||||
def test_cloud_models_return_false(self, model):
|
||||
"""Cloud model prefixes should not be treated as local."""
|
||||
from framework.runner.runner import AgentRunner
|
||||
|
||||
assert AgentRunner._is_local_model(model) is False
|
||||
|
||||
Reference in New Issue
Block a user