Commit e7f2b2409cd4db20be799b9ce9995e939c0d4807
0 parents
first commit
Showing
30 changed files
with
4836 additions
and
0 deletions
Show diff stats
| 1 | +++ a/.env.example | ||
| @@ -0,0 +1,47 @@ | @@ -0,0 +1,47 @@ | ||
| 1 | +# ==================== | ||
| 2 | +# OpenAI Configuration | ||
| 3 | +# ==================== | ||
| 4 | +OPENAI_API_KEY= | ||
| 5 | +OPENAI_MODEL=gpt-4o-mini | ||
| 6 | +OPENAI_EMBEDDING_MODEL=text-embedding-3-small | ||
| 7 | +OPENAI_TEMPERATURE=1 | ||
| 8 | +OPENAI_MAX_TOKENS=1000 | ||
| 9 | + | ||
| 10 | +# ==================== | ||
| 11 | +# CLIP Server Configuration | ||
| 12 | +# ==================== | ||
| 13 | +CLIP_SERVER_URL=grpc://localhost:51000 | ||
| 14 | + | ||
| 15 | +# ==================== | ||
| 16 | +# Milvus Configuration | ||
| 17 | +# ==================== | ||
| 18 | +MILVUS_HOST=localhost | ||
| 19 | +MILVUS_PORT=19530 | ||
| 20 | + | ||
| 21 | +# Collection settings | ||
| 22 | +TEXT_COLLECTION_NAME=text_embeddings | ||
| 23 | +IMAGE_COLLECTION_NAME=image_embeddings | ||
| 24 | +TEXT_DIM=1536 | ||
| 25 | +IMAGE_DIM=512 | ||
| 26 | + | ||
| 27 | +# ==================== | ||
| 28 | +# Search Configuration | ||
| 29 | +# ==================== | ||
| 30 | +TOP_K_RESULTS=30 | ||
| 31 | +SIMILARITY_THRESHOLD=0.6 | ||
| 32 | + | ||
| 33 | +# ==================== | ||
| 34 | +# Application Configuration | ||
| 35 | +# ==================== | ||
| 36 | +APP_HOST=0.0.0.0 | ||
| 37 | +APP_PORT=8000 | ||
| 38 | +DEBUG=true | ||
| 39 | +LOG_LEVEL=INFO | ||
| 40 | + | ||
| 41 | +# ==================== | ||
| 42 | +# Data Paths | ||
| 43 | +# ==================== | ||
| 44 | +RAW_DATA_PATH=./data/raw | ||
| 45 | +PROCESSED_DATA_PATH=./data/processed | ||
| 46 | +IMAGE_DATA_PATH=./data/images | ||
| 47 | + |
| 1 | +++ a/.gitignore | ||
| @@ -0,0 +1,83 @@ | @@ -0,0 +1,83 @@ | ||
| 1 | +# Python | ||
| 2 | +__pycache__/ | ||
| 3 | +*.py[cod] | ||
| 4 | +*$py.class | ||
| 5 | +*.so | ||
| 6 | +.Python | ||
| 7 | +build/ | ||
| 8 | +develop-eggs/ | ||
| 9 | +dist/ | ||
| 10 | +downloads/ | ||
| 11 | +eggs/ | ||
| 12 | +.eggs/ | ||
| 13 | +lib/ | ||
| 14 | +lib64/ | ||
| 15 | +parts/ | ||
| 16 | +sdist/ | ||
| 17 | +var/ | ||
| 18 | +wheels/ | ||
| 19 | +*.egg-info/ | ||
| 20 | +.installed.cfg | ||
| 21 | +*.egg | ||
| 22 | +MANIFEST | ||
| 23 | + | ||
| 24 | +# Virtual Environment | ||
| 25 | +venv/ | ||
| 26 | +env/ | ||
| 27 | +ENV/ | ||
| 28 | +.venv | ||
| 29 | + | ||
| 30 | +# Environment Variables | ||
| 31 | +.env | ||
| 32 | +*.env | ||
| 33 | +!.env.example | ||
| 34 | + | ||
| 35 | +# IDEs | ||
| 36 | +.vscode/ | ||
| 37 | +.idea/ | ||
| 38 | +.cursor/ | ||
| 39 | +*.swp | ||
| 40 | +*.swo | ||
| 41 | +*~ | ||
| 42 | +.DS_Store | ||
| 43 | + | ||
| 44 | +# Data Files - ignore everything in data/ except .gitkeep files | ||
| 45 | +data/** | ||
| 46 | +!data/ | ||
| 47 | +!data/raw/ | ||
| 48 | +!data/processed/ | ||
| 49 | +!data/images/ | ||
| 50 | +!data/**/.gitkeep | ||
| 51 | + | ||
| 52 | +# Database | ||
| 53 | +*.db | ||
| 54 | +*.sqlite | ||
| 55 | +*.sqlite3 | ||
| 56 | +data/milvus_lite.db | ||
| 57 | + | ||
| 58 | +# Docker volumes | ||
| 59 | +volumes/ | ||
| 60 | + | ||
| 61 | +# Logs | ||
| 62 | +*.log | ||
| 63 | +logs/ | ||
| 64 | +nohup.out | ||
| 65 | + | ||
| 66 | +# Testing | ||
| 67 | +.pytest_cache/ | ||
| 68 | +.coverage | ||
| 69 | +htmlcov/ | ||
| 70 | +.tox/ | ||
| 71 | + | ||
| 72 | +# Jupyter | ||
| 73 | +.ipynb_checkpoints/ | ||
| 74 | +*.ipynb | ||
| 75 | + | ||
| 76 | +# Model caches | ||
| 77 | +.cache/ | ||
| 78 | +models/ | ||
| 79 | + | ||
| 80 | +# Temporary files | ||
| 81 | +tmp/ | ||
| 82 | +temp/ | ||
| 83 | +*.tmp |
| 1 | +++ a/LICENSE | ||
| @@ -0,0 +1,21 @@ | @@ -0,0 +1,21 @@ | ||
| 1 | +MIT License | ||
| 2 | + | ||
| 3 | +Copyright (c) 2025 zhangruotian | ||
| 4 | + | ||
| 5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| 6 | +of this software and associated documentation files (the "Software"), to deal | ||
| 7 | +in the Software without restriction, including without limitation the rights | ||
| 8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| 9 | +copies of the Software, and to permit persons to whom the Software is | ||
| 10 | +furnished to do so, subject to the following conditions: | ||
| 11 | + | ||
| 12 | +The above copyright notice and this permission notice shall be included in all | ||
| 13 | +copies or substantial portions of the Software. | ||
| 14 | + | ||
| 15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| 16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| 17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| 18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| 19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| 20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| 21 | +SOFTWARE. |
| 1 | +++ a/README.md | ||
| @@ -0,0 +1,161 @@ | @@ -0,0 +1,161 @@ | ||
| 1 | +# OmniShopAgent | ||
| 2 | + | ||
| 3 | +An autonomous multi-modal fashion shopping agent powered by **LangGraph** and **ReAct pattern**. | ||
| 4 | + | ||
| 5 | +## Demo | ||
| 6 | + | ||
| 7 | +📄 **[demo.pdf](./demo.pdf)** | ||
| 8 | + | ||
| 9 | +## Overview | ||
| 10 | + | ||
| 11 | +OmniShopAgent autonomously decides which tools to call, maintains conversation state, and determines when to respond. Built with **LangGraph**, it uses agentic patterns for intelligent product discovery. | ||
| 12 | + | ||
| 13 | +**Key Features:** | ||
| 14 | +- Autonomous tool selection and execution | ||
| 15 | +- Multi-modal search (text + image) | ||
| 16 | +- Conversational context awareness | ||
| 17 | +- Real-time visual analysis | ||
| 18 | + | ||
| 19 | +## Tech Stack | ||
| 20 | + | ||
| 21 | +| Component | Technology | | ||
| 22 | +|-----------|-----------| | ||
| 23 | +| **Agent Framework** | LangGraph | | ||
| 24 | +| **LLM** | any LLM supported by LangChain | | ||
| 25 | +| **Text Embedding** | text-embedding-3-small | | ||
| 26 | +| **Image Embedding** | CLIP ViT-B/32 | | ||
| 27 | +| **Vector Database** | Milvus | | ||
| 28 | +| **Frontend** | Streamlit | | ||
| 29 | +| **Dataset** | Kaggle Fashion Products | | ||
| 30 | + | ||
| 31 | +## Architecture | ||
| 32 | + | ||
| 33 | +**Agent Flow:** | ||
| 34 | + | ||
| 35 | +```mermaid | ||
| 36 | +graph LR | ||
| 37 | + START --> Agent | ||
| 38 | + Agent -->|Has tool_calls| Tools | ||
| 39 | + Agent -->|No tool_calls| END | ||
| 40 | + Tools --> Agent | ||
| 41 | + | ||
| 42 | + subgraph "Agent Node" | ||
| 43 | + A[Receive Messages] --> B[LLM Reasoning] | ||
| 44 | + B --> C{Need Tools?} | ||
| 45 | + C -->|Yes| D[Generate tool_calls] | ||
| 46 | + C -->|No| E[Generate Response] | ||
| 47 | + end | ||
| 48 | + | ||
| 49 | + subgraph "Tool Node" | ||
| 50 | + F[Execute Tools] --> G[Return ToolMessage] | ||
| 51 | + end | ||
| 52 | +``` | ||
| 53 | + | ||
| 54 | +**Available Tools:** | ||
| 55 | +- `search_products(query)` - Text-based semantic search | ||
| 56 | +- `search_by_image(image_path)` - Visual similarity search | ||
| 57 | +- `analyze_image_style(image_path)` - VLM style analysis | ||
| 58 | + | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +## Examples | ||
| 62 | + | ||
| 63 | +**Text Search:** | ||
| 64 | +``` | ||
| 65 | +User: "winter coats for women" | ||
| 66 | +Agent: search_products("winter coats women") → Returns 5 products | ||
| 67 | +``` | ||
| 68 | + | ||
| 69 | +**Image Upload:** | ||
| 70 | +``` | ||
| 71 | +User: [uploads sneaker photo] "find similar" | ||
| 72 | +Agent: search_by_image(path) → Returns visually similar shoes | ||
| 73 | +``` | ||
| 74 | + | ||
| 75 | +**Style Analysis + Search:** | ||
| 76 | +``` | ||
| 77 | +User: [uploads vintage jacket] "what style is this? find matching pants" | ||
| 78 | +Agent: analyze_image_style(path) → "Vintage denim bomber..." | ||
| 79 | + search_products("vintage pants casual") → Returns matching items | ||
| 80 | +``` | ||
| 81 | + | ||
| 82 | +**Multi-turn Context:** | ||
| 83 | +``` | ||
| 84 | +Turn 1: "show me red dresses" | ||
| 85 | +Agent: search_products("red dresses") → Results | ||
| 86 | + | ||
| 87 | +Turn 2: "make them formal" | ||
| 88 | +Agent: [remembers context] → search_products("red formal dresses") → Results | ||
| 89 | +``` | ||
| 90 | + | ||
| 91 | +**Complex Reasoning:** | ||
| 92 | +``` | ||
| 93 | +User: [uploads office outfit] "I like the shirt but need something more casual" | ||
| 94 | +Agent: analyze_image_style(path) → Extracts shirt details | ||
| 95 | + search_products("casual shirt [color] [style]") → Returns casual alternatives | ||
| 96 | +``` | ||
| 97 | + | ||
| 98 | +## Installation | ||
| 99 | + | ||
| 100 | +**Prerequisites:** | ||
| 101 | +- Python 3.12+ (LangChain 1.x 要求 Python 3.10+) | ||
| 102 | +- OpenAI API Key | ||
| 103 | +- Docker & Docker Compose | ||
| 104 | + | ||
| 105 | +### 1. Setup Environment | ||
| 106 | +```bash | ||
| 107 | +# Clone and install dependencies | ||
| 108 | +git clone <repository-url> | ||
| 109 | +cd OmniShopAgent | ||
| 110 | +python -m venv venv | ||
| 111 | +source venv/bin/activate # Windows: venv\Scripts\activate | ||
| 112 | +pip install -r requirements.txt | ||
| 113 | + | ||
| 114 | +# Configure environment variables | ||
| 115 | +cp .env.example .env | ||
| 116 | +# Edit .env and add your OPENAI_API_KEY | ||
| 117 | +``` | ||
| 118 | + | ||
| 119 | +### 2. Download Dataset | ||
| 120 | +Download the [Fashion Product Images Dataset](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-dataset) from Kaggle and extract to `./data/`: | ||
| 121 | + | ||
| 122 | +```python | ||
| 123 | +python scripts/download_dataset.py | ||
| 124 | +``` | ||
| 125 | + | ||
| 126 | +Expected structure: | ||
| 127 | +``` | ||
| 128 | +data/ | ||
| 129 | +├── images/ # ~44k product images | ||
| 130 | +├── styles.csv # Product metadata | ||
| 131 | +└── images.csv # Image filenames | ||
| 132 | +``` | ||
| 133 | + | ||
| 134 | +### 3. Start Services | ||
| 135 | + | ||
| 136 | +```bash | ||
| 137 | +docker-compose up | ||
| 138 | +python -m clip_server | ||
| 139 | +``` | ||
| 140 | + | ||
| 141 | + | ||
| 142 | +### 4. Index Data | ||
| 143 | + | ||
| 144 | +```bash | ||
| 145 | +python scripts/index_data.py | ||
| 146 | +``` | ||
| 147 | + | ||
| 148 | +This generates and stores text/image embeddings for all 44k products in Milvus. | ||
| 149 | + | ||
| 150 | +### 5. Launch Application | ||
| 151 | +```bash | ||
| 152 | +# 使用启动脚本(推荐) | ||
| 153 | +./scripts/start.sh | ||
| 154 | + | ||
| 155 | +# 或直接运行 | ||
| 156 | +streamlit run app.py | ||
| 157 | +``` | ||
| 158 | +Opens at `http://localhost:8501` | ||
| 159 | + | ||
| 160 | +### CentOS 8 部署 | ||
| 161 | +详见 [docs/DEPLOY_CENTOS8.md](docs/DEPLOY_CENTOS8.md) |
| 1 | +++ a/app.py | ||
| @@ -0,0 +1,732 @@ | @@ -0,0 +1,732 @@ | ||
| 1 | +""" | ||
| 2 | +OmniShopAgent - Streamlit UI | ||
| 3 | +Multi-modal fashion shopping assistant with conversational AI | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import logging | ||
| 7 | +import re | ||
| 8 | +import uuid | ||
| 9 | +from pathlib import Path | ||
| 10 | +from typing import Optional | ||
| 11 | + | ||
| 12 | +import streamlit as st | ||
| 13 | +from PIL import Image, ImageOps | ||
| 14 | + | ||
| 15 | +from app.agents.shopping_agent import ShoppingAgent | ||
| 16 | + | ||
| 17 | +# Configure logging | ||
| 18 | +logging.basicConfig( | ||
| 19 | + level=logging.INFO, | ||
| 20 | + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
| 21 | +) | ||
| 22 | +logger = logging.getLogger(__name__) | ||
| 23 | + | ||
| 24 | +# Page config | ||
| 25 | +st.set_page_config( | ||
| 26 | + page_title="OmniShopAgent", | ||
| 27 | + page_icon="👗", | ||
| 28 | + layout="centered", | ||
| 29 | + initial_sidebar_state="collapsed", | ||
| 30 | +) | ||
| 31 | + | ||
| 32 | +# Custom CSS - ChatGPT-like style | ||
| 33 | +st.markdown( | ||
| 34 | + """ | ||
| 35 | + <style> | ||
| 36 | + /* Hide default Streamlit elements */ | ||
| 37 | + #MainMenu {visibility: hidden;} | ||
| 38 | + footer {visibility: hidden;} | ||
| 39 | + header {visibility: hidden;} | ||
| 40 | + | ||
| 41 | + /* Body and root container */ | ||
| 42 | + .main .block-container { | ||
| 43 | + padding-bottom: 180px !important; | ||
| 44 | + padding-top: 2rem; | ||
| 45 | + max-width: 900px; | ||
| 46 | + margin: 0 auto; | ||
| 47 | + } | ||
| 48 | + | ||
| 49 | + /* Fixed input container at bottom */ | ||
| 50 | + .fixed-input-container { | ||
| 51 | + position: fixed; | ||
| 52 | + bottom: 0; | ||
| 53 | + left: 0; | ||
| 54 | + right: 0; | ||
| 55 | + background: white; | ||
| 56 | + border-top: 1px solid #e5e5e5; | ||
| 57 | + padding: 1rem 0; | ||
| 58 | + z-index: 1000; | ||
| 59 | + box-shadow: 0 -2px 10px rgba(0,0,0,0.05); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + .fixed-input-container .block-container { | ||
| 63 | + max-width: 900px; | ||
| 64 | + margin: 0 auto; | ||
| 65 | + padding: 0 1rem !important; | ||
| 66 | + } | ||
| 67 | + | ||
| 68 | + /* Message bubbles */ | ||
| 69 | + .message { | ||
| 70 | + margin: 1rem 0; | ||
| 71 | + padding: 1rem 1.5rem; | ||
| 72 | + border-radius: 1rem; | ||
| 73 | + animation: fadeIn 0.3s ease-in; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + @keyframes fadeIn { | ||
| 77 | + from { opacity: 0; transform: translateY(10px); } | ||
| 78 | + to { opacity: 1; transform: translateY(0); } | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + .user-message { | ||
| 82 | + background: transparent; | ||
| 83 | + margin: 0 0 1rem 0; | ||
| 84 | + padding: 0; | ||
| 85 | + border-radius: 0; | ||
| 86 | + } | ||
| 87 | + | ||
| 88 | + .assistant-message { | ||
| 89 | + background: white; | ||
| 90 | + border: 1px solid #e5e5e5; | ||
| 91 | + margin-right: 3rem; | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + /* Product cards - simplified */ | ||
| 95 | + .stImage { | ||
| 96 | + border-radius: 0px; | ||
| 97 | + overflow: hidden; | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + .stImage img { | ||
| 101 | + transition: transform 0.2s; | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + .stImage:hover img { | ||
| 105 | + transform: scale(1.05); | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + /* Scroll to bottom behavior */ | ||
| 109 | + html { | ||
| 110 | + scroll-behavior: smooth; | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + /* Header */ | ||
| 114 | + .app-header { | ||
| 115 | + text-align: center; | ||
| 116 | + padding: 2rem 1rem; | ||
| 117 | + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | ||
| 118 | + color: white; | ||
| 119 | + border-radius: 1rem; | ||
| 120 | + margin-bottom: 2rem; | ||
| 121 | + } | ||
| 122 | + | ||
| 123 | + .app-title { | ||
| 124 | + font-size: 2rem; | ||
| 125 | + font-weight: 700; | ||
| 126 | + margin: 0; | ||
| 127 | + } | ||
| 128 | + | ||
| 129 | + .app-subtitle { | ||
| 130 | + font-size: 1rem; | ||
| 131 | + opacity: 0.9; | ||
| 132 | + margin-top: 0.5rem; | ||
| 133 | + } | ||
| 134 | + | ||
| 135 | + /* Welcome screen */ | ||
| 136 | + .welcome-container { | ||
| 137 | + text-align: center; | ||
| 138 | + padding: 4rem 2rem; | ||
| 139 | + color: #666; | ||
| 140 | + } | ||
| 141 | + | ||
| 142 | + .welcome-title { | ||
| 143 | + font-size: 2rem; | ||
| 144 | + font-weight: 600; | ||
| 145 | + color: #1a1a1a; | ||
| 146 | + margin-bottom: 1rem; | ||
| 147 | + } | ||
| 148 | + | ||
| 149 | + .welcome-features { | ||
| 150 | + display: grid; | ||
| 151 | + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | ||
| 152 | + gap: 1.5rem; | ||
| 153 | + margin: 2rem 0; | ||
| 154 | + } | ||
| 155 | + | ||
| 156 | + .feature-card { | ||
| 157 | + background: #f7f7f8; | ||
| 158 | + padding: 1.5rem; | ||
| 159 | + border-radius: 12px; | ||
| 160 | + transition: all 0.2s; | ||
| 161 | + } | ||
| 162 | + | ||
| 163 | + .feature-card:hover { | ||
| 164 | + background: #efefef; | ||
| 165 | + transform: translateY(-2px); | ||
| 166 | + } | ||
| 167 | + | ||
| 168 | + .feature-icon { | ||
| 169 | + font-size: 2rem; | ||
| 170 | + margin-bottom: 0.5rem; | ||
| 171 | + } | ||
| 172 | + | ||
| 173 | + .feature-title { | ||
| 174 | + font-weight: 600; | ||
| 175 | + margin-bottom: 0.25rem; | ||
| 176 | + } | ||
| 177 | + | ||
| 178 | + /* Image preview */ | ||
| 179 | + .image-preview { | ||
| 180 | + position: relative; | ||
| 181 | + display: inline-block; | ||
| 182 | + margin: 0.5rem 0; | ||
| 183 | + } | ||
| 184 | + | ||
| 185 | + .image-preview img { | ||
| 186 | + max-width: 200px; | ||
| 187 | + border-radius: 8px; | ||
| 188 | + border: 2px solid #e5e5e5; | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + .remove-image-btn { | ||
| 192 | + position: absolute; | ||
| 193 | + top: 5px; | ||
| 194 | + right: 5px; | ||
| 195 | + background: rgba(0,0,0,0.6); | ||
| 196 | + color: white; | ||
| 197 | + border: none; | ||
| 198 | + border-radius: 50%; | ||
| 199 | + width: 24px; | ||
| 200 | + height: 24px; | ||
| 201 | + cursor: pointer; | ||
| 202 | + font-size: 14px; | ||
| 203 | + } | ||
| 204 | + | ||
| 205 | + /* Buttons */ | ||
| 206 | + .stButton>button { | ||
| 207 | + border-radius: 8px; | ||
| 208 | + border: 1px solid #e5e5e5; | ||
| 209 | + padding: 0.5rem 1rem; | ||
| 210 | + transition: all 0.2s; | ||
| 211 | + } | ||
| 212 | + | ||
| 213 | + .stButton>button:hover { | ||
| 214 | + background: #f0f0f0; | ||
| 215 | + border-color: #d0d0d0; | ||
| 216 | + } | ||
| 217 | + | ||
| 218 | + /* Hide upload button label */ | ||
| 219 | + .uploadedFile { | ||
| 220 | + display: none; | ||
| 221 | + } | ||
| 222 | + </style> | ||
| 223 | + """, | ||
| 224 | + unsafe_allow_html=True, | ||
| 225 | +) | ||
| 226 | + | ||
| 227 | + | ||
| 228 | +# Initialize session state | ||
| 229 | +def initialize_session(): | ||
| 230 | + """Initialize session state variables""" | ||
| 231 | + if "session_id" not in st.session_state: | ||
| 232 | + st.session_state.session_id = str(uuid.uuid4()) | ||
| 233 | + | ||
| 234 | + if "shopping_agent" not in st.session_state: | ||
| 235 | + st.session_state.shopping_agent = ShoppingAgent( | ||
| 236 | + session_id=st.session_state.session_id | ||
| 237 | + ) | ||
| 238 | + | ||
| 239 | + if "messages" not in st.session_state: | ||
| 240 | + st.session_state.messages = [] | ||
| 241 | + | ||
| 242 | + if "uploaded_image" not in st.session_state: | ||
| 243 | + st.session_state.uploaded_image = None | ||
| 244 | + | ||
| 245 | + if "show_image_upload" not in st.session_state: | ||
| 246 | + st.session_state.show_image_upload = False | ||
| 247 | + | ||
| 248 | + | ||
| 249 | +def save_uploaded_image(uploaded_file) -> Optional[str]: | ||
| 250 | + """Save uploaded image to temp directory""" | ||
| 251 | + if uploaded_file is None: | ||
| 252 | + return None | ||
| 253 | + | ||
| 254 | + try: | ||
| 255 | + temp_dir = Path("temp_uploads") | ||
| 256 | + temp_dir.mkdir(exist_ok=True) | ||
| 257 | + | ||
| 258 | + image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" | ||
| 259 | + with open(image_path, "wb") as f: | ||
| 260 | + f.write(uploaded_file.getbuffer()) | ||
| 261 | + | ||
| 262 | + logger.info(f"Saved uploaded image to {image_path}") | ||
| 263 | + return str(image_path) | ||
| 264 | + | ||
| 265 | + except Exception as e: | ||
| 266 | + logger.error(f"Error saving uploaded image: {e}") | ||
| 267 | + st.error(f"Failed to save image: {str(e)}") | ||
| 268 | + return None | ||
| 269 | + | ||
| 270 | + | ||
| 271 | +def extract_products_from_response(response: str) -> list: | ||
| 272 | + """Extract product information from agent response | ||
| 273 | + | ||
| 274 | + Returns list of dicts with product info | ||
| 275 | + """ | ||
| 276 | + products = [] | ||
| 277 | + | ||
| 278 | + # Pattern to match product blocks in the response | ||
| 279 | + # Looking for ID, name, and other details | ||
| 280 | + lines = response.split("\n") | ||
| 281 | + current_product = {} | ||
| 282 | + | ||
| 283 | + for line in lines: | ||
| 284 | + line = line.strip() | ||
| 285 | + | ||
| 286 | + # Match product number (e.g., "1. Product Name" or "**1. Product Name**") | ||
| 287 | + if re.match(r"^\*?\*?\d+\.\s+", line): | ||
| 288 | + if current_product: | ||
| 289 | + products.append(current_product) | ||
| 290 | + current_product = {} | ||
| 291 | + # Extract product name | ||
| 292 | + name = re.sub(r"^\*?\*?\d+\.\s+", "", line) | ||
| 293 | + name = name.replace("**", "").strip() | ||
| 294 | + current_product["name"] = name | ||
| 295 | + | ||
| 296 | + # Match ID | ||
| 297 | + elif "ID:" in line or "id:" in line: | ||
| 298 | + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) | ||
| 299 | + if id_match: | ||
| 300 | + current_product["id"] = id_match.group(1) | ||
| 301 | + | ||
| 302 | + # Match Category | ||
| 303 | + elif "Category:" in line: | ||
| 304 | + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) | ||
| 305 | + if cat_match: | ||
| 306 | + current_product["category"] = cat_match.group(1).strip() | ||
| 307 | + | ||
| 308 | + # Match Color | ||
| 309 | + elif "Color:" in line: | ||
| 310 | + color_match = re.search(r"Color:\s*(\w+)", line) | ||
| 311 | + if color_match: | ||
| 312 | + current_product["color"] = color_match.group(1) | ||
| 313 | + | ||
| 314 | + # Match Gender | ||
| 315 | + elif "Gender:" in line: | ||
| 316 | + gender_match = re.search(r"Gender:\s*(\w+)", line) | ||
| 317 | + if gender_match: | ||
| 318 | + current_product["gender"] = gender_match.group(1) | ||
| 319 | + | ||
| 320 | + # Match Season | ||
| 321 | + elif "Season:" in line: | ||
| 322 | + season_match = re.search(r"Season:\s*(\w+)", line) | ||
| 323 | + if season_match: | ||
| 324 | + current_product["season"] = season_match.group(1) | ||
| 325 | + | ||
| 326 | + # Match Usage | ||
| 327 | + elif "Usage:" in line: | ||
| 328 | + usage_match = re.search(r"Usage:\s*(\w+)", line) | ||
| 329 | + if usage_match: | ||
| 330 | + current_product["usage"] = usage_match.group(1) | ||
| 331 | + | ||
| 332 | + # Match Similarity/Relevance score | ||
| 333 | + elif "Similarity:" in line or "Relevance:" in line: | ||
| 334 | + score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) | ||
| 335 | + if score_match: | ||
| 336 | + current_product["score"] = score_match.group(1) | ||
| 337 | + | ||
| 338 | + # Add last product | ||
| 339 | + if current_product: | ||
| 340 | + products.append(current_product) | ||
| 341 | + | ||
| 342 | + return products | ||
| 343 | + | ||
| 344 | + | ||
| 345 | +def display_product_card(product: dict): | ||
| 346 | + """Display a product card with image and name""" | ||
| 347 | + product_id = product.get("id", "") | ||
| 348 | + name = product.get("name", "Unknown Product") | ||
| 349 | + | ||
| 350 | + # Debug: log what we got | ||
| 351 | + logger.info(f"Displaying product: ID={product_id}, Name={name}") | ||
| 352 | + | ||
| 353 | + # Try to load image from data/images directory | ||
| 354 | + if product_id: | ||
| 355 | + image_path = Path(f"data/images/{product_id}.jpg") | ||
| 356 | + | ||
| 357 | + if image_path.exists(): | ||
| 358 | + try: | ||
| 359 | + img = Image.open(image_path) | ||
| 360 | + # Fixed size for all images | ||
| 361 | + target_size = (200, 200) | ||
| 362 | + try: | ||
| 363 | + # Try new Pillow API | ||
| 364 | + img_processed = ImageOps.fit( | ||
| 365 | + img, target_size, method=Image.Resampling.LANCZOS | ||
| 366 | + ) | ||
| 367 | + except AttributeError: | ||
| 368 | + # Fallback for older Pillow versions | ||
| 369 | + img_processed = ImageOps.fit( | ||
| 370 | + img, target_size, method=Image.LANCZOS | ||
| 371 | + ) | ||
| 372 | + | ||
| 373 | + # Display image with fixed width | ||
| 374 | + st.image(img_processed, use_container_width=False, width=200) | ||
| 375 | + st.markdown(f"**{name}**") | ||
| 376 | + st.caption(f"ID: {product_id}") | ||
| 377 | + return | ||
| 378 | + except Exception as e: | ||
| 379 | + logger.warning(f"Failed to load image {image_path}: {e}") | ||
| 380 | + else: | ||
| 381 | + logger.warning(f"Image not found: {image_path}") | ||
| 382 | + | ||
| 383 | + # Fallback: no image | ||
| 384 | + st.markdown(f"**📷 {name}**") | ||
| 385 | + if product_id: | ||
| 386 | + st.caption(f"ID: {product_id}") | ||
| 387 | + else: | ||
| 388 | + st.caption("ID not available") | ||
| 389 | + | ||
| 390 | + | ||
| 391 | +def display_message(message: dict): | ||
| 392 | + """Display a chat message""" | ||
| 393 | + role = message["role"] | ||
| 394 | + content = message["content"] | ||
| 395 | + image_path = message.get("image_path") | ||
| 396 | + tool_calls = message.get("tool_calls", []) | ||
| 397 | + | ||
| 398 | + if role == "user": | ||
| 399 | + st.markdown('<div class="message user-message">', unsafe_allow_html=True) | ||
| 400 | + | ||
| 401 | + if image_path and Path(image_path).exists(): | ||
| 402 | + try: | ||
| 403 | + img = Image.open(image_path) | ||
| 404 | + st.image(img, width=200) | ||
| 405 | + except Exception: | ||
| 406 | + logger.warning(f"Failed to load user uploaded image: {image_path}") | ||
| 407 | + | ||
| 408 | + st.markdown(content) | ||
| 409 | + st.markdown("</div>", unsafe_allow_html=True) | ||
| 410 | + | ||
| 411 | + else: # assistant | ||
| 412 | + # Display tool calls horizontally - only tool names | ||
| 413 | + if tool_calls: | ||
| 414 | + tool_names = [tc['name'] for tc in tool_calls] | ||
| 415 | + st.caption(" → ".join(tool_names)) | ||
| 416 | + st.markdown("") | ||
| 417 | + | ||
| 418 | + # Extract and display products if any | ||
| 419 | + products = extract_products_from_response(content) | ||
| 420 | + | ||
| 421 | + # Debug logging | ||
| 422 | + logger.info(f"Extracted {len(products)} products from response") | ||
| 423 | + for p in products: | ||
| 424 | + logger.info(f"Product: {p}") | ||
| 425 | + | ||
| 426 | + if products: | ||
| 427 | + def parse_score(product: dict) -> float: | ||
| 428 | + score = product.get("score") | ||
| 429 | + if score is None: | ||
| 430 | + return 0.0 | ||
| 431 | + try: | ||
| 432 | + return float(score) | ||
| 433 | + except (TypeError, ValueError): | ||
| 434 | + return 0.0 | ||
| 435 | + | ||
| 436 | + # Sort by score and limit to 3 | ||
| 437 | + products = sorted(products, key=parse_score, reverse=True)[:3] | ||
| 438 | + | ||
| 439 | + logger.info(f"Displaying top {len(products)} products") | ||
| 440 | + | ||
| 441 | + # Display the text response first (without product details) | ||
| 442 | + text_lines = [] | ||
| 443 | + for line in content.split("\n"): | ||
| 444 | + # Skip product detail lines | ||
| 445 | + if not any( | ||
| 446 | + keyword in line | ||
| 447 | + for keyword in [ | ||
| 448 | + "ID:", | ||
| 449 | + "Category:", | ||
| 450 | + "Color:", | ||
| 451 | + "Gender:", | ||
| 452 | + "Season:", | ||
| 453 | + "Usage:", | ||
| 454 | + "Similarity:", | ||
| 455 | + "Relevance:", | ||
| 456 | + ] | ||
| 457 | + ): | ||
| 458 | + if not re.match(r"^\*?\*?\d+\.\s+", line): | ||
| 459 | + text_lines.append(line) | ||
| 460 | + | ||
| 461 | + intro_text = "\n".join(text_lines).strip() | ||
| 462 | + if intro_text: | ||
| 463 | + st.markdown(intro_text) | ||
| 464 | + | ||
| 465 | + # Display product cards in grid | ||
| 466 | + st.markdown("<br>", unsafe_allow_html=True) | ||
| 467 | + | ||
| 468 | + # Create exactly 3 columns with equal width | ||
| 469 | + cols = st.columns(3) | ||
| 470 | + for j, product in enumerate(products[:3]): # Ensure max 3 | ||
| 471 | + with cols[j]: | ||
| 472 | + display_product_card(product) | ||
| 473 | + else: | ||
| 474 | + # No products found, display full content | ||
| 475 | + st.markdown(content) | ||
| 476 | + | ||
| 477 | + st.markdown("</div>", unsafe_allow_html=True) | ||
| 478 | + | ||
| 479 | + | ||
| 480 | +def display_welcome(): | ||
| 481 | + """Display welcome screen""" | ||
| 482 | + | ||
| 483 | + col1, col2, col3, col4 = st.columns(4) | ||
| 484 | + | ||
| 485 | + with col1: | ||
| 486 | + st.markdown( | ||
| 487 | + """ | ||
| 488 | + <div class="feature-card"> | ||
| 489 | + <div class="feature-icon">💬</div> | ||
| 490 | + <div class="feature-title">Text Search</div> | ||
| 491 | + <div>Describe what you want</div> | ||
| 492 | + </div> | ||
| 493 | + """, | ||
| 494 | + unsafe_allow_html=True, | ||
| 495 | + ) | ||
| 496 | + | ||
| 497 | + with col2: | ||
| 498 | + st.markdown( | ||
| 499 | + """ | ||
| 500 | + <div class="feature-card"> | ||
| 501 | + <div class="feature-icon">📸</div> | ||
| 502 | + <div class="feature-title">Image Search</div> | ||
| 503 | + <div>Upload product photos</div> | ||
| 504 | + </div> | ||
| 505 | + """, | ||
| 506 | + unsafe_allow_html=True, | ||
| 507 | + ) | ||
| 508 | + | ||
| 509 | + with col3: | ||
| 510 | + st.markdown( | ||
| 511 | + """ | ||
| 512 | + <div class="feature-card"> | ||
| 513 | + <div class="feature-icon">🔍</div> | ||
| 514 | + <div class="feature-title">Visual Analysis</div> | ||
| 515 | + <div>AI analyzes prodcut style</div> | ||
| 516 | + </div> | ||
| 517 | + """, | ||
| 518 | + unsafe_allow_html=True, | ||
| 519 | + ) | ||
| 520 | + | ||
| 521 | + with col4: | ||
| 522 | + st.markdown( | ||
| 523 | + """ | ||
| 524 | + <div class="feature-card"> | ||
| 525 | + <div class="feature-icon">💭</div> | ||
| 526 | + <div class="feature-title">Conversational</div> | ||
| 527 | + <div>Remembers context</div> | ||
| 528 | + </div> | ||
| 529 | + """, | ||
| 530 | + unsafe_allow_html=True, | ||
| 531 | + ) | ||
| 532 | + | ||
| 533 | + st.markdown("<br><br>", unsafe_allow_html=True) | ||
| 534 | + | ||
| 535 | + | ||
| 536 | +def main(): | ||
| 537 | + """Main Streamlit app""" | ||
| 538 | + initialize_session() | ||
| 539 | + | ||
| 540 | + # Header | ||
| 541 | + st.markdown( | ||
| 542 | + """ | ||
| 543 | + <div class="app-header"> | ||
| 544 | + <div class="app-title">👗 OmniShopAgent</div> | ||
| 545 | + <div class="app-subtitle">AI Fashion Shopping Assistant</div> | ||
| 546 | + </div> | ||
| 547 | + """, | ||
| 548 | + unsafe_allow_html=True, | ||
| 549 | + ) | ||
| 550 | + | ||
| 551 | + # Sidebar (collapsed by default, but accessible) | ||
| 552 | + with st.sidebar: | ||
| 553 | + st.markdown("### ⚙️ Settings") | ||
| 554 | + | ||
| 555 | + if st.button("🗑️ Clear Chat", use_container_width=True): | ||
| 556 | + if "shopping_agent" in st.session_state: | ||
| 557 | + st.session_state.shopping_agent.clear_history() | ||
| 558 | + st.session_state.messages = [] | ||
| 559 | + st.session_state.uploaded_image = None | ||
| 560 | + st.rerun() | ||
| 561 | + | ||
| 562 | + st.markdown("---") | ||
| 563 | + st.caption(f"Session: `{st.session_state.session_id[:8]}...`") | ||
| 564 | + | ||
| 565 | + # Chat messages container | ||
| 566 | + messages_container = st.container() | ||
| 567 | + | ||
| 568 | + with messages_container: | ||
| 569 | + if not st.session_state.messages: | ||
| 570 | + display_welcome() | ||
| 571 | + else: | ||
| 572 | + for message in st.session_state.messages: | ||
| 573 | + display_message(message) | ||
| 574 | + | ||
| 575 | + # Fixed input area at bottom (using container to simulate fixed position) | ||
| 576 | + st.markdown('<div class="fixed-input-container">', unsafe_allow_html=True) | ||
| 577 | + | ||
| 578 | + input_container = st.container() | ||
| 579 | + | ||
| 580 | + with input_container: | ||
| 581 | + # Image upload area (shown when + is clicked) | ||
| 582 | + if st.session_state.show_image_upload: | ||
| 583 | + uploaded_file = st.file_uploader( | ||
| 584 | + "Choose an image", | ||
| 585 | + type=["jpg", "jpeg", "png"], | ||
| 586 | + key="file_uploader", | ||
| 587 | + ) | ||
| 588 | + | ||
| 589 | + if uploaded_file: | ||
| 590 | + st.session_state.uploaded_image = uploaded_file | ||
| 591 | + # Show preview | ||
| 592 | + col1, col2 = st.columns([1, 4]) | ||
| 593 | + with col1: | ||
| 594 | + img = Image.open(uploaded_file) | ||
| 595 | + st.image(img, width=100) | ||
| 596 | + with col2: | ||
| 597 | + if st.button("❌ Remove"): | ||
| 598 | + st.session_state.uploaded_image = None | ||
| 599 | + st.session_state.show_image_upload = False | ||
| 600 | + st.rerun() | ||
| 601 | + | ||
| 602 | + # Input row | ||
| 603 | + col1, col2 = st.columns([1, 12]) | ||
| 604 | + | ||
| 605 | + with col1: | ||
| 606 | + # Image upload toggle button | ||
| 607 | + if st.button("➕", help="Add image", use_container_width=True): | ||
| 608 | + st.session_state.show_image_upload = ( | ||
| 609 | + not st.session_state.show_image_upload | ||
| 610 | + ) | ||
| 611 | + st.rerun() | ||
| 612 | + | ||
| 613 | + with col2: | ||
| 614 | + # Text input | ||
| 615 | + user_query = st.chat_input( | ||
| 616 | + "Ask about fashion products...", | ||
| 617 | + key="chat_input", | ||
| 618 | + ) | ||
| 619 | + | ||
| 620 | + st.markdown("</div>", unsafe_allow_html=True) | ||
| 621 | + | ||
| 622 | + # Process user input | ||
| 623 | + if user_query: | ||
| 624 | + # Ensure shopping agent is initialized | ||
| 625 | + if "shopping_agent" not in st.session_state: | ||
| 626 | + st.error("Session not initialized. Please refresh the page.") | ||
| 627 | + st.stop() | ||
| 628 | + | ||
| 629 | + # Save uploaded image if present, or get from recent history | ||
| 630 | + image_path = None | ||
| 631 | + if st.session_state.uploaded_image: | ||
| 632 | + # User explicitly uploaded an image for this query | ||
| 633 | + image_path = save_uploaded_image(st.session_state.uploaded_image) | ||
| 634 | + else: | ||
| 635 | + # Check if query refers to a previous image | ||
| 636 | + query_lower = user_query.lower() | ||
| 637 | + if any( | ||
| 638 | + ref in query_lower | ||
| 639 | + for ref in [ | ||
| 640 | + "this", | ||
| 641 | + "that", | ||
| 642 | + "the image", | ||
| 643 | + "the shirt", | ||
| 644 | + "the product", | ||
| 645 | + "it", | ||
| 646 | + ] | ||
| 647 | + ): | ||
| 648 | + # Find the most recent message with an image | ||
| 649 | + for msg in reversed(st.session_state.messages): | ||
| 650 | + if msg.get("role") == "user" and msg.get("image_path"): | ||
| 651 | + image_path = msg["image_path"] | ||
| 652 | + logger.info(f"Using image from previous message: {image_path}") | ||
| 653 | + break | ||
| 654 | + | ||
| 655 | + # Add user message | ||
| 656 | + st.session_state.messages.append( | ||
| 657 | + { | ||
| 658 | + "role": "user", | ||
| 659 | + "content": user_query, | ||
| 660 | + "image_path": image_path, | ||
| 661 | + } | ||
| 662 | + ) | ||
| 663 | + | ||
| 664 | + # Display user message immediately | ||
| 665 | + with messages_container: | ||
| 666 | + display_message(st.session_state.messages[-1]) | ||
| 667 | + | ||
| 668 | + # Process with shopping agent | ||
| 669 | + try: | ||
| 670 | + shopping_agent = st.session_state.shopping_agent | ||
| 671 | + | ||
| 672 | + # Handle greetings | ||
| 673 | + query_lower = user_query.lower().strip() | ||
| 674 | + if query_lower in ["hi", "hello", "hey"]: | ||
| 675 | + response = """Hello! 👋 I'm your fashion shopping assistant. | ||
| 676 | + | ||
| 677 | +I can help you: | ||
| 678 | +- Search for products by description | ||
| 679 | +- Find items similar to images you upload | ||
| 680 | +- Analyze product styles | ||
| 681 | + | ||
| 682 | +What are you looking for today?""" | ||
| 683 | + tool_calls = [] | ||
| 684 | + else: | ||
| 685 | + # Process with agent | ||
| 686 | + result = shopping_agent.chat( | ||
| 687 | + query=user_query, | ||
| 688 | + image_path=image_path, | ||
| 689 | + ) | ||
| 690 | + response = result["response"] | ||
| 691 | + tool_calls = result.get("tool_calls", []) | ||
| 692 | + | ||
| 693 | + # Add assistant message | ||
| 694 | + st.session_state.messages.append( | ||
| 695 | + { | ||
| 696 | + "role": "assistant", | ||
| 697 | + "content": response, | ||
| 698 | + "tool_calls": tool_calls, | ||
| 699 | + } | ||
| 700 | + ) | ||
| 701 | + | ||
| 702 | + # Clear uploaded image and hide upload area after sending | ||
| 703 | + st.session_state.uploaded_image = None | ||
| 704 | + st.session_state.show_image_upload = False | ||
| 705 | + | ||
| 706 | + # Auto-scroll to bottom with JavaScript | ||
| 707 | + st.markdown( | ||
| 708 | + """ | ||
| 709 | + <script> | ||
| 710 | + window.scrollTo(0, document.body.scrollHeight); | ||
| 711 | + </script> | ||
| 712 | + """, | ||
| 713 | + unsafe_allow_html=True, | ||
| 714 | + ) | ||
| 715 | + | ||
| 716 | + except Exception as e: | ||
| 717 | + logger.error(f"Error processing query: {e}", exc_info=True) | ||
| 718 | + error_msg = f"I apologize, I encountered an error: {str(e)}" | ||
| 719 | + | ||
| 720 | + st.session_state.messages.append( | ||
| 721 | + { | ||
| 722 | + "role": "assistant", | ||
| 723 | + "content": error_msg, | ||
| 724 | + } | ||
| 725 | + ) | ||
| 726 | + | ||
| 727 | + # Rerun to update UI | ||
| 728 | + st.rerun() | ||
| 729 | + | ||
| 730 | + | ||
| 731 | +if __name__ == "__main__": | ||
| 732 | + main() |
| 1 | +++ a/app/agents/shopping_agent.py | ||
| @@ -0,0 +1,272 @@ | @@ -0,0 +1,272 @@ | ||
| 1 | +""" | ||
| 2 | +Conversational Shopping Agent with LangGraph | ||
| 3 | +True ReAct agent with autonomous tool calling and message accumulation | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import logging | ||
| 7 | +from pathlib import Path | ||
| 8 | +from typing import Optional, Sequence | ||
| 9 | + | ||
| 10 | +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | ||
| 11 | +from langchain_openai import ChatOpenAI | ||
| 12 | +from langgraph.checkpoint.memory import MemorySaver | ||
| 13 | +from langgraph.graph import END, START, StateGraph | ||
| 14 | +from langgraph.graph.message import add_messages | ||
| 15 | +from langgraph.prebuilt import ToolNode | ||
| 16 | +from typing_extensions import Annotated, TypedDict | ||
| 17 | + | ||
| 18 | +from app.config import settings | ||
| 19 | +from app.tools.search_tools import get_all_tools | ||
| 20 | + | ||
| 21 | +logger = logging.getLogger(__name__) | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +def _extract_message_text(msg) -> str: | ||
| 25 | + """Extract text from message content. | ||
| 26 | + LangChain 1.0: content may be str or content_blocks (list) for multimodal.""" | ||
| 27 | + content = getattr(msg, "content", "") | ||
| 28 | + if isinstance(content, str): | ||
| 29 | + return content | ||
| 30 | + if isinstance(content, list): | ||
| 31 | + parts = [] | ||
| 32 | + for block in content: | ||
| 33 | + if isinstance(block, dict): | ||
| 34 | + parts.append(block.get("text", block.get("content", ""))) | ||
| 35 | + else: | ||
| 36 | + parts.append(str(block)) | ||
| 37 | + return "".join(str(p) for p in parts) | ||
| 38 | + return str(content) if content else "" | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +class AgentState(TypedDict): | ||
| 42 | + """State for the shopping agent with message accumulation""" | ||
| 43 | + | ||
| 44 | + messages: Annotated[Sequence[BaseMessage], add_messages] | ||
| 45 | + current_image_path: Optional[str] # Track uploaded image | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +class ShoppingAgent: | ||
| 49 | + """True ReAct agent with autonomous decision making""" | ||
| 50 | + | ||
| 51 | + def __init__(self, session_id: Optional[str] = None): | ||
| 52 | + self.session_id = session_id or "default" | ||
| 53 | + | ||
| 54 | + # Initialize LLM | ||
| 55 | + self.llm = ChatOpenAI( | ||
| 56 | + model=settings.openai_model, | ||
| 57 | + temperature=settings.openai_temperature, | ||
| 58 | + api_key=settings.openai_api_key, | ||
| 59 | + ) | ||
| 60 | + | ||
| 61 | + # Get tools and bind to model | ||
| 62 | + self.tools = get_all_tools() | ||
| 63 | + self.llm_with_tools = self.llm.bind_tools(self.tools) | ||
| 64 | + | ||
| 65 | + # Build graph | ||
| 66 | + self.graph = self._build_graph() | ||
| 67 | + | ||
| 68 | + logger.info(f"Shopping agent initialized for session: {self.session_id}") | ||
| 69 | + | ||
| 70 | + def _build_graph(self): | ||
| 71 | + """Build the LangGraph StateGraph""" | ||
| 72 | + | ||
| 73 | + # System prompt for the agent | ||
| 74 | + system_prompt = """You are an intelligent fashion shopping assistant. You can: | ||
| 75 | +1. Search for products by text description (use search_products) | ||
| 76 | +2. Find visually similar products from images (use search_by_image) | ||
| 77 | +3. Analyze image style and attributes (use analyze_image_style) | ||
| 78 | + | ||
| 79 | +When a user asks about products: | ||
| 80 | +- For text queries: use search_products directly | ||
| 81 | +- For image uploads: decide if you need to analyze_image_style first, then search | ||
| 82 | +- You can call multiple tools in sequence if needed | ||
| 83 | +- Always provide helpful, friendly responses | ||
| 84 | + | ||
| 85 | +CRITICAL FORMATTING RULES: | ||
| 86 | +When presenting product results, you MUST use this EXACT format for EACH product: | ||
| 87 | + | ||
| 88 | +1. [Product Name] | ||
| 89 | + ID: [Product ID Number] | ||
| 90 | + Category: [Category] | ||
| 91 | + Color: [Color] | ||
| 92 | + Gender: [Gender] | ||
| 93 | + (Include Season, Usage, Relevance if available) | ||
| 94 | + | ||
| 95 | +Example: | ||
| 96 | +1. Puma Men White 3/4 Length Pants | ||
| 97 | + ID: 12345 | ||
| 98 | + Category: Apparel > Bottomwear > Track Pants | ||
| 99 | + Color: White | ||
| 100 | + Gender: Men | ||
| 101 | + Season: Summer | ||
| 102 | + Usage: Sports | ||
| 103 | + Relevance: 95.2% | ||
| 104 | + | ||
| 105 | +DO NOT skip the ID field! It is essential for displaying product images. | ||
| 106 | +Be conversational in your introduction, but preserve the exact product format.""" | ||
| 107 | + | ||
| 108 | + def agent_node(state: AgentState): | ||
| 109 | + """Agent decision node - decides which tools to call or when to respond""" | ||
| 110 | + messages = state["messages"] | ||
| 111 | + | ||
| 112 | + # Add system prompt if first message | ||
| 113 | + if not any(isinstance(m, SystemMessage) for m in messages): | ||
| 114 | + messages = [SystemMessage(content=system_prompt)] + list(messages) | ||
| 115 | + | ||
| 116 | + # Handle image context | ||
| 117 | + if state.get("current_image_path"): | ||
| 118 | + # Inject image path context for tool calls | ||
| 119 | + # The agent can reference this in its reasoning | ||
| 120 | + pass | ||
| 121 | + | ||
| 122 | + # Invoke LLM with tools | ||
| 123 | + response = self.llm_with_tools.invoke(messages) | ||
| 124 | + return {"messages": [response]} | ||
| 125 | + | ||
| 126 | + # Create tool node | ||
| 127 | + tool_node = ToolNode(self.tools) | ||
| 128 | + | ||
| 129 | + def should_continue(state: AgentState): | ||
| 130 | + """Determine if agent should continue or end""" | ||
| 131 | + messages = state["messages"] | ||
| 132 | + last_message = messages[-1] | ||
| 133 | + | ||
| 134 | + # If LLM made tool calls, continue to tools | ||
| 135 | + if hasattr(last_message, "tool_calls") and last_message.tool_calls: | ||
| 136 | + return "tools" | ||
| 137 | + # Otherwise, end (agent has final response) | ||
| 138 | + return END | ||
| 139 | + | ||
| 140 | + # Build graph | ||
| 141 | + workflow = StateGraph(AgentState) | ||
| 142 | + | ||
| 143 | + workflow.add_node("agent", agent_node) | ||
| 144 | + workflow.add_node("tools", tool_node) | ||
| 145 | + | ||
| 146 | + workflow.add_edge(START, "agent") | ||
| 147 | + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) | ||
| 148 | + workflow.add_edge("tools", "agent") | ||
| 149 | + | ||
| 150 | + # Compile with memory | ||
| 151 | + checkpointer = MemorySaver() | ||
| 152 | + return workflow.compile(checkpointer=checkpointer) | ||
| 153 | + | ||
| 154 | + def chat(self, query: str, image_path: Optional[str] = None) -> dict: | ||
| 155 | + """Process user query with the agent | ||
| 156 | + | ||
| 157 | + Args: | ||
| 158 | + query: User's text query | ||
| 159 | + image_path: Optional path to uploaded image | ||
| 160 | + | ||
| 161 | + Returns: | ||
| 162 | + Dict with response and metadata | ||
| 163 | + """ | ||
| 164 | + try: | ||
| 165 | + logger.info( | ||
| 166 | + f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})" | ||
| 167 | + ) | ||
| 168 | + | ||
| 169 | + # Validate image | ||
| 170 | + if image_path and not Path(image_path).exists(): | ||
| 171 | + return { | ||
| 172 | + "response": f"Error: Image file not found at '{image_path}'", | ||
| 173 | + "error": True, | ||
| 174 | + } | ||
| 175 | + | ||
| 176 | + # Build input message | ||
| 177 | + message_content = query | ||
| 178 | + if image_path: | ||
| 179 | + message_content = f"{query}\n[User uploaded image: {image_path}]" | ||
| 180 | + | ||
| 181 | + # Invoke agent | ||
| 182 | + config = {"configurable": {"thread_id": self.session_id}} | ||
| 183 | + input_state = { | ||
| 184 | + "messages": [HumanMessage(content=message_content)], | ||
| 185 | + "current_image_path": image_path, | ||
| 186 | + } | ||
| 187 | + | ||
| 188 | + # Track tool calls | ||
| 189 | + tool_calls = [] | ||
| 190 | + | ||
| 191 | + # Stream events to capture tool calls | ||
| 192 | + for event in self.graph.stream(input_state, config=config): | ||
| 193 | + logger.info(f"Event: {event}") | ||
| 194 | + | ||
| 195 | + # Check for agent node (tool calls) | ||
| 196 | + if "agent" in event: | ||
| 197 | + agent_output = event["agent"] | ||
| 198 | + if "messages" in agent_output: | ||
| 199 | + for msg in agent_output["messages"]: | ||
| 200 | + if hasattr(msg, "tool_calls") and msg.tool_calls: | ||
| 201 | + for tc in msg.tool_calls: | ||
| 202 | + tool_calls.append({ | ||
| 203 | + "name": tc["name"], | ||
| 204 | + "args": tc.get("args", {}), | ||
| 205 | + }) | ||
| 206 | + | ||
| 207 | + # Check for tool node (tool results) | ||
| 208 | + if "tools" in event: | ||
| 209 | + tools_output = event["tools"] | ||
| 210 | + if "messages" in tools_output: | ||
| 211 | + for i, msg in enumerate(tools_output["messages"]): | ||
| 212 | + if i < len(tool_calls): | ||
| 213 | + tool_calls[i]["result"] = str(msg.content)[:200] + "..." | ||
| 214 | + | ||
| 215 | + # Get final state | ||
| 216 | + final_state = self.graph.get_state(config) | ||
| 217 | + final_message = final_state.values["messages"][-1] | ||
| 218 | + response_text = _extract_message_text(final_message) | ||
| 219 | + | ||
| 220 | + logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls") | ||
| 221 | + | ||
| 222 | + return { | ||
| 223 | + "response": response_text, | ||
| 224 | + "tool_calls": tool_calls, | ||
| 225 | + "error": False, | ||
| 226 | + } | ||
| 227 | + | ||
| 228 | + except Exception as e: | ||
| 229 | + logger.error(f"Error in agent chat: {e}", exc_info=True) | ||
| 230 | + return { | ||
| 231 | + "response": f"I apologize, I encountered an error: {str(e)}", | ||
| 232 | + "error": True, | ||
| 233 | + } | ||
| 234 | + | ||
| 235 | + def get_conversation_history(self) -> list: | ||
| 236 | + """Get conversation history for this session""" | ||
| 237 | + try: | ||
| 238 | + config = {"configurable": {"thread_id": self.session_id}} | ||
| 239 | + state = self.graph.get_state(config) | ||
| 240 | + | ||
| 241 | + if not state or not state.values.get("messages"): | ||
| 242 | + return [] | ||
| 243 | + | ||
| 244 | + messages = state.values["messages"] | ||
| 245 | + result = [] | ||
| 246 | + | ||
| 247 | + for msg in messages: | ||
| 248 | + # Skip system messages and tool messages | ||
| 249 | + if isinstance(msg, SystemMessage): | ||
| 250 | + continue | ||
| 251 | + if hasattr(msg, "type") and msg.type in ["system", "tool"]: | ||
| 252 | + continue | ||
| 253 | + | ||
| 254 | + role = "user" if msg.type == "human" else "assistant" | ||
| 255 | + result.append({"role": role, "content": _extract_message_text(msg)}) | ||
| 256 | + | ||
| 257 | + return result | ||
| 258 | + | ||
| 259 | + except Exception as e: | ||
| 260 | + logger.error(f"Error getting history: {e}") | ||
| 261 | + return [] | ||
| 262 | + | ||
| 263 | + def clear_history(self): | ||
| 264 | + """Clear conversation history for this session""" | ||
| 265 | + # With MemorySaver, we can't easily clear, but we can log | ||
| 266 | + logger.info(f"[{self.session_id}] History clear requested") | ||
| 267 | + # In production, implement proper clearing or use new thread_id | ||
| 268 | + | ||
| 269 | + | ||
| 270 | +def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: | ||
| 271 | + """Factory function to create a shopping agent""" | ||
| 272 | + return ShoppingAgent(session_id=session_id) |
| 1 | +++ a/app/config.py | ||
| @@ -0,0 +1,86 @@ | @@ -0,0 +1,86 @@ | ||
| 1 | +""" | ||
| 2 | +Configuration management for OmniShopAgent | ||
| 3 | +Loads environment variables and provides configuration objects | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import os | ||
| 7 | + | ||
| 8 | +from pydantic_settings import BaseSettings | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +class Settings(BaseSettings): | ||
| 12 | + """Application settings loaded from environment variables | ||
| 13 | + | ||
| 14 | + All settings can be configured via .env file or environment variables. | ||
| 15 | + """ | ||
| 16 | + | ||
| 17 | + # OpenAI Configuration | ||
| 18 | + openai_api_key: str | ||
| 19 | + openai_model: str = "gpt-4o-mini" | ||
| 20 | + openai_embedding_model: str = "text-embedding-3-small" | ||
| 21 | + openai_temperature: float = 0.7 | ||
| 22 | + openai_max_tokens: int = 1000 | ||
| 23 | + | ||
| 24 | + # CLIP Server Configuration | ||
| 25 | + clip_server_url: str = "grpc://localhost:51000" | ||
| 26 | + | ||
| 27 | + # Milvus Configuration | ||
| 28 | + milvus_uri: str = "http://localhost:19530" | ||
| 29 | + milvus_host: str = "localhost" | ||
| 30 | + milvus_port: int = 19530 | ||
| 31 | + text_collection_name: str = "text_embeddings" | ||
| 32 | + image_collection_name: str = "image_embeddings" | ||
| 33 | + text_dim: int = 1536 | ||
| 34 | + image_dim: int = 512 | ||
| 35 | + | ||
| 36 | + @property | ||
| 37 | + def milvus_uri_absolute(self) -> str: | ||
| 38 | + """Get absolute path for Milvus URI | ||
| 39 | + | ||
| 40 | + Returns: | ||
| 41 | + - For http/https URIs: returns as-is (Milvus Standalone) | ||
| 42 | + - For file paths starting with ./: converts to absolute path (Milvus Lite) | ||
| 43 | + - For other paths: returns as-is | ||
| 44 | + """ | ||
| 45 | + import os | ||
| 46 | + | ||
| 47 | + # If it's a network URI, return as-is (Milvus Standalone) | ||
| 48 | + if self.milvus_uri.startswith(("http://", "https://")): | ||
| 49 | + return self.milvus_uri | ||
| 50 | + # If it's a relative path, convert to absolute (Milvus Lite) | ||
| 51 | + if self.milvus_uri.startswith("./"): | ||
| 52 | + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
| 53 | + return os.path.join(base_dir, self.milvus_uri[2:]) | ||
| 54 | + # Otherwise return as-is | ||
| 55 | + return self.milvus_uri | ||
| 56 | + | ||
| 57 | + # Search Configuration | ||
| 58 | + top_k_results: int = 10 | ||
| 59 | + similarity_threshold: float = 0.6 | ||
| 60 | + | ||
| 61 | + # Application Configuration | ||
| 62 | + app_host: str = "0.0.0.0" | ||
| 63 | + app_port: int = 8000 | ||
| 64 | + debug: bool = True | ||
| 65 | + log_level: str = "INFO" | ||
| 66 | + | ||
| 67 | + # Data Paths | ||
| 68 | + raw_data_path: str = "./data/raw" | ||
| 69 | + processed_data_path: str = "./data/processed" | ||
| 70 | + image_data_path: str = "./data/images" | ||
| 71 | + | ||
| 72 | + class Config: | ||
| 73 | + env_file = ".env" | ||
| 74 | + env_file_encoding = "utf-8" | ||
| 75 | + case_sensitive = False | ||
| 76 | + | ||
| 77 | + | ||
| 78 | +# Global settings instance | ||
| 79 | +settings = Settings() | ||
| 80 | + | ||
| 81 | + | ||
| 82 | +# Helper function to get absolute paths | ||
| 83 | +def get_absolute_path(relative_path: str) -> str: | ||
| 84 | + """Convert relative path to absolute path""" | ||
| 85 | + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
| 86 | + return os.path.join(base_dir, relative_path) |
| 1 | +++ a/app/services/__init__.py | ||
| @@ -0,0 +1,14 @@ | @@ -0,0 +1,14 @@ | ||
| 1 | +""" | ||
| 2 | +Services Module | ||
| 3 | +Provides database and embedding services for the application | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +from app.services.embedding_service import EmbeddingService, get_embedding_service | ||
| 7 | +from app.services.milvus_service import MilvusService, get_milvus_service | ||
| 8 | + | ||
| 9 | +__all__ = [ | ||
| 10 | + "EmbeddingService", | ||
| 11 | + "get_embedding_service", | ||
| 12 | + "MilvusService", | ||
| 13 | + "get_milvus_service", | ||
| 14 | +] |
| 1 | +++ a/app/services/embedding_service.py | ||
| @@ -0,0 +1,293 @@ | @@ -0,0 +1,293 @@ | ||
| 1 | +""" | ||
| 2 | +Embedding Service for Text and Image Embeddings | ||
| 3 | +Supports OpenAI text embeddings and CLIP image embeddings | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import logging | ||
| 7 | +from pathlib import Path | ||
| 8 | +from typing import List, Optional, Union | ||
| 9 | + | ||
| 10 | +import numpy as np | ||
| 11 | +from clip_client import Client as ClipClient | ||
| 12 | +from openai import OpenAI | ||
| 13 | + | ||
| 14 | +from app.config import settings | ||
| 15 | + | ||
| 16 | +logger = logging.getLogger(__name__) | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class EmbeddingService: | ||
| 20 | + """Service for generating text and image embeddings""" | ||
| 21 | + | ||
| 22 | + def __init__( | ||
| 23 | + self, | ||
| 24 | + openai_api_key: Optional[str] = None, | ||
| 25 | + clip_server_url: Optional[str] = None, | ||
| 26 | + ): | ||
| 27 | + """Initialize embedding service | ||
| 28 | + | ||
| 29 | + Args: | ||
| 30 | + openai_api_key: OpenAI API key. If None, uses settings.openai_api_key | ||
| 31 | + clip_server_url: CLIP server URL. If None, uses settings.clip_server_url | ||
| 32 | + """ | ||
| 33 | + # Initialize OpenAI client for text embeddings | ||
| 34 | + self.openai_api_key = openai_api_key or settings.openai_api_key | ||
| 35 | + self.openai_client = OpenAI(api_key=self.openai_api_key) | ||
| 36 | + self.text_embedding_model = settings.openai_embedding_model | ||
| 37 | + | ||
| 38 | + # Initialize CLIP client for image embeddings | ||
| 39 | + self.clip_server_url = clip_server_url or settings.clip_server_url | ||
| 40 | + self.clip_client: Optional[ClipClient] = None | ||
| 41 | + | ||
| 42 | + logger.info("Embedding service initialized") | ||
| 43 | + | ||
| 44 | + def connect_clip(self) -> None: | ||
| 45 | + """Connect to CLIP server""" | ||
| 46 | + try: | ||
| 47 | + self.clip_client = ClipClient(server=self.clip_server_url) | ||
| 48 | + logger.info(f"Connected to CLIP server at {self.clip_server_url}") | ||
| 49 | + except Exception as e: | ||
| 50 | + logger.error(f"Failed to connect to CLIP server: {e}") | ||
| 51 | + raise | ||
| 52 | + | ||
| 53 | + def disconnect_clip(self) -> None: | ||
| 54 | + """Disconnect from CLIP server""" | ||
| 55 | + if self.clip_client: | ||
| 56 | + # Note: clip_client doesn't have explicit close method | ||
| 57 | + self.clip_client = None | ||
| 58 | + logger.info("Disconnected from CLIP server") | ||
| 59 | + | ||
| 60 | + def get_text_embedding(self, text: str) -> List[float]: | ||
| 61 | + """Get embedding for a single text | ||
| 62 | + | ||
| 63 | + Args: | ||
| 64 | + text: Input text | ||
| 65 | + | ||
| 66 | + Returns: | ||
| 67 | + Embedding vector as list of floats | ||
| 68 | + """ | ||
| 69 | + try: | ||
| 70 | + response = self.openai_client.embeddings.create( | ||
| 71 | + input=text, model=self.text_embedding_model | ||
| 72 | + ) | ||
| 73 | + embedding = response.data[0].embedding | ||
| 74 | + logger.debug(f"Generated text embedding for: {text[:50]}...") | ||
| 75 | + return embedding | ||
| 76 | + except Exception as e: | ||
| 77 | + logger.error(f"Failed to generate text embedding: {e}") | ||
| 78 | + raise | ||
| 79 | + | ||
| 80 | + def get_text_embeddings_batch( | ||
| 81 | + self, texts: List[str], batch_size: int = 100 | ||
| 82 | + ) -> List[List[float]]: | ||
| 83 | + """Get embeddings for multiple texts in batches | ||
| 84 | + | ||
| 85 | + Args: | ||
| 86 | + texts: List of input texts | ||
| 87 | + batch_size: Number of texts to process at once | ||
| 88 | + | ||
| 89 | + Returns: | ||
| 90 | + List of embedding vectors | ||
| 91 | + """ | ||
| 92 | + all_embeddings = [] | ||
| 93 | + | ||
| 94 | + for i in range(0, len(texts), batch_size): | ||
| 95 | + batch = texts[i : i + batch_size] | ||
| 96 | + | ||
| 97 | + try: | ||
| 98 | + response = self.openai_client.embeddings.create( | ||
| 99 | + input=batch, model=self.text_embedding_model | ||
| 100 | + ) | ||
| 101 | + | ||
| 102 | + # Extract embeddings in the correct order | ||
| 103 | + embeddings = [item.embedding for item in response.data] | ||
| 104 | + all_embeddings.extend(embeddings) | ||
| 105 | + | ||
| 106 | + logger.info( | ||
| 107 | + f"Generated text embeddings for batch {i // batch_size + 1}: {len(embeddings)} embeddings" | ||
| 108 | + ) | ||
| 109 | + | ||
| 110 | + except Exception as e: | ||
| 111 | + logger.error( | ||
| 112 | + f"Failed to generate text embeddings for batch {i // batch_size + 1}: {e}" | ||
| 113 | + ) | ||
| 114 | + raise | ||
| 115 | + | ||
| 116 | + return all_embeddings | ||
| 117 | + | ||
| 118 | + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: | ||
| 119 | + """Get CLIP embedding for a single image | ||
| 120 | + | ||
| 121 | + Args: | ||
| 122 | + image_path: Path to image file | ||
| 123 | + | ||
| 124 | + Returns: | ||
| 125 | + Embedding vector as list of floats | ||
| 126 | + """ | ||
| 127 | + if not self.clip_client: | ||
| 128 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | ||
| 129 | + | ||
| 130 | + image_path = Path(image_path) | ||
| 131 | + if not image_path.exists(): | ||
| 132 | + raise FileNotFoundError(f"Image not found: {image_path}") | ||
| 133 | + | ||
| 134 | + try: | ||
| 135 | + # Get embedding from CLIP server using image path (as string) | ||
| 136 | + result = self.clip_client.encode([str(image_path)]) | ||
| 137 | + | ||
| 138 | + # Extract embedding - result is numpy array | ||
| 139 | + import numpy as np | ||
| 140 | + | ||
| 141 | + if isinstance(result, np.ndarray): | ||
| 142 | + # If result is numpy array, use first element | ||
| 143 | + embedding = ( | ||
| 144 | + result[0].tolist() if len(result.shape) > 1 else result.tolist() | ||
| 145 | + ) | ||
| 146 | + else: | ||
| 147 | + # If result is DocumentArray | ||
| 148 | + embedding = result[0].embedding.tolist() | ||
| 149 | + | ||
| 150 | + logger.debug(f"Generated image embedding for: {image_path.name}") | ||
| 151 | + return embedding | ||
| 152 | + | ||
| 153 | + except Exception as e: | ||
| 154 | + logger.error(f"Failed to generate image embedding for {image_path}: {e}") | ||
| 155 | + raise | ||
| 156 | + | ||
| 157 | + def get_image_embeddings_batch( | ||
| 158 | + self, image_paths: List[Union[str, Path]], batch_size: int = 32 | ||
| 159 | + ) -> List[Optional[List[float]]]: | ||
| 160 | + """Get CLIP embeddings for multiple images in batches | ||
| 161 | + | ||
| 162 | + Args: | ||
| 163 | + image_paths: List of paths to image files | ||
| 164 | + batch_size: Number of images to process at once | ||
| 165 | + | ||
| 166 | + Returns: | ||
| 167 | + List of embedding vectors (None for failed images) | ||
| 168 | + """ | ||
| 169 | + if not self.clip_client: | ||
| 170 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | ||
| 171 | + | ||
| 172 | + all_embeddings = [] | ||
| 173 | + | ||
| 174 | + for i in range(0, len(image_paths), batch_size): | ||
| 175 | + batch_paths = image_paths[i : i + batch_size] | ||
| 176 | + valid_paths = [] | ||
| 177 | + valid_indices = [] | ||
| 178 | + | ||
| 179 | + # Check which images exist | ||
| 180 | + for idx, path in enumerate(batch_paths): | ||
| 181 | + path = Path(path) | ||
| 182 | + if path.exists(): | ||
| 183 | + valid_paths.append(str(path)) | ||
| 184 | + valid_indices.append(idx) | ||
| 185 | + else: | ||
| 186 | + logger.warning(f"Image not found: {path}") | ||
| 187 | + | ||
| 188 | + # Get embeddings for valid images | ||
| 189 | + if valid_paths: | ||
| 190 | + try: | ||
| 191 | + # Send paths as strings to CLIP server | ||
| 192 | + result = self.clip_client.encode(valid_paths) | ||
| 193 | + | ||
| 194 | + # Create embeddings list with None for missing images | ||
| 195 | + batch_embeddings = [None] * len(batch_paths) | ||
| 196 | + | ||
| 197 | + # Handle result format - could be numpy array or DocumentArray | ||
| 198 | + import numpy as np | ||
| 199 | + | ||
| 200 | + if isinstance(result, np.ndarray): | ||
| 201 | + # Result is numpy array - shape (n_images, embedding_dim) | ||
| 202 | + for idx in range(len(result)): | ||
| 203 | + original_idx = valid_indices[idx] | ||
| 204 | + batch_embeddings[original_idx] = result[idx].tolist() | ||
| 205 | + else: | ||
| 206 | + # Result is DocumentArray | ||
| 207 | + for idx, doc in enumerate(result): | ||
| 208 | + original_idx = valid_indices[idx] | ||
| 209 | + batch_embeddings[original_idx] = doc.embedding.tolist() | ||
| 210 | + | ||
| 211 | + all_embeddings.extend(batch_embeddings) | ||
| 212 | + | ||
| 213 | + logger.info( | ||
| 214 | + f"Generated image embeddings for batch {i // batch_size + 1}: " | ||
| 215 | + f"{len(valid_paths)}/{len(batch_paths)} successful" | ||
| 216 | + ) | ||
| 217 | + | ||
| 218 | + except Exception as e: | ||
| 219 | + logger.error( | ||
| 220 | + f"Failed to generate image embeddings for batch {i // batch_size + 1}: {e}" | ||
| 221 | + ) | ||
| 222 | + # Add None for all images in failed batch | ||
| 223 | + all_embeddings.extend([None] * len(batch_paths)) | ||
| 224 | + else: | ||
| 225 | + # All images in batch failed to load | ||
| 226 | + all_embeddings.extend([None] * len(batch_paths)) | ||
| 227 | + | ||
| 228 | + return all_embeddings | ||
| 229 | + | ||
| 230 | + def get_text_embedding_from_image( | ||
| 231 | + self, image_path: Union[str, Path] | ||
| 232 | + ) -> List[float]: | ||
| 233 | + """Get text-based embedding by describing the image | ||
| 234 | + This is useful for cross-modal search | ||
| 235 | + | ||
| 236 | + Note: This is a placeholder for future implementation | ||
| 237 | + that could use vision models to generate text descriptions | ||
| 238 | + | ||
| 239 | + Args: | ||
| 240 | + image_path: Path to image file | ||
| 241 | + | ||
| 242 | + Returns: | ||
| 243 | + Text embedding vector | ||
| 244 | + """ | ||
| 245 | + # For now, we just return the image embedding | ||
| 246 | + # In the future, this could use a vision-language model to generate | ||
| 247 | + # a text description and then embed that | ||
| 248 | + raise NotImplementedError("Text embedding from image not yet implemented") | ||
| 249 | + | ||
| 250 | + def cosine_similarity( | ||
| 251 | + self, embedding1: List[float], embedding2: List[float] | ||
| 252 | + ) -> float: | ||
| 253 | + """Calculate cosine similarity between two embeddings | ||
| 254 | + | ||
| 255 | + Args: | ||
| 256 | + embedding1: First embedding vector | ||
| 257 | + embedding2: Second embedding vector | ||
| 258 | + | ||
| 259 | + Returns: | ||
| 260 | + Cosine similarity score (0-1) | ||
| 261 | + """ | ||
| 262 | + vec1 = np.array(embedding1) | ||
| 263 | + vec2 = np.array(embedding2) | ||
| 264 | + | ||
| 265 | + # Normalize vectors | ||
| 266 | + vec1_norm = vec1 / np.linalg.norm(vec1) | ||
| 267 | + vec2_norm = vec2 / np.linalg.norm(vec2) | ||
| 268 | + | ||
| 269 | + # Calculate cosine similarity | ||
| 270 | + similarity = np.dot(vec1_norm, vec2_norm) | ||
| 271 | + | ||
| 272 | + return float(similarity) | ||
| 273 | + | ||
| 274 | + def get_embedding_dimensions(self) -> dict: | ||
| 275 | + """Get the dimensions of text and image embeddings | ||
| 276 | + | ||
| 277 | + Returns: | ||
| 278 | + Dictionary with text_dim and image_dim | ||
| 279 | + """ | ||
| 280 | + return {"text_dim": settings.text_dim, "image_dim": settings.image_dim} | ||
| 281 | + | ||
| 282 | + | ||
| 283 | +# Global instance | ||
| 284 | +_embedding_service: Optional[EmbeddingService] = None | ||
| 285 | + | ||
| 286 | + | ||
| 287 | +def get_embedding_service() -> EmbeddingService: | ||
| 288 | + """Get or create the global embedding service instance""" | ||
| 289 | + global _embedding_service | ||
| 290 | + if _embedding_service is None: | ||
| 291 | + _embedding_service = EmbeddingService() | ||
| 292 | + _embedding_service.connect_clip() | ||
| 293 | + return _embedding_service |
| 1 | +++ a/app/services/milvus_service.py | ||
| @@ -0,0 +1,480 @@ | @@ -0,0 +1,480 @@ | ||
| 1 | +""" | ||
| 2 | +Milvus Service for Vector Storage and Similarity Search | ||
| 3 | +Manages text and image embeddings in separate collections | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import logging | ||
| 7 | +from typing import Any, Dict, List, Optional | ||
| 8 | + | ||
| 9 | +from pymilvus import ( | ||
| 10 | + DataType, | ||
| 11 | + MilvusClient, | ||
| 12 | +) | ||
| 13 | + | ||
| 14 | +from app.config import settings | ||
| 15 | + | ||
| 16 | +logger = logging.getLogger(__name__) | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class MilvusService: | ||
| 20 | + """Service for managing vector embeddings in Milvus""" | ||
| 21 | + | ||
| 22 | + def __init__(self, uri: Optional[str] = None): | ||
| 23 | + """Initialize Milvus service | ||
| 24 | + | ||
| 25 | + Args: | ||
| 26 | + uri: Milvus connection URI. If None, uses settings.milvus_uri | ||
| 27 | + """ | ||
| 28 | + if uri: | ||
| 29 | + self.uri = uri | ||
| 30 | + else: | ||
| 31 | + # Use absolute path for Milvus Lite | ||
| 32 | + self.uri = settings.milvus_uri_absolute | ||
| 33 | + self.text_collection_name = settings.text_collection_name | ||
| 34 | + self.image_collection_name = settings.image_collection_name | ||
| 35 | + self.text_dim = settings.text_dim | ||
| 36 | + self.image_dim = settings.image_dim | ||
| 37 | + | ||
| 38 | + # Use MilvusClient for simplified operations | ||
| 39 | + self._client: Optional[MilvusClient] = None | ||
| 40 | + | ||
| 41 | + logger.info(f"Initializing Milvus service with URI: {self.uri}") | ||
| 42 | + | ||
| 43 | + def is_connected(self) -> bool: | ||
| 44 | + """Check if connected to Milvus""" | ||
| 45 | + return self._client is not None | ||
| 46 | + | ||
| 47 | + def connect(self) -> None: | ||
| 48 | + """Connect to Milvus""" | ||
| 49 | + if self.is_connected(): | ||
| 50 | + return | ||
| 51 | + try: | ||
| 52 | + self._client = MilvusClient(uri=self.uri) | ||
| 53 | + logger.info(f"Connected to Milvus at {self.uri}") | ||
| 54 | + except Exception as e: | ||
| 55 | + logger.error(f"Failed to connect to Milvus: {e}") | ||
| 56 | + raise | ||
| 57 | + | ||
| 58 | + def disconnect(self) -> None: | ||
| 59 | + """Disconnect from Milvus""" | ||
| 60 | + if self._client: | ||
| 61 | + self._client.close() | ||
| 62 | + self._client = None | ||
| 63 | + logger.info("Disconnected from Milvus") | ||
| 64 | + | ||
| 65 | + @property | ||
| 66 | + def client(self) -> MilvusClient: | ||
| 67 | + """Get the Milvus client""" | ||
| 68 | + if not self._client: | ||
| 69 | + raise RuntimeError("Milvus not connected. Call connect() first.") | ||
| 70 | + return self._client | ||
| 71 | + | ||
| 72 | + def create_text_collection(self, recreate: bool = False) -> None: | ||
| 73 | + """Create collection for text embeddings with product metadata | ||
| 74 | + | ||
| 75 | + Args: | ||
| 76 | + recreate: If True, drop existing collection and recreate | ||
| 77 | + """ | ||
| 78 | + if recreate and self.client.has_collection(self.text_collection_name): | ||
| 79 | + self.client.drop_collection(self.text_collection_name) | ||
| 80 | + logger.info(f"Dropped existing collection: {self.text_collection_name}") | ||
| 81 | + | ||
| 82 | + if self.client.has_collection(self.text_collection_name): | ||
| 83 | + logger.info(f"Text collection already exists: {self.text_collection_name}") | ||
| 84 | + return | ||
| 85 | + | ||
| 86 | + # Create collection with schema (includes metadata fields) | ||
| 87 | + schema = MilvusClient.create_schema( | ||
| 88 | + auto_id=False, | ||
| 89 | + enable_dynamic_field=True, # Allow additional metadata fields | ||
| 90 | + ) | ||
| 91 | + | ||
| 92 | + # Core fields | ||
| 93 | + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | ||
| 94 | + schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) | ||
| 95 | + schema.add_field( | ||
| 96 | + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim | ||
| 97 | + ) | ||
| 98 | + | ||
| 99 | + # Product metadata fields | ||
| 100 | + schema.add_field( | ||
| 101 | + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 | ||
| 102 | + ) | ||
| 103 | + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | ||
| 104 | + schema.add_field( | ||
| 105 | + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 | ||
| 106 | + ) | ||
| 107 | + schema.add_field( | ||
| 108 | + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 | ||
| 109 | + ) | ||
| 110 | + schema.add_field( | ||
| 111 | + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 | ||
| 112 | + ) | ||
| 113 | + schema.add_field( | ||
| 114 | + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 | ||
| 115 | + ) | ||
| 116 | + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) | ||
| 117 | + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) | ||
| 118 | + | ||
| 119 | + # Create index parameters | ||
| 120 | + index_params = self.client.prepare_index_params() | ||
| 121 | + index_params.add_index( | ||
| 122 | + field_name="embedding", | ||
| 123 | + index_type="AUTOINDEX", | ||
| 124 | + metric_type="COSINE", | ||
| 125 | + ) | ||
| 126 | + | ||
| 127 | + # Create collection | ||
| 128 | + self.client.create_collection( | ||
| 129 | + collection_name=self.text_collection_name, | ||
| 130 | + schema=schema, | ||
| 131 | + index_params=index_params, | ||
| 132 | + ) | ||
| 133 | + | ||
| 134 | + logger.info( | ||
| 135 | + f"Created text collection with metadata: {self.text_collection_name}" | ||
| 136 | + ) | ||
| 137 | + | ||
| 138 | + def create_image_collection(self, recreate: bool = False) -> None: | ||
| 139 | + """Create collection for image embeddings with product metadata | ||
| 140 | + | ||
| 141 | + Args: | ||
| 142 | + recreate: If True, drop existing collection and recreate | ||
| 143 | + """ | ||
| 144 | + if recreate and self.client.has_collection(self.image_collection_name): | ||
| 145 | + self.client.drop_collection(self.image_collection_name) | ||
| 146 | + logger.info(f"Dropped existing collection: {self.image_collection_name}") | ||
| 147 | + | ||
| 148 | + if self.client.has_collection(self.image_collection_name): | ||
| 149 | + logger.info( | ||
| 150 | + f"Image collection already exists: {self.image_collection_name}" | ||
| 151 | + ) | ||
| 152 | + return | ||
| 153 | + | ||
| 154 | + # Create collection with schema (includes metadata fields) | ||
| 155 | + schema = MilvusClient.create_schema( | ||
| 156 | + auto_id=False, | ||
| 157 | + enable_dynamic_field=True, # Allow additional metadata fields | ||
| 158 | + ) | ||
| 159 | + | ||
| 160 | + # Core fields | ||
| 161 | + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | ||
| 162 | + schema.add_field( | ||
| 163 | + field_name="image_path", datatype=DataType.VARCHAR, max_length=500 | ||
| 164 | + ) | ||
| 165 | + schema.add_field( | ||
| 166 | + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim | ||
| 167 | + ) | ||
| 168 | + | ||
| 169 | + # Product metadata fields | ||
| 170 | + schema.add_field( | ||
| 171 | + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 | ||
| 172 | + ) | ||
| 173 | + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | ||
| 174 | + schema.add_field( | ||
| 175 | + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 | ||
| 176 | + ) | ||
| 177 | + schema.add_field( | ||
| 178 | + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 | ||
| 179 | + ) | ||
| 180 | + schema.add_field( | ||
| 181 | + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 | ||
| 182 | + ) | ||
| 183 | + schema.add_field( | ||
| 184 | + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 | ||
| 185 | + ) | ||
| 186 | + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) | ||
| 187 | + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) | ||
| 188 | + | ||
| 189 | + # Create index parameters | ||
| 190 | + index_params = self.client.prepare_index_params() | ||
| 191 | + index_params.add_index( | ||
| 192 | + field_name="embedding", | ||
| 193 | + index_type="AUTOINDEX", | ||
| 194 | + metric_type="COSINE", | ||
| 195 | + ) | ||
| 196 | + | ||
| 197 | + # Create collection | ||
| 198 | + self.client.create_collection( | ||
| 199 | + collection_name=self.image_collection_name, | ||
| 200 | + schema=schema, | ||
| 201 | + index_params=index_params, | ||
| 202 | + ) | ||
| 203 | + | ||
| 204 | + logger.info( | ||
| 205 | + f"Created image collection with metadata: {self.image_collection_name}" | ||
| 206 | + ) | ||
| 207 | + | ||
| 208 | + def insert_text_embeddings( | ||
| 209 | + self, | ||
| 210 | + embeddings: List[Dict[str, Any]], | ||
| 211 | + ) -> int: | ||
| 212 | + """Insert text embeddings with metadata into collection | ||
| 213 | + | ||
| 214 | + Args: | ||
| 215 | + embeddings: List of dictionaries with keys: | ||
| 216 | + - id: unique ID (product ID) | ||
| 217 | + - text: the text that was embedded | ||
| 218 | + - embedding: the embedding vector | ||
| 219 | + - productDisplayName, gender, masterCategory, etc. (metadata) | ||
| 220 | + | ||
| 221 | + Returns: | ||
| 222 | + Number of inserted embeddings | ||
| 223 | + """ | ||
| 224 | + if not embeddings: | ||
| 225 | + return 0 | ||
| 226 | + | ||
| 227 | + try: | ||
| 228 | + # Insert data directly (all fields including metadata) | ||
| 229 | + # Milvus will accept all fields defined in schema + dynamic fields | ||
| 230 | + data = embeddings | ||
| 231 | + | ||
| 232 | + # Insert data | ||
| 233 | + result = self.client.insert( | ||
| 234 | + collection_name=self.text_collection_name, | ||
| 235 | + data=data, | ||
| 236 | + ) | ||
| 237 | + | ||
| 238 | + logger.info(f"Inserted {len(data)} text embeddings") | ||
| 239 | + return len(data) | ||
| 240 | + | ||
| 241 | + except Exception as e: | ||
| 242 | + logger.error(f"Failed to insert text embeddings: {e}") | ||
| 243 | + raise | ||
| 244 | + | ||
| 245 | + def insert_image_embeddings( | ||
| 246 | + self, | ||
| 247 | + embeddings: List[Dict[str, Any]], | ||
| 248 | + ) -> int: | ||
| 249 | + """Insert image embeddings with metadata into collection | ||
| 250 | + | ||
| 251 | + Args: | ||
| 252 | + embeddings: List of dictionaries with keys: | ||
| 253 | + - id: unique ID (product ID) | ||
| 254 | + - image_path: path to the image file | ||
| 255 | + - embedding: the embedding vector | ||
| 256 | + - productDisplayName, gender, masterCategory, etc. (metadata) | ||
| 257 | + | ||
| 258 | + Returns: | ||
| 259 | + Number of inserted embeddings | ||
| 260 | + """ | ||
| 261 | + if not embeddings: | ||
| 262 | + return 0 | ||
| 263 | + | ||
| 264 | + try: | ||
| 265 | + # Insert data directly (all fields including metadata) | ||
| 266 | + # Milvus will accept all fields defined in schema + dynamic fields | ||
| 267 | + data = embeddings | ||
| 268 | + | ||
| 269 | + # Insert data | ||
| 270 | + result = self.client.insert( | ||
| 271 | + collection_name=self.image_collection_name, | ||
| 272 | + data=data, | ||
| 273 | + ) | ||
| 274 | + | ||
| 275 | + logger.info(f"Inserted {len(data)} image embeddings") | ||
| 276 | + return len(data) | ||
| 277 | + | ||
| 278 | + except Exception as e: | ||
| 279 | + logger.error(f"Failed to insert image embeddings: {e}") | ||
| 280 | + raise | ||
| 281 | + | ||
| 282 | + def search_similar_text( | ||
| 283 | + self, | ||
| 284 | + query_embedding: List[float], | ||
| 285 | + limit: int = 10, | ||
| 286 | + filters: Optional[str] = None, | ||
| 287 | + output_fields: Optional[List[str]] = None, | ||
| 288 | + ) -> List[Dict[str, Any]]: | ||
| 289 | + """Search for similar text embeddings | ||
| 290 | + | ||
| 291 | + Args: | ||
| 292 | + query_embedding: Query embedding vector | ||
| 293 | + limit: Maximum number of results | ||
| 294 | + filters: Filter expression (e.g., "product_id in [1, 2, 3]") | ||
| 295 | + output_fields: List of fields to return | ||
| 296 | + | ||
| 297 | + Returns: | ||
| 298 | + List of search results with fields: | ||
| 299 | + - id: embedding ID | ||
| 300 | + - distance: similarity distance | ||
| 301 | + - entity: the matched entity with requested fields | ||
| 302 | + """ | ||
| 303 | + try: | ||
| 304 | + if output_fields is None: | ||
| 305 | + output_fields = [ | ||
| 306 | + "id", | ||
| 307 | + "text", | ||
| 308 | + "productDisplayName", | ||
| 309 | + "gender", | ||
| 310 | + "masterCategory", | ||
| 311 | + "subCategory", | ||
| 312 | + "articleType", | ||
| 313 | + "baseColour", | ||
| 314 | + ] | ||
| 315 | + | ||
| 316 | + search_params = {} | ||
| 317 | + if filters: | ||
| 318 | + search_params["expr"] = filters | ||
| 319 | + | ||
| 320 | + results = self.client.search( | ||
| 321 | + collection_name=self.text_collection_name, | ||
| 322 | + data=[query_embedding], | ||
| 323 | + limit=limit, | ||
| 324 | + output_fields=output_fields, | ||
| 325 | + search_params=search_params, | ||
| 326 | + ) | ||
| 327 | + | ||
| 328 | + # Format results | ||
| 329 | + formatted_results = [] | ||
| 330 | + if results and len(results) > 0: | ||
| 331 | + for hit in results[0]: | ||
| 332 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | ||
| 333 | + # Extract fields from entity | ||
| 334 | + entity = hit.get("entity", {}) | ||
| 335 | + for field in output_fields: | ||
| 336 | + if field in entity: | ||
| 337 | + result[field] = entity.get(field) | ||
| 338 | + formatted_results.append(result) | ||
| 339 | + | ||
| 340 | + logger.debug(f"Found {len(formatted_results)} similar text embeddings") | ||
| 341 | + return formatted_results | ||
| 342 | + | ||
| 343 | + except Exception as e: | ||
| 344 | + logger.error(f"Failed to search similar text: {e}") | ||
| 345 | + raise | ||
| 346 | + | ||
| 347 | + def search_similar_images( | ||
| 348 | + self, | ||
| 349 | + query_embedding: List[float], | ||
| 350 | + limit: int = 10, | ||
| 351 | + filters: Optional[str] = None, | ||
| 352 | + output_fields: Optional[List[str]] = None, | ||
| 353 | + ) -> List[Dict[str, Any]]: | ||
| 354 | + """Search for similar image embeddings | ||
| 355 | + | ||
| 356 | + Args: | ||
| 357 | + query_embedding: Query embedding vector | ||
| 358 | + limit: Maximum number of results | ||
| 359 | + filters: Filter expression (e.g., "product_id in [1, 2, 3]") | ||
| 360 | + output_fields: List of fields to return | ||
| 361 | + | ||
| 362 | + Returns: | ||
| 363 | + List of search results with fields: | ||
| 364 | + - id: embedding ID | ||
| 365 | + - distance: similarity distance | ||
| 366 | + - entity: the matched entity with requested fields | ||
| 367 | + """ | ||
| 368 | + try: | ||
| 369 | + if output_fields is None: | ||
| 370 | + output_fields = [ | ||
| 371 | + "id", | ||
| 372 | + "image_path", | ||
| 373 | + "productDisplayName", | ||
| 374 | + "gender", | ||
| 375 | + "masterCategory", | ||
| 376 | + "subCategory", | ||
| 377 | + "articleType", | ||
| 378 | + "baseColour", | ||
| 379 | + ] | ||
| 380 | + | ||
| 381 | + search_params = {} | ||
| 382 | + if filters: | ||
| 383 | + search_params["expr"] = filters | ||
| 384 | + | ||
| 385 | + results = self.client.search( | ||
| 386 | + collection_name=self.image_collection_name, | ||
| 387 | + data=[query_embedding], | ||
| 388 | + limit=limit, | ||
| 389 | + output_fields=output_fields, | ||
| 390 | + search_params=search_params, | ||
| 391 | + ) | ||
| 392 | + | ||
| 393 | + # Format results | ||
| 394 | + formatted_results = [] | ||
| 395 | + if results and len(results) > 0: | ||
| 396 | + for hit in results[0]: | ||
| 397 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | ||
| 398 | + # Extract fields from entity | ||
| 399 | + entity = hit.get("entity", {}) | ||
| 400 | + for field in output_fields: | ||
| 401 | + if field in entity: | ||
| 402 | + result[field] = entity.get(field) | ||
| 403 | + formatted_results.append(result) | ||
| 404 | + | ||
| 405 | + logger.debug(f"Found {len(formatted_results)} similar image embeddings") | ||
| 406 | + return formatted_results | ||
| 407 | + | ||
| 408 | + except Exception as e: | ||
| 409 | + logger.error(f"Failed to search similar images: {e}") | ||
| 410 | + raise | ||
| 411 | + | ||
| 412 | + def get_collection_stats(self, collection_name: str) -> Dict[str, Any]: | ||
| 413 | + """Get statistics for a collection | ||
| 414 | + | ||
| 415 | + Args: | ||
| 416 | + collection_name: Name of the collection | ||
| 417 | + | ||
| 418 | + Returns: | ||
| 419 | + Dictionary with collection statistics | ||
| 420 | + """ | ||
| 421 | + try: | ||
| 422 | + stats = self.client.get_collection_stats(collection_name) | ||
| 423 | + return { | ||
| 424 | + "collection_name": collection_name, | ||
| 425 | + "row_count": stats.get("row_count", 0), | ||
| 426 | + } | ||
| 427 | + except Exception as e: | ||
| 428 | + logger.error(f"Failed to get collection stats: {e}") | ||
| 429 | + return {"collection_name": collection_name, "row_count": 0} | ||
| 430 | + | ||
| 431 | + def delete_by_ids(self, collection_name: str, ids: List[int]) -> int: | ||
| 432 | + """Delete embeddings by IDs | ||
| 433 | + | ||
| 434 | + Args: | ||
| 435 | + collection_name: Name of the collection | ||
| 436 | + ids: List of IDs to delete | ||
| 437 | + | ||
| 438 | + Returns: | ||
| 439 | + Number of deleted embeddings | ||
| 440 | + """ | ||
| 441 | + if not ids: | ||
| 442 | + return 0 | ||
| 443 | + | ||
| 444 | + try: | ||
| 445 | + self.client.delete( | ||
| 446 | + collection_name=collection_name, | ||
| 447 | + ids=ids, | ||
| 448 | + ) | ||
| 449 | + logger.info(f"Deleted {len(ids)} embeddings from {collection_name}") | ||
| 450 | + return len(ids) | ||
| 451 | + except Exception as e: | ||
| 452 | + logger.error(f"Failed to delete embeddings: {e}") | ||
| 453 | + raise | ||
| 454 | + | ||
| 455 | + def clear_collection(self, collection_name: str) -> None: | ||
| 456 | + """Clear all data from a collection | ||
| 457 | + | ||
| 458 | + Args: | ||
| 459 | + collection_name: Name of the collection | ||
| 460 | + """ | ||
| 461 | + try: | ||
| 462 | + if self.client.has_collection(collection_name): | ||
| 463 | + self.client.drop_collection(collection_name) | ||
| 464 | + logger.info(f"Dropped collection: {collection_name}") | ||
| 465 | + except Exception as e: | ||
| 466 | + logger.error(f"Failed to clear collection: {e}") | ||
| 467 | + raise | ||
| 468 | + | ||
| 469 | + | ||
| 470 | +# Global instance | ||
| 471 | +_milvus_service: Optional[MilvusService] = None | ||
| 472 | + | ||
| 473 | + | ||
| 474 | +def get_milvus_service() -> MilvusService: | ||
| 475 | + """Get or create the global Milvus service instance""" | ||
| 476 | + global _milvus_service | ||
| 477 | + if _milvus_service is None: | ||
| 478 | + _milvus_service = MilvusService() | ||
| 479 | + _milvus_service.connect() | ||
| 480 | + return _milvus_service |
| 1 | +++ a/app/tools/__init__.py | ||
| @@ -0,0 +1,17 @@ | @@ -0,0 +1,17 @@ | ||
| 1 | +""" | ||
| 2 | +LangChain Tools for Product Search and Discovery | ||
| 3 | +""" | ||
| 4 | + | ||
| 5 | +from app.tools.search_tools import ( | ||
| 6 | + analyze_image_style, | ||
| 7 | + get_all_tools, | ||
| 8 | + search_by_image, | ||
| 9 | + search_products, | ||
| 10 | +) | ||
| 11 | + | ||
| 12 | +__all__ = [ | ||
| 13 | + "search_products", | ||
| 14 | + "search_by_image", | ||
| 15 | + "analyze_image_style", | ||
| 16 | + "get_all_tools", | ||
| 17 | +] |
| 1 | +++ a/app/tools/search_tools.py | ||
| @@ -0,0 +1,294 @@ | @@ -0,0 +1,294 @@ | ||
| 1 | +""" | ||
| 2 | +Search Tools for Product Discovery | ||
| 3 | +Provides text-based, image-based, and VLM reasoning capabilities | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import base64 | ||
| 7 | +import logging | ||
| 8 | +from pathlib import Path | ||
| 9 | +from typing import Optional | ||
| 10 | + | ||
| 11 | +from langchain_core.tools import tool | ||
| 12 | +from openai import OpenAI | ||
| 13 | + | ||
| 14 | +from app.config import settings | ||
| 15 | +from app.services.embedding_service import EmbeddingService | ||
| 16 | +from app.services.milvus_service import MilvusService | ||
| 17 | + | ||
| 18 | +logger = logging.getLogger(__name__) | ||
| 19 | + | ||
| 20 | +# Initialize services as singletons | ||
| 21 | +_embedding_service: Optional[EmbeddingService] = None | ||
| 22 | +_milvus_service: Optional[MilvusService] = None | ||
| 23 | +_openai_client: Optional[OpenAI] = None | ||
| 24 | + | ||
| 25 | + | ||
| 26 | +def get_embedding_service() -> EmbeddingService: | ||
| 27 | + global _embedding_service | ||
| 28 | + if _embedding_service is None: | ||
| 29 | + _embedding_service = EmbeddingService() | ||
| 30 | + return _embedding_service | ||
| 31 | + | ||
| 32 | + | ||
| 33 | +def get_milvus_service() -> MilvusService: | ||
| 34 | + global _milvus_service | ||
| 35 | + if _milvus_service is None: | ||
| 36 | + _milvus_service = MilvusService() | ||
| 37 | + _milvus_service.connect() | ||
| 38 | + return _milvus_service | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +def get_openai_client() -> OpenAI: | ||
| 42 | + global _openai_client | ||
| 43 | + if _openai_client is None: | ||
| 44 | + _openai_client = OpenAI(api_key=settings.openai_api_key) | ||
| 45 | + return _openai_client | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +@tool | ||
| 49 | +def search_products(query: str, limit: int = 5) -> str: | ||
| 50 | + """Search for fashion products using natural language descriptions. | ||
| 51 | + | ||
| 52 | + Use when users describe what they want: | ||
| 53 | + - "Find me red summer dresses" | ||
| 54 | + - "Show me blue running shoes" | ||
| 55 | + - "I want casual shirts for men" | ||
| 56 | + | ||
| 57 | + Args: | ||
| 58 | + query: Natural language product description | ||
| 59 | + limit: Maximum number of results (1-20) | ||
| 60 | + | ||
| 61 | + Returns: | ||
| 62 | + Formatted string with product information | ||
| 63 | + """ | ||
| 64 | + try: | ||
| 65 | + logger.info(f"Searching products: '{query}', limit: {limit}") | ||
| 66 | + | ||
| 67 | + embedding_service = get_embedding_service() | ||
| 68 | + milvus_service = get_milvus_service() | ||
| 69 | + | ||
| 70 | + if not milvus_service.is_connected(): | ||
| 71 | + milvus_service.connect() | ||
| 72 | + | ||
| 73 | + query_embedding = embedding_service.get_text_embedding(query) | ||
| 74 | + | ||
| 75 | + results = milvus_service.search_similar_text( | ||
| 76 | + query_embedding=query_embedding, | ||
| 77 | + limit=min(limit, 20), | ||
| 78 | + filters=None, | ||
| 79 | + output_fields=[ | ||
| 80 | + "id", | ||
| 81 | + "productDisplayName", | ||
| 82 | + "gender", | ||
| 83 | + "masterCategory", | ||
| 84 | + "subCategory", | ||
| 85 | + "articleType", | ||
| 86 | + "baseColour", | ||
| 87 | + "season", | ||
| 88 | + "usage", | ||
| 89 | + ], | ||
| 90 | + ) | ||
| 91 | + | ||
| 92 | + if not results: | ||
| 93 | + return "No products found matching your search." | ||
| 94 | + | ||
| 95 | + output = f"Found {len(results)} product(s):\n\n" | ||
| 96 | + | ||
| 97 | + for idx, product in enumerate(results, 1): | ||
| 98 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | ||
| 99 | + output += f" ID: {product.get('id', 'N/A')}\n" | ||
| 100 | + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" | ||
| 101 | + output += f" Color: {product.get('baseColour', 'N/A')}\n" | ||
| 102 | + output += f" Gender: {product.get('gender', 'N/A')}\n" | ||
| 103 | + | ||
| 104 | + if product.get("season"): | ||
| 105 | + output += f" Season: {product.get('season')}\n" | ||
| 106 | + if product.get("usage"): | ||
| 107 | + output += f" Usage: {product.get('usage')}\n" | ||
| 108 | + | ||
| 109 | + if "distance" in product: | ||
| 110 | + similarity = 1 - product["distance"] | ||
| 111 | + output += f" Relevance: {similarity:.2%}\n" | ||
| 112 | + | ||
| 113 | + output += "\n" | ||
| 114 | + | ||
| 115 | + return output.strip() | ||
| 116 | + | ||
| 117 | + except Exception as e: | ||
| 118 | + logger.error(f"Error searching products: {e}", exc_info=True) | ||
| 119 | + return f"Error searching products: {str(e)}" | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +@tool | ||
| 123 | +def search_by_image(image_path: str, limit: int = 5) -> str: | ||
| 124 | + """Find similar fashion products using an image. | ||
| 125 | + | ||
| 126 | + Use when users want visually similar items: | ||
| 127 | + - User uploads an image and asks "find similar items" | ||
| 128 | + - "Show me products that look like this" | ||
| 129 | + | ||
| 130 | + Args: | ||
| 131 | + image_path: Path to the image file | ||
| 132 | + limit: Maximum number of results (1-20) | ||
| 133 | + | ||
| 134 | + Returns: | ||
| 135 | + Formatted string with similar products | ||
| 136 | + """ | ||
| 137 | + try: | ||
| 138 | + logger.info(f"Image search: '{image_path}', limit: {limit}") | ||
| 139 | + | ||
| 140 | + img_path = Path(image_path) | ||
| 141 | + if not img_path.exists(): | ||
| 142 | + return f"Error: Image file not found at '{image_path}'" | ||
| 143 | + | ||
| 144 | + embedding_service = get_embedding_service() | ||
| 145 | + milvus_service = get_milvus_service() | ||
| 146 | + | ||
| 147 | + if not milvus_service.is_connected(): | ||
| 148 | + milvus_service.connect() | ||
| 149 | + | ||
| 150 | + if ( | ||
| 151 | + not hasattr(embedding_service, "clip_client") | ||
| 152 | + or embedding_service.clip_client is None | ||
| 153 | + ): | ||
| 154 | + embedding_service.connect_clip() | ||
| 155 | + | ||
| 156 | + image_embedding = embedding_service.get_image_embedding(image_path) | ||
| 157 | + | ||
| 158 | + if image_embedding is None: | ||
| 159 | + return "Error: Failed to generate embedding for image" | ||
| 160 | + | ||
| 161 | + results = milvus_service.search_similar_images( | ||
| 162 | + query_embedding=image_embedding, | ||
| 163 | + limit=min(limit + 1, 21), | ||
| 164 | + filters=None, | ||
| 165 | + output_fields=[ | ||
| 166 | + "id", | ||
| 167 | + "image_path", | ||
| 168 | + "productDisplayName", | ||
| 169 | + "gender", | ||
| 170 | + "masterCategory", | ||
| 171 | + "subCategory", | ||
| 172 | + "articleType", | ||
| 173 | + "baseColour", | ||
| 174 | + "season", | ||
| 175 | + "usage", | ||
| 176 | + ], | ||
| 177 | + ) | ||
| 178 | + | ||
| 179 | + if not results: | ||
| 180 | + return "No similar products found." | ||
| 181 | + | ||
| 182 | + # Filter out the query image itself | ||
| 183 | + query_id = img_path.stem | ||
| 184 | + filtered_results = [] | ||
| 185 | + for result in results: | ||
| 186 | + result_path = result.get("image_path", "") | ||
| 187 | + if Path(result_path).stem != query_id: | ||
| 188 | + filtered_results.append(result) | ||
| 189 | + if len(filtered_results) >= limit: | ||
| 190 | + break | ||
| 191 | + | ||
| 192 | + if not filtered_results: | ||
| 193 | + return "No similar products found." | ||
| 194 | + | ||
| 195 | + output = f"Found {len(filtered_results)} visually similar product(s):\n\n" | ||
| 196 | + | ||
| 197 | + for idx, product in enumerate(filtered_results, 1): | ||
| 198 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | ||
| 199 | + output += f" ID: {product.get('id', 'N/A')}\n" | ||
| 200 | + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" | ||
| 201 | + output += f" Color: {product.get('baseColour', 'N/A')}\n" | ||
| 202 | + output += f" Gender: {product.get('gender', 'N/A')}\n" | ||
| 203 | + | ||
| 204 | + if product.get("season"): | ||
| 205 | + output += f" Season: {product.get('season')}\n" | ||
| 206 | + if product.get("usage"): | ||
| 207 | + output += f" Usage: {product.get('usage')}\n" | ||
| 208 | + | ||
| 209 | + if "distance" in product: | ||
| 210 | + similarity = 1 - product["distance"] | ||
| 211 | + output += f" Visual Similarity: {similarity:.2%}\n" | ||
| 212 | + | ||
| 213 | + output += "\n" | ||
| 214 | + | ||
| 215 | + return output.strip() | ||
| 216 | + | ||
| 217 | + except Exception as e: | ||
| 218 | + logger.error(f"Error in image search: {e}", exc_info=True) | ||
| 219 | + return f"Error searching by image: {str(e)}" | ||
| 220 | + | ||
| 221 | + | ||
| 222 | +@tool | ||
| 223 | +def analyze_image_style(image_path: str) -> str: | ||
| 224 | + """Analyze a fashion product image using AI vision to extract detailed style information. | ||
| 225 | + | ||
| 226 | + Use when you need to understand style/attributes from an image: | ||
| 227 | + - Understand the style, color, pattern of a product | ||
| 228 | + - Extract attributes like "casual", "formal", "vintage" | ||
| 229 | + - Get detailed descriptions for subsequent searches | ||
| 230 | + | ||
| 231 | + Args: | ||
| 232 | + image_path: Path to the image file | ||
| 233 | + | ||
| 234 | + Returns: | ||
| 235 | + Detailed text description of the product's visual attributes | ||
| 236 | + """ | ||
| 237 | + try: | ||
| 238 | + logger.info(f"Analyzing image with VLM: '{image_path}'") | ||
| 239 | + | ||
| 240 | + img_path = Path(image_path) | ||
| 241 | + if not img_path.exists(): | ||
| 242 | + return f"Error: Image file not found at '{image_path}'" | ||
| 243 | + | ||
| 244 | + with open(img_path, "rb") as image_file: | ||
| 245 | + image_data = base64.b64encode(image_file.read()).decode("utf-8") | ||
| 246 | + | ||
| 247 | + prompt = """Analyze this fashion product image and provide a detailed description. | ||
| 248 | + | ||
| 249 | +Include: | ||
| 250 | +- Product type (e.g., shirt, dress, shoes, pants, bag) | ||
| 251 | +- Primary colors | ||
| 252 | +- Style/design (e.g., casual, formal, sporty, vintage, modern) | ||
| 253 | +- Pattern or texture (e.g., plain, striped, checked, floral) | ||
| 254 | +- Key features (e.g., collar type, sleeve length, fit) | ||
| 255 | +- Material appearance (if obvious, e.g., denim, cotton, leather) | ||
| 256 | +- Suitable occasion (e.g., office wear, party, casual, sports) | ||
| 257 | + | ||
| 258 | +Provide a comprehensive yet concise description (3-4 sentences).""" | ||
| 259 | + | ||
| 260 | + client = get_openai_client() | ||
| 261 | + response = client.chat.completions.create( | ||
| 262 | + model="gpt-4o-mini", | ||
| 263 | + messages=[ | ||
| 264 | + { | ||
| 265 | + "role": "user", | ||
| 266 | + "content": [ | ||
| 267 | + {"type": "text", "text": prompt}, | ||
| 268 | + { | ||
| 269 | + "type": "image_url", | ||
| 270 | + "image_url": { | ||
| 271 | + "url": f"data:image/jpeg;base64,{image_data}", | ||
| 272 | + "detail": "high", | ||
| 273 | + }, | ||
| 274 | + }, | ||
| 275 | + ], | ||
| 276 | + } | ||
| 277 | + ], | ||
| 278 | + max_tokens=500, | ||
| 279 | + temperature=0.3, | ||
| 280 | + ) | ||
| 281 | + | ||
| 282 | + analysis = response.choices[0].message.content.strip() | ||
| 283 | + logger.info("VLM analysis completed") | ||
| 284 | + | ||
| 285 | + return analysis | ||
| 286 | + | ||
| 287 | + except Exception as e: | ||
| 288 | + logger.error(f"Error analyzing image: {e}", exc_info=True) | ||
| 289 | + return f"Error analyzing image: {str(e)}" | ||
| 290 | + | ||
| 291 | + | ||
| 292 | +def get_all_tools(): | ||
| 293 | + """Get all available tools for the agent""" | ||
| 294 | + return [search_products, search_by_image, analyze_image_style] |
No preview for this file type
| 1 | +++ a/docker-compose.yml | ||
| @@ -0,0 +1,76 @@ | @@ -0,0 +1,76 @@ | ||
| 1 | +version: '3.5' | ||
| 2 | + | ||
| 3 | +services: | ||
| 4 | + etcd: | ||
| 5 | + container_name: milvus-etcd | ||
| 6 | + image: quay.io/coreos/etcd:v3.5.5 | ||
| 7 | + environment: | ||
| 8 | + - ETCD_AUTO_COMPACTION_MODE=revision | ||
| 9 | + - ETCD_AUTO_COMPACTION_RETENTION=1000 | ||
| 10 | + - ETCD_QUOTA_BACKEND_BYTES=4294967296 | ||
| 11 | + - ETCD_SNAPSHOT_COUNT=50000 | ||
| 12 | + volumes: | ||
| 13 | + - ./volumes/etcd:/etcd | ||
| 14 | + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | ||
| 15 | + healthcheck: | ||
| 16 | + test: ["CMD", "etcdctl", "endpoint", "health"] | ||
| 17 | + interval: 30s | ||
| 18 | + timeout: 20s | ||
| 19 | + retries: 3 | ||
| 20 | + | ||
| 21 | + minio: | ||
| 22 | + container_name: milvus-minio | ||
| 23 | + image: minio/minio:RELEASE.2023-03-20T20-16-18Z | ||
| 24 | + environment: | ||
| 25 | + MINIO_ACCESS_KEY: minioadmin | ||
| 26 | + MINIO_SECRET_KEY: minioadmin | ||
| 27 | + ports: | ||
| 28 | + - "9001:9001" | ||
| 29 | + - "9000:9000" | ||
| 30 | + volumes: | ||
| 31 | + - ./volumes/minio:/minio_data | ||
| 32 | + command: minio server /minio_data --console-address ":9001" | ||
| 33 | + healthcheck: | ||
| 34 | + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] | ||
| 35 | + interval: 30s | ||
| 36 | + timeout: 20s | ||
| 37 | + retries: 3 | ||
| 38 | + | ||
| 39 | + standalone: | ||
| 40 | + container_name: milvus-standalone | ||
| 41 | + image: milvusdb/milvus:v2.4.0 | ||
| 42 | + command: ["milvus", "run", "standalone"] | ||
| 43 | + security_opt: | ||
| 44 | + - seccomp:unconfined | ||
| 45 | + environment: | ||
| 46 | + ETCD_ENDPOINTS: etcd:2379 | ||
| 47 | + MINIO_ADDRESS: minio:9000 | ||
| 48 | + volumes: | ||
| 49 | + - ./volumes/milvus:/var/lib/milvus | ||
| 50 | + healthcheck: | ||
| 51 | + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] | ||
| 52 | + interval: 30s | ||
| 53 | + start_period: 90s | ||
| 54 | + timeout: 20s | ||
| 55 | + retries: 3 | ||
| 56 | + ports: | ||
| 57 | + - "19530:19530" | ||
| 58 | + - "9091:9091" | ||
| 59 | + depends_on: | ||
| 60 | + - "etcd" | ||
| 61 | + - "minio" | ||
| 62 | + | ||
| 63 | + attu: | ||
| 64 | + container_name: milvus-attu | ||
| 65 | + image: zilliz/attu:v2.4 | ||
| 66 | + environment: | ||
| 67 | + MILVUS_URL: milvus-standalone:19530 | ||
| 68 | + ports: | ||
| 69 | + - "8000:3000" | ||
| 70 | + depends_on: | ||
| 71 | + - "standalone" | ||
| 72 | + | ||
| 73 | +networks: | ||
| 74 | + default: | ||
| 75 | + name: milvus | ||
| 76 | + |
| 1 | +++ a/docs/DEPLOY_CENTOS8.md | ||
| @@ -0,0 +1,216 @@ | @@ -0,0 +1,216 @@ | ||
| 1 | +# OmniShopAgent centOS 8 部署指南 | ||
| 2 | + | ||
| 3 | +## 一、环境要求 | ||
| 4 | + | ||
| 5 | +| 组件 | 要求 | | ||
| 6 | +|------|------| | ||
| 7 | +| 操作系统 | CentOS 8.x | | ||
| 8 | +| Python | 3.12+(LangChain 1.x 要求 3.10+) | | ||
| 9 | +| 内存 | 建议 8GB+(Milvus + CLIP 较占内存) | | ||
| 10 | +| 磁盘 | 建议 20GB+(含数据集) | | ||
| 11 | + | ||
| 12 | +## 二、快速部署步骤 | ||
| 13 | + | ||
| 14 | +### 2.1 一键环境准备(推荐) | ||
| 15 | + | ||
| 16 | +```bash | ||
| 17 | +cd /path/to/shop_agent | ||
| 18 | +chmod +x scripts/*.sh | ||
| 19 | +./scripts/setup_env_centos8.sh | ||
| 20 | +``` | ||
| 21 | + | ||
| 22 | +该脚本会: | ||
| 23 | +- 安装系统依赖(gcc、openssl-devel 等) | ||
| 24 | +- 安装 Docker(用于 Milvus) | ||
| 25 | +- 安装 Python 3.12(conda 或源码编译) | ||
| 26 | +- 创建虚拟环境并安装 requirements.txt | ||
| 27 | + | ||
| 28 | +### 2.2 手动部署(分步执行) | ||
| 29 | + | ||
| 30 | +#### 步骤 1:安装系统依赖 | ||
| 31 | + | ||
| 32 | +```bash | ||
| 33 | +sudo dnf install -y gcc gcc-c++ make openssl-devel bzip2-devel \ | ||
| 34 | + libffi-devel sqlite-devel xz-devel zlib-devel curl wget git | ||
| 35 | +``` | ||
| 36 | + | ||
| 37 | +#### 步骤 2:安装 Python 3.12 | ||
| 38 | + | ||
| 39 | +**方式 A:Miniconda(推荐)** | ||
| 40 | + | ||
| 41 | +```bash | ||
| 42 | +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | ||
| 43 | +bash Miniconda3-latest-Linux-x86_64.sh | ||
| 44 | +# 按提示安装后 | ||
| 45 | +conda create -n shop_agent python=3.12 | ||
| 46 | +conda activate shop_agent | ||
| 47 | +``` | ||
| 48 | + | ||
| 49 | +**方式 B:从源码编译** | ||
| 50 | + | ||
| 51 | +```bash | ||
| 52 | +sudo dnf groupinstall -y 'Development Tools' | ||
| 53 | +cd /tmp | ||
| 54 | +wget https://www.python.org/ftp/python/3.12.0/Python-3.12.0.tgz | ||
| 55 | +tar xzf Python-3.12.0.tgz | ||
| 56 | +cd Python-3.12.0 | ||
| 57 | +./configure --enable-optimizations --prefix=/usr/local | ||
| 58 | +make -j $(nproc) | ||
| 59 | +sudo make altinstall | ||
| 60 | +``` | ||
| 61 | + | ||
| 62 | +#### 步骤 3:安装 Docker | ||
| 63 | + | ||
| 64 | +```bash | ||
| 65 | +sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo | ||
| 66 | +sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin | ||
| 67 | +sudo systemctl enable docker && sudo systemctl start docker | ||
| 68 | +sudo usermod -aG docker $USER | ||
| 69 | +# 执行 newgrp docker 或重新登录 | ||
| 70 | +``` | ||
| 71 | + | ||
| 72 | +#### 步骤 4:创建虚拟环境并安装依赖 | ||
| 73 | + | ||
| 74 | +```bash | ||
| 75 | +cd /path/to/shop_agent | ||
| 76 | +python3.12 -m venv venv | ||
| 77 | +source venv/bin/activate | ||
| 78 | +pip install -U pip | ||
| 79 | +pip install -r requirements.txt | ||
| 80 | +``` | ||
| 81 | + | ||
| 82 | +#### 步骤 5:配置环境变量 | ||
| 83 | + | ||
| 84 | +```bash | ||
| 85 | +cp .env.example .env | ||
| 86 | +# 编辑 .env,至少配置: | ||
| 87 | +# OPENAI_API_KEY=sk-xxx | ||
| 88 | +# MILVUS_HOST=localhost | ||
| 89 | +# MILVUS_PORT=19530 | ||
| 90 | +# CLIP_SERVER_URL=grpc://localhost:51000 | ||
| 91 | +``` | ||
| 92 | + | ||
| 93 | +## 三、数据准备 | ||
| 94 | + | ||
| 95 | +### 3.1 下载数据集 | ||
| 96 | + | ||
| 97 | +```bash | ||
| 98 | +# 需先配置 Kaggle API:~/.kaggle/kaggle.json | ||
| 99 | +python scripts/download_dataset.py | ||
| 100 | +``` | ||
| 101 | + | ||
| 102 | +### 3.2 启动 Milvus 并索引数据 | ||
| 103 | + | ||
| 104 | +```bash | ||
| 105 | +# 启动 Milvus | ||
| 106 | +./scripts/run_milvus.sh | ||
| 107 | + | ||
| 108 | +# 等待就绪后,创建索引 | ||
| 109 | +python scripts/index_data.py | ||
| 110 | +``` | ||
| 111 | + | ||
| 112 | +## 四、启动服务 | ||
| 113 | + | ||
| 114 | +### 4.1 启动脚本说明 | ||
| 115 | + | ||
| 116 | +| 脚本 | 用途 | | ||
| 117 | +|------|------| | ||
| 118 | +| `start.sh` | 主启动脚本:启动 Milvus + Streamlit | | ||
| 119 | +| `stop.sh` | 停止所有服务 | | ||
| 120 | +| `run_milvus.sh` | 仅启动 Milvus | | ||
| 121 | +| `run_clip.sh` | 仅启动 CLIP(图像搜索需此服务) | | ||
| 122 | +| `check_services.sh` | 健康检查 | | ||
| 123 | + | ||
| 124 | +### 4.2 启动应用 | ||
| 125 | + | ||
| 126 | +```bash | ||
| 127 | +# 方式 1:使用 start.sh(推荐) | ||
| 128 | +./scripts/start.sh | ||
| 129 | + | ||
| 130 | +# 方式 2:分步启动 | ||
| 131 | +# 终端 1:Milvus | ||
| 132 | +./scripts/run_milvus.sh | ||
| 133 | + | ||
| 134 | +# 终端 2:CLIP(图像搜索需要) | ||
| 135 | +./scripts/run_clip.sh | ||
| 136 | + | ||
| 137 | +# 终端 3:Streamlit | ||
| 138 | +source venv/bin/activate | ||
| 139 | +streamlit run app.py --server.port=8501 --server.address=0.0.0.0 | ||
| 140 | +``` | ||
| 141 | + | ||
| 142 | +### 4.3 访问地址 | ||
| 143 | + | ||
| 144 | +- **Streamlit 应用**:http://服务器IP:8501 | ||
| 145 | +- **Milvus Attu 管理界面**:http://服务器IP:8000 | ||
| 146 | + | ||
| 147 | +## 五、生产部署建议 | ||
| 148 | + | ||
| 149 | +### 5.1 使用 systemd 管理 Streamlit | ||
| 150 | + | ||
| 151 | +创建 `/etc/systemd/system/omishop-agent.service`: | ||
| 152 | + | ||
| 153 | +```ini | ||
| 154 | +[Unit] | ||
| 155 | +Description=OmniShopAgent Streamlit App | ||
| 156 | +After=network.target docker.service | ||
| 157 | + | ||
| 158 | +[Service] | ||
| 159 | +Type=simple | ||
| 160 | +User=your_user | ||
| 161 | +WorkingDirectory=/path/to/shop_agent | ||
| 162 | +Environment="PATH=/path/to/shop_agent/venv/bin" | ||
| 163 | +ExecStart=/path/to/shop_agent/venv/bin/streamlit run app.py --server.port=8501 --server.address=0.0.0.0 | ||
| 164 | +Restart=on-failure | ||
| 165 | + | ||
| 166 | +[Install] | ||
| 167 | +WantedBy=multi-user.target | ||
| 168 | +``` | ||
| 169 | + | ||
| 170 | +```bash | ||
| 171 | +sudo systemctl daemon-reload | ||
| 172 | +sudo systemctl enable omishop-agent | ||
| 173 | +sudo systemctl start omishop-agent | ||
| 174 | +``` | ||
| 175 | + | ||
| 176 | +### 5.2 使用 Nginx 反向代理(可选) | ||
| 177 | + | ||
| 178 | +```nginx | ||
| 179 | +server { | ||
| 180 | + listen 80; | ||
| 181 | + server_name your-domain.com; | ||
| 182 | + location / { | ||
| 183 | + proxy_pass http://127.0.0.1:8501; | ||
| 184 | + proxy_http_version 1.1; | ||
| 185 | + proxy_set_header Upgrade $http_upgrade; | ||
| 186 | + proxy_set_header Connection "upgrade"; | ||
| 187 | + proxy_set_header Host $host; | ||
| 188 | + proxy_set_header X-Real-IP $remote_addr; | ||
| 189 | + } | ||
| 190 | +} | ||
| 191 | +``` | ||
| 192 | + | ||
| 193 | +### 5.3 防火墙 | ||
| 194 | + | ||
| 195 | +```bash | ||
| 196 | +sudo firewall-cmd --permanent --add-port=8501/tcp | ||
| 197 | +sudo firewall-cmd --permanent --add-port=19530/tcp | ||
| 198 | +sudo firewall-cmd --reload | ||
| 199 | +``` | ||
| 200 | + | ||
| 201 | +## 六、常见问题 | ||
| 202 | + | ||
| 203 | +### Q: Python 3.12 编译失败? | ||
| 204 | +A: 确保已安装 `openssl-devel`、`libffi-devel`,或直接使用 Miniconda。 | ||
| 205 | + | ||
| 206 | +### Q: Docker 权限不足? | ||
| 207 | +A: 执行 `sudo usermod -aG docker $USER` 后重新登录。 | ||
| 208 | + | ||
| 209 | +### Q: Milvus 启动超时? | ||
| 210 | +A: 首次启动需拉取镜像,可能较慢。可检查 `docker compose logs -f standalone`。 | ||
| 211 | + | ||
| 212 | +### Q: 图像搜索不可用? | ||
| 213 | +A: 需单独启动 CLIP 服务:`./scripts/run_clip.sh`。 | ||
| 214 | + | ||
| 215 | +### Q: 健康检查? | ||
| 216 | +A: 执行 `./scripts/check_services.sh` 查看各组件状态。 |
| 1 | +++ a/docs/LANGCHAIN_1.0_MIGRATION.md | ||
| @@ -0,0 +1,77 @@ | @@ -0,0 +1,77 @@ | ||
| 1 | +# LangChain 1.0 升级说明 | ||
| 2 | + | ||
| 3 | +## 一、升级概览 | ||
| 4 | + | ||
| 5 | +本项目已完成从 LangChain 0.3 到 LangChain 1.x 的升级,并同步升级 LangGraph 至 1.x。升级后兼容 Python 3.12。 | ||
| 6 | + | ||
| 7 | +## 二、依赖变更 | ||
| 8 | + | ||
| 9 | +| 包 | 升级前 | 升级后 | | ||
| 10 | +|----|--------|--------| | ||
| 11 | +| langchain | >=0.3.0 | >=1.0.0 | | ||
| 12 | +| langchain-core | (间接依赖) | >=0.3.0 | | ||
| 13 | +| langchain-openai | >=0.2.0 | >=0.2.0 | | ||
| 14 | +| langgraph | >=0.2.74 | >=1.0.0 | | ||
| 15 | +| langchain-community | >=0.4.0 | **已移除**(项目未使用) | | ||
| 16 | + | ||
| 17 | +## 三、代码改造说明 | ||
| 18 | + | ||
| 19 | +### 3.1 保持不变的部分 | ||
| 20 | + | ||
| 21 | +项目采用 **自定义 StateGraph** 架构(非 `create_react_agent`),以下导入在 LangChain/LangGraph 1.x 中保持兼容: | ||
| 22 | + | ||
| 23 | +```python | ||
| 24 | +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | ||
| 25 | +from langchain_openai import ChatOpenAI | ||
| 26 | +from langgraph.checkpoint.memory import MemorySaver | ||
| 27 | +from langgraph.graph import END, START, StateGraph | ||
| 28 | +from langgraph.graph.message import add_messages | ||
| 29 | +from langgraph.prebuilt import ToolNode | ||
| 30 | +from langchain_core.tools import tool | ||
| 31 | +``` | ||
| 32 | + | ||
| 33 | +### 3.2 已适配的变更 | ||
| 34 | + | ||
| 35 | +1. **消息内容提取**:LangChain 1.0 引入 `content_blocks`,`content` 可能为字符串或 multimodal 列表。新增 `_extract_message_text()` 辅助函数,统一处理两种格式。 | ||
| 36 | +2. **依赖精简**:移除未使用的 `langchain-community`,减少依赖冲突。 | ||
| 37 | + | ||
| 38 | +### 3.3 LangChain 1.0 主要变更(参考) | ||
| 39 | + | ||
| 40 | +- **包命名空间精简**:核心功能移至 `langchain-core`,`langchain` 主包聚焦 Agent 构建 | ||
| 41 | +- **create_agent**:若未来迁移到 `langchain.agents.create_agent`,可参考 `docs/Skills实现方案-LangChain1.0.md` | ||
| 42 | +- **langchain-classic**:Legacy chains、Retrievers 等已迁移至 `langchain-classic`,本项目未使用 | ||
| 43 | + | ||
| 44 | +## 四、环境要求 | ||
| 45 | + | ||
| 46 | +- **Python**:3.12+(**LangChain 1.x 要求 Python 3.10+**,不支持 3.9 及以下) | ||
| 47 | +- 若系统默认 Python 版本过低,需使用虚拟环境: | ||
| 48 | + | ||
| 49 | +```bash | ||
| 50 | +# 方式 1:使用 conda(推荐,项目根目录 scripts/setup_conda_env.sh) | ||
| 51 | +conda create -n shop_agent python=3.12 | ||
| 52 | +conda activate shop_agent | ||
| 53 | +pip install -r requirements.txt | ||
| 54 | + | ||
| 55 | +# 方式 2:使用 pyenv | ||
| 56 | +pyenv install 3.12 | ||
| 57 | +pyenv local 3.12 | ||
| 58 | +pip install -r requirements.txt | ||
| 59 | + | ||
| 60 | +# 方式 3:使用 venv(需系统已安装 python3.12) | ||
| 61 | +python3.12 -m venv venv | ||
| 62 | +source venv/bin/activate # Linux/Mac | ||
| 63 | +pip install -r requirements.txt | ||
| 64 | +``` | ||
| 65 | + | ||
| 66 | +## 五、验证 | ||
| 67 | + | ||
| 68 | +```bash | ||
| 69 | +# 验证导入 | ||
| 70 | +python -c " | ||
| 71 | +from langchain_core.messages import HumanMessage | ||
| 72 | +from langchain_openai import ChatOpenAI | ||
| 73 | +from langgraph.graph import StateGraph | ||
| 74 | +from langgraph.prebuilt import ToolNode | ||
| 75 | +print('LangChain 1.x 依赖加载成功') | ||
| 76 | +" | ||
| 77 | +``` |
| 1 | +++ a/docs/Skills实现方案-LangChain1.0.md | ||
| @@ -0,0 +1,318 @@ | @@ -0,0 +1,318 @@ | ||
| 1 | +# Skills 渐进式展开实现方案(LangChain 1.0+) | ||
| 2 | + | ||
| 3 | +## 一、需求概述 | ||
| 4 | + | ||
| 5 | +用 **Skills** 替代零散的工具调用,实现**渐进式展开**(Progressive Disclosure): | ||
| 6 | +Agent 在 system prompt 中只看到技能摘要,按需加载详细技能内容,减少 token 消耗、提升扩展性。 | ||
| 7 | + | ||
| 8 | +| 技能 | 英文标识 | 职责 | | ||
| 9 | +|------|----------|------| | ||
| 10 | +| 查找相关商品 | lookup_related | 基于文本/图片查找相似或相关商品 | | ||
| 11 | +| 搜索商品 | search_products | 按自然语言描述搜索商品 | | ||
| 12 | +| 检验商品 | check_product | 检验商品是否符合用户要求 | | ||
| 13 | +| 结果包装 | result_packaging | 格式化、排序、筛选并呈现结果 | | ||
| 14 | +| 售后相关 | after_sales | 退换货、物流、保修等售后问题 | | ||
| 15 | + | ||
| 16 | +--- | ||
| 17 | + | ||
| 18 | +## 二、LangChain 1.0 中的 Skills 实现方式 | ||
| 19 | + | ||
| 20 | +### 2.1 两种实现路线 | ||
| 21 | + | ||
| 22 | +| 方式 | 适用场景 | 依赖 | | ||
| 23 | +|------|----------|------| | ||
| 24 | +| **方式 A:create_agent + 自定义 Skill 中间件** | 购物导购等业务 Agent | `langchain>=1.0`、`langgraph>=1.0` | | ||
| 25 | +| **方式 B:Deep Agents + SKILL.md** | 依赖文件系统、多技能目录 | `deepagents` | | ||
| 26 | + | ||
| 27 | +购物导购场景推荐**方式 A**,更易与现有 Milvus、CLIP 等服务集成。 | ||
| 28 | + | ||
| 29 | +### 2.2 核心思路:Progressive Disclosure | ||
| 30 | + | ||
| 31 | +``` | ||
| 32 | +用户请求 → Agent 看轻量描述 → 判断需要的技能 → load_skill → 拿到完整说明 → 执行工具 → 回复 | ||
| 33 | +``` | ||
| 34 | + | ||
| 35 | +- **启动时**:只注入技能名称 + 简短描述(1–2 句) | ||
| 36 | +- **按需加载**:Agent 调用 `load_skill(skill_name)` 获取完整指令 | ||
| 37 | +- **执行**:按技能说明调用对应工具 | ||
| 38 | + | ||
| 39 | +--- | ||
| 40 | + | ||
| 41 | +## 三、实现架构 | ||
| 42 | + | ||
| 43 | +### 3.1 技能定义结构 | ||
| 44 | + | ||
| 45 | +```python | ||
| 46 | +from typing import TypedDict | ||
| 47 | + | ||
| 48 | +class Skill(TypedDict): | ||
| 49 | + """可渐进式展开的技能""" | ||
| 50 | + name: str # 唯一标识 | ||
| 51 | + description: str # 1-2 句,展示在 system prompt | ||
| 52 | + content: str # 完整指令,仅在 load_skill 时返回 | ||
| 53 | +``` | ||
| 54 | + | ||
| 55 | +### 3.2 五个技能定义示例 | ||
| 56 | + | ||
| 57 | +```python | ||
| 58 | +SKILLS: list[Skill] = [ | ||
| 59 | + { | ||
| 60 | + "name": "lookup_related", | ||
| 61 | + "description": "查找与某商品相关的其他商品,支持以图搜图、文本相似、同品类推荐。", | ||
| 62 | + "content": """# 查找相关商品 | ||
| 63 | + | ||
| 64 | +## 适用场景 | ||
| 65 | +- 用户上传图片要求「找类似的」 | ||
| 66 | +- 用户说「和这个差不多」「搭配的裤子」 | ||
| 67 | +- 用户已有一件商品,想找相关款 | ||
| 68 | + | ||
| 69 | +## 操作步骤 | ||
| 70 | +1. **有图片**:先调用 `analyze_image_style` 理解风格,再调用 `search_by_image` 或 `search_products` | ||
| 71 | +2. **无图片**:用 `search_products` 描述品类+风格+颜色 | ||
| 72 | +3. 可结合上下文中的商品 ID、品类做同品类推荐 | ||
| 73 | + | ||
| 74 | +## 可用工具 | ||
| 75 | +- `search_by_image(image_path, limit)`:以图搜图 | ||
| 76 | +- `search_products(query, limit)`:文本搜索 | ||
| 77 | +- `analyze_image_style(image_path)`:分析图片风格""", | ||
| 78 | + }, | ||
| 79 | + { | ||
| 80 | + "name": "search_products", | ||
| 81 | + "description": "按自然语言描述搜索商品,如「红色连衣裙」「运动鞋」等。", | ||
| 82 | + "content": """# 搜索商品 | ||
| 83 | + | ||
| 84 | +## 适用场景 | ||
| 85 | +- 用户用文字描述想要什么 | ||
| 86 | +- 如「冬天穿的外套」「正装衬衫」「跑步鞋」 | ||
| 87 | + | ||
| 88 | +## 操作步骤 | ||
| 89 | +1. 将用户描述整理成结构化 query(品类+颜色+风格+场景) | ||
| 90 | +2. 调用 `search_products(query, limit)`,limit 默认 5–10 | ||
| 91 | +3. 如有图片,可先 `analyze_image_style` 提炼关键词再搜索 | ||
| 92 | + | ||
| 93 | +## 可用工具 | ||
| 94 | +- `search_products(query, limit)`:自然语言搜索""", | ||
| 95 | + }, | ||
| 96 | + { | ||
| 97 | + "name": "check_product", | ||
| 98 | + "description": "检验商品是否符合用户要求,如尺寸、材质、场合、价格区间等。", | ||
| 99 | + "content": """# 检验商品是否符合要求 | ||
| 100 | + | ||
| 101 | +## 适用场景 | ||
| 102 | +- 用户问「这款适合我吗」「有没有 XX 材质的」 | ||
| 103 | +- 用户提出约束:尺寸、价格、场合、材质 | ||
| 104 | + | ||
| 105 | +## 操作步骤 | ||
| 106 | +1. 从对话中提取约束条件(尺寸、材质、场合、价格等) | ||
| 107 | +2. 对已召回商品做筛选或二次搜索 | ||
| 108 | +3. 调用 `search_products` 时在 query 中带上约束 | ||
| 109 | +4. 回复时明确说明哪些符合、哪些不符合 | ||
| 110 | + | ||
| 111 | +## 注意 | ||
| 112 | +- 无专门工具时,用 search_products 的 query 表达约束 | ||
| 113 | +- 可结合商品元数据(baseColour, season, usage 等)做简单筛选""", | ||
| 114 | + }, | ||
| 115 | + { | ||
| 116 | + "name": "result_packaging", | ||
| 117 | + "description": "对搜索结果进行格式化、排序、筛选并呈现给用户。", | ||
| 118 | + "content": """# 结果包装 | ||
| 119 | + | ||
| 120 | +## 适用场景 | ||
| 121 | +- 工具返回多条商品后需要整理呈现 | ||
| 122 | +- 用户要求「按价格排序」「只要前 3 个」 | ||
| 123 | + | ||
| 124 | +## 操作步骤 | ||
| 125 | +1. 按相关性/相似度排序 | ||
| 126 | +2. 限制展示数量(通常 3–5 个) | ||
| 127 | +3. **必须使用以下格式**呈现每个商品: | ||
| 128 | + | ||
| 129 | +``` | ||
| 130 | +1. [Product Name] | ||
| 131 | + ID: [Product ID Number] | ||
| 132 | + Category: [Category] | ||
| 133 | + Color: [Color] | ||
| 134 | + Gender: [Gender] | ||
| 135 | + Season: [Season] | ||
| 136 | + Usage: [Usage] | ||
| 137 | + Relevance: [XX%] | ||
| 138 | +``` | ||
| 139 | + | ||
| 140 | +4. ID 字段不可省略,用于前端展示图片""", | ||
| 141 | + }, | ||
| 142 | + { | ||
| 143 | + "name": "after_sales", | ||
| 144 | + "description": "处理退换货、物流、保修、尺码建议等售后问题。", | ||
| 145 | + "content": """# 售后相关 | ||
| 146 | + | ||
| 147 | +## 适用场景 | ||
| 148 | +- 退换货政策、运费、签收时间 | ||
| 149 | +- 尺码建议、洗涤说明 | ||
| 150 | +- 保修、客服联系方式 | ||
| 151 | + | ||
| 152 | +## 操作步骤 | ||
| 153 | +1. 此类问题无需调用商品搜索工具 | ||
| 154 | +2. 按平台统一售后政策回答 | ||
| 155 | +3. 涉及具体商品时,可结合商品 ID 查询详情后再回答 | ||
| 156 | +4. 复杂问题引导用户联系客服""", | ||
| 157 | + }, | ||
| 158 | +] | ||
| 159 | +``` | ||
| 160 | + | ||
| 161 | +--- | ||
| 162 | + | ||
| 163 | +## 四、核心代码实现 | ||
| 164 | + | ||
| 165 | +### 4.1 load_skill 工具 | ||
| 166 | + | ||
| 167 | +```python | ||
| 168 | +from langchain.tools import tool | ||
| 169 | + | ||
| 170 | +@tool | ||
| 171 | +def load_skill(skill_name: str) -> str: | ||
| 172 | + """加载技能的完整内容到 Agent 上下文中。 | ||
| 173 | + | ||
| 174 | + 当需要处理特定类型请求时,调用此工具获取该技能的详细说明和操作步骤。 | ||
| 175 | + | ||
| 176 | + Args: | ||
| 177 | + skill_name: 技能名称,可选值:lookup_related, search_products, check_product, result_packaging, after_sales | ||
| 178 | + """ | ||
| 179 | + for skill in SKILLS: | ||
| 180 | + if skill["name"] == skill_name: | ||
| 181 | + return f"Loaded skill: {skill_name}\n\n{skill['content']}" | ||
| 182 | + | ||
| 183 | + available = ", ".join(s["name"] for s in SKILLS) | ||
| 184 | + return f"Skill '{skill_name}' not found. Available: {available}" | ||
| 185 | +``` | ||
| 186 | + | ||
| 187 | +### 4.2 SkillMiddleware(注入技能描述) | ||
| 188 | + | ||
| 189 | +```python | ||
| 190 | +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse | ||
| 191 | +from langchain.messages import SystemMessage | ||
| 192 | +from typing import Callable | ||
| 193 | + | ||
| 194 | +class ShoppingSkillMiddleware(AgentMiddleware): | ||
| 195 | + """将技能描述注入 system prompt,使 Agent 能发现并按需加载技能""" | ||
| 196 | + | ||
| 197 | + tools = [load_skill] | ||
| 198 | + | ||
| 199 | + def __init__(self): | ||
| 200 | + skills_list = [] | ||
| 201 | + for skill in SKILLS: | ||
| 202 | + skills_list.append(f"- **{skill['name']}**: {skill['description']}") | ||
| 203 | + self.skills_prompt = "\n".join(skills_list) | ||
| 204 | + | ||
| 205 | + def wrap_model_call( | ||
| 206 | + self, | ||
| 207 | + request: ModelRequest, | ||
| 208 | + handler: Callable[[ModelRequest], ModelResponse], | ||
| 209 | + ) -> ModelResponse: | ||
| 210 | + skills_addendum = ( | ||
| 211 | + f"\n\n## 可用技能(按需加载)\n\n{self.skills_prompt}\n\n" | ||
| 212 | + "在需要详细说明时,使用 load_skill 工具加载对应技能。" | ||
| 213 | + ) | ||
| 214 | + new_content = list(request.system_message.content_blocks) + [ | ||
| 215 | + {"type": "text", "text": skills_addendum} | ||
| 216 | + ] | ||
| 217 | + new_system_message = SystemMessage(content=new_content) | ||
| 218 | + modified_request = request.override(system_message=new_system_message) | ||
| 219 | + return handler(modified_request) | ||
| 220 | +``` | ||
| 221 | + | ||
| 222 | +### 4.3 创建带 Skills 的 Agent | ||
| 223 | + | ||
| 224 | +```python | ||
| 225 | +from langchain.agents import create_agent | ||
| 226 | +from langgraph.checkpoint.memory import MemorySaver | ||
| 227 | + | ||
| 228 | +# 基础工具(搜索、以图搜图、风格分析等) | ||
| 229 | +from app.tools.search_tools import search_products, search_by_image, analyze_image_style | ||
| 230 | + | ||
| 231 | +agent = create_agent( | ||
| 232 | + model="gpt-4o-mini", | ||
| 233 | + tools=[ | ||
| 234 | + load_skill, # 技能加载 | ||
| 235 | + search_products, | ||
| 236 | + search_by_image, | ||
| 237 | + analyze_image_style, | ||
| 238 | + ], | ||
| 239 | + system_prompt="""你是智能时尚购物助手。根据用户需求,先判断使用哪个技能,必要时用 load_skill 加载技能详情。 | ||
| 240 | + | ||
| 241 | +处理商品结果时,必须遵守 result_packaging 技能中的格式要求。""", | ||
| 242 | + middleware=[ShoppingSkillMiddleware()], | ||
| 243 | + checkpointer=MemorySaver(), | ||
| 244 | +) | ||
| 245 | +``` | ||
| 246 | + | ||
| 247 | +--- | ||
| 248 | + | ||
| 249 | +## 五、与工具的关系 | ||
| 250 | + | ||
| 251 | +| 能力 | 技能 | 工具 | | ||
| 252 | +|------|------|------| | ||
| 253 | +| 查找相关 | lookup_related | search_by_image, search_products, analyze_image_style | | ||
| 254 | +| 搜索商品 | search_products | search_products | | ||
| 255 | +| 检验商品 | check_product | search_products(用 query 表达约束) | | ||
| 256 | +| 结果包装 | result_packaging | 无(纯 prompt 约束) | | ||
| 257 | +| 售后 | after_sales | 无(或对接客服 API) | | ||
| 258 | + | ||
| 259 | +- **技能**:提供「何时用、怎么用」的说明,支持渐进式加载。 | ||
| 260 | +- **工具**:实际执行搜索、分析等操作。 | ||
| 261 | + | ||
| 262 | +--- | ||
| 263 | + | ||
| 264 | +## 六、可选:技能约束(进阶) | ||
| 265 | + | ||
| 266 | +若希望「先加载技能再使用工具」,可结合 `ToolRuntime` 和 state 做约束: | ||
| 267 | + | ||
| 268 | +```python | ||
| 269 | +from langchain.tools import tool, ToolRuntime | ||
| 270 | +from langgraph.types import Command | ||
| 271 | +from langchain.messages import ToolMessage | ||
| 272 | +from typing_extensions import NotRequired | ||
| 273 | + | ||
| 274 | +class CustomState(AgentState): | ||
| 275 | + skills_loaded: NotRequired[list[str]] | ||
| 276 | + | ||
| 277 | +@tool | ||
| 278 | +def load_skill(skill_name: str, runtime: ToolRuntime) -> Command: | ||
| 279 | + """...""" | ||
| 280 | + for skill in SKILLS: | ||
| 281 | + if skill["name"] == skill_name: | ||
| 282 | + content = f"Loaded skill: {skill_name}\n\n{skill['content']}" | ||
| 283 | + return Command(update={ | ||
| 284 | + "messages": [ToolMessage(content=content, tool_call_id=runtime.tool_call_id)], | ||
| 285 | + "skills_loaded": [skill_name], | ||
| 286 | + }) | ||
| 287 | + # ... | ||
| 288 | + | ||
| 289 | +# 在 check_product 等工具中检查 skills_loaded | ||
| 290 | +``` | ||
| 291 | + | ||
| 292 | +--- | ||
| 293 | + | ||
| 294 | +## 七、依赖与版本 | ||
| 295 | + | ||
| 296 | +```text | ||
| 297 | +# requirements.txt | ||
| 298 | +langchain>=1.0.0 | ||
| 299 | +langchain-openai>=0.2.0 | ||
| 300 | +langchain-core>=0.3.0 | ||
| 301 | +langgraph>=1.0.0 | ||
| 302 | +``` | ||
| 303 | + | ||
| 304 | +- Python 3.10+ | ||
| 305 | +- 若使用 Deep Agents 的 SKILL.md,需额外安装 `deepagents` | ||
| 306 | + | ||
| 307 | +--- | ||
| 308 | + | ||
| 309 | +## 八、总结 | ||
| 310 | + | ||
| 311 | +| 项目 | 说明 | | ||
| 312 | +|------|------| | ||
| 313 | +| **效果** | 系统 prompt 只放简短技能描述,按需加载完整内容,减少 token、便于扩展 | | ||
| 314 | +| **流程** | 轻量描述 → load_skill → 完整说明 → 调用工具 → 回复 | | ||
| 315 | +| **实现** | `SkillMiddleware` + `load_skill` + `create_agent` | | ||
| 316 | +| **技能** | lookup_related, search_products, check_product, result_packaging, after_sales | | ||
| 317 | + | ||
| 318 | +完整示例可参考官方教程:[Build a SQL assistant with on-demand skills](https://docs.langchain.com/oss/python/langchain/multi-agent/skills-sql-assistant)。 |
| 1 | +++ a/requirements.txt | ||
| @@ -0,0 +1,40 @@ | @@ -0,0 +1,40 @@ | ||
| 1 | +# Core Framework | ||
| 2 | +fastapi>=0.109.0 | ||
| 3 | +uvicorn[standard]>=0.27.0 | ||
| 4 | +pydantic>=2.6.0 | ||
| 5 | +pydantic-settings>=2.1.0 | ||
| 6 | +streamlit>=1.50.0 | ||
| 7 | + | ||
| 8 | +# LLM & LangChain (Python 3.12, LangChain 1.x) | ||
| 9 | +langchain>=1.0.0 | ||
| 10 | +langchain-core>=0.3.0 | ||
| 11 | +langchain-openai>=0.2.0 | ||
| 12 | +langgraph>=1.0.0 | ||
| 13 | +openai>=1.12.0 | ||
| 14 | + | ||
| 15 | +# Embeddings & Vision | ||
| 16 | +clip-client>=3.5.0 # CLIP-as-Service client | ||
| 17 | +Pillow>=10.2.0 # Image processing | ||
| 18 | + | ||
| 19 | +# Vector Database | ||
| 20 | +pymilvus>=2.3.6 | ||
| 21 | + | ||
| 22 | +# Databases | ||
| 23 | +pymongo>=4.6.1 | ||
| 24 | + | ||
| 25 | +# Utilities | ||
| 26 | +python-dotenv>=1.0.1 | ||
| 27 | +python-multipart>=0.0.9 | ||
| 28 | +aiofiles>=23.2.1 | ||
| 29 | +requests>=2.31.0 | ||
| 30 | + | ||
| 31 | +# Data Processing | ||
| 32 | +pandas>=2.2.3 | ||
| 33 | +numpy>=1.26.4 | ||
| 34 | +tqdm>=4.66.1 | ||
| 35 | + | ||
| 36 | +# Development & Testing | ||
| 37 | +pytest>=8.0.0 | ||
| 38 | +pytest-asyncio>=0.23.4 | ||
| 39 | +httpx>=0.26.0 | ||
| 40 | +black>=24.1.1 |
| 1 | +++ a/scripts/check_services.sh | ||
| @@ -0,0 +1,93 @@ | @@ -0,0 +1,93 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - 服务健康检查脚本 | ||
| 4 | +# 检查 Milvus、CLIP、Streamlit 等依赖服务状态 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | +RED='\033[0;31m' | ||
| 11 | +GREEN='\033[0;32m' | ||
| 12 | +YELLOW='\033[1;33m' | ||
| 13 | +NC='\033[0m' | ||
| 14 | + | ||
| 15 | +echo "==========================================" | ||
| 16 | +echo "OmniShopAgent 服务健康检查" | ||
| 17 | +echo "==========================================" | ||
| 18 | + | ||
| 19 | +# 1. Python 环境 | ||
| 20 | +echo -n "[Python] " | ||
| 21 | +if command -v python3 &>/dev/null; then | ||
| 22 | + VER=$(python3 -c 'import sys; v=sys.version_info; print(f"{v.major}.{v.minor}.{v.micro}")' 2>/dev/null) | ||
| 23 | + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then | ||
| 24 | + echo -e "${GREEN}OK${NC} $VER" | ||
| 25 | + else | ||
| 26 | + echo -e "${YELLOW}WARN${NC} $VER (建议 3.12+)" | ||
| 27 | + fi | ||
| 28 | +else | ||
| 29 | + echo -e "${RED}FAIL${NC} 未找到" | ||
| 30 | +fi | ||
| 31 | + | ||
| 32 | +# 2. 虚拟环境 | ||
| 33 | +echo -n "[Virtualenv] " | ||
| 34 | +if [ -d "$PROJECT_ROOT/venv" ]; then | ||
| 35 | + echo -e "${GREEN}OK${NC} $PROJECT_ROOT/venv" | ||
| 36 | +else | ||
| 37 | + echo -e "${YELLOW}WARN${NC} 未找到 venv" | ||
| 38 | +fi | ||
| 39 | + | ||
| 40 | +# 3. .env 配置 | ||
| 41 | +echo -n "[.env] " | ||
| 42 | +if [ -f "$PROJECT_ROOT/.env" ]; then | ||
| 43 | + if grep -q "OPENAI_API_KEY=sk-" "$PROJECT_ROOT/.env" 2>/dev/null; then | ||
| 44 | + echo -e "${GREEN}OK${NC} 已配置" | ||
| 45 | + else | ||
| 46 | + echo -e "${YELLOW}WARN${NC} 请配置 OPENAI_API_KEY" | ||
| 47 | + fi | ||
| 48 | +else | ||
| 49 | + echo -e "${RED}FAIL${NC} 未找到" | ||
| 50 | +fi | ||
| 51 | + | ||
| 52 | +# 4. Milvus | ||
| 53 | +echo -n "[Milvus] " | ||
| 54 | +if command -v docker &>/dev/null; then | ||
| 55 | + if docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then | ||
| 56 | + if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then | ||
| 57 | + echo -e "${GREEN}OK${NC} localhost:19530" | ||
| 58 | + else | ||
| 59 | + echo -e "${YELLOW}WARN${NC} 容器运行中,健康检查未响应" | ||
| 60 | + fi | ||
| 61 | + else | ||
| 62 | + echo -e "${YELLOW}WARN${NC} 未运行 (docker compose up -d)" | ||
| 63 | + fi | ||
| 64 | +else | ||
| 65 | + echo -e "${YELLOW}SKIP${NC} Docker 未安装" | ||
| 66 | +fi | ||
| 67 | + | ||
| 68 | +# 5. CLIP 服务(可选) | ||
| 69 | +echo -n "[CLIP] " | ||
| 70 | +if timeout 2 bash -c 'echo >/dev/tcp/localhost/51000' 2>/dev/null; then | ||
| 71 | + echo -e "${GREEN}OK${NC} localhost:51000" | ||
| 72 | +else | ||
| 73 | + echo -e "${YELLOW}WARN${NC} 未运行 (图像搜索需启动: python -m clip_server launch)" | ||
| 74 | +fi | ||
| 75 | + | ||
| 76 | +# 6. 数据目录 | ||
| 77 | +echo -n "[数据] " | ||
| 78 | +if [ -d "$PROJECT_ROOT/data/images" ] && [ -f "$PROJECT_ROOT/data/styles.csv" ]; then | ||
| 79 | + IMG_COUNT=$(find "$PROJECT_ROOT/data/images" -name "*.jpg" 2>/dev/null | wc -l) | ||
| 80 | + echo -e "${GREEN}OK${NC} $IMG_COUNT 张图片" | ||
| 81 | +else | ||
| 82 | + echo -e "${YELLOW}WARN${NC} 未找到 data/images 或 data/styles.csv (运行 download_dataset.py)" | ||
| 83 | +fi | ||
| 84 | + | ||
| 85 | +# 7. Streamlit | ||
| 86 | +echo -n "[Streamlit] " | ||
| 87 | +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then | ||
| 88 | + echo -e "${GREEN}OK${NC} 运行中" | ||
| 89 | +else | ||
| 90 | + echo -e "${YELLOW}WARN${NC} 未运行 (./scripts/start.sh)" | ||
| 91 | +fi | ||
| 92 | + | ||
| 93 | +echo "==========================================" |
| 1 | +++ a/scripts/download_dataset.py | ||
| @@ -0,0 +1,95 @@ | @@ -0,0 +1,95 @@ | ||
| 1 | +""" | ||
| 2 | +Script to download the Fashion Product Images Dataset from Kaggle | ||
| 3 | + | ||
| 4 | +Requirements: | ||
| 5 | +1. Install Kaggle CLI: pip install kaggle | ||
| 6 | +2. Setup Kaggle API credentials: | ||
| 7 | + - Go to https://www.kaggle.com/settings/account | ||
| 8 | + - Click "Create New API Token" | ||
| 9 | + - Save kaggle.json to ~/.kaggle/kaggle.json | ||
| 10 | + - chmod 600 ~/.kaggle/kaggle.json | ||
| 11 | + | ||
| 12 | +Usage: | ||
| 13 | + python scripts/download_dataset.py | ||
| 14 | +""" | ||
| 15 | + | ||
| 16 | +import subprocess | ||
| 17 | +import zipfile | ||
| 18 | +from pathlib import Path | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +def download_dataset(): | ||
| 22 | + """Download and extract the Fashion Product Images Dataset""" | ||
| 23 | + | ||
| 24 | + # Get project root | ||
| 25 | + project_root = Path(__file__).parent.parent | ||
| 26 | + raw_data_path = project_root / "data" / "raw" | ||
| 27 | + | ||
| 28 | + # Check if data already exists | ||
| 29 | + if (raw_data_path / "styles.csv").exists(): | ||
| 30 | + print("Dataset already exists in data/raw/") | ||
| 31 | + response = input("Do you want to re-download? (y/n): ") | ||
| 32 | + if response.lower() != "y": | ||
| 33 | + print("Skipping download.") | ||
| 34 | + return | ||
| 35 | + | ||
| 36 | + # Check Kaggle credentials | ||
| 37 | + kaggle_json = Path.home() / ".kaggle" / "kaggle.json" | ||
| 38 | + if not kaggle_json.exists(): | ||
| 39 | + print(" Kaggle API credentials not found!") | ||
| 40 | + return | ||
| 41 | + | ||
| 42 | + print("Downloading dataset from Kaggle...") | ||
| 43 | + | ||
| 44 | + try: | ||
| 45 | + # Download using Kaggle API | ||
| 46 | + subprocess.run( | ||
| 47 | + [ | ||
| 48 | + "kaggle", | ||
| 49 | + "datasets", | ||
| 50 | + "download", | ||
| 51 | + "-d", | ||
| 52 | + "paramaggarwal/fashion-product-images-dataset", | ||
| 53 | + "-p", | ||
| 54 | + str(raw_data_path), | ||
| 55 | + ], | ||
| 56 | + check=True, | ||
| 57 | + ) | ||
| 58 | + | ||
| 59 | + print("Download complete!") | ||
| 60 | + | ||
| 61 | + # Extract zip file | ||
| 62 | + zip_path = raw_data_path / "fashion-product-images-dataset.zip" | ||
| 63 | + if zip_path.exists(): | ||
| 64 | + print("Extracting files...") | ||
| 65 | + with zipfile.ZipFile(zip_path, "r") as zip_ref: | ||
| 66 | + zip_ref.extractall(raw_data_path) | ||
| 67 | + | ||
| 68 | + print("Extraction complete!") | ||
| 69 | + | ||
| 70 | + # Clean up zip file | ||
| 71 | + zip_path.unlink() | ||
| 72 | + print("Cleaned up zip file") | ||
| 73 | + | ||
| 74 | + # Verify files | ||
| 75 | + styles_csv = raw_data_path / "styles.csv" | ||
| 76 | + images_dir = raw_data_path / "images" | ||
| 77 | + | ||
| 78 | + if styles_csv.exists() and images_dir.exists(): | ||
| 79 | + print("\Dataset ready!") | ||
| 80 | + | ||
| 81 | + # Count images | ||
| 82 | + image_count = len(list(images_dir.glob("*.jpg"))) | ||
| 83 | + print(f"- Total images: {image_count:,}") | ||
| 84 | + else: | ||
| 85 | + print("Warning: Expected files not found") | ||
| 86 | + | ||
| 87 | + except subprocess.CalledProcessError: | ||
| 88 | + print("Download failed!") | ||
| 89 | + | ||
| 90 | + except Exception as e: | ||
| 91 | + print(f"Error: {e}") | ||
| 92 | + | ||
| 93 | + | ||
| 94 | +if __name__ == "__main__": | ||
| 95 | + download_dataset() |
| 1 | +++ a/scripts/index_data.py | ||
| @@ -0,0 +1,467 @@ | @@ -0,0 +1,467 @@ | ||
| 1 | +""" | ||
| 2 | +Data Indexing Script | ||
| 3 | +Generates embeddings for products and stores them in Milvus | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import csv | ||
| 7 | +import logging | ||
| 8 | +import os | ||
| 9 | +import sys | ||
| 10 | +from pathlib import Path | ||
| 11 | +from typing import Any, Dict, Optional | ||
| 12 | + | ||
| 13 | +from tqdm import tqdm | ||
| 14 | + | ||
| 15 | +# Add parent directory to path | ||
| 16 | +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
| 17 | + | ||
| 18 | +# Import config and settings first | ||
| 19 | +# Direct imports from files to avoid __init__.py circular issues | ||
| 20 | +import importlib.util | ||
| 21 | + | ||
| 22 | +from app.config import get_absolute_path, settings | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def load_service_module(module_name, file_name): | ||
| 26 | + """Load a service module directly from file""" | ||
| 27 | + spec = importlib.util.spec_from_file_location( | ||
| 28 | + module_name, | ||
| 29 | + os.path.join( | ||
| 30 | + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | ||
| 31 | + f"app/services/{file_name}", | ||
| 32 | + ), | ||
| 33 | + ) | ||
| 34 | + module = importlib.util.module_from_spec(spec) | ||
| 35 | + spec.loader.exec_module(module) | ||
| 36 | + return module | ||
| 37 | + | ||
| 38 | + | ||
| 39 | +embedding_module = load_service_module("embedding_service", "embedding_service.py") | ||
| 40 | +milvus_module = load_service_module("milvus_service", "milvus_service.py") | ||
| 41 | + | ||
| 42 | +EmbeddingService = embedding_module.EmbeddingService | ||
| 43 | +MilvusService = milvus_module.MilvusService | ||
| 44 | + | ||
| 45 | +# Configure logging | ||
| 46 | +logging.basicConfig( | ||
| 47 | + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | ||
| 48 | +) | ||
| 49 | +logger = logging.getLogger(__name__) | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +class DataIndexer: | ||
| 53 | + """Index product data by generating and storing embeddings""" | ||
| 54 | + | ||
| 55 | + def __init__(self): | ||
| 56 | + """Initialize services""" | ||
| 57 | + self.embedding_service = EmbeddingService() | ||
| 58 | + self.milvus_service = MilvusService() | ||
| 59 | + | ||
| 60 | + self.image_dir = Path(get_absolute_path(settings.image_data_path)) | ||
| 61 | + self.styles_csv = get_absolute_path("./data/styles.csv") | ||
| 62 | + self.images_csv = get_absolute_path("./data/images.csv") | ||
| 63 | + | ||
| 64 | + # Load product data from CSV | ||
| 65 | + self.products = self._load_products_from_csv() | ||
| 66 | + | ||
| 67 | + def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: | ||
| 68 | + """Load products from CSV files""" | ||
| 69 | + products = {} | ||
| 70 | + | ||
| 71 | + # Load images mapping | ||
| 72 | + images_dict = {} | ||
| 73 | + with open(self.images_csv, "r", encoding="utf-8") as f: | ||
| 74 | + reader = csv.DictReader(f) | ||
| 75 | + for row in reader: | ||
| 76 | + product_id = int(row["filename"].split(".")[0]) | ||
| 77 | + images_dict[product_id] = row["link"] | ||
| 78 | + | ||
| 79 | + # Load styles/products | ||
| 80 | + with open(self.styles_csv, "r", encoding="utf-8") as f: | ||
| 81 | + reader = csv.DictReader(f) | ||
| 82 | + for row in reader: | ||
| 83 | + try: | ||
| 84 | + product_id = int(row["id"]) | ||
| 85 | + products[product_id] = { | ||
| 86 | + "id": product_id, | ||
| 87 | + "gender": row.get("gender", ""), | ||
| 88 | + "masterCategory": row.get("masterCategory", ""), | ||
| 89 | + "subCategory": row.get("subCategory", ""), | ||
| 90 | + "articleType": row.get("articleType", ""), | ||
| 91 | + "baseColour": row.get("baseColour", ""), | ||
| 92 | + "season": row.get("season", ""), | ||
| 93 | + "year": int(row["year"]) if row.get("year") else 0, | ||
| 94 | + "usage": row.get("usage", ""), | ||
| 95 | + "productDisplayName": row.get("productDisplayName", ""), | ||
| 96 | + "imageUrl": images_dict.get(product_id, ""), | ||
| 97 | + "imagePath": f"{product_id}.jpg", | ||
| 98 | + } | ||
| 99 | + except (ValueError, KeyError) as e: | ||
| 100 | + logger.warning(f"Error loading product {row.get('id')}: {e}") | ||
| 101 | + continue | ||
| 102 | + | ||
| 103 | + logger.info(f"Loaded {len(products)} products from CSV") | ||
| 104 | + return products | ||
| 105 | + | ||
| 106 | + def setup(self) -> None: | ||
| 107 | + """Setup connections and collections""" | ||
| 108 | + logger.info("Setting up services...") | ||
| 109 | + | ||
| 110 | + # Connect to CLIP server | ||
| 111 | + self.embedding_service.connect_clip() | ||
| 112 | + logger.info("✓ CLIP server connected") | ||
| 113 | + | ||
| 114 | + # Connect to Milvus | ||
| 115 | + self.milvus_service.connect() | ||
| 116 | + logger.info("✓ Milvus connected") | ||
| 117 | + | ||
| 118 | + # Create Milvus collections | ||
| 119 | + self.milvus_service.create_text_collection(recreate=False) | ||
| 120 | + self.milvus_service.create_image_collection(recreate=False) | ||
| 121 | + logger.info("✓ Milvus collections ready") | ||
| 122 | + | ||
| 123 | + def teardown(self) -> None: | ||
| 124 | + """Close all connections""" | ||
| 125 | + logger.info("Closing connections...") | ||
| 126 | + self.embedding_service.disconnect_clip() | ||
| 127 | + self.milvus_service.disconnect() | ||
| 128 | + logger.info("✓ All connections closed") | ||
| 129 | + | ||
| 130 | + def index_text_embeddings( | ||
| 131 | + self, batch_size: int = 100, skip: int = 0, limit: Optional[int] = None | ||
| 132 | + ) -> Dict[str, int]: | ||
| 133 | + """Generate and store text embeddings for products | ||
| 134 | + | ||
| 135 | + Args: | ||
| 136 | + batch_size: Number of products to process at once | ||
| 137 | + skip: Number of products to skip | ||
| 138 | + limit: Maximum number of products to process (None for all) | ||
| 139 | + | ||
| 140 | + Returns: | ||
| 141 | + Dictionary with indexing statistics | ||
| 142 | + """ | ||
| 143 | + logger.info("Starting text embedding indexing...") | ||
| 144 | + | ||
| 145 | + # Get products list | ||
| 146 | + product_ids = list(self.products.keys())[skip:] | ||
| 147 | + if limit: | ||
| 148 | + product_ids = product_ids[:limit] | ||
| 149 | + | ||
| 150 | + total_products = len(product_ids) | ||
| 151 | + processed = 0 | ||
| 152 | + inserted = 0 | ||
| 153 | + errors = 0 | ||
| 154 | + | ||
| 155 | + with tqdm(total=total_products, desc="Indexing text embeddings") as pbar: | ||
| 156 | + while processed < total_products: | ||
| 157 | + # Get batch of products | ||
| 158 | + current_batch_size = min(batch_size, total_products - processed) | ||
| 159 | + batch_ids = product_ids[processed : processed + current_batch_size] | ||
| 160 | + products = [self.products[pid] for pid in batch_ids] | ||
| 161 | + | ||
| 162 | + if not products: | ||
| 163 | + break | ||
| 164 | + | ||
| 165 | + try: | ||
| 166 | + # Prepare texts for embedding | ||
| 167 | + texts = [] | ||
| 168 | + text_mappings = [] | ||
| 169 | + | ||
| 170 | + for product in products: | ||
| 171 | + # Create text representation of product | ||
| 172 | + text = self._create_product_text(product) | ||
| 173 | + texts.append(text) | ||
| 174 | + text_mappings.append( | ||
| 175 | + {"product_id": product["id"], "text": text} | ||
| 176 | + ) | ||
| 177 | + | ||
| 178 | + # Generate embeddings | ||
| 179 | + embeddings = self.embedding_service.get_text_embeddings_batch( | ||
| 180 | + texts, batch_size=50 # OpenAI batch size | ||
| 181 | + ) | ||
| 182 | + | ||
| 183 | + # Prepare data for Milvus (with metadata) | ||
| 184 | + milvus_data = [] | ||
| 185 | + for idx, (mapping, embedding) in enumerate( | ||
| 186 | + zip(text_mappings, embeddings) | ||
| 187 | + ): | ||
| 188 | + product_id = mapping["product_id"] | ||
| 189 | + product = self.products[product_id] | ||
| 190 | + | ||
| 191 | + milvus_data.append( | ||
| 192 | + { | ||
| 193 | + "id": product_id, | ||
| 194 | + "text": mapping["text"][ | ||
| 195 | + :2000 | ||
| 196 | + ], # Truncate to max length | ||
| 197 | + "embedding": embedding, | ||
| 198 | + # Product metadata | ||
| 199 | + "productDisplayName": product["productDisplayName"][ | ||
| 200 | + :500 | ||
| 201 | + ], | ||
| 202 | + "gender": product["gender"][:50], | ||
| 203 | + "masterCategory": product["masterCategory"][:100], | ||
| 204 | + "subCategory": product["subCategory"][:100], | ||
| 205 | + "articleType": product["articleType"][:100], | ||
| 206 | + "baseColour": product["baseColour"][:50], | ||
| 207 | + "season": product["season"][:50], | ||
| 208 | + "usage": product["usage"][:50], | ||
| 209 | + "year": product["year"], | ||
| 210 | + "imageUrl": product["imageUrl"], | ||
| 211 | + "imagePath": product["imagePath"], | ||
| 212 | + } | ||
| 213 | + ) | ||
| 214 | + | ||
| 215 | + # Insert into Milvus | ||
| 216 | + count = self.milvus_service.insert_text_embeddings(milvus_data) | ||
| 217 | + inserted += count | ||
| 218 | + | ||
| 219 | + except Exception as e: | ||
| 220 | + logger.error( | ||
| 221 | + f"Error processing text batch at offset {processed}: {e}" | ||
| 222 | + ) | ||
| 223 | + errors += len(products) | ||
| 224 | + | ||
| 225 | + processed += len(products) | ||
| 226 | + pbar.update(len(products)) | ||
| 227 | + | ||
| 228 | + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} | ||
| 229 | + | ||
| 230 | + logger.info(f"Text embedding indexing completed: {stats}") | ||
| 231 | + return stats | ||
| 232 | + | ||
| 233 | + def index_image_embeddings( | ||
| 234 | + self, batch_size: int = 32, skip: int = 0, limit: Optional[int] = None | ||
| 235 | + ) -> Dict[str, int]: | ||
| 236 | + """Generate and store image embeddings for products | ||
| 237 | + | ||
| 238 | + Args: | ||
| 239 | + batch_size: Number of images to process at once | ||
| 240 | + skip: Number of products to skip | ||
| 241 | + limit: Maximum number of products to process (None for all) | ||
| 242 | + | ||
| 243 | + Returns: | ||
| 244 | + Dictionary with indexing statistics | ||
| 245 | + """ | ||
| 246 | + logger.info("Starting image embedding indexing...") | ||
| 247 | + | ||
| 248 | + # Get products list | ||
| 249 | + product_ids = list(self.products.keys())[skip:] | ||
| 250 | + if limit: | ||
| 251 | + product_ids = product_ids[:limit] | ||
| 252 | + | ||
| 253 | + total_products = len(product_ids) | ||
| 254 | + processed = 0 | ||
| 255 | + inserted = 0 | ||
| 256 | + errors = 0 | ||
| 257 | + | ||
| 258 | + with tqdm(total=total_products, desc="Indexing image embeddings") as pbar: | ||
| 259 | + while processed < total_products: | ||
| 260 | + # Get batch of products | ||
| 261 | + current_batch_size = min(batch_size, total_products - processed) | ||
| 262 | + batch_ids = product_ids[processed : processed + current_batch_size] | ||
| 263 | + products = [self.products[pid] for pid in batch_ids] | ||
| 264 | + | ||
| 265 | + if not products: | ||
| 266 | + break | ||
| 267 | + | ||
| 268 | + try: | ||
| 269 | + # Prepare image paths | ||
| 270 | + image_paths = [] | ||
| 271 | + image_mappings = [] | ||
| 272 | + | ||
| 273 | + for product in products: | ||
| 274 | + image_path = self.image_dir / product["imagePath"] | ||
| 275 | + image_paths.append(image_path) | ||
| 276 | + image_mappings.append( | ||
| 277 | + { | ||
| 278 | + "product_id": product["id"], | ||
| 279 | + "image_path": product["imagePath"], | ||
| 280 | + } | ||
| 281 | + ) | ||
| 282 | + | ||
| 283 | + # Generate embeddings | ||
| 284 | + embeddings = self.embedding_service.get_image_embeddings_batch( | ||
| 285 | + image_paths, batch_size=batch_size | ||
| 286 | + ) | ||
| 287 | + | ||
| 288 | + # Prepare data for Milvus (with metadata) | ||
| 289 | + milvus_data = [] | ||
| 290 | + for idx, (mapping, embedding) in enumerate( | ||
| 291 | + zip(image_mappings, embeddings) | ||
| 292 | + ): | ||
| 293 | + if embedding is not None: | ||
| 294 | + product_id = mapping["product_id"] | ||
| 295 | + product = self.products[product_id] | ||
| 296 | + | ||
| 297 | + milvus_data.append( | ||
| 298 | + { | ||
| 299 | + "id": product_id, | ||
| 300 | + "image_path": mapping["image_path"], | ||
| 301 | + "embedding": embedding, | ||
| 302 | + # Product metadata | ||
| 303 | + "productDisplayName": product["productDisplayName"][ | ||
| 304 | + :500 | ||
| 305 | + ], | ||
| 306 | + "gender": product["gender"][:50], | ||
| 307 | + "masterCategory": product["masterCategory"][:100], | ||
| 308 | + "subCategory": product["subCategory"][:100], | ||
| 309 | + "articleType": product["articleType"][:100], | ||
| 310 | + "baseColour": product["baseColour"][:50], | ||
| 311 | + "season": product["season"][:50], | ||
| 312 | + "usage": product["usage"][:50], | ||
| 313 | + "year": product["year"], | ||
| 314 | + "imageUrl": product["imageUrl"], | ||
| 315 | + } | ||
| 316 | + ) | ||
| 317 | + else: | ||
| 318 | + errors += 1 | ||
| 319 | + | ||
| 320 | + # Insert into Milvus | ||
| 321 | + if milvus_data: | ||
| 322 | + count = self.milvus_service.insert_image_embeddings(milvus_data) | ||
| 323 | + inserted += count | ||
| 324 | + | ||
| 325 | + except Exception as e: | ||
| 326 | + logger.error( | ||
| 327 | + f"Error processing image batch at offset {processed}: {e}" | ||
| 328 | + ) | ||
| 329 | + errors += len(products) | ||
| 330 | + | ||
| 331 | + processed += len(products) | ||
| 332 | + pbar.update(len(products)) | ||
| 333 | + | ||
| 334 | + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} | ||
| 335 | + | ||
| 336 | + logger.info(f"Image embedding indexing completed: {stats}") | ||
| 337 | + return stats | ||
| 338 | + | ||
| 339 | + def _create_product_text(self, product: Dict[str, Any]) -> str: | ||
| 340 | + """Create text representation of product for embedding | ||
| 341 | + | ||
| 342 | + Args: | ||
| 343 | + product: Product document | ||
| 344 | + | ||
| 345 | + Returns: | ||
| 346 | + Text representation | ||
| 347 | + """ | ||
| 348 | + # Create a natural language description | ||
| 349 | + parts = [ | ||
| 350 | + product.get("productDisplayName", ""), | ||
| 351 | + f"Gender: {product.get('gender', '')}", | ||
| 352 | + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", | ||
| 353 | + f"Type: {product.get('articleType', '')}", | ||
| 354 | + f"Color: {product.get('baseColour', '')}", | ||
| 355 | + f"Season: {product.get('season', '')}", | ||
| 356 | + f"Usage: {product.get('usage', '')}", | ||
| 357 | + ] | ||
| 358 | + | ||
| 359 | + text = " | ".join( | ||
| 360 | + [p for p in parts if p and p != "Gender: " and p != "Color: "] | ||
| 361 | + ) | ||
| 362 | + return text | ||
| 363 | + | ||
| 364 | + def get_stats(self) -> Dict[str, Any]: | ||
| 365 | + """Get indexing statistics | ||
| 366 | + | ||
| 367 | + Returns: | ||
| 368 | + Dictionary with statistics | ||
| 369 | + """ | ||
| 370 | + text_stats = self.milvus_service.get_collection_stats( | ||
| 371 | + self.milvus_service.text_collection_name | ||
| 372 | + ) | ||
| 373 | + image_stats = self.milvus_service.get_collection_stats( | ||
| 374 | + self.milvus_service.image_collection_name | ||
| 375 | + ) | ||
| 376 | + | ||
| 377 | + return { | ||
| 378 | + "total_products": len(self.products), | ||
| 379 | + "milvus_text": text_stats, | ||
| 380 | + "milvus_image": image_stats, | ||
| 381 | + } | ||
| 382 | + | ||
| 383 | + | ||
| 384 | +def main(): | ||
| 385 | + """Main function""" | ||
| 386 | + import argparse | ||
| 387 | + | ||
| 388 | + parser = argparse.ArgumentParser(description="Index product data for search") | ||
| 389 | + parser.add_argument( | ||
| 390 | + "--mode", | ||
| 391 | + choices=["text", "image", "both"], | ||
| 392 | + default="both", | ||
| 393 | + help="Which embeddings to index", | ||
| 394 | + ) | ||
| 395 | + parser.add_argument( | ||
| 396 | + "--batch-size", type=int, default=100, help="Batch size for processing" | ||
| 397 | + ) | ||
| 398 | + parser.add_argument( | ||
| 399 | + "--skip", type=int, default=0, help="Number of products to skip" | ||
| 400 | + ) | ||
| 401 | + parser.add_argument( | ||
| 402 | + "--limit", type=int, default=None, help="Maximum number of products to process" | ||
| 403 | + ) | ||
| 404 | + parser.add_argument("--stats", action="store_true", help="Show statistics only") | ||
| 405 | + | ||
| 406 | + args = parser.parse_args() | ||
| 407 | + | ||
| 408 | + # Create indexer | ||
| 409 | + indexer = DataIndexer() | ||
| 410 | + | ||
| 411 | + try: | ||
| 412 | + # Setup services | ||
| 413 | + indexer.setup() | ||
| 414 | + | ||
| 415 | + if args.stats: | ||
| 416 | + # Show statistics | ||
| 417 | + stats = indexer.get_stats() | ||
| 418 | + print("\n=== Indexing Statistics ===") | ||
| 419 | + print(f"\nTotal Products in CSV: {stats['total_products']}") | ||
| 420 | + | ||
| 421 | + print("\nMilvus Text Embeddings:") | ||
| 422 | + print(f" Collection: {stats['milvus_text']['collection_name']}") | ||
| 423 | + print(f" Total embeddings: {stats['milvus_text']['row_count']}") | ||
| 424 | + | ||
| 425 | + print("\nMilvus Image Embeddings:") | ||
| 426 | + print(f" Collection: {stats['milvus_image']['collection_name']}") | ||
| 427 | + print(f" Total embeddings: {stats['milvus_image']['row_count']}") | ||
| 428 | + | ||
| 429 | + print( | ||
| 430 | + f"\nCoverage: {stats['milvus_image']['row_count'] / stats['total_products'] * 100:.1f}%" | ||
| 431 | + ) | ||
| 432 | + else: | ||
| 433 | + # Index data | ||
| 434 | + if args.mode in ["text", "both"]: | ||
| 435 | + logger.info("=== Indexing Text Embeddings ===") | ||
| 436 | + text_stats = indexer.index_text_embeddings( | ||
| 437 | + batch_size=args.batch_size, skip=args.skip, limit=args.limit | ||
| 438 | + ) | ||
| 439 | + print(f"\nText Indexing Results: {text_stats}") | ||
| 440 | + | ||
| 441 | + if args.mode in ["image", "both"]: | ||
| 442 | + logger.info("=== Indexing Image Embeddings ===") | ||
| 443 | + image_stats = indexer.index_image_embeddings( | ||
| 444 | + batch_size=min(args.batch_size, 32), # Smaller batch for images | ||
| 445 | + skip=args.skip, | ||
| 446 | + limit=args.limit, | ||
| 447 | + ) | ||
| 448 | + print(f"\nImage Indexing Results: {image_stats}") | ||
| 449 | + | ||
| 450 | + # Show final statistics | ||
| 451 | + logger.info("\n=== Final Statistics ===") | ||
| 452 | + stats = indexer.get_stats() | ||
| 453 | + print(f"Total products: {stats['total_products']}") | ||
| 454 | + print(f"Text embeddings: {stats['milvus_text']['row_count']}") | ||
| 455 | + print(f"Image embeddings: {stats['milvus_image']['row_count']}") | ||
| 456 | + | ||
| 457 | + except KeyboardInterrupt: | ||
| 458 | + logger.info("\nIndexing interrupted by user") | ||
| 459 | + except Exception as e: | ||
| 460 | + logger.error(f"Error during indexing: {e}", exc_info=True) | ||
| 461 | + sys.exit(1) | ||
| 462 | + finally: | ||
| 463 | + indexer.teardown() | ||
| 464 | + | ||
| 465 | + | ||
| 466 | +if __name__ == "__main__": | ||
| 467 | + main() |
| 1 | +++ a/scripts/run_clip.sh | ||
| @@ -0,0 +1,22 @@ | @@ -0,0 +1,22 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - 启动 CLIP 图像向量服务 | ||
| 4 | +# 图像搜索、以图搜图功能依赖此服务 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | ||
| 11 | + | ||
| 12 | +cd "$PROJECT_ROOT" | ||
| 13 | + | ||
| 14 | +if [ -d "$VENV_DIR" ]; then | ||
| 15 | + set +u | ||
| 16 | + source "$VENV_DIR/bin/activate" | ||
| 17 | + set -u | ||
| 18 | +fi | ||
| 19 | + | ||
| 20 | +echo "启动 CLIP 服务 (端口 51000)..." | ||
| 21 | +echo "按 Ctrl+C 停止" | ||
| 22 | +exec python -m clip_server launch |
| 1 | +++ a/scripts/run_milvus.sh | ||
| @@ -0,0 +1,31 @@ | @@ -0,0 +1,31 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - 启动 Milvus 向量数据库 | ||
| 4 | +# 使用 Docker Compose 启动 Milvus 及相关依赖 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | + | ||
| 11 | +cd "$PROJECT_ROOT" | ||
| 12 | + | ||
| 13 | +if ! command -v docker &>/dev/null; then | ||
| 14 | + echo "错误: 未安装 Docker。请先运行 setup_env_centos8.sh" | ||
| 15 | + exit 1 | ||
| 16 | +fi | ||
| 17 | + | ||
| 18 | +echo "启动 Milvus..." | ||
| 19 | +docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { | ||
| 20 | + echo "错误: 无法执行 docker compose。请确保已安装 Docker Compose" | ||
| 21 | + exit 1 | ||
| 22 | +} | ||
| 23 | + | ||
| 24 | +echo "等待 Milvus 就绪 (约 60 秒)..." | ||
| 25 | +sleep 60 | ||
| 26 | + | ||
| 27 | +if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then | ||
| 28 | + echo "Milvus 已就绪: localhost:19530" | ||
| 29 | +else | ||
| 30 | + echo "提示: Milvus 可能仍在启动,请稍后执行 check_services.sh 检查" | ||
| 31 | +fi |
| 1 | +++ a/scripts/setup_env_centos8.sh | ||
| @@ -0,0 +1,152 @@ | @@ -0,0 +1,152 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - CentOS 8 环境准备脚本 | ||
| 4 | +# 准备 Python 3.12、Docker、依赖及虚拟环境 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | ||
| 11 | +PYTHON_VERSION="${PYTHON_VERSION:-3.12}" | ||
| 12 | + | ||
| 13 | +echo "==========================================" | ||
| 14 | +echo "OmniShopAgent - CentOS 8 环境准备" | ||
| 15 | +echo "==========================================" | ||
| 16 | +echo "项目目录: $PROJECT_ROOT" | ||
| 17 | +echo "虚拟环境: $VENV_DIR" | ||
| 18 | +echo "Python 版本: $PYTHON_VERSION" | ||
| 19 | +echo "==========================================" | ||
| 20 | + | ||
| 21 | +# ----------------------------------------------------------------------------- | ||
| 22 | +# 1. 安装系统依赖 | ||
| 23 | +# ----------------------------------------------------------------------------- | ||
| 24 | +echo "[1/4] 安装系统依赖..." | ||
| 25 | +sudo dnf install -y \ | ||
| 26 | + gcc \ | ||
| 27 | + gcc-c++ \ | ||
| 28 | + make \ | ||
| 29 | + openssl-devel \ | ||
| 30 | + bzip2-devel \ | ||
| 31 | + libffi-devel \ | ||
| 32 | + sqlite-devel \ | ||
| 33 | + xz-devel \ | ||
| 34 | + zlib-devel \ | ||
| 35 | + readline-devel \ | ||
| 36 | + tk-devel \ | ||
| 37 | + libuuid-devel \ | ||
| 38 | + curl \ | ||
| 39 | + wget \ | ||
| 40 | + git \ | ||
| 41 | + tar | ||
| 42 | + | ||
| 43 | +# ----------------------------------------------------------------------------- | ||
| 44 | +# 2. 安装 Docker(用于 Milvus) | ||
| 45 | +# ----------------------------------------------------------------------------- | ||
| 46 | +echo "[2/4] 检查/安装 Docker..." | ||
| 47 | +if ! command -v docker &>/dev/null; then | ||
| 48 | + echo " 安装 Docker..." | ||
| 49 | + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 2>/dev/null || { | ||
| 50 | + sudo dnf install -y dnf-plugins-core | ||
| 51 | + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo | ||
| 52 | + } | ||
| 53 | + sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin | ||
| 54 | + sudo systemctl enable docker | ||
| 55 | + sudo systemctl start docker | ||
| 56 | + sudo usermod -aG docker "$USER" 2>/dev/null || true | ||
| 57 | + echo " Docker 已安装。请执行 'newgrp docker' 或重新登录以使用 docker 命令。" | ||
| 58 | +else | ||
| 59 | + echo " Docker 已安装: $(docker --version)" | ||
| 60 | +fi | ||
| 61 | + | ||
| 62 | +# ----------------------------------------------------------------------------- | ||
| 63 | +# 3. 安装 Python 3.12 | ||
| 64 | +# ----------------------------------------------------------------------------- | ||
| 65 | +echo "[3/4] 安装 Python $PYTHON_VERSION..." | ||
| 66 | + | ||
| 67 | +USE_CONDA=false | ||
| 68 | +if command -v python3.12 &>/dev/null; then | ||
| 69 | + echo " Python 3.12 已安装" | ||
| 70 | +elif command -v conda &>/dev/null; then | ||
| 71 | + echo " 使用 conda 创建 Python $PYTHON_VERSION 环境..." | ||
| 72 | + conda create -n shop_agent "python=$PYTHON_VERSION" -y | ||
| 73 | + USE_CONDA=true | ||
| 74 | + echo " Conda 环境已创建。请执行: conda activate shop_agent" | ||
| 75 | + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" | ||
| 76 | + echo " 跳过 venv 创建..." | ||
| 77 | +else | ||
| 78 | + echo " 从源码编译 Python $PYTHON_VERSION..." | ||
| 79 | + sudo dnf groupinstall -y 'Development Tools' | ||
| 80 | + cd /tmp | ||
| 81 | + PY_URL="https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.0.tgz" | ||
| 82 | + PY_TGZ="Python-${PYTHON_VERSION}.0.tgz" | ||
| 83 | + [ -f "$PY_TGZ" ] || wget -q "$PY_URL" -O "$PY_TGZ" | ||
| 84 | + tar xzf "$PY_TGZ" | ||
| 85 | + cd "Python-${PYTHON_VERSION}.0" | ||
| 86 | + ./configure --enable-optimizations --prefix=/usr/local | ||
| 87 | + make -j "$(nproc)" | ||
| 88 | + sudo make altinstall | ||
| 89 | + cd /tmp | ||
| 90 | + rm -rf "Python-${PYTHON_VERSION}.0" "$PY_TGZ" | ||
| 91 | +fi | ||
| 92 | + | ||
| 93 | +# ----------------------------------------------------------------------------- | ||
| 94 | +# 4. 创建虚拟环境并安装依赖(非 conda 时) | ||
| 95 | +# ----------------------------------------------------------------------------- | ||
| 96 | +if [ "$USE_CONDA" = true ]; then | ||
| 97 | + echo "[4/4] 已使用 conda,跳过 venv 创建" | ||
| 98 | +else | ||
| 99 | + echo "[4/4] 创建虚拟环境与安装 Python 依赖..." | ||
| 100 | + | ||
| 101 | + PYTHON_BIN="" | ||
| 102 | + for p in python3.12 python3.11 python3; do | ||
| 103 | + if command -v "$p" &>/dev/null; then | ||
| 104 | + VER=$("$p" -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || echo "0") | ||
| 105 | + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then | ||
| 106 | + PYTHON_BIN="$p" | ||
| 107 | + break | ||
| 108 | + fi | ||
| 109 | + fi | ||
| 110 | + done | ||
| 111 | + | ||
| 112 | + if [ -z "$PYTHON_BIN" ]; then | ||
| 113 | + echo " 错误: 未找到 Python 3.10+。若使用 conda,请先执行: conda activate shop_agent" | ||
| 114 | + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" | ||
| 115 | + exit 1 | ||
| 116 | + fi | ||
| 117 | + | ||
| 118 | + if [ ! -d "$VENV_DIR" ]; then | ||
| 119 | + echo " 创建虚拟环境: $VENV_DIR" | ||
| 120 | + "$PYTHON_BIN" -m venv "$VENV_DIR" | ||
| 121 | + fi | ||
| 122 | + | ||
| 123 | + echo " 激活虚拟环境并安装依赖..." | ||
| 124 | + set +u | ||
| 125 | + source "$VENV_DIR/bin/activate" | ||
| 126 | + set -u | ||
| 127 | + pip install -U pip | ||
| 128 | + pip install -r "$PROJECT_ROOT/requirements.txt" | ||
| 129 | + echo " Python 依赖安装完成。" | ||
| 130 | +fi | ||
| 131 | + | ||
| 132 | +# 配置 .env | ||
| 133 | +if [ ! -f "$PROJECT_ROOT/.env" ]; then | ||
| 134 | + echo "" | ||
| 135 | + echo " 创建 .env 配置文件..." | ||
| 136 | + cp "$PROJECT_ROOT/.env.example" "$PROJECT_ROOT/.env" | ||
| 137 | + echo " 请编辑 $PROJECT_ROOT/.env 配置 OPENAI_API_KEY 等参数。" | ||
| 138 | +fi | ||
| 139 | + | ||
| 140 | +echo "" | ||
| 141 | +echo "==========================================" | ||
| 142 | +echo "环境准备完成!" | ||
| 143 | +echo "==========================================" | ||
| 144 | +echo "下一步:" | ||
| 145 | +echo " 1. 编辑 .env 配置 OPENAI_API_KEY" | ||
| 146 | +echo " 2. 下载数据: python scripts/download_dataset.py" | ||
| 147 | +echo " 3. 启动 Milvus: ./scripts/run_milvus.sh" | ||
| 148 | +echo " 4. 索引数据: python scripts/index_data.py" | ||
| 149 | +echo " 5. 启动应用: ./scripts/start.sh" | ||
| 150 | +echo "" | ||
| 151 | +echo "激活虚拟环境: source $VENV_DIR/bin/activate" | ||
| 152 | +echo "==========================================" |
| 1 | +++ a/scripts/start.sh | ||
| @@ -0,0 +1,63 @@ | @@ -0,0 +1,63 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - 启动脚本 | ||
| 4 | +# 启动 Milvus、CLIP(可选)、Streamlit 应用 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | ||
| 11 | +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" | ||
| 12 | +STREAMLIT_HOST="${STREAMLIT_HOST:-0.0.0.0}" | ||
| 13 | + | ||
| 14 | +cd "$PROJECT_ROOT" | ||
| 15 | + | ||
| 16 | +# 激活虚拟环境 | ||
| 17 | +if [ -d "$VENV_DIR" ]; then | ||
| 18 | + echo "激活虚拟环境: $VENV_DIR" | ||
| 19 | + set +u | ||
| 20 | + source "$VENV_DIR/bin/activate" | ||
| 21 | + set -u | ||
| 22 | +else | ||
| 23 | + echo "警告: 未找到虚拟环境 $VENV_DIR,使用当前 Python" | ||
| 24 | +fi | ||
| 25 | + | ||
| 26 | +echo "==========================================" | ||
| 27 | +echo "OmniShopAgent 启动" | ||
| 28 | +echo "==========================================" | ||
| 29 | + | ||
| 30 | +# 1. 启动 Milvus(Docker) | ||
| 31 | +if command -v docker &>/dev/null; then | ||
| 32 | + echo "[1/3] 检查 Milvus..." | ||
| 33 | + if ! docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then | ||
| 34 | + echo " 启动 Milvus (docker compose)..." | ||
| 35 | + docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { | ||
| 36 | + echo " 警告: 无法启动 Milvus,请手动执行: docker compose up -d" | ||
| 37 | + } | ||
| 38 | + echo " 等待 Milvus 就绪 (30s)..." | ||
| 39 | + sleep 30 | ||
| 40 | + else | ||
| 41 | + echo " Milvus 已运行" | ||
| 42 | + fi | ||
| 43 | +else | ||
| 44 | + echo "[1/3] 跳过 Milvus: 未安装 Docker" | ||
| 45 | +fi | ||
| 46 | + | ||
| 47 | +# 2. 检查 CLIP(可选,图像搜索需要) | ||
| 48 | +echo "[2/3] 检查 CLIP 服务..." | ||
| 49 | +echo " 提示: 图像搜索需 CLIP。若未启动,请另开终端执行: python -m clip_server launch" | ||
| 50 | +echo " 文本搜索可无需 CLIP。" | ||
| 51 | + | ||
| 52 | +# 3. 启动 Streamlit | ||
| 53 | +echo "[3/3] 启动 Streamlit (端口 $STREAMLIT_PORT)..." | ||
| 54 | +echo "" | ||
| 55 | +echo " 访问: http://$STREAMLIT_HOST:$STREAMLIT_PORT" | ||
| 56 | +echo " 按 Ctrl+C 停止" | ||
| 57 | +echo "==========================================" | ||
| 58 | + | ||
| 59 | +exec streamlit run app.py \ | ||
| 60 | + --server.port="$STREAMLIT_PORT" \ | ||
| 61 | + --server.address="$STREAMLIT_HOST" \ | ||
| 62 | + --server.headless=true \ | ||
| 63 | + --browser.gatherUsageStats=false |
| 1 | +++ a/scripts/stop.sh | ||
| @@ -0,0 +1,46 @@ | @@ -0,0 +1,46 @@ | ||
| 1 | +#!/usr/bin/env bash | ||
| 2 | +# ============================================================================= | ||
| 3 | +# OmniShopAgent - 停止脚本 | ||
| 4 | +# 停止 Streamlit 进程及 Milvus 容器 | ||
| 5 | +# ============================================================================= | ||
| 6 | +set -euo pipefail | ||
| 7 | + | ||
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | ||
| 10 | +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" | ||
| 11 | + | ||
| 12 | +echo "==========================================" | ||
| 13 | +echo "OmniShopAgent 停止" | ||
| 14 | +echo "==========================================" | ||
| 15 | + | ||
| 16 | +# 1. 停止 Streamlit 进程 | ||
| 17 | +echo "[1/2] 停止 Streamlit..." | ||
| 18 | +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then | ||
| 19 | + pkill -f "streamlit run app.py" 2>/dev/null || true | ||
| 20 | + echo " Streamlit 已停止" | ||
| 21 | +else | ||
| 22 | + echo " Streamlit 未在运行" | ||
| 23 | +fi | ||
| 24 | + | ||
| 25 | +# 按端口查找并终止 | ||
| 26 | +if command -v lsof &>/dev/null; then | ||
| 27 | + PID=$(lsof -ti:$STREAMLIT_PORT 2>/dev/null || true) | ||
| 28 | + if [ -n "$PID" ]; then | ||
| 29 | + kill $PID 2>/dev/null || true | ||
| 30 | + echo " 已终止端口 $STREAMLIT_PORT 上的进程" | ||
| 31 | + fi | ||
| 32 | +fi | ||
| 33 | + | ||
| 34 | +# 2. 可选:停止 Milvus 容器 | ||
| 35 | +echo "[2/2] 停止 Milvus..." | ||
| 36 | +if command -v docker &>/dev/null; then | ||
| 37 | + cd "$PROJECT_ROOT" | ||
| 38 | + docker compose down 2>/dev/null || docker-compose down 2>/dev/null || true | ||
| 39 | + echo " Milvus 已停止" | ||
| 40 | +else | ||
| 41 | + echo " Docker 未安装,跳过" | ||
| 42 | +fi | ||
| 43 | + | ||
| 44 | +echo "==========================================" | ||
| 45 | +echo "OmniShopAgent 已停止" | ||
| 46 | +echo "==========================================" |
| 1 | +++ a/技术实现报告.md | ||
| @@ -0,0 +1,624 @@ | @@ -0,0 +1,624 @@ | ||
| 1 | +# OmniShopAgent 项目技术实现报告 | ||
| 2 | + | ||
| 3 | +## 一、项目概述 | ||
| 4 | + | ||
| 5 | +OmniShopAgent 是一个基于 **LangGraph** 和 **ReAct 模式** 的自主多模态时尚购物智能体。系统能够自主决定调用哪些工具、维护对话状态、判断何时回复,实现智能化的商品发现与推荐。 | ||
| 6 | + | ||
| 7 | +### 核心特性 | ||
| 8 | + | ||
| 9 | +- **自主工具选择与执行**:Agent 根据用户意图自主选择并调用工具 | ||
| 10 | +- **多模态搜索**:支持文本搜索 + 图像搜索 | ||
| 11 | +- **对话上下文感知**:多轮对话中保持上下文记忆 | ||
| 12 | +- **实时视觉分析**:基于 VLM 的图片风格分析 | ||
| 13 | + | ||
| 14 | +--- | ||
| 15 | + | ||
| 16 | +## 二、技术栈 | ||
| 17 | + | ||
| 18 | +| 组件 | 技术选型 | | ||
| 19 | +|------|----------| | ||
| 20 | +| 运行环境 | Python 3.12 | | ||
| 21 | +| Agent 框架 | LangGraph 1.x | | ||
| 22 | +| LLM 框架 | LangChain 1.x(支持任意 LLM,默认 gpt-4o-mini) | | ||
| 23 | +| 文本向量 | text-embedding-3-small | | ||
| 24 | +| 图像向量 | CLIP ViT-B/32 | | ||
| 25 | +| 向量数据库 | Milvus | | ||
| 26 | +| 前端 | Streamlit | | ||
| 27 | +| 数据集 | Kaggle Fashion Products | | ||
| 28 | + | ||
| 29 | +--- | ||
| 30 | + | ||
| 31 | +## 三、系统架构 | ||
| 32 | + | ||
| 33 | +### 3.1 整体架构图 | ||
| 34 | + | ||
| 35 | +``` | ||
| 36 | +┌─────────────────────────────────────────────────────────────────┐ | ||
| 37 | +│ Streamlit 前端 (app.py) │ | ||
| 38 | +└─────────────────────────────────────────────────────────────────┘ | ||
| 39 | + │ | ||
| 40 | + ▼ | ||
| 41 | +┌─────────────────────────────────────────────────────────────────┐ | ||
| 42 | +│ ShoppingAgent (shopping_agent.py) │ | ||
| 43 | +│ ┌───────────────────────────────────────────────────────────┐ │ | ||
| 44 | +│ │ LangGraph StateGraph + ReAct Pattern │ │ | ||
| 45 | +│ │ START → Agent → [Has tool_calls?] → Tools → Agent → END │ │ | ||
| 46 | +│ └───────────────────────────────────────────────────────────┘ │ | ||
| 47 | +└─────────────────────────────────────────────────────────────────┘ | ||
| 48 | + │ │ │ | ||
| 49 | + ▼ ▼ ▼ | ||
| 50 | +┌──────────────┐ ┌──────────────────┐ ┌─────────────────────┐ | ||
| 51 | +│ search_ │ │ search_by_image │ │ analyze_image_style │ | ||
| 52 | +│ products │ │ │ │ (OpenAI Vision) │ | ||
| 53 | +└──────┬───────┘ └────────┬─────────┘ └──────────┬───────────┘ | ||
| 54 | + │ │ │ | ||
| 55 | + ▼ ▼ ▼ | ||
| 56 | +┌─────────────────────────────────────────────────────────────────┐ | ||
| 57 | +│ EmbeddingService (embedding_service.py) │ | ||
| 58 | +│ OpenAI API (文本) │ CLIP Server (图像) │ | ||
| 59 | +└─────────────────────────────────────────────────────────────────┘ | ||
| 60 | + │ | ||
| 61 | + ▼ | ||
| 62 | +┌─────────────────────────────────────────────────────────────────┐ | ||
| 63 | +│ MilvusService (milvus_service.py) │ | ||
| 64 | +│ text_embeddings 集合 │ image_embeddings 集合 │ | ||
| 65 | +└─────────────────────────────────────────────────────────────────┘ | ||
| 66 | +``` | ||
| 67 | + | ||
| 68 | +### 3.2 Agent 流程图(LangGraph) | ||
| 69 | + | ||
| 70 | +```mermaid | ||
| 71 | +graph LR | ||
| 72 | + START --> Agent | ||
| 73 | + Agent -->|Has tool_calls| Tools | ||
| 74 | + Agent -->|No tool_calls| END | ||
| 75 | + Tools --> Agent | ||
| 76 | +``` | ||
| 77 | + | ||
| 78 | +--- | ||
| 79 | + | ||
| 80 | +## 四、关键代码实现 | ||
| 81 | + | ||
| 82 | +### 4.1 Agent 核心实现(shopping_agent.py) | ||
| 83 | + | ||
| 84 | +#### 4.1.1 状态定义 | ||
| 85 | + | ||
| 86 | +```python | ||
| 87 | +from typing_extensions import Annotated, TypedDict | ||
| 88 | +from langgraph.graph.message import add_messages | ||
| 89 | + | ||
| 90 | +class AgentState(TypedDict): | ||
| 91 | + """State for the shopping agent with message accumulation""" | ||
| 92 | + messages: Annotated[Sequence[BaseMessage], add_messages] | ||
| 93 | + current_image_path: Optional[str] # Track uploaded image | ||
| 94 | +``` | ||
| 95 | + | ||
| 96 | +- `messages` 使用 `add_messages` 实现消息累加,支持多轮对话 | ||
| 97 | +- `current_image_path` 存储当前上传的图片路径供工具使用 | ||
| 98 | + | ||
| 99 | +#### 4.1.2 LangGraph 图构建 | ||
| 100 | + | ||
| 101 | +```python | ||
| 102 | +def _build_graph(self): | ||
| 103 | + """Build the LangGraph StateGraph""" | ||
| 104 | + | ||
| 105 | + def agent_node(state: AgentState): | ||
| 106 | + """Agent decision node - decides which tools to call or when to respond""" | ||
| 107 | + messages = state["messages"] | ||
| 108 | + if not any(isinstance(m, SystemMessage) for m in messages): | ||
| 109 | + messages = [SystemMessage(content=system_prompt)] + list(messages) | ||
| 110 | + response = self.llm_with_tools.invoke(messages) | ||
| 111 | + return {"messages": [response]} | ||
| 112 | + | ||
| 113 | + tool_node = ToolNode(self.tools) | ||
| 114 | + | ||
| 115 | + def should_continue(state: AgentState): | ||
| 116 | + """Determine if agent should continue or end""" | ||
| 117 | + last_message = state["messages"][-1] | ||
| 118 | + if hasattr(last_message, "tool_calls") and last_message.tool_calls: | ||
| 119 | + return "tools" | ||
| 120 | + return END | ||
| 121 | + | ||
| 122 | + workflow = StateGraph(AgentState) | ||
| 123 | + workflow.add_node("agent", agent_node) | ||
| 124 | + workflow.add_node("tools", tool_node) | ||
| 125 | + workflow.add_edge(START, "agent") | ||
| 126 | + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) | ||
| 127 | + workflow.add_edge("tools", "agent") | ||
| 128 | + | ||
| 129 | + checkpointer = MemorySaver() | ||
| 130 | + return workflow.compile(checkpointer=checkpointer) | ||
| 131 | +``` | ||
| 132 | + | ||
| 133 | +关键点: | ||
| 134 | +- **agent_node**:将消息传入 LLM,由 LLM 决定是否调用工具 | ||
| 135 | +- **should_continue**:若有 `tool_calls` 则进入工具节点,否则结束 | ||
| 136 | +- **MemorySaver**:按 `thread_id` 持久化对话状态 | ||
| 137 | + | ||
| 138 | +#### 4.1.3 System Prompt 设计 | ||
| 139 | + | ||
| 140 | +```python | ||
| 141 | +system_prompt = """You are an intelligent fashion shopping assistant. You can: | ||
| 142 | +1. Search for products by text description (use search_products) | ||
| 143 | +2. Find visually similar products from images (use search_by_image) | ||
| 144 | +3. Analyze image style and attributes (use analyze_image_style) | ||
| 145 | + | ||
| 146 | +When a user asks about products: | ||
| 147 | +- For text queries: use search_products directly | ||
| 148 | +- For image uploads: decide if you need to analyze_image_style first, then search | ||
| 149 | +- You can call multiple tools in sequence if needed | ||
| 150 | +- Always provide helpful, friendly responses | ||
| 151 | + | ||
| 152 | +CRITICAL FORMATTING RULES: | ||
| 153 | +When presenting product results, you MUST use this EXACT format for EACH product: | ||
| 154 | +1. [Product Name] | ||
| 155 | + ID: [Product ID Number] | ||
| 156 | + Category: [Category] | ||
| 157 | + Color: [Color] | ||
| 158 | + Gender: [Gender] | ||
| 159 | + (Include Season, Usage, Relevance if available) | ||
| 160 | +...""" | ||
| 161 | +``` | ||
| 162 | + | ||
| 163 | +通过 system prompt 约束工具使用和输出格式,保证前端可正确解析产品信息。 | ||
| 164 | + | ||
| 165 | +#### 4.1.4 对话入口与流式处理 | ||
| 166 | + | ||
| 167 | +```python | ||
| 168 | +def chat(self, query: str, image_path: Optional[str] = None) -> dict: | ||
| 169 | + # Build input message | ||
| 170 | + message_content = query | ||
| 171 | + if image_path: | ||
| 172 | + message_content = f"{query}\n[User uploaded image: {image_path}]" | ||
| 173 | + | ||
| 174 | + config = {"configurable": {"thread_id": self.session_id}} | ||
| 175 | + input_state = { | ||
| 176 | + "messages": [HumanMessage(content=message_content)], | ||
| 177 | + "current_image_path": image_path, | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + tool_calls = [] | ||
| 181 | + for event in self.graph.stream(input_state, config=config): | ||
| 182 | + if "agent" in event: | ||
| 183 | + for msg in event["agent"].get("messages", []): | ||
| 184 | + if hasattr(msg, "tool_calls") and msg.tool_calls: | ||
| 185 | + for tc in msg.tool_calls: | ||
| 186 | + tool_calls.append({"name": tc["name"], "args": tc.get("args", {})}) | ||
| 187 | + if "tools" in event: | ||
| 188 | + # 记录工具执行结果 | ||
| 189 | + ... | ||
| 190 | + | ||
| 191 | + final_state = self.graph.get_state(config) | ||
| 192 | + response_text = final_state.values["messages"][-1].content | ||
| 193 | + | ||
| 194 | + return {"response": response_text, "tool_calls": tool_calls, "error": False} | ||
| 195 | +``` | ||
| 196 | + | ||
| 197 | +--- | ||
| 198 | + | ||
| 199 | +### 4.2 搜索工具实现(search_tools.py) | ||
| 200 | + | ||
| 201 | +#### 4.2.1 文本语义搜索 | ||
| 202 | + | ||
| 203 | +```python | ||
| 204 | +@tool | ||
| 205 | +def search_products(query: str, limit: int = 5) -> str: | ||
| 206 | + """Search for fashion products using natural language descriptions.""" | ||
| 207 | + try: | ||
| 208 | + embedding_service = get_embedding_service() | ||
| 209 | + milvus_service = get_milvus_service() | ||
| 210 | + | ||
| 211 | + query_embedding = embedding_service.get_text_embedding(query) | ||
| 212 | + | ||
| 213 | + results = milvus_service.search_similar_text( | ||
| 214 | + query_embedding=query_embedding, | ||
| 215 | + limit=min(limit, 20), | ||
| 216 | + filters=None, | ||
| 217 | + output_fields=[ | ||
| 218 | + "id", "productDisplayName", "gender", "masterCategory", | ||
| 219 | + "subCategory", "articleType", "baseColour", "season", "usage", | ||
| 220 | + ], | ||
| 221 | + ) | ||
| 222 | + | ||
| 223 | + if not results: | ||
| 224 | + return "No products found matching your search." | ||
| 225 | + | ||
| 226 | + output = f"Found {len(results)} product(s):\n\n" | ||
| 227 | + for idx, product in enumerate(results, 1): | ||
| 228 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | ||
| 229 | + output += f" ID: {product.get('id', 'N/A')}\n" | ||
| 230 | + output += f" Category: {product.get('masterCategory')} > {product.get('subCategory')} > {product.get('articleType')}\n" | ||
| 231 | + output += f" Color: {product.get('baseColour')}\n" | ||
| 232 | + output += f" Gender: {product.get('gender')}\n" | ||
| 233 | + if "distance" in product: | ||
| 234 | + similarity = 1 - product["distance"] | ||
| 235 | + output += f" Relevance: {similarity:.2%}\n" | ||
| 236 | + output += "\n" | ||
| 237 | + | ||
| 238 | + return output.strip() | ||
| 239 | + except Exception as e: | ||
| 240 | + return f"Error searching products: {str(e)}" | ||
| 241 | +``` | ||
| 242 | + | ||
| 243 | +#### 4.2.2 图像相似度搜索 | ||
| 244 | + | ||
| 245 | +```python | ||
| 246 | +@tool | ||
| 247 | +def search_by_image(image_path: str, limit: int = 5) -> str: | ||
| 248 | + """Find similar fashion products using an image.""" | ||
| 249 | + if not Path(image_path).exists(): | ||
| 250 | + return f"Error: Image file not found at '{image_path}'" | ||
| 251 | + | ||
| 252 | + embedding_service = get_embedding_service() | ||
| 253 | + milvus_service = get_milvus_service() | ||
| 254 | + | ||
| 255 | + if not embedding_service.clip_client: | ||
| 256 | + embedding_service.connect_clip() | ||
| 257 | + | ||
| 258 | + image_embedding = embedding_service.get_image_embedding(image_path) | ||
| 259 | + | ||
| 260 | + results = milvus_service.search_similar_images( | ||
| 261 | + query_embedding=image_embedding, | ||
| 262 | + limit=min(limit + 1, 21), | ||
| 263 | + output_fields=[...], | ||
| 264 | + ) | ||
| 265 | + | ||
| 266 | + # 过滤掉查询图像本身(如上传的是商品库中的图) | ||
| 267 | + query_id = Path(image_path).stem | ||
| 268 | + filtered_results = [r for r in results if Path(r.get("image_path", "")).stem != query_id] | ||
| 269 | + filtered_results = filtered_results[:limit] | ||
| 270 | + | ||
| 271 | + | ||
| 272 | +``` | ||
| 273 | + | ||
| 274 | +#### 4.2.3 视觉分析(VLM) | ||
| 275 | + | ||
| 276 | +```python | ||
| 277 | +@tool | ||
| 278 | +def analyze_image_style(image_path: str) -> str: | ||
| 279 | + """Analyze a fashion product image using AI vision to extract detailed style information.""" | ||
| 280 | + with open(img_path, "rb") as image_file: | ||
| 281 | + image_data = base64.b64encode(image_file.read()).decode("utf-8") | ||
| 282 | + | ||
| 283 | + prompt = """Analyze this fashion product image and provide a detailed description. | ||
| 284 | +Include: | ||
| 285 | +- Product type (e.g., shirt, dress, shoes, pants, bag) | ||
| 286 | +- Primary colors | ||
| 287 | +- Style/design (e.g., casual, formal, sporty, vintage, modern) | ||
| 288 | +- Pattern or texture (e.g., plain, striped, checked, floral) | ||
| 289 | +- Key features (e.g., collar type, sleeve length, fit) | ||
| 290 | +- Material appearance (if obvious, e.g., denim, cotton, leather) | ||
| 291 | +- Suitable occasion (e.g., office wear, party, casual, sports) | ||
| 292 | +Provide a comprehensive yet concise description (3-4 sentences).""" | ||
| 293 | + | ||
| 294 | + client = get_openai_client() | ||
| 295 | + response = client.chat.completions.create( | ||
| 296 | + model="gpt-4o-mini", | ||
| 297 | + messages=[{ | ||
| 298 | + "role": "user", | ||
| 299 | + "content": [ | ||
| 300 | + {"type": "text", "text": prompt}, | ||
| 301 | + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}", "detail": "high"}}, | ||
| 302 | + ], | ||
| 303 | + }], | ||
| 304 | + max_tokens=500, | ||
| 305 | + temperature=0.3, | ||
| 306 | + ) | ||
| 307 | + | ||
| 308 | + return response.choices[0].message.content.strip() | ||
| 309 | +``` | ||
| 310 | + | ||
| 311 | +--- | ||
| 312 | + | ||
| 313 | +### 4.3 向量服务实现 | ||
| 314 | + | ||
| 315 | +#### 4.3.1 EmbeddingService(embedding_service.py) | ||
| 316 | + | ||
| 317 | +```python | ||
| 318 | +class EmbeddingService: | ||
| 319 | + def get_text_embedding(self, text: str) -> List[float]: | ||
| 320 | + """OpenAI text-embedding-3-small""" | ||
| 321 | + response = self.openai_client.embeddings.create( | ||
| 322 | + input=text, model=self.text_embedding_model | ||
| 323 | + ) | ||
| 324 | + return response.data[0].embedding | ||
| 325 | + | ||
| 326 | + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: | ||
| 327 | + """CLIP 图像向量""" | ||
| 328 | + if not self.clip_client: | ||
| 329 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | ||
| 330 | + result = self.clip_client.encode([str(image_path)]) | ||
| 331 | + if isinstance(result, np.ndarray): | ||
| 332 | + embedding = result[0].tolist() if len(result.shape) > 1 else result.tolist() | ||
| 333 | + else: | ||
| 334 | + embedding = result[0].embedding.tolist() | ||
| 335 | + return embedding | ||
| 336 | + | ||
| 337 | + def get_text_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]: | ||
| 338 | + """批量文本嵌入,用于索引""" | ||
| 339 | + for i in range(0, len(texts), batch_size): | ||
| 340 | + batch = texts[i : i + batch_size] | ||
| 341 | + response = self.openai_client.embeddings.create(input=batch, ...) | ||
| 342 | + embeddings = [item.embedding for item in response.data] | ||
| 343 | + all_embeddings.extend(embeddings) | ||
| 344 | + return all_embeddings | ||
| 345 | +``` | ||
| 346 | + | ||
| 347 | +#### 4.3.2 MilvusService(milvus_service.py) | ||
| 348 | + | ||
| 349 | +**文本集合 Schema:** | ||
| 350 | + | ||
| 351 | +```python | ||
| 352 | +schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True) | ||
| 353 | +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | ||
| 354 | +schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) | ||
| 355 | +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim) # 1536 | ||
| 356 | +schema.add_field(field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500) | ||
| 357 | +schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | ||
| 358 | +schema.add_field(field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100) | ||
| 359 | +# ... 更多元数据字段 | ||
| 360 | +``` | ||
| 361 | + | ||
| 362 | +**图像集合 Schema:** | ||
| 363 | + | ||
| 364 | +```python | ||
| 365 | +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | ||
| 366 | +schema.add_field(field_name="image_path", datatype=DataType.VARCHAR, max_length=500) | ||
| 367 | +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim) # 512 | ||
| 368 | +# ... 产品元数据 | ||
| 369 | +``` | ||
| 370 | + | ||
| 371 | +**相似度搜索:** | ||
| 372 | + | ||
| 373 | +```python | ||
| 374 | +def search_similar_text(self, query_embedding, limit=10, output_fields=None): | ||
| 375 | + results = self.client.search( | ||
| 376 | + collection_name=self.text_collection_name, | ||
| 377 | + data=[query_embedding], | ||
| 378 | + limit=limit, | ||
| 379 | + output_fields=output_fields, | ||
| 380 | + ) | ||
| 381 | + formatted_results = [] | ||
| 382 | + for hit in results[0]: | ||
| 383 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | ||
| 384 | + entity = hit.get("entity", {}) | ||
| 385 | + for field in output_fields: | ||
| 386 | + if field in entity: | ||
| 387 | + result[field] = entity.get(field) | ||
| 388 | + formatted_results.append(result) | ||
| 389 | + return formatted_results | ||
| 390 | +``` | ||
| 391 | + | ||
| 392 | +--- | ||
| 393 | + | ||
| 394 | +### 4.4 数据索引脚本(index_data.py) | ||
| 395 | + | ||
| 396 | +#### 4.4.1 产品数据加载 | ||
| 397 | + | ||
| 398 | +```python | ||
| 399 | +def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: | ||
| 400 | + products = {} | ||
| 401 | + # 加载 images.csv 映射 | ||
| 402 | + with open(self.images_csv, "r") as f: | ||
| 403 | + images_dict = {int(row["filename"].split(".")[0]): row["link"] for row in csv.DictReader(f)} | ||
| 404 | + | ||
| 405 | + # 加载 styles.csv | ||
| 406 | + with open(self.styles_csv, "r") as f: | ||
| 407 | + for row in csv.DictReader(f): | ||
| 408 | + product_id = int(row["id"]) | ||
| 409 | + products[product_id] = { | ||
| 410 | + "id": product_id, | ||
| 411 | + "gender": row.get("gender", ""), | ||
| 412 | + "masterCategory": row.get("masterCategory", ""), | ||
| 413 | + "subCategory": row.get("subCategory", ""), | ||
| 414 | + "articleType": row.get("articleType", ""), | ||
| 415 | + "baseColour": row.get("baseColour", ""), | ||
| 416 | + "season": row.get("season", ""), | ||
| 417 | + "usage": row.get("usage", ""), | ||
| 418 | + "productDisplayName": row.get("productDisplayName", ""), | ||
| 419 | + "imagePath": f"{product_id}.jpg", | ||
| 420 | + } | ||
| 421 | + return products | ||
| 422 | +``` | ||
| 423 | + | ||
| 424 | +#### 4.4.2 文本索引 | ||
| 425 | + | ||
| 426 | +```python | ||
| 427 | +def _create_product_text(self, product: Dict[str, Any]) -> str: | ||
| 428 | + """构造产品文本用于 embedding""" | ||
| 429 | + parts = [ | ||
| 430 | + product.get("productDisplayName", ""), | ||
| 431 | + f"Gender: {product.get('gender', '')}", | ||
| 432 | + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", | ||
| 433 | + f"Type: {product.get('articleType', '')}", | ||
| 434 | + f"Color: {product.get('baseColour', '')}", | ||
| 435 | + f"Season: {product.get('season', '')}", | ||
| 436 | + f"Usage: {product.get('usage', '')}", | ||
| 437 | + ] | ||
| 438 | + return " | ".join([p for p in parts if p and p != "Gender: " and p != "Color: "]) | ||
| 439 | +``` | ||
| 440 | + | ||
| 441 | +#### 4.4.3 批量索引流程 | ||
| 442 | + | ||
| 443 | +```python | ||
| 444 | +# 文本索引 | ||
| 445 | +texts = [self._create_product_text(p) for p in products] | ||
| 446 | +embeddings = self.embedding_service.get_text_embeddings_batch(texts, batch_size=50) | ||
| 447 | +milvus_data = [{ | ||
| 448 | + "id": product_id, | ||
| 449 | + "text": text[:2000], | ||
| 450 | + "embedding": embedding, | ||
| 451 | + "productDisplayName": product["productDisplayName"][:500], | ||
| 452 | + "gender": product["gender"][:50], | ||
| 453 | + # ... 其他元数据 | ||
| 454 | +} for product_id, text, embedding in zip(...)] | ||
| 455 | +self.milvus_service.insert_text_embeddings(milvus_data) | ||
| 456 | + | ||
| 457 | +# 图像索引 | ||
| 458 | +image_paths = [self.image_dir / p["imagePath"] for p in products] | ||
| 459 | +embeddings = self.embedding_service.get_image_embeddings_batch(image_paths, batch_size=32) | ||
| 460 | +# 类似插入 image_embeddings 集合 | ||
| 461 | +``` | ||
| 462 | + | ||
| 463 | +--- | ||
| 464 | + | ||
| 465 | +### 4.5 Streamlit 前端(app.py) | ||
| 466 | + | ||
| 467 | +#### 4.5.1 会话与 Agent 初始化 | ||
| 468 | + | ||
| 469 | +```python | ||
| 470 | +def initialize_session(): | ||
| 471 | + if "session_id" not in st.session_state: | ||
| 472 | + st.session_state.session_id = str(uuid.uuid4()) | ||
| 473 | + if "shopping_agent" not in st.session_state: | ||
| 474 | + st.session_state.shopping_agent = ShoppingAgent(session_id=st.session_state.session_id) | ||
| 475 | + if "messages" not in st.session_state: | ||
| 476 | + st.session_state.messages = [] | ||
| 477 | + if "uploaded_image" not in st.session_state: | ||
| 478 | + st.session_state.uploaded_image = None | ||
| 479 | +``` | ||
| 480 | + | ||
| 481 | +#### 4.5.2 产品信息解析 | ||
| 482 | + | ||
| 483 | +```python | ||
| 484 | +def extract_products_from_response(response: str) -> list: | ||
| 485 | + """从 Agent 回复中解析产品信息""" | ||
| 486 | + products = [] | ||
| 487 | + for line in response.split("\n"): | ||
| 488 | + if re.match(r"^\*?\*?\d+\.\s+", line): | ||
| 489 | + if current_product: | ||
| 490 | + products.append(current_product) | ||
| 491 | + current_product = {"name": re.sub(r"^\*?\*?\d+\.\s+", "", line).replace("**", "").strip()} | ||
| 492 | + elif "ID:" in line: | ||
| 493 | + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) | ||
| 494 | + if id_match: | ||
| 495 | + current_product["id"] = id_match.group(1) | ||
| 496 | + elif "Category:" in line: | ||
| 497 | + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) | ||
| 498 | + if cat_match: | ||
| 499 | + current_product["category"] = cat_match.group(1).strip() | ||
| 500 | + # ... Color, Gender, Season, Usage, Similarity/Relevance | ||
| 501 | + return products | ||
| 502 | +``` | ||
| 503 | + | ||
| 504 | +#### 4.5.3 多轮对话中的图片引用 | ||
| 505 | + | ||
| 506 | +```python | ||
| 507 | +# 用户输入 "make them formal" 时,若上一条消息有图片,则引用该图片 | ||
| 508 | +if any(ref in query_lower for ref in ["this", "that", "the image", "it"]): | ||
| 509 | + for msg in reversed(st.session_state.messages): | ||
| 510 | + if msg.get("role") == "user" and msg.get("image_path"): | ||
| 511 | + image_path = msg["image_path"] | ||
| 512 | + break | ||
| 513 | +``` | ||
| 514 | + | ||
| 515 | +--- | ||
| 516 | + | ||
| 517 | +### 4.6 配置管理(config.py) | ||
| 518 | + | ||
| 519 | +```python | ||
| 520 | +class Settings(BaseSettings): | ||
| 521 | + openai_api_key: str | ||
| 522 | + openai_model: str = "gpt-4o-mini" | ||
| 523 | + openai_embedding_model: str = "text-embedding-3-small" | ||
| 524 | + clip_server_url: str = "grpc://localhost:51000" | ||
| 525 | + milvus_uri: str = "http://localhost:19530" | ||
| 526 | + text_collection_name: str = "text_embeddings" | ||
| 527 | + image_collection_name: str = "image_embeddings" | ||
| 528 | + text_dim: int = 1536 | ||
| 529 | + image_dim: int = 512 | ||
| 530 | + | ||
| 531 | + @property | ||
| 532 | + def milvus_uri_absolute(self) -> str: | ||
| 533 | + """支持 Milvus Standalone 和 Milvus Lite""" | ||
| 534 | + if self.milvus_uri.startswith(("http://", "https://")): | ||
| 535 | + return self.milvus_uri | ||
| 536 | + if self.milvus_uri.startswith("./"): | ||
| 537 | + return os.path.join(base_dir, self.milvus_uri[2:]) | ||
| 538 | + return self.milvus_uri | ||
| 539 | + | ||
| 540 | + class Config: | ||
| 541 | + env_file = ".env" | ||
| 542 | +``` | ||
| 543 | + | ||
| 544 | +--- | ||
| 545 | + | ||
| 546 | +## 五、部署与运行 | ||
| 547 | + | ||
| 548 | +### 5.1 依赖服务 | ||
| 549 | + | ||
| 550 | +```yaml | ||
| 551 | +# docker-compose.yml 提供 | ||
| 552 | +- etcd: 元数据存储 | ||
| 553 | +- minio: 对象存储 | ||
| 554 | +- milvus-standalone: 向量数据库 | ||
| 555 | +- attu: Milvus 管理界面 | ||
| 556 | +``` | ||
| 557 | + | ||
| 558 | +### 5.2 启动流程 | ||
| 559 | + | ||
| 560 | +```bash | ||
| 561 | +# 1. 环境 | ||
| 562 | +pip install -r requirements.txt | ||
| 563 | +cp .env.example .env # 配置 OPENAI_API_KEY | ||
| 564 | + | ||
| 565 | +# 2. 下载数据 | ||
| 566 | +python scripts/download_dataset.py # Kaggle Fashion Product Images Dataset | ||
| 567 | + | ||
| 568 | +# 3. 启动 CLIP 服务(需单独运行) | ||
| 569 | +python -m clip_server | ||
| 570 | + | ||
| 571 | +# 4. 启动 Milvus | ||
| 572 | +docker-compose up | ||
| 573 | + | ||
| 574 | +# 5. 索引数据 | ||
| 575 | +python scripts/index_data.py | ||
| 576 | + | ||
| 577 | +# 6. 启动应用 | ||
| 578 | +streamlit run app.py | ||
| 579 | +``` | ||
| 580 | + | ||
| 581 | +--- | ||
| 582 | + | ||
| 583 | +## 六、典型交互流程 | ||
| 584 | + | ||
| 585 | +| 场景 | 用户输入 | Agent 行为 | 工具调用 | | ||
| 586 | +|------|----------|------------|----------| | ||
| 587 | +| 文本搜索 | "winter coats for women" | 直接文本搜索 | `search_products("winter coats women")` | | ||
| 588 | +| 图像搜索 | [上传图片] "find similar" | 图像相似度搜索 | `search_by_image(path)` | | ||
| 589 | +| 风格分析+搜索 | [上传复古夹克] "what style? find matching pants" | 先分析风格再搜索 | `analyze_image_style(path)` → `search_products("vintage pants casual")` | | ||
| 590 | +| 多轮上下文 | [第1轮] "show me red dresses"<br>[第2轮] "make them formal" | 结合上下文 | `search_products("red formal dresses")` | | ||
| 591 | + | ||
| 592 | +--- | ||
| 593 | + | ||
| 594 | +## 七、设计要点总结 | ||
| 595 | + | ||
| 596 | +1. **ReAct 模式**:Agent 自主决定何时调用工具、调用哪些工具、是否继续调用。 | ||
| 597 | +2. **LangGraph 状态图**:`START → Agent → [条件] → Tools → Agent → END`,支持多轮工具调用。 | ||
| 598 | +3. **多模态**:文本 + 图像 + VLM 分析,覆盖文本搜索、以图搜图、风格理解。 | ||
| 599 | +4. **双向量集合**:Milvus 中 text_embeddings / image_embeddings 分别存储,支持不同模态的检索。 | ||
| 600 | +5. **会话持久化**:`MemorySaver` + `thread_id` 实现多轮对话记忆。 | ||
| 601 | +6. **格式约束**:System prompt 严格限制产品输出格式,便于前端解析和展示。 | ||
| 602 | + | ||
| 603 | +--- | ||
| 604 | + | ||
| 605 | +## 八、附录:项目结构 | ||
| 606 | + | ||
| 607 | +``` | ||
| 608 | +OmniShopAgent/ | ||
| 609 | +├── app/ | ||
| 610 | +│ ├── agents/ | ||
| 611 | +│ │ └── shopping_agent.py | ||
| 612 | +│ ├── config.py | ||
| 613 | +│ ├── services/ | ||
| 614 | +│ │ ├── embedding_service.py | ||
| 615 | +│ │ └── milvus_service.py | ||
| 616 | +│ └── tools/ | ||
| 617 | +│ └── search_tools.py | ||
| 618 | +├── scripts/ | ||
| 619 | +│ ├── download_dataset.py | ||
| 620 | +│ └── index_data.py | ||
| 621 | +├── app.py | ||
| 622 | +├── docker-compose.yml | ||
| 623 | +└── requirements.txt | ||
| 624 | +``` |