feat(vision): add GCP Vision API integration (#4231)
* feat(vision): add GCP Vision API integration * refactor(vision): move GCP Vision credentials to dedicated folder * fix: clean up credentials imports and updated gitignore * followed ruff alphabetic order for credentials
This commit is contained in:
+3
-1
@@ -74,4 +74,6 @@ exports/*
|
||||
|
||||
docs/github-issues/*
|
||||
core/tests/*dumps/*
|
||||
screenshots/*
|
||||
|
||||
screenshots/*
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ from .apollo import APOLLO_CREDENTIALS
|
||||
from .base import CredentialError, CredentialSpec
|
||||
from .browser import get_aden_auth_url, get_aden_setup_url, open_browser
|
||||
from .email import EMAIL_CREDENTIALS
|
||||
from .gcp_vision import GCP_VISION_CREDENTIALS
|
||||
from .github import GITHUB_CREDENTIALS
|
||||
from .health_check import HealthCheckResult, check_credential_health
|
||||
from .hubspot import HUBSPOT_CREDENTIALS
|
||||
@@ -74,6 +75,7 @@ CREDENTIAL_SPECS = {
|
||||
**LLM_CREDENTIALS,
|
||||
**SEARCH_CREDENTIALS,
|
||||
**EMAIL_CREDENTIALS,
|
||||
**GCP_VISION_CREDENTIALS,
|
||||
**APOLLO_CREDENTIALS,
|
||||
**GITHUB_CREDENTIALS,
|
||||
**HUBSPOT_CREDENTIALS,
|
||||
@@ -106,6 +108,7 @@ __all__ = [
|
||||
"LLM_CREDENTIALS",
|
||||
"SEARCH_CREDENTIALS",
|
||||
"EMAIL_CREDENTIALS",
|
||||
"GCP_VISION_CREDENTIALS",
|
||||
"GITHUB_CREDENTIALS",
|
||||
"HUBSPOT_CREDENTIALS",
|
||||
"SLACK_CREDENTIALS",
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
GCP Vision tool credentials.
|
||||
|
||||
Contains credentials for Google Cloud Vision API integration.
|
||||
"""
|
||||
|
||||
from .base import CredentialSpec
|
||||
|
||||
GCP_VISION_CREDENTIALS = {
|
||||
"google_vision": CredentialSpec(
|
||||
env_var="GOOGLE_CLOUD_VISION_API_KEY",
|
||||
tools=[
|
||||
"vision_detect_labels",
|
||||
"vision_detect_text",
|
||||
"vision_detect_faces",
|
||||
"vision_localize_objects",
|
||||
"vision_detect_logos",
|
||||
"vision_detect_landmarks",
|
||||
"vision_image_properties",
|
||||
"vision_web_detection",
|
||||
"vision_safe_search",
|
||||
],
|
||||
required=True,
|
||||
startup_required=False,
|
||||
help_url="https://console.cloud.google.com/apis/credentials",
|
||||
description="Google Cloud Vision API key for image analysis",
|
||||
# Auth method support
|
||||
aden_supported=False,
|
||||
aden_provider_name="",
|
||||
direct_api_key_supported=True,
|
||||
api_key_instructions="""To get a Google Cloud Vision API key:
|
||||
1. Go to Google Cloud Console (console.cloud.google.com)
|
||||
2. Create a new project or select existing
|
||||
3. Go to APIs & Services > Library
|
||||
4. Search for "Cloud Vision API" and enable it
|
||||
5. Go to APIs & Services > Credentials
|
||||
6. Click "Create Credentials" > "API Key"
|
||||
7. Copy the API key""",
|
||||
# Health check configuration
|
||||
health_check_endpoint="",
|
||||
health_check_method="GET",
|
||||
# Credential store mapping
|
||||
credential_id="google_vision",
|
||||
credential_key="api_key",
|
||||
),
|
||||
}
|
||||
@@ -46,6 +46,7 @@ from .pdf_read_tool import register_tools as register_pdf_read
|
||||
from .runtime_logs_tool import register_tools as register_runtime_logs
|
||||
from .serpapi_tool import register_tools as register_serpapi
|
||||
from .slack_tool import register_tools as register_slack
|
||||
from .vision_tool import register_tools as register_vision
|
||||
from .web_scrape_tool import register_tools as register_web_scrape
|
||||
from .web_search_tool import register_tools as register_web_search
|
||||
|
||||
@@ -81,6 +82,7 @@ def register_all_tools(
|
||||
register_apollo(mcp, credentials=credentials)
|
||||
register_serpapi(mcp, credentials=credentials)
|
||||
register_slack(mcp, credentials=credentials)
|
||||
register_vision(mcp, credentials=credentials)
|
||||
|
||||
# Register file system toolkits
|
||||
register_view_file(mcp)
|
||||
@@ -219,6 +221,16 @@ def register_all_tools(
|
||||
"slack_kick_user_from_channel",
|
||||
"slack_delete_file",
|
||||
"slack_get_team_stats",
|
||||
# Vision tools
|
||||
"vision_detect_labels",
|
||||
"vision_detect_text",
|
||||
"vision_detect_faces",
|
||||
"vision_localize_objects",
|
||||
"vision_detect_logos",
|
||||
"vision_detect_landmarks",
|
||||
"vision_image_properties",
|
||||
"vision_web_detection",
|
||||
"vision_safe_search",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
# Google Cloud Vision Tool
|
||||
|
||||
Image analysis tool using Google Cloud Vision API.
|
||||
|
||||
## Features
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `vision_detect_labels` | Identify objects, scenes, activities |
|
||||
| `vision_detect_text` | Extract text from images (OCR) |
|
||||
| `vision_detect_faces` | Detect faces and emotions |
|
||||
| `vision_localize_objects` | Detect objects with bounding boxes |
|
||||
| `vision_detect_logos` | Identify brand logos |
|
||||
| `vision_detect_landmarks` | Identify famous places |
|
||||
| `vision_image_properties` | Get dominant colors and crop hints |
|
||||
| `vision_web_detection` | Find similar images online |
|
||||
| `vision_safe_search` | Detect inappropriate content |
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Get API Key
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com)
|
||||
2. Create a new project or select existing
|
||||
3. Go to **APIs & Services > Library**
|
||||
4. Search for "Cloud Vision API" and enable it
|
||||
5. Go to **APIs & Services > Credentials**
|
||||
6. Click **Create Credentials > API Key**
|
||||
7. Copy the API key
|
||||
|
||||
### 2. Set Environment Variable
|
||||
|
||||
```bash
|
||||
export GOOGLE_CLOUD_VISION_API_KEY=your_api_key
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Label Detection
|
||||
|
||||
```python
|
||||
result = vision_detect_labels(
|
||||
image_source="https://example.com/photo.jpg",
|
||||
max_labels=5
|
||||
)
|
||||
# {"labels": [{"description": "Dog", "score": 0.97}, ...]}
|
||||
```
|
||||
|
||||
### Text Detection (OCR)
|
||||
|
||||
```python
|
||||
result = vision_detect_text(image_source="/path/to/receipt.jpg")
|
||||
# {"text": "Store: Amazon\nTotal: $49.99", "blocks": [...]}
|
||||
```
|
||||
|
||||
### Face Detection
|
||||
|
||||
```python
|
||||
result = vision_detect_faces(image_source="https://example.com/group.jpg")
|
||||
# {"faces": [{"joy": "VERY_LIKELY", "anger": "VERY_UNLIKELY", ...}]}
|
||||
```
|
||||
|
||||
### Object Localization
|
||||
|
||||
```python
|
||||
result = vision_localize_objects(image_source="/path/to/image.jpg")
|
||||
# {"objects": [{"name": "Cat", "score": 0.92, "bounds": [...]}]}
|
||||
```
|
||||
|
||||
### Logo Detection
|
||||
|
||||
```python
|
||||
result = vision_detect_logos(image_source="https://example.com/product.jpg")
|
||||
# {"logos": [{"description": "Nike", "score": 0.95}]}
|
||||
```
|
||||
|
||||
### Landmark Detection
|
||||
|
||||
```python
|
||||
result = vision_detect_landmarks(image_source="/path/to/travel.jpg")
|
||||
# {"landmarks": [{"description": "Eiffel Tower", "location": {"latitude": 48.85, "longitude": 2.29}}]}
|
||||
```
|
||||
|
||||
### Image Properties
|
||||
|
||||
```python
|
||||
result = vision_image_properties(image_source="https://example.com/art.jpg")
|
||||
# {"colors": [{"red": 255, "green": 128, "blue": 0, "score": 0.5}], "crop_hints": [...]}
|
||||
```
|
||||
|
||||
### Web Detection
|
||||
|
||||
```python
|
||||
result = vision_web_detection(image_source="/path/to/image.jpg")
|
||||
# {"web_entities": [...], "similar_images": [...], "pages_with_image": [...]}
|
||||
```
|
||||
|
||||
### Safe Search
|
||||
|
||||
```python
|
||||
result = vision_safe_search(image_source="https://example.com/upload.jpg")
|
||||
# {"adult": "VERY_UNLIKELY", "violence": "VERY_UNLIKELY", "racy": "POSSIBLE", ...}
|
||||
```
|
||||
|
||||
## Input Types
|
||||
|
||||
| Type | Example |
|
||||
|------|---------|
|
||||
| URL | `https://example.com/image.jpg` |
|
||||
| Local file | `/path/to/image.jpg` |
|
||||
|
||||
**Supported formats:** JPEG, PNG, GIF, BMP, WEBP, ICO
|
||||
**Max file size:** 10MB
|
||||
|
||||
## Error Handling
|
||||
|
||||
```python
|
||||
# File not found
|
||||
{"error": "File not found: /path/to/missing.jpg"}
|
||||
|
||||
# File too large
|
||||
{"error": "File exceeds 10MB limit (12.5MB)"}
|
||||
|
||||
# Missing credentials
|
||||
{"error": "GOOGLE_CLOUD_VISION_API_KEY not configured", "help": "..."}
|
||||
|
||||
# API errors
|
||||
{"error": "Invalid API key"}
|
||||
{"error": "Rate limit exceeded. Try again later."}
|
||||
```
|
||||
|
||||
## Pricing
|
||||
|
||||
- **First 1000 images/month:** Free
|
||||
- **After:** ~$1.50 per 1000 images
|
||||
|
||||
See [Cloud Vision Pricing](https://cloud.google.com/vision/pricing) for details.
|
||||
|
||||
## Likelihood Values
|
||||
|
||||
Face detection and safe search return likelihood values:
|
||||
|
||||
| Value | Meaning |
|
||||
|-------|---------|
|
||||
| `VERY_UNLIKELY` | Very unlikely |
|
||||
| `UNLIKELY` | Unlikely |
|
||||
| `POSSIBLE` | Possible |
|
||||
| `LIKELY` | Likely |
|
||||
| `VERY_LIKELY` | Very likely |
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Google Cloud Vision tool for image analysis."""
|
||||
|
||||
from .vision_tool import register_tools
|
||||
|
||||
__all__ = ["register_tools"]
|
||||
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
Google Cloud Vision Tool - Image analysis using Google Cloud Vision API.
|
||||
|
||||
Supports:
|
||||
- Label detection (objects, scenes, activities)
|
||||
- Text detection (OCR)
|
||||
- Face detection (emotions)
|
||||
- Object localization (bounding boxes)
|
||||
- Logo detection
|
||||
- Landmark detection
|
||||
- Image properties (colors, crop hints)
|
||||
- Web detection (similar images)
|
||||
- Safe search (content moderation)
|
||||
|
||||
API Reference: https://cloud.google.com/vision/docs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from fastmcp import FastMCP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aden_tools.credentials import CredentialStoreAdapter
|
||||
|
||||
VISION_API_URL = "https://vision.googleapis.com/v1/images:annotate"
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
class _VisionClient:
|
||||
"""Internal client for Google Cloud Vision API."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
self._api_key = api_key
|
||||
|
||||
def _load_image(self, image_source: str) -> dict[str, Any] | dict[str, str]:
|
||||
"""
|
||||
Load image from URL or local file.
|
||||
|
||||
Returns:
|
||||
Image dict for API request, or error dict if failed.
|
||||
"""
|
||||
# Check if URL
|
||||
if image_source.startswith(("http://", "https://")):
|
||||
return {"source": {"imageUri": image_source}}
|
||||
|
||||
# Local file
|
||||
file_path = Path(image_source)
|
||||
if not file_path.exists():
|
||||
return {"error": f"File not found: {image_source}"}
|
||||
|
||||
if not file_path.is_file():
|
||||
return {"error": f"Not a file: {image_source}"}
|
||||
|
||||
# Check file size
|
||||
file_size = file_path.stat().st_size
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
size_mb = file_size / (1024 * 1024)
|
||||
return {"error": f"File exceeds 10MB limit ({size_mb:.1f}MB)"}
|
||||
|
||||
# Read and encode
|
||||
try:
|
||||
content = file_path.read_bytes()
|
||||
encoded = base64.b64encode(content).decode("utf-8")
|
||||
return {"content": encoded}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to read file: {str(e)}"}
|
||||
|
||||
def _call_api(
|
||||
self, image_data: dict[str, Any], features: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Make request to Vision API."""
|
||||
try:
|
||||
response = httpx.post(
|
||||
VISION_API_URL,
|
||||
params={"key": self._api_key},
|
||||
json={"requests": [{"image": image_data, "features": features}]},
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._handle_response(response)
|
||||
except httpx.TimeoutException:
|
||||
return {"error": "Request timed out"}
|
||||
except httpx.RequestError as e:
|
||||
return {"error": f"Network error: {str(e)}"}
|
||||
|
||||
def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
|
||||
"""Handle API response and errors."""
|
||||
if response.status_code == 400:
|
||||
return {"error": "Invalid request. Check image format and size."}
|
||||
if response.status_code == 401:
|
||||
return {"error": "Invalid API key"}
|
||||
if response.status_code == 403:
|
||||
return {"error": "API key not authorized. Enable Vision API in Google Cloud Console."}
|
||||
if response.status_code == 429:
|
||||
return {"error": "Rate limit exceeded. Try again later."}
|
||||
if response.status_code != 200:
|
||||
return {"error": f"Vision API error (HTTP {response.status_code})"}
|
||||
|
||||
data = response.json()
|
||||
responses = data.get("responses", [])
|
||||
if not responses:
|
||||
return {"error": "Empty response from API"}
|
||||
|
||||
result = responses[0]
|
||||
if "error" in result:
|
||||
return {"error": result["error"].get("message", "Unknown API error")}
|
||||
|
||||
return result
|
||||
|
||||
def detect_labels(self, image_source: str, max_results: int = 10) -> dict[str, Any]:
|
||||
"""Detect labels in image."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(
|
||||
image_data, [{"type": "LABEL_DETECTION", "maxResults": max_results}]
|
||||
)
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
labels = [
|
||||
{"description": label["description"], "score": round(label["score"], 3)}
|
||||
for label in result.get("labelAnnotations", [])
|
||||
]
|
||||
return {"labels": labels}
|
||||
|
||||
def detect_text(self, image_source: str) -> dict[str, Any]:
|
||||
"""Detect text in image (OCR)."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(image_data, [{"type": "TEXT_DETECTION"}])
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
annotations = result.get("textAnnotations", [])
|
||||
if not annotations:
|
||||
return {"text": "", "blocks": []}
|
||||
|
||||
# First annotation is full text
|
||||
full_text = annotations[0].get("description", "")
|
||||
blocks = [
|
||||
{
|
||||
"text": ann.get("description", ""),
|
||||
"bounds": ann.get("boundingPoly", {}).get("vertices", []),
|
||||
}
|
||||
for ann in annotations[1:]
|
||||
]
|
||||
return {"text": full_text, "blocks": blocks}
|
||||
|
||||
def detect_faces(self, image_source: str, max_results: int = 10) -> dict[str, Any]:
|
||||
"""Detect faces and emotions in image."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(image_data, [{"type": "FACE_DETECTION", "maxResults": max_results}])
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
faces = []
|
||||
for face in result.get("faceAnnotations", []):
|
||||
faces.append(
|
||||
{
|
||||
"joy": face.get("joyLikelihood", "UNKNOWN"),
|
||||
"sorrow": face.get("sorrowLikelihood", "UNKNOWN"),
|
||||
"anger": face.get("angerLikelihood", "UNKNOWN"),
|
||||
"surprise": face.get("surpriseLikelihood", "UNKNOWN"),
|
||||
"confidence": round(face.get("detectionConfidence", 0), 3),
|
||||
"bounds": face.get("boundingPoly", {}).get("vertices", []),
|
||||
}
|
||||
)
|
||||
return {"faces": faces}
|
||||
|
||||
def localize_objects(self, image_source: str, max_results: int = 10) -> dict[str, Any]:
|
||||
"""Detect objects with bounding boxes."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(
|
||||
image_data, [{"type": "OBJECT_LOCALIZATION", "maxResults": max_results}]
|
||||
)
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
objects = [
|
||||
{
|
||||
"name": obj.get("name", ""),
|
||||
"score": round(obj.get("score", 0), 3),
|
||||
"bounds": obj.get("boundingPoly", {}).get("normalizedVertices", []),
|
||||
}
|
||||
for obj in result.get("localizedObjectAnnotations", [])
|
||||
]
|
||||
return {"objects": objects}
|
||||
|
||||
def detect_logos(self, image_source: str, max_results: int = 5) -> dict[str, Any]:
|
||||
"""Detect logos in image."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(image_data, [{"type": "LOGO_DETECTION", "maxResults": max_results}])
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
logos = [
|
||||
{
|
||||
"description": logo.get("description", ""),
|
||||
"score": round(logo.get("score", 0), 3),
|
||||
}
|
||||
for logo in result.get("logoAnnotations", [])
|
||||
]
|
||||
return {"logos": logos}
|
||||
|
||||
def detect_landmarks(self, image_source: str, max_results: int = 5) -> dict[str, Any]:
|
||||
"""Detect landmarks in image."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(
|
||||
image_data, [{"type": "LANDMARK_DETECTION", "maxResults": max_results}]
|
||||
)
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
landmarks = []
|
||||
for lm in result.get("landmarkAnnotations", []):
|
||||
location = {}
|
||||
locations = lm.get("locations", [])
|
||||
if locations:
|
||||
lat_lng = locations[0].get("latLng", {})
|
||||
location = {
|
||||
"latitude": lat_lng.get("latitude"),
|
||||
"longitude": lat_lng.get("longitude"),
|
||||
}
|
||||
landmarks.append(
|
||||
{
|
||||
"description": lm.get("description", ""),
|
||||
"score": round(lm.get("score", 0), 3),
|
||||
"location": location,
|
||||
}
|
||||
)
|
||||
return {"landmarks": landmarks}
|
||||
|
||||
def get_image_properties(self, image_source: str) -> dict[str, Any]:
|
||||
"""Get image properties (colors, crop hints)."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(
|
||||
image_data,
|
||||
[{"type": "IMAGE_PROPERTIES"}, {"type": "CROP_HINTS"}],
|
||||
)
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
# Extract colors
|
||||
colors = []
|
||||
color_info = result.get("imagePropertiesAnnotation", {})
|
||||
dominant_colors = color_info.get("dominantColors", {}).get("colors", [])
|
||||
for color in dominant_colors[:5]:
|
||||
rgb = color.get("color", {})
|
||||
colors.append(
|
||||
{
|
||||
"red": int(rgb.get("red", 0)),
|
||||
"green": int(rgb.get("green", 0)),
|
||||
"blue": int(rgb.get("blue", 0)),
|
||||
"score": round(color.get("score", 0), 3),
|
||||
"pixel_fraction": round(color.get("pixelFraction", 0), 3),
|
||||
}
|
||||
)
|
||||
|
||||
# Extract crop hints
|
||||
crop_hints = []
|
||||
hints_annotation = result.get("cropHintsAnnotation", {})
|
||||
for hint in hints_annotation.get("cropHints", []):
|
||||
crop_hints.append(
|
||||
{
|
||||
"bounds": hint.get("boundingPoly", {}).get("vertices", []),
|
||||
"confidence": round(hint.get("confidence", 0), 3),
|
||||
}
|
||||
)
|
||||
|
||||
return {"colors": colors, "crop_hints": crop_hints}
|
||||
|
||||
def web_detection(self, image_source: str) -> dict[str, Any]:
|
||||
"""Find similar images and web references."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(image_data, [{"type": "WEB_DETECTION"}])
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
web = result.get("webDetection", {})
|
||||
|
||||
web_entities = [
|
||||
{
|
||||
"description": entity.get("description", ""),
|
||||
"score": round(entity.get("score", 0), 3),
|
||||
}
|
||||
for entity in web.get("webEntities", [])[:10]
|
||||
]
|
||||
|
||||
similar_images = [img.get("url", "") for img in web.get("visuallySimilarImages", [])[:5]]
|
||||
|
||||
pages_with_image = [
|
||||
{"url": page.get("url", ""), "title": page.get("pageTitle", "")}
|
||||
for page in web.get("pagesWithMatchingImages", [])[:5]
|
||||
]
|
||||
|
||||
return {
|
||||
"web_entities": web_entities,
|
||||
"similar_images": similar_images,
|
||||
"pages_with_image": pages_with_image,
|
||||
}
|
||||
|
||||
def safe_search(self, image_source: str) -> dict[str, Any]:
|
||||
"""Detect inappropriate content."""
|
||||
image_data = self._load_image(image_source)
|
||||
if "error" in image_data:
|
||||
return image_data
|
||||
|
||||
result = self._call_api(image_data, [{"type": "SAFE_SEARCH_DETECTION"}])
|
||||
if "error" in result:
|
||||
return result
|
||||
|
||||
safe = result.get("safeSearchAnnotation", {})
|
||||
return {
|
||||
"adult": safe.get("adult", "UNKNOWN"),
|
||||
"spoof": safe.get("spoof", "UNKNOWN"),
|
||||
"medical": safe.get("medical", "UNKNOWN"),
|
||||
"violence": safe.get("violence", "UNKNOWN"),
|
||||
"racy": safe.get("racy", "UNKNOWN"),
|
||||
}
|
||||
|
||||
|
||||
def register_tools(
|
||||
mcp: FastMCP,
|
||||
credentials: CredentialStoreAdapter | None = None,
|
||||
) -> None:
|
||||
"""Register Google Cloud Vision tools with the MCP server."""
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
"""Get API key from credentials or environment."""
|
||||
if credentials is not None:
|
||||
return credentials.get("google_vision")
|
||||
return os.getenv("GOOGLE_CLOUD_VISION_API_KEY")
|
||||
|
||||
def _get_client() -> _VisionClient | dict[str, str]:
|
||||
"""Get Vision client, or return error dict if no credentials."""
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
return {
|
||||
"error": "GOOGLE_CLOUD_VISION_API_KEY not configured",
|
||||
"help": "Get an API key at https://console.cloud.google.com/apis/credentials",
|
||||
}
|
||||
return _VisionClient(api_key)
|
||||
|
||||
@mcp.tool()
|
||||
def vision_detect_labels(
|
||||
image_source: str,
|
||||
max_labels: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect labels (objects, scenes, activities) in an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
max_labels: Maximum number of labels to return (1-100, default 10)
|
||||
|
||||
Returns:
|
||||
Dict with labels and confidence scores, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.detect_labels(image_source, min(max(1, max_labels), 100))
|
||||
|
||||
@mcp.tool()
|
||||
def vision_detect_text(image_source: str) -> dict:
|
||||
"""
|
||||
Extract text from an image (OCR).
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
|
||||
Returns:
|
||||
Dict with extracted text and text blocks with positions, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.detect_text(image_source)
|
||||
|
||||
@mcp.tool()
|
||||
def vision_detect_faces(
|
||||
image_source: str,
|
||||
max_faces: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect faces and emotions in an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
max_faces: Maximum number of faces to detect (1-100, default 10)
|
||||
|
||||
Returns:
|
||||
Dict with faces including emotions (joy, sorrow, anger, surprise), or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.detect_faces(image_source, min(max(1, max_faces), 100))
|
||||
|
||||
@mcp.tool()
|
||||
def vision_localize_objects(
|
||||
image_source: str,
|
||||
max_objects: int = 10,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect objects with bounding box coordinates in an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
max_objects: Maximum number of objects to detect (1-100, default 10)
|
||||
|
||||
Returns:
|
||||
Dict with objects including names, scores, and normalized bounding boxes, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.localize_objects(image_source, min(max(1, max_objects), 100))
|
||||
|
||||
@mcp.tool()
|
||||
def vision_detect_logos(
|
||||
image_source: str,
|
||||
max_logos: int = 5,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect brand logos in an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
max_logos: Maximum number of logos to detect (1-20, default 5)
|
||||
|
||||
Returns:
|
||||
Dict with detected logos and confidence scores, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.detect_logos(image_source, min(max(1, max_logos), 20))
|
||||
|
||||
@mcp.tool()
|
||||
def vision_detect_landmarks(
|
||||
image_source: str,
|
||||
max_landmarks: int = 5,
|
||||
) -> dict:
|
||||
"""
|
||||
Detect famous landmarks in an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
max_landmarks: Maximum number of landmarks to detect (1-20, default 5)
|
||||
|
||||
Returns:
|
||||
Dict with landmarks including names, scores, and GPS coordinates, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.detect_landmarks(image_source, min(max(1, max_landmarks), 20))
|
||||
|
||||
@mcp.tool()
|
||||
def vision_image_properties(image_source: str) -> dict:
|
||||
"""
|
||||
Get image properties including dominant colors and crop hints.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
|
||||
Returns:
|
||||
Dict with dominant colors (RGB, score) and crop hints, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.get_image_properties(image_source)
|
||||
|
||||
@mcp.tool()
|
||||
def vision_web_detection(image_source: str) -> dict:
|
||||
"""
|
||||
Find similar images and web references for an image.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
|
||||
Returns:
|
||||
Dict with web entities, similar images, and pages containing the image
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.web_detection(image_source)
|
||||
|
||||
@mcp.tool()
|
||||
def vision_safe_search(image_source: str) -> dict:
|
||||
"""
|
||||
Detect inappropriate content in an image.
|
||||
|
||||
Checks for: adult, spoof, medical, violence, racy content.
|
||||
Each category returns a likelihood: VERY_UNLIKELY, UNLIKELY, POSSIBLE, LIKELY, VERY_LIKELY.
|
||||
|
||||
Args:
|
||||
image_source: URL or local file path to the image
|
||||
|
||||
Returns:
|
||||
Dict with likelihood ratings for each category, or error dict
|
||||
"""
|
||||
client = _get_client()
|
||||
if isinstance(client, dict):
|
||||
return client
|
||||
return client.safe_search(image_source)
|
||||
@@ -0,0 +1,548 @@
|
||||
"""Tests for Google Cloud Vision tool."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from aden_tools.tools.vision_tool import register_tools
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp() -> FastMCP:
|
||||
"""Create a fresh FastMCP instance for testing."""
|
||||
return FastMCP("test-server")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(tmp_path: Path) -> Path:
|
||||
"""Create a small test image file."""
|
||||
# Create a minimal valid PNG (1x1 pixel)
|
||||
png_data = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
)
|
||||
image_file = tmp_path / "test.png"
|
||||
image_file.write_bytes(png_data)
|
||||
return image_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def large_file(tmp_path: Path) -> Path:
|
||||
"""Create a file larger than 10MB."""
|
||||
large_file = tmp_path / "large.png"
|
||||
large_file.write_bytes(b"x" * (11 * 1024 * 1024)) # 11MB
|
||||
return large_file
|
||||
|
||||
|
||||
# --- Credential Tests ---
|
||||
|
||||
|
||||
def test_missing_credentials(mcp: FastMCP):
|
||||
"""Test error when API key not configured."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "GOOGLE_CLOUD_VISION_API_KEY" in result["error"]
|
||||
assert "help" in result
|
||||
|
||||
|
||||
def test_credentials_from_env(mcp: FastMCP):
|
||||
"""Test that credentials are retrieved from environment."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
mock_response = {"responses": [{"labelAnnotations": []}]}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "labels" in result
|
||||
|
||||
|
||||
# --- Image Loading Tests ---
|
||||
|
||||
|
||||
def test_file_not_found(mcp: FastMCP):
|
||||
"""Test error when local file doesn't exist."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
result = tool_fn(image_source="/nonexistent/path/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "File not found" in result["error"]
|
||||
|
||||
|
||||
def test_file_too_large(mcp: FastMCP, large_file: Path):
|
||||
"""Test error when file exceeds 10MB limit."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
result = tool_fn(image_source=str(large_file))
|
||||
|
||||
assert "error" in result
|
||||
assert "10MB" in result["error"]
|
||||
|
||||
|
||||
def test_directory_not_file(mcp: FastMCP, tmp_path: Path):
|
||||
"""Test error when path is a directory, not a file."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
result = tool_fn(image_source=str(tmp_path))
|
||||
|
||||
assert "error" in result
|
||||
assert "Not a file" in result["error"]
|
||||
|
||||
|
||||
# --- API Response Tests ---
|
||||
|
||||
|
||||
def test_detect_labels_success(mcp: FastMCP):
|
||||
"""Test successful label detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"labelAnnotations": [
|
||||
{"description": "Dog", "score": 0.97},
|
||||
{"description": "Animal", "score": 0.95},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/dog.jpg", max_labels=5)
|
||||
|
||||
assert "labels" in result
|
||||
assert len(result["labels"]) == 2
|
||||
assert result["labels"][0]["description"] == "Dog"
|
||||
assert result["labels"][0]["score"] == 0.97
|
||||
|
||||
|
||||
def test_detect_text_success(mcp: FastMCP):
|
||||
"""Test successful text detection (OCR)."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_text"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"textAnnotations": [
|
||||
{"description": "Hello World\nLine 2"},
|
||||
{"description": "Hello", "boundingPoly": {"vertices": [{"x": 0, "y": 0}]}},
|
||||
{"description": "World", "boundingPoly": {"vertices": [{"x": 50, "y": 0}]}},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/text.jpg")
|
||||
|
||||
assert "text" in result
|
||||
assert result["text"] == "Hello World\nLine 2"
|
||||
assert "blocks" in result
|
||||
assert len(result["blocks"]) == 2
|
||||
|
||||
|
||||
def test_detect_faces_success(mcp: FastMCP):
|
||||
"""Test successful face detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_faces"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"faceAnnotations": [
|
||||
{
|
||||
"joyLikelihood": "VERY_LIKELY",
|
||||
"sorrowLikelihood": "VERY_UNLIKELY",
|
||||
"angerLikelihood": "VERY_UNLIKELY",
|
||||
"surpriseLikelihood": "UNLIKELY",
|
||||
"detectionConfidence": 0.98,
|
||||
"boundingPoly": {"vertices": [{"x": 10, "y": 10}]},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/face.jpg")
|
||||
|
||||
assert "faces" in result
|
||||
assert len(result["faces"]) == 1
|
||||
assert result["faces"][0]["joy"] == "VERY_LIKELY"
|
||||
assert result["faces"][0]["confidence"] == 0.98
|
||||
|
||||
|
||||
def test_localize_objects_success(mcp: FastMCP):
|
||||
"""Test successful object localization."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_localize_objects"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"localizedObjectAnnotations": [
|
||||
{
|
||||
"name": "Cat",
|
||||
"score": 0.92,
|
||||
"boundingPoly": {
|
||||
"normalizedVertices": [
|
||||
{"x": 0.1, "y": 0.2},
|
||||
{"x": 0.9, "y": 0.8},
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/cat.jpg")
|
||||
|
||||
assert "objects" in result
|
||||
assert len(result["objects"]) == 1
|
||||
assert result["objects"][0]["name"] == "Cat"
|
||||
|
||||
|
||||
def test_detect_logos_success(mcp: FastMCP):
|
||||
"""Test successful logo detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_logos"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"logoAnnotations": [
|
||||
{"description": "Apple", "score": 0.95},
|
||||
{"description": "Nike", "score": 0.88},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/logos.jpg")
|
||||
|
||||
assert "logos" in result
|
||||
assert len(result["logos"]) == 2
|
||||
assert result["logos"][0]["description"] == "Apple"
|
||||
|
||||
|
||||
def test_detect_landmarks_success(mcp: FastMCP):
|
||||
"""Test successful landmark detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_landmarks"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"landmarkAnnotations": [
|
||||
{
|
||||
"description": "Eiffel Tower",
|
||||
"score": 0.96,
|
||||
"locations": [{"latLng": {"latitude": 48.8584, "longitude": 2.2945}}],
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/paris.jpg")
|
||||
|
||||
assert "landmarks" in result
|
||||
assert len(result["landmarks"]) == 1
|
||||
assert result["landmarks"][0]["description"] == "Eiffel Tower"
|
||||
assert result["landmarks"][0]["location"]["latitude"] == 48.8584
|
||||
|
||||
|
||||
def test_image_properties_success(mcp: FastMCP):
|
||||
"""Test successful image properties extraction."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_image_properties"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"imagePropertiesAnnotation": {
|
||||
"dominantColors": {
|
||||
"colors": [
|
||||
{
|
||||
"color": {"red": 255, "green": 0, "blue": 0},
|
||||
"score": 0.5,
|
||||
"pixelFraction": 0.3,
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"cropHintsAnnotation": {
|
||||
"cropHints": [{"boundingPoly": {"vertices": []}, "confidence": 0.8}]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/colorful.jpg")
|
||||
|
||||
assert "colors" in result
|
||||
assert len(result["colors"]) == 1
|
||||
assert result["colors"][0]["red"] == 255
|
||||
assert "crop_hints" in result
|
||||
|
||||
|
||||
def test_web_detection_success(mcp: FastMCP):
|
||||
"""Test successful web detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_web_detection"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"webDetection": {
|
||||
"webEntities": [{"description": "Sunset", "score": 0.9}],
|
||||
"visuallySimilarImages": [{"url": "https://similar.com/1.jpg"}],
|
||||
"pagesWithMatchingImages": [
|
||||
{"url": "https://page.com", "pageTitle": "Sunset Photos"}
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/sunset.jpg")
|
||||
|
||||
assert "web_entities" in result
|
||||
assert "similar_images" in result
|
||||
assert "pages_with_image" in result
|
||||
assert result["web_entities"][0]["description"] == "Sunset"
|
||||
|
||||
|
||||
def test_safe_search_success(mcp: FastMCP):
|
||||
"""Test successful safe search detection."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_safe_search"].fn
|
||||
|
||||
mock_response = {
|
||||
"responses": [
|
||||
{
|
||||
"safeSearchAnnotation": {
|
||||
"adult": "VERY_UNLIKELY",
|
||||
"spoof": "UNLIKELY",
|
||||
"medical": "VERY_UNLIKELY",
|
||||
"violence": "VERY_UNLIKELY",
|
||||
"racy": "POSSIBLE",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/photo.jpg")
|
||||
|
||||
assert result["adult"] == "VERY_UNLIKELY"
|
||||
assert result["violence"] == "VERY_UNLIKELY"
|
||||
assert result["racy"] == "POSSIBLE"
|
||||
|
||||
|
||||
# --- Local File Tests ---
|
||||
|
||||
|
||||
def test_local_file_success(mcp: FastMCP, sample_image: Path):
|
||||
"""Test successful processing of local file."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
mock_response = {"responses": [{"labelAnnotations": [{"description": "Image", "score": 0.9}]}]}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source=str(sample_image))
|
||||
|
||||
assert "labels" in result
|
||||
# Verify base64 content was sent
|
||||
call_args = mock_post.call_args
|
||||
request_json = call_args.kwargs["json"]
|
||||
assert "content" in request_json["requests"][0]["image"]
|
||||
|
||||
|
||||
# --- Error Handling Tests ---
|
||||
|
||||
|
||||
def test_api_error_401(mcp: FastMCP):
|
||||
"""Test handling of invalid API key error."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(401)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "Invalid API key" in result["error"]
|
||||
|
||||
|
||||
def test_api_error_403(mcp: FastMCP):
|
||||
"""Test handling of unauthorized API key error."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(403)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "not authorized" in result["error"]
|
||||
|
||||
|
||||
def test_api_error_429(mcp: FastMCP):
|
||||
"""Test handling of rate limit error."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(429)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "Rate limit" in result["error"]
|
||||
|
||||
|
||||
def test_timeout_error(mcp: FastMCP):
|
||||
"""Test handling of request timeout."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.side_effect = httpx.TimeoutException("Timeout")
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "timed out" in result["error"]
|
||||
|
||||
|
||||
def test_network_error(mcp: FastMCP):
|
||||
"""Test handling of network error."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.side_effect = httpx.RequestError("Network error")
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "Network error" in result["error"]
|
||||
|
||||
|
||||
def test_empty_response(mcp: FastMCP):
|
||||
"""Test handling of empty API response."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json={"responses": []})
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "Empty response" in result["error"]
|
||||
|
||||
|
||||
def test_api_error_in_response(mcp: FastMCP):
|
||||
"""Test handling of error in API response body."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
mock_response = {"responses": [{"error": {"message": "Image too small"}}]}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert "error" in result
|
||||
assert "Image too small" in result["error"]
|
||||
|
||||
|
||||
# --- Parameter Validation Tests ---
|
||||
|
||||
|
||||
def test_max_labels_clamped(mcp: FastMCP):
|
||||
"""Test that max_labels is clamped to valid range."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_labels"].fn
|
||||
|
||||
mock_response = {"responses": [{"labelAnnotations": []}]}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
# Test with value > 100
|
||||
tool_fn(image_source="https://example.com/image.jpg", max_labels=200)
|
||||
|
||||
# Verify maxResults was clamped to 100
|
||||
call_args = mock_post.call_args
|
||||
features = call_args.kwargs["json"]["requests"][0]["features"]
|
||||
assert features[0]["maxResults"] == 100
|
||||
|
||||
|
||||
def test_detect_text_no_text_found(mcp: FastMCP):
|
||||
"""Test text detection when no text is found."""
|
||||
register_tools(mcp, credentials=None)
|
||||
tool_fn = mcp._tool_manager._tools["vision_detect_text"].fn
|
||||
|
||||
mock_response = {"responses": [{"textAnnotations": []}]}
|
||||
|
||||
with patch.dict(os.environ, {"GOOGLE_CLOUD_VISION_API_KEY": "test-api-key"}):
|
||||
with patch("httpx.post") as mock_post:
|
||||
mock_post.return_value = httpx.Response(200, json=mock_response)
|
||||
result = tool_fn(image_source="https://example.com/image.jpg")
|
||||
|
||||
assert result["text"] == ""
|
||||
assert result["blocks"] == []
|
||||
Reference in New Issue
Block a user