Commit e7f2b2409cd4db20be799b9ce9995e939c0d4807
0 parents
first commit
Showing
30 changed files
with
4836 additions
and
0 deletions
Show diff stats
| 1 | +++ a/.env.example | |
| ... | ... | @@ -0,0 +1,47 @@ |
| 1 | +# ==================== | |
| 2 | +# OpenAI Configuration | |
| 3 | +# ==================== | |
| 4 | +OPENAI_API_KEY= | |
| 5 | +OPENAI_MODEL=gpt-4o-mini | |
| 6 | +OPENAI_EMBEDDING_MODEL=text-embedding-3-small | |
| 7 | +OPENAI_TEMPERATURE=1 | |
| 8 | +OPENAI_MAX_TOKENS=1000 | |
| 9 | + | |
| 10 | +# ==================== | |
| 11 | +# CLIP Server Configuration | |
| 12 | +# ==================== | |
| 13 | +CLIP_SERVER_URL=grpc://localhost:51000 | |
| 14 | + | |
| 15 | +# ==================== | |
| 16 | +# Milvus Configuration | |
| 17 | +# ==================== | |
| 18 | +MILVUS_HOST=localhost | |
| 19 | +MILVUS_PORT=19530 | |
| 20 | + | |
| 21 | +# Collection settings | |
| 22 | +TEXT_COLLECTION_NAME=text_embeddings | |
| 23 | +IMAGE_COLLECTION_NAME=image_embeddings | |
| 24 | +TEXT_DIM=1536 | |
| 25 | +IMAGE_DIM=512 | |
| 26 | + | |
| 27 | +# ==================== | |
| 28 | +# Search Configuration | |
| 29 | +# ==================== | |
| 30 | +TOP_K_RESULTS=30 | |
| 31 | +SIMILARITY_THRESHOLD=0.6 | |
| 32 | + | |
| 33 | +# ==================== | |
| 34 | +# Application Configuration | |
| 35 | +# ==================== | |
| 36 | +APP_HOST=0.0.0.0 | |
| 37 | +APP_PORT=8000 | |
| 38 | +DEBUG=true | |
| 39 | +LOG_LEVEL=INFO | |
| 40 | + | |
| 41 | +# ==================== | |
| 42 | +# Data Paths | |
| 43 | +# ==================== | |
| 44 | +RAW_DATA_PATH=./data/raw | |
| 45 | +PROCESSED_DATA_PATH=./data/processed | |
| 46 | +IMAGE_DATA_PATH=./data/images | |
| 47 | + | ... | ... |
| 1 | +++ a/.gitignore | |
| ... | ... | @@ -0,0 +1,83 @@ |
| 1 | +# Python | |
| 2 | +__pycache__/ | |
| 3 | +*.py[cod] | |
| 4 | +*$py.class | |
| 5 | +*.so | |
| 6 | +.Python | |
| 7 | +build/ | |
| 8 | +develop-eggs/ | |
| 9 | +dist/ | |
| 10 | +downloads/ | |
| 11 | +eggs/ | |
| 12 | +.eggs/ | |
| 13 | +lib/ | |
| 14 | +lib64/ | |
| 15 | +parts/ | |
| 16 | +sdist/ | |
| 17 | +var/ | |
| 18 | +wheels/ | |
| 19 | +*.egg-info/ | |
| 20 | +.installed.cfg | |
| 21 | +*.egg | |
| 22 | +MANIFEST | |
| 23 | + | |
| 24 | +# Virtual Environment | |
| 25 | +venv/ | |
| 26 | +env/ | |
| 27 | +ENV/ | |
| 28 | +.venv | |
| 29 | + | |
| 30 | +# Environment Variables | |
| 31 | +.env | |
| 32 | +*.env | |
| 33 | +!.env.example | |
| 34 | + | |
| 35 | +# IDEs | |
| 36 | +.vscode/ | |
| 37 | +.idea/ | |
| 38 | +.cursor/ | |
| 39 | +*.swp | |
| 40 | +*.swo | |
| 41 | +*~ | |
| 42 | +.DS_Store | |
| 43 | + | |
| 44 | +# Data Files - ignore everything in data/ except .gitkeep files | |
| 45 | +data/** | |
| 46 | +!data/ | |
| 47 | +!data/raw/ | |
| 48 | +!data/processed/ | |
| 49 | +!data/images/ | |
| 50 | +!data/**/.gitkeep | |
| 51 | + | |
| 52 | +# Database | |
| 53 | +*.db | |
| 54 | +*.sqlite | |
| 55 | +*.sqlite3 | |
| 56 | +data/milvus_lite.db | |
| 57 | + | |
| 58 | +# Docker volumes | |
| 59 | +volumes/ | |
| 60 | + | |
| 61 | +# Logs | |
| 62 | +*.log | |
| 63 | +logs/ | |
| 64 | +nohup.out | |
| 65 | + | |
| 66 | +# Testing | |
| 67 | +.pytest_cache/ | |
| 68 | +.coverage | |
| 69 | +htmlcov/ | |
| 70 | +.tox/ | |
| 71 | + | |
| 72 | +# Jupyter | |
| 73 | +.ipynb_checkpoints/ | |
| 74 | +*.ipynb | |
| 75 | + | |
| 76 | +# Model caches | |
| 77 | +.cache/ | |
| 78 | +models/ | |
| 79 | + | |
| 80 | +# Temporary files | |
| 81 | +tmp/ | |
| 82 | +temp/ | |
| 83 | +*.tmp | ... | ... |
| 1 | +++ a/LICENSE | |
| ... | ... | @@ -0,0 +1,21 @@ |
| 1 | +MIT License | |
| 2 | + | |
| 3 | +Copyright (c) 2025 zhangruotian | |
| 4 | + | |
| 5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | |
| 6 | +of this software and associated documentation files (the "Software"), to deal | |
| 7 | +in the Software without restriction, including without limitation the rights | |
| 8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| 9 | +copies of the Software, and to permit persons to whom the Software is | |
| 10 | +furnished to do so, subject to the following conditions: | |
| 11 | + | |
| 12 | +The above copyright notice and this permission notice shall be included in all | |
| 13 | +copies or substantial portions of the Software. | |
| 14 | + | |
| 15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| 16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| 17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| 18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| 19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| 20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| 21 | +SOFTWARE. | ... | ... |
| 1 | +++ a/README.md | |
| ... | ... | @@ -0,0 +1,161 @@ |
| 1 | +# OmniShopAgent | |
| 2 | + | |
| 3 | +An autonomous multi-modal fashion shopping agent powered by **LangGraph** and **ReAct pattern**. | |
| 4 | + | |
| 5 | +## Demo | |
| 6 | + | |
| 7 | +📄 **[demo.pdf](./demo.pdf)** | |
| 8 | + | |
| 9 | +## Overview | |
| 10 | + | |
| 11 | +OmniShopAgent autonomously decides which tools to call, maintains conversation state, and determines when to respond. Built with **LangGraph**, it uses agentic patterns for intelligent product discovery. | |
| 12 | + | |
| 13 | +**Key Features:** | |
| 14 | +- Autonomous tool selection and execution | |
| 15 | +- Multi-modal search (text + image) | |
| 16 | +- Conversational context awareness | |
| 17 | +- Real-time visual analysis | |
| 18 | + | |
| 19 | +## Tech Stack | |
| 20 | + | |
| 21 | +| Component | Technology | | |
| 22 | +|-----------|-----------| | |
| 23 | +| **Agent Framework** | LangGraph | | |
| 24 | +| **LLM** | any LLM supported by LangChain | | |
| 25 | +| **Text Embedding** | text-embedding-3-small | | |
| 26 | +| **Image Embedding** | CLIP ViT-B/32 | | |
| 27 | +| **Vector Database** | Milvus | | |
| 28 | +| **Frontend** | Streamlit | | |
| 29 | +| **Dataset** | Kaggle Fashion Products | | |
| 30 | + | |
| 31 | +## Architecture | |
| 32 | + | |
| 33 | +**Agent Flow:** | |
| 34 | + | |
| 35 | +```mermaid | |
| 36 | +graph LR | |
| 37 | + START --> Agent | |
| 38 | + Agent -->|Has tool_calls| Tools | |
| 39 | + Agent -->|No tool_calls| END | |
| 40 | + Tools --> Agent | |
| 41 | + | |
| 42 | + subgraph "Agent Node" | |
| 43 | + A[Receive Messages] --> B[LLM Reasoning] | |
| 44 | + B --> C{Need Tools?} | |
| 45 | + C -->|Yes| D[Generate tool_calls] | |
| 46 | + C -->|No| E[Generate Response] | |
| 47 | + end | |
| 48 | + | |
| 49 | + subgraph "Tool Node" | |
| 50 | + F[Execute Tools] --> G[Return ToolMessage] | |
| 51 | + end | |
| 52 | +``` | |
| 53 | + | |
| 54 | +**Available Tools:** | |
| 55 | +- `search_products(query)` - Text-based semantic search | |
| 56 | +- `search_by_image(image_path)` - Visual similarity search | |
| 57 | +- `analyze_image_style(image_path)` - VLM style analysis | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | +## Examples | |
| 62 | + | |
| 63 | +**Text Search:** | |
| 64 | +``` | |
| 65 | +User: "winter coats for women" | |
| 66 | +Agent: search_products("winter coats women") → Returns 5 products | |
| 67 | +``` | |
| 68 | + | |
| 69 | +**Image Upload:** | |
| 70 | +``` | |
| 71 | +User: [uploads sneaker photo] "find similar" | |
| 72 | +Agent: search_by_image(path) → Returns visually similar shoes | |
| 73 | +``` | |
| 74 | + | |
| 75 | +**Style Analysis + Search:** | |
| 76 | +``` | |
| 77 | +User: [uploads vintage jacket] "what style is this? find matching pants" | |
| 78 | +Agent: analyze_image_style(path) → "Vintage denim bomber..." | |
| 79 | + search_products("vintage pants casual") → Returns matching items | |
| 80 | +``` | |
| 81 | + | |
| 82 | +**Multi-turn Context:** | |
| 83 | +``` | |
| 84 | +Turn 1: "show me red dresses" | |
| 85 | +Agent: search_products("red dresses") → Results | |
| 86 | + | |
| 87 | +Turn 2: "make them formal" | |
| 88 | +Agent: [remembers context] → search_products("red formal dresses") → Results | |
| 89 | +``` | |
| 90 | + | |
| 91 | +**Complex Reasoning:** | |
| 92 | +``` | |
| 93 | +User: [uploads office outfit] "I like the shirt but need something more casual" | |
| 94 | +Agent: analyze_image_style(path) → Extracts shirt details | |
| 95 | + search_products("casual shirt [color] [style]") → Returns casual alternatives | |
| 96 | +``` | |
| 97 | + | |
| 98 | +## Installation | |
| 99 | + | |
| 100 | +**Prerequisites:** | |
| 101 | +- Python 3.12+ (LangChain 1.x 要求 Python 3.10+) | |
| 102 | +- OpenAI API Key | |
| 103 | +- Docker & Docker Compose | |
| 104 | + | |
| 105 | +### 1. Setup Environment | |
| 106 | +```bash | |
| 107 | +# Clone and install dependencies | |
| 108 | +git clone <repository-url> | |
| 109 | +cd OmniShopAgent | |
| 110 | +python -m venv venv | |
| 111 | +source venv/bin/activate # Windows: venv\Scripts\activate | |
| 112 | +pip install -r requirements.txt | |
| 113 | + | |
| 114 | +# Configure environment variables | |
| 115 | +cp .env.example .env | |
| 116 | +# Edit .env and add your OPENAI_API_KEY | |
| 117 | +``` | |
| 118 | + | |
| 119 | +### 2. Download Dataset | |
| 120 | +Download the [Fashion Product Images Dataset](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-dataset) from Kaggle and extract to `./data/`: | |
| 121 | + | |
| 122 | +```python | |
| 123 | +python scripts/download_dataset.py | |
| 124 | +``` | |
| 125 | + | |
| 126 | +Expected structure: | |
| 127 | +``` | |
| 128 | +data/ | |
| 129 | +├── images/ # ~44k product images | |
| 130 | +├── styles.csv # Product metadata | |
| 131 | +└── images.csv # Image filenames | |
| 132 | +``` | |
| 133 | + | |
| 134 | +### 3. Start Services | |
| 135 | + | |
| 136 | +```bash | |
| 137 | +docker-compose up | |
| 138 | +python -m clip_server | |
| 139 | +``` | |
| 140 | + | |
| 141 | + | |
| 142 | +### 4. Index Data | |
| 143 | + | |
| 144 | +```bash | |
| 145 | +python scripts/index_data.py | |
| 146 | +``` | |
| 147 | + | |
| 148 | +This generates and stores text/image embeddings for all 44k products in Milvus. | |
| 149 | + | |
| 150 | +### 5. Launch Application | |
| 151 | +```bash | |
| 152 | +# 使用启动脚本(推荐) | |
| 153 | +./scripts/start.sh | |
| 154 | + | |
| 155 | +# 或直接运行 | |
| 156 | +streamlit run app.py | |
| 157 | +``` | |
| 158 | +Opens at `http://localhost:8501` | |
| 159 | + | |
| 160 | +### CentOS 8 部署 | |
| 161 | +详见 [docs/DEPLOY_CENTOS8.md](docs/DEPLOY_CENTOS8.md) | ... | ... |
| 1 | +++ a/app.py | |
| ... | ... | @@ -0,0 +1,732 @@ |
| 1 | +""" | |
| 2 | +OmniShopAgent - Streamlit UI | |
| 3 | +Multi-modal fashion shopping assistant with conversational AI | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import logging | |
| 7 | +import re | |
| 8 | +import uuid | |
| 9 | +from pathlib import Path | |
| 10 | +from typing import Optional | |
| 11 | + | |
| 12 | +import streamlit as st | |
| 13 | +from PIL import Image, ImageOps | |
| 14 | + | |
| 15 | +from app.agents.shopping_agent import ShoppingAgent | |
| 16 | + | |
| 17 | +# Configure logging | |
| 18 | +logging.basicConfig( | |
| 19 | + level=logging.INFO, | |
| 20 | + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| 21 | +) | |
| 22 | +logger = logging.getLogger(__name__) | |
| 23 | + | |
| 24 | +# Page config | |
| 25 | +st.set_page_config( | |
| 26 | + page_title="OmniShopAgent", | |
| 27 | + page_icon="👗", | |
| 28 | + layout="centered", | |
| 29 | + initial_sidebar_state="collapsed", | |
| 30 | +) | |
| 31 | + | |
| 32 | +# Custom CSS - ChatGPT-like style | |
| 33 | +st.markdown( | |
| 34 | + """ | |
| 35 | + <style> | |
| 36 | + /* Hide default Streamlit elements */ | |
| 37 | + #MainMenu {visibility: hidden;} | |
| 38 | + footer {visibility: hidden;} | |
| 39 | + header {visibility: hidden;} | |
| 40 | + | |
| 41 | + /* Body and root container */ | |
| 42 | + .main .block-container { | |
| 43 | + padding-bottom: 180px !important; | |
| 44 | + padding-top: 2rem; | |
| 45 | + max-width: 900px; | |
| 46 | + margin: 0 auto; | |
| 47 | + } | |
| 48 | + | |
| 49 | + /* Fixed input container at bottom */ | |
| 50 | + .fixed-input-container { | |
| 51 | + position: fixed; | |
| 52 | + bottom: 0; | |
| 53 | + left: 0; | |
| 54 | + right: 0; | |
| 55 | + background: white; | |
| 56 | + border-top: 1px solid #e5e5e5; | |
| 57 | + padding: 1rem 0; | |
| 58 | + z-index: 1000; | |
| 59 | + box-shadow: 0 -2px 10px rgba(0,0,0,0.05); | |
| 60 | + } | |
| 61 | + | |
| 62 | + .fixed-input-container .block-container { | |
| 63 | + max-width: 900px; | |
| 64 | + margin: 0 auto; | |
| 65 | + padding: 0 1rem !important; | |
| 66 | + } | |
| 67 | + | |
| 68 | + /* Message bubbles */ | |
| 69 | + .message { | |
| 70 | + margin: 1rem 0; | |
| 71 | + padding: 1rem 1.5rem; | |
| 72 | + border-radius: 1rem; | |
| 73 | + animation: fadeIn 0.3s ease-in; | |
| 74 | + } | |
| 75 | + | |
| 76 | + @keyframes fadeIn { | |
| 77 | + from { opacity: 0; transform: translateY(10px); } | |
| 78 | + to { opacity: 1; transform: translateY(0); } | |
| 79 | + } | |
| 80 | + | |
| 81 | + .user-message { | |
| 82 | + background: transparent; | |
| 83 | + margin: 0 0 1rem 0; | |
| 84 | + padding: 0; | |
| 85 | + border-radius: 0; | |
| 86 | + } | |
| 87 | + | |
| 88 | + .assistant-message { | |
| 89 | + background: white; | |
| 90 | + border: 1px solid #e5e5e5; | |
| 91 | + margin-right: 3rem; | |
| 92 | + } | |
| 93 | + | |
| 94 | + /* Product cards - simplified */ | |
| 95 | + .stImage { | |
| 96 | + border-radius: 0px; | |
| 97 | + overflow: hidden; | |
| 98 | + } | |
| 99 | + | |
| 100 | + .stImage img { | |
| 101 | + transition: transform 0.2s; | |
| 102 | + } | |
| 103 | + | |
| 104 | + .stImage:hover img { | |
| 105 | + transform: scale(1.05); | |
| 106 | + } | |
| 107 | + | |
| 108 | + /* Scroll to bottom behavior */ | |
| 109 | + html { | |
| 110 | + scroll-behavior: smooth; | |
| 111 | + } | |
| 112 | + | |
| 113 | + /* Header */ | |
| 114 | + .app-header { | |
| 115 | + text-align: center; | |
| 116 | + padding: 2rem 1rem; | |
| 117 | + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| 118 | + color: white; | |
| 119 | + border-radius: 1rem; | |
| 120 | + margin-bottom: 2rem; | |
| 121 | + } | |
| 122 | + | |
| 123 | + .app-title { | |
| 124 | + font-size: 2rem; | |
| 125 | + font-weight: 700; | |
| 126 | + margin: 0; | |
| 127 | + } | |
| 128 | + | |
| 129 | + .app-subtitle { | |
| 130 | + font-size: 1rem; | |
| 131 | + opacity: 0.9; | |
| 132 | + margin-top: 0.5rem; | |
| 133 | + } | |
| 134 | + | |
| 135 | + /* Welcome screen */ | |
| 136 | + .welcome-container { | |
| 137 | + text-align: center; | |
| 138 | + padding: 4rem 2rem; | |
| 139 | + color: #666; | |
| 140 | + } | |
| 141 | + | |
| 142 | + .welcome-title { | |
| 143 | + font-size: 2rem; | |
| 144 | + font-weight: 600; | |
| 145 | + color: #1a1a1a; | |
| 146 | + margin-bottom: 1rem; | |
| 147 | + } | |
| 148 | + | |
| 149 | + .welcome-features { | |
| 150 | + display: grid; | |
| 151 | + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
| 152 | + gap: 1.5rem; | |
| 153 | + margin: 2rem 0; | |
| 154 | + } | |
| 155 | + | |
| 156 | + .feature-card { | |
| 157 | + background: #f7f7f8; | |
| 158 | + padding: 1.5rem; | |
| 159 | + border-radius: 12px; | |
| 160 | + transition: all 0.2s; | |
| 161 | + } | |
| 162 | + | |
| 163 | + .feature-card:hover { | |
| 164 | + background: #efefef; | |
| 165 | + transform: translateY(-2px); | |
| 166 | + } | |
| 167 | + | |
| 168 | + .feature-icon { | |
| 169 | + font-size: 2rem; | |
| 170 | + margin-bottom: 0.5rem; | |
| 171 | + } | |
| 172 | + | |
| 173 | + .feature-title { | |
| 174 | + font-weight: 600; | |
| 175 | + margin-bottom: 0.25rem; | |
| 176 | + } | |
| 177 | + | |
| 178 | + /* Image preview */ | |
| 179 | + .image-preview { | |
| 180 | + position: relative; | |
| 181 | + display: inline-block; | |
| 182 | + margin: 0.5rem 0; | |
| 183 | + } | |
| 184 | + | |
| 185 | + .image-preview img { | |
| 186 | + max-width: 200px; | |
| 187 | + border-radius: 8px; | |
| 188 | + border: 2px solid #e5e5e5; | |
| 189 | + } | |
| 190 | + | |
| 191 | + .remove-image-btn { | |
| 192 | + position: absolute; | |
| 193 | + top: 5px; | |
| 194 | + right: 5px; | |
| 195 | + background: rgba(0,0,0,0.6); | |
| 196 | + color: white; | |
| 197 | + border: none; | |
| 198 | + border-radius: 50%; | |
| 199 | + width: 24px; | |
| 200 | + height: 24px; | |
| 201 | + cursor: pointer; | |
| 202 | + font-size: 14px; | |
| 203 | + } | |
| 204 | + | |
| 205 | + /* Buttons */ | |
| 206 | + .stButton>button { | |
| 207 | + border-radius: 8px; | |
| 208 | + border: 1px solid #e5e5e5; | |
| 209 | + padding: 0.5rem 1rem; | |
| 210 | + transition: all 0.2s; | |
| 211 | + } | |
| 212 | + | |
| 213 | + .stButton>button:hover { | |
| 214 | + background: #f0f0f0; | |
| 215 | + border-color: #d0d0d0; | |
| 216 | + } | |
| 217 | + | |
| 218 | + /* Hide upload button label */ | |
| 219 | + .uploadedFile { | |
| 220 | + display: none; | |
| 221 | + } | |
| 222 | + </style> | |
| 223 | + """, | |
| 224 | + unsafe_allow_html=True, | |
| 225 | +) | |
| 226 | + | |
| 227 | + | |
| 228 | +# Initialize session state | |
| 229 | +def initialize_session(): | |
| 230 | + """Initialize session state variables""" | |
| 231 | + if "session_id" not in st.session_state: | |
| 232 | + st.session_state.session_id = str(uuid.uuid4()) | |
| 233 | + | |
| 234 | + if "shopping_agent" not in st.session_state: | |
| 235 | + st.session_state.shopping_agent = ShoppingAgent( | |
| 236 | + session_id=st.session_state.session_id | |
| 237 | + ) | |
| 238 | + | |
| 239 | + if "messages" not in st.session_state: | |
| 240 | + st.session_state.messages = [] | |
| 241 | + | |
| 242 | + if "uploaded_image" not in st.session_state: | |
| 243 | + st.session_state.uploaded_image = None | |
| 244 | + | |
| 245 | + if "show_image_upload" not in st.session_state: | |
| 246 | + st.session_state.show_image_upload = False | |
| 247 | + | |
| 248 | + | |
| 249 | +def save_uploaded_image(uploaded_file) -> Optional[str]: | |
| 250 | + """Save uploaded image to temp directory""" | |
| 251 | + if uploaded_file is None: | |
| 252 | + return None | |
| 253 | + | |
| 254 | + try: | |
| 255 | + temp_dir = Path("temp_uploads") | |
| 256 | + temp_dir.mkdir(exist_ok=True) | |
| 257 | + | |
| 258 | + image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" | |
| 259 | + with open(image_path, "wb") as f: | |
| 260 | + f.write(uploaded_file.getbuffer()) | |
| 261 | + | |
| 262 | + logger.info(f"Saved uploaded image to {image_path}") | |
| 263 | + return str(image_path) | |
| 264 | + | |
| 265 | + except Exception as e: | |
| 266 | + logger.error(f"Error saving uploaded image: {e}") | |
| 267 | + st.error(f"Failed to save image: {str(e)}") | |
| 268 | + return None | |
| 269 | + | |
| 270 | + | |
| 271 | +def extract_products_from_response(response: str) -> list: | |
| 272 | + """Extract product information from agent response | |
| 273 | + | |
| 274 | + Returns list of dicts with product info | |
| 275 | + """ | |
| 276 | + products = [] | |
| 277 | + | |
| 278 | + # Pattern to match product blocks in the response | |
| 279 | + # Looking for ID, name, and other details | |
| 280 | + lines = response.split("\n") | |
| 281 | + current_product = {} | |
| 282 | + | |
| 283 | + for line in lines: | |
| 284 | + line = line.strip() | |
| 285 | + | |
| 286 | + # Match product number (e.g., "1. Product Name" or "**1. Product Name**") | |
| 287 | + if re.match(r"^\*?\*?\d+\.\s+", line): | |
| 288 | + if current_product: | |
| 289 | + products.append(current_product) | |
| 290 | + current_product = {} | |
| 291 | + # Extract product name | |
| 292 | + name = re.sub(r"^\*?\*?\d+\.\s+", "", line) | |
| 293 | + name = name.replace("**", "").strip() | |
| 294 | + current_product["name"] = name | |
| 295 | + | |
| 296 | + # Match ID | |
| 297 | + elif "ID:" in line or "id:" in line: | |
| 298 | + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) | |
| 299 | + if id_match: | |
| 300 | + current_product["id"] = id_match.group(1) | |
| 301 | + | |
| 302 | + # Match Category | |
| 303 | + elif "Category:" in line: | |
| 304 | + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) | |
| 305 | + if cat_match: | |
| 306 | + current_product["category"] = cat_match.group(1).strip() | |
| 307 | + | |
| 308 | + # Match Color | |
| 309 | + elif "Color:" in line: | |
| 310 | + color_match = re.search(r"Color:\s*(\w+)", line) | |
| 311 | + if color_match: | |
| 312 | + current_product["color"] = color_match.group(1) | |
| 313 | + | |
| 314 | + # Match Gender | |
| 315 | + elif "Gender:" in line: | |
| 316 | + gender_match = re.search(r"Gender:\s*(\w+)", line) | |
| 317 | + if gender_match: | |
| 318 | + current_product["gender"] = gender_match.group(1) | |
| 319 | + | |
| 320 | + # Match Season | |
| 321 | + elif "Season:" in line: | |
| 322 | + season_match = re.search(r"Season:\s*(\w+)", line) | |
| 323 | + if season_match: | |
| 324 | + current_product["season"] = season_match.group(1) | |
| 325 | + | |
| 326 | + # Match Usage | |
| 327 | + elif "Usage:" in line: | |
| 328 | + usage_match = re.search(r"Usage:\s*(\w+)", line) | |
| 329 | + if usage_match: | |
| 330 | + current_product["usage"] = usage_match.group(1) | |
| 331 | + | |
| 332 | + # Match Similarity/Relevance score | |
| 333 | + elif "Similarity:" in line or "Relevance:" in line: | |
| 334 | + score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) | |
| 335 | + if score_match: | |
| 336 | + current_product["score"] = score_match.group(1) | |
| 337 | + | |
| 338 | + # Add last product | |
| 339 | + if current_product: | |
| 340 | + products.append(current_product) | |
| 341 | + | |
| 342 | + return products | |
| 343 | + | |
| 344 | + | |
| 345 | +def display_product_card(product: dict): | |
| 346 | + """Display a product card with image and name""" | |
| 347 | + product_id = product.get("id", "") | |
| 348 | + name = product.get("name", "Unknown Product") | |
| 349 | + | |
| 350 | + # Debug: log what we got | |
| 351 | + logger.info(f"Displaying product: ID={product_id}, Name={name}") | |
| 352 | + | |
| 353 | + # Try to load image from data/images directory | |
| 354 | + if product_id: | |
| 355 | + image_path = Path(f"data/images/{product_id}.jpg") | |
| 356 | + | |
| 357 | + if image_path.exists(): | |
| 358 | + try: | |
| 359 | + img = Image.open(image_path) | |
| 360 | + # Fixed size for all images | |
| 361 | + target_size = (200, 200) | |
| 362 | + try: | |
| 363 | + # Try new Pillow API | |
| 364 | + img_processed = ImageOps.fit( | |
| 365 | + img, target_size, method=Image.Resampling.LANCZOS | |
| 366 | + ) | |
| 367 | + except AttributeError: | |
| 368 | + # Fallback for older Pillow versions | |
| 369 | + img_processed = ImageOps.fit( | |
| 370 | + img, target_size, method=Image.LANCZOS | |
| 371 | + ) | |
| 372 | + | |
| 373 | + # Display image with fixed width | |
| 374 | + st.image(img_processed, use_container_width=False, width=200) | |
| 375 | + st.markdown(f"**{name}**") | |
| 376 | + st.caption(f"ID: {product_id}") | |
| 377 | + return | |
| 378 | + except Exception as e: | |
| 379 | + logger.warning(f"Failed to load image {image_path}: {e}") | |
| 380 | + else: | |
| 381 | + logger.warning(f"Image not found: {image_path}") | |
| 382 | + | |
| 383 | + # Fallback: no image | |
| 384 | + st.markdown(f"**📷 {name}**") | |
| 385 | + if product_id: | |
| 386 | + st.caption(f"ID: {product_id}") | |
| 387 | + else: | |
| 388 | + st.caption("ID not available") | |
| 389 | + | |
| 390 | + | |
| 391 | +def display_message(message: dict): | |
| 392 | + """Display a chat message""" | |
| 393 | + role = message["role"] | |
| 394 | + content = message["content"] | |
| 395 | + image_path = message.get("image_path") | |
| 396 | + tool_calls = message.get("tool_calls", []) | |
| 397 | + | |
| 398 | + if role == "user": | |
| 399 | + st.markdown('<div class="message user-message">', unsafe_allow_html=True) | |
| 400 | + | |
| 401 | + if image_path and Path(image_path).exists(): | |
| 402 | + try: | |
| 403 | + img = Image.open(image_path) | |
| 404 | + st.image(img, width=200) | |
| 405 | + except Exception: | |
| 406 | + logger.warning(f"Failed to load user uploaded image: {image_path}") | |
| 407 | + | |
| 408 | + st.markdown(content) | |
| 409 | + st.markdown("</div>", unsafe_allow_html=True) | |
| 410 | + | |
| 411 | + else: # assistant | |
| 412 | + # Display tool calls horizontally - only tool names | |
| 413 | + if tool_calls: | |
| 414 | + tool_names = [tc['name'] for tc in tool_calls] | |
| 415 | + st.caption(" → ".join(tool_names)) | |
| 416 | + st.markdown("") | |
| 417 | + | |
| 418 | + # Extract and display products if any | |
| 419 | + products = extract_products_from_response(content) | |
| 420 | + | |
| 421 | + # Debug logging | |
| 422 | + logger.info(f"Extracted {len(products)} products from response") | |
| 423 | + for p in products: | |
| 424 | + logger.info(f"Product: {p}") | |
| 425 | + | |
| 426 | + if products: | |
| 427 | + def parse_score(product: dict) -> float: | |
| 428 | + score = product.get("score") | |
| 429 | + if score is None: | |
| 430 | + return 0.0 | |
| 431 | + try: | |
| 432 | + return float(score) | |
| 433 | + except (TypeError, ValueError): | |
| 434 | + return 0.0 | |
| 435 | + | |
| 436 | + # Sort by score and limit to 3 | |
| 437 | + products = sorted(products, key=parse_score, reverse=True)[:3] | |
| 438 | + | |
| 439 | + logger.info(f"Displaying top {len(products)} products") | |
| 440 | + | |
| 441 | + # Display the text response first (without product details) | |
| 442 | + text_lines = [] | |
| 443 | + for line in content.split("\n"): | |
| 444 | + # Skip product detail lines | |
| 445 | + if not any( | |
| 446 | + keyword in line | |
| 447 | + for keyword in [ | |
| 448 | + "ID:", | |
| 449 | + "Category:", | |
| 450 | + "Color:", | |
| 451 | + "Gender:", | |
| 452 | + "Season:", | |
| 453 | + "Usage:", | |
| 454 | + "Similarity:", | |
| 455 | + "Relevance:", | |
| 456 | + ] | |
| 457 | + ): | |
| 458 | + if not re.match(r"^\*?\*?\d+\.\s+", line): | |
| 459 | + text_lines.append(line) | |
| 460 | + | |
| 461 | + intro_text = "\n".join(text_lines).strip() | |
| 462 | + if intro_text: | |
| 463 | + st.markdown(intro_text) | |
| 464 | + | |
| 465 | + # Display product cards in grid | |
| 466 | + st.markdown("<br>", unsafe_allow_html=True) | |
| 467 | + | |
| 468 | + # Create exactly 3 columns with equal width | |
| 469 | + cols = st.columns(3) | |
| 470 | + for j, product in enumerate(products[:3]): # Ensure max 3 | |
| 471 | + with cols[j]: | |
| 472 | + display_product_card(product) | |
| 473 | + else: | |
| 474 | + # No products found, display full content | |
| 475 | + st.markdown(content) | |
| 476 | + | |
| 477 | + st.markdown("</div>", unsafe_allow_html=True) | |
| 478 | + | |
| 479 | + | |
| 480 | +def display_welcome(): | |
| 481 | + """Display welcome screen""" | |
| 482 | + | |
| 483 | + col1, col2, col3, col4 = st.columns(4) | |
| 484 | + | |
| 485 | + with col1: | |
| 486 | + st.markdown( | |
| 487 | + """ | |
| 488 | + <div class="feature-card"> | |
| 489 | + <div class="feature-icon">💬</div> | |
| 490 | + <div class="feature-title">Text Search</div> | |
| 491 | + <div>Describe what you want</div> | |
| 492 | + </div> | |
| 493 | + """, | |
| 494 | + unsafe_allow_html=True, | |
| 495 | + ) | |
| 496 | + | |
| 497 | + with col2: | |
| 498 | + st.markdown( | |
| 499 | + """ | |
| 500 | + <div class="feature-card"> | |
| 501 | + <div class="feature-icon">📸</div> | |
| 502 | + <div class="feature-title">Image Search</div> | |
| 503 | + <div>Upload product photos</div> | |
| 504 | + </div> | |
| 505 | + """, | |
| 506 | + unsafe_allow_html=True, | |
| 507 | + ) | |
| 508 | + | |
| 509 | + with col3: | |
| 510 | + st.markdown( | |
| 511 | + """ | |
| 512 | + <div class="feature-card"> | |
| 513 | + <div class="feature-icon">🔍</div> | |
| 514 | + <div class="feature-title">Visual Analysis</div> | |
| 515 | + <div>AI analyzes prodcut style</div> | |
| 516 | + </div> | |
| 517 | + """, | |
| 518 | + unsafe_allow_html=True, | |
| 519 | + ) | |
| 520 | + | |
| 521 | + with col4: | |
| 522 | + st.markdown( | |
| 523 | + """ | |
| 524 | + <div class="feature-card"> | |
| 525 | + <div class="feature-icon">💭</div> | |
| 526 | + <div class="feature-title">Conversational</div> | |
| 527 | + <div>Remembers context</div> | |
| 528 | + </div> | |
| 529 | + """, | |
| 530 | + unsafe_allow_html=True, | |
| 531 | + ) | |
| 532 | + | |
| 533 | + st.markdown("<br><br>", unsafe_allow_html=True) | |
| 534 | + | |
| 535 | + | |
| 536 | +def main(): | |
| 537 | + """Main Streamlit app""" | |
| 538 | + initialize_session() | |
| 539 | + | |
| 540 | + # Header | |
| 541 | + st.markdown( | |
| 542 | + """ | |
| 543 | + <div class="app-header"> | |
| 544 | + <div class="app-title">👗 OmniShopAgent</div> | |
| 545 | + <div class="app-subtitle">AI Fashion Shopping Assistant</div> | |
| 546 | + </div> | |
| 547 | + """, | |
| 548 | + unsafe_allow_html=True, | |
| 549 | + ) | |
| 550 | + | |
| 551 | + # Sidebar (collapsed by default, but accessible) | |
| 552 | + with st.sidebar: | |
| 553 | + st.markdown("### ⚙️ Settings") | |
| 554 | + | |
| 555 | + if st.button("🗑️ Clear Chat", use_container_width=True): | |
| 556 | + if "shopping_agent" in st.session_state: | |
| 557 | + st.session_state.shopping_agent.clear_history() | |
| 558 | + st.session_state.messages = [] | |
| 559 | + st.session_state.uploaded_image = None | |
| 560 | + st.rerun() | |
| 561 | + | |
| 562 | + st.markdown("---") | |
| 563 | + st.caption(f"Session: `{st.session_state.session_id[:8]}...`") | |
| 564 | + | |
| 565 | + # Chat messages container | |
| 566 | + messages_container = st.container() | |
| 567 | + | |
| 568 | + with messages_container: | |
| 569 | + if not st.session_state.messages: | |
| 570 | + display_welcome() | |
| 571 | + else: | |
| 572 | + for message in st.session_state.messages: | |
| 573 | + display_message(message) | |
| 574 | + | |
| 575 | + # Fixed input area at bottom (using container to simulate fixed position) | |
| 576 | + st.markdown('<div class="fixed-input-container">', unsafe_allow_html=True) | |
| 577 | + | |
| 578 | + input_container = st.container() | |
| 579 | + | |
| 580 | + with input_container: | |
| 581 | + # Image upload area (shown when + is clicked) | |
| 582 | + if st.session_state.show_image_upload: | |
| 583 | + uploaded_file = st.file_uploader( | |
| 584 | + "Choose an image", | |
| 585 | + type=["jpg", "jpeg", "png"], | |
| 586 | + key="file_uploader", | |
| 587 | + ) | |
| 588 | + | |
| 589 | + if uploaded_file: | |
| 590 | + st.session_state.uploaded_image = uploaded_file | |
| 591 | + # Show preview | |
| 592 | + col1, col2 = st.columns([1, 4]) | |
| 593 | + with col1: | |
| 594 | + img = Image.open(uploaded_file) | |
| 595 | + st.image(img, width=100) | |
| 596 | + with col2: | |
| 597 | + if st.button("❌ Remove"): | |
| 598 | + st.session_state.uploaded_image = None | |
| 599 | + st.session_state.show_image_upload = False | |
| 600 | + st.rerun() | |
| 601 | + | |
| 602 | + # Input row | |
| 603 | + col1, col2 = st.columns([1, 12]) | |
| 604 | + | |
| 605 | + with col1: | |
| 606 | + # Image upload toggle button | |
| 607 | + if st.button("➕", help="Add image", use_container_width=True): | |
| 608 | + st.session_state.show_image_upload = ( | |
| 609 | + not st.session_state.show_image_upload | |
| 610 | + ) | |
| 611 | + st.rerun() | |
| 612 | + | |
| 613 | + with col2: | |
| 614 | + # Text input | |
| 615 | + user_query = st.chat_input( | |
| 616 | + "Ask about fashion products...", | |
| 617 | + key="chat_input", | |
| 618 | + ) | |
| 619 | + | |
| 620 | + st.markdown("</div>", unsafe_allow_html=True) | |
| 621 | + | |
| 622 | + # Process user input | |
| 623 | + if user_query: | |
| 624 | + # Ensure shopping agent is initialized | |
| 625 | + if "shopping_agent" not in st.session_state: | |
| 626 | + st.error("Session not initialized. Please refresh the page.") | |
| 627 | + st.stop() | |
| 628 | + | |
| 629 | + # Save uploaded image if present, or get from recent history | |
| 630 | + image_path = None | |
| 631 | + if st.session_state.uploaded_image: | |
| 632 | + # User explicitly uploaded an image for this query | |
| 633 | + image_path = save_uploaded_image(st.session_state.uploaded_image) | |
| 634 | + else: | |
| 635 | + # Check if query refers to a previous image | |
| 636 | + query_lower = user_query.lower() | |
| 637 | + if any( | |
| 638 | + ref in query_lower | |
| 639 | + for ref in [ | |
| 640 | + "this", | |
| 641 | + "that", | |
| 642 | + "the image", | |
| 643 | + "the shirt", | |
| 644 | + "the product", | |
| 645 | + "it", | |
| 646 | + ] | |
| 647 | + ): | |
| 648 | + # Find the most recent message with an image | |
| 649 | + for msg in reversed(st.session_state.messages): | |
| 650 | + if msg.get("role") == "user" and msg.get("image_path"): | |
| 651 | + image_path = msg["image_path"] | |
| 652 | + logger.info(f"Using image from previous message: {image_path}") | |
| 653 | + break | |
| 654 | + | |
| 655 | + # Add user message | |
| 656 | + st.session_state.messages.append( | |
| 657 | + { | |
| 658 | + "role": "user", | |
| 659 | + "content": user_query, | |
| 660 | + "image_path": image_path, | |
| 661 | + } | |
| 662 | + ) | |
| 663 | + | |
| 664 | + # Display user message immediately | |
| 665 | + with messages_container: | |
| 666 | + display_message(st.session_state.messages[-1]) | |
| 667 | + | |
| 668 | + # Process with shopping agent | |
| 669 | + try: | |
| 670 | + shopping_agent = st.session_state.shopping_agent | |
| 671 | + | |
| 672 | + # Handle greetings | |
| 673 | + query_lower = user_query.lower().strip() | |
| 674 | + if query_lower in ["hi", "hello", "hey"]: | |
| 675 | + response = """Hello! 👋 I'm your fashion shopping assistant. | |
| 676 | + | |
| 677 | +I can help you: | |
| 678 | +- Search for products by description | |
| 679 | +- Find items similar to images you upload | |
| 680 | +- Analyze product styles | |
| 681 | + | |
| 682 | +What are you looking for today?""" | |
| 683 | + tool_calls = [] | |
| 684 | + else: | |
| 685 | + # Process with agent | |
| 686 | + result = shopping_agent.chat( | |
| 687 | + query=user_query, | |
| 688 | + image_path=image_path, | |
| 689 | + ) | |
| 690 | + response = result["response"] | |
| 691 | + tool_calls = result.get("tool_calls", []) | |
| 692 | + | |
| 693 | + # Add assistant message | |
| 694 | + st.session_state.messages.append( | |
| 695 | + { | |
| 696 | + "role": "assistant", | |
| 697 | + "content": response, | |
| 698 | + "tool_calls": tool_calls, | |
| 699 | + } | |
| 700 | + ) | |
| 701 | + | |
| 702 | + # Clear uploaded image and hide upload area after sending | |
| 703 | + st.session_state.uploaded_image = None | |
| 704 | + st.session_state.show_image_upload = False | |
| 705 | + | |
| 706 | + # Auto-scroll to bottom with JavaScript | |
| 707 | + st.markdown( | |
| 708 | + """ | |
| 709 | + <script> | |
| 710 | + window.scrollTo(0, document.body.scrollHeight); | |
| 711 | + </script> | |
| 712 | + """, | |
| 713 | + unsafe_allow_html=True, | |
| 714 | + ) | |
| 715 | + | |
| 716 | + except Exception as e: | |
| 717 | + logger.error(f"Error processing query: {e}", exc_info=True) | |
| 718 | + error_msg = f"I apologize, I encountered an error: {str(e)}" | |
| 719 | + | |
| 720 | + st.session_state.messages.append( | |
| 721 | + { | |
| 722 | + "role": "assistant", | |
| 723 | + "content": error_msg, | |
| 724 | + } | |
| 725 | + ) | |
| 726 | + | |
| 727 | + # Rerun to update UI | |
| 728 | + st.rerun() | |
| 729 | + | |
| 730 | + | |
| 731 | +if __name__ == "__main__": | |
| 732 | + main() | ... | ... |
| 1 | +++ a/app/agents/shopping_agent.py | |
| ... | ... | @@ -0,0 +1,272 @@ |
| 1 | +""" | |
| 2 | +Conversational Shopping Agent with LangGraph | |
| 3 | +True ReAct agent with autonomous tool calling and message accumulation | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import logging | |
| 7 | +from pathlib import Path | |
| 8 | +from typing import Optional, Sequence | |
| 9 | + | |
| 10 | +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
| 11 | +from langchain_openai import ChatOpenAI | |
| 12 | +from langgraph.checkpoint.memory import MemorySaver | |
| 13 | +from langgraph.graph import END, START, StateGraph | |
| 14 | +from langgraph.graph.message import add_messages | |
| 15 | +from langgraph.prebuilt import ToolNode | |
| 16 | +from typing_extensions import Annotated, TypedDict | |
| 17 | + | |
| 18 | +from app.config import settings | |
| 19 | +from app.tools.search_tools import get_all_tools | |
| 20 | + | |
| 21 | +logger = logging.getLogger(__name__) | |
| 22 | + | |
| 23 | + | |
| 24 | +def _extract_message_text(msg) -> str: | |
| 25 | + """Extract text from message content. | |
| 26 | + LangChain 1.0: content may be str or content_blocks (list) for multimodal.""" | |
| 27 | + content = getattr(msg, "content", "") | |
| 28 | + if isinstance(content, str): | |
| 29 | + return content | |
| 30 | + if isinstance(content, list): | |
| 31 | + parts = [] | |
| 32 | + for block in content: | |
| 33 | + if isinstance(block, dict): | |
| 34 | + parts.append(block.get("text", block.get("content", ""))) | |
| 35 | + else: | |
| 36 | + parts.append(str(block)) | |
| 37 | + return "".join(str(p) for p in parts) | |
| 38 | + return str(content) if content else "" | |
| 39 | + | |
| 40 | + | |
| 41 | +class AgentState(TypedDict): | |
| 42 | + """State for the shopping agent with message accumulation""" | |
| 43 | + | |
| 44 | + messages: Annotated[Sequence[BaseMessage], add_messages] | |
| 45 | + current_image_path: Optional[str] # Track uploaded image | |
| 46 | + | |
| 47 | + | |
| 48 | +class ShoppingAgent: | |
| 49 | + """True ReAct agent with autonomous decision making""" | |
| 50 | + | |
| 51 | + def __init__(self, session_id: Optional[str] = None): | |
| 52 | + self.session_id = session_id or "default" | |
| 53 | + | |
| 54 | + # Initialize LLM | |
| 55 | + self.llm = ChatOpenAI( | |
| 56 | + model=settings.openai_model, | |
| 57 | + temperature=settings.openai_temperature, | |
| 58 | + api_key=settings.openai_api_key, | |
| 59 | + ) | |
| 60 | + | |
| 61 | + # Get tools and bind to model | |
| 62 | + self.tools = get_all_tools() | |
| 63 | + self.llm_with_tools = self.llm.bind_tools(self.tools) | |
| 64 | + | |
| 65 | + # Build graph | |
| 66 | + self.graph = self._build_graph() | |
| 67 | + | |
| 68 | + logger.info(f"Shopping agent initialized for session: {self.session_id}") | |
| 69 | + | |
| 70 | + def _build_graph(self): | |
| 71 | + """Build the LangGraph StateGraph""" | |
| 72 | + | |
| 73 | + # System prompt for the agent | |
| 74 | + system_prompt = """You are an intelligent fashion shopping assistant. You can: | |
| 75 | +1. Search for products by text description (use search_products) | |
| 76 | +2. Find visually similar products from images (use search_by_image) | |
| 77 | +3. Analyze image style and attributes (use analyze_image_style) | |
| 78 | + | |
| 79 | +When a user asks about products: | |
| 80 | +- For text queries: use search_products directly | |
| 81 | +- For image uploads: decide if you need to analyze_image_style first, then search | |
| 82 | +- You can call multiple tools in sequence if needed | |
| 83 | +- Always provide helpful, friendly responses | |
| 84 | + | |
| 85 | +CRITICAL FORMATTING RULES: | |
| 86 | +When presenting product results, you MUST use this EXACT format for EACH product: | |
| 87 | + | |
| 88 | +1. [Product Name] | |
| 89 | + ID: [Product ID Number] | |
| 90 | + Category: [Category] | |
| 91 | + Color: [Color] | |
| 92 | + Gender: [Gender] | |
| 93 | + (Include Season, Usage, Relevance if available) | |
| 94 | + | |
| 95 | +Example: | |
| 96 | +1. Puma Men White 3/4 Length Pants | |
| 97 | + ID: 12345 | |
| 98 | + Category: Apparel > Bottomwear > Track Pants | |
| 99 | + Color: White | |
| 100 | + Gender: Men | |
| 101 | + Season: Summer | |
| 102 | + Usage: Sports | |
| 103 | + Relevance: 95.2% | |
| 104 | + | |
| 105 | +DO NOT skip the ID field! It is essential for displaying product images. | |
| 106 | +Be conversational in your introduction, but preserve the exact product format.""" | |
| 107 | + | |
| 108 | + def agent_node(state: AgentState): | |
| 109 | + """Agent decision node - decides which tools to call or when to respond""" | |
| 110 | + messages = state["messages"] | |
| 111 | + | |
| 112 | + # Add system prompt if first message | |
| 113 | + if not any(isinstance(m, SystemMessage) for m in messages): | |
| 114 | + messages = [SystemMessage(content=system_prompt)] + list(messages) | |
| 115 | + | |
| 116 | + # Handle image context | |
| 117 | + if state.get("current_image_path"): | |
| 118 | + # Inject image path context for tool calls | |
| 119 | + # The agent can reference this in its reasoning | |
| 120 | + pass | |
| 121 | + | |
| 122 | + # Invoke LLM with tools | |
| 123 | + response = self.llm_with_tools.invoke(messages) | |
| 124 | + return {"messages": [response]} | |
| 125 | + | |
| 126 | + # Create tool node | |
| 127 | + tool_node = ToolNode(self.tools) | |
| 128 | + | |
| 129 | + def should_continue(state: AgentState): | |
| 130 | + """Determine if agent should continue or end""" | |
| 131 | + messages = state["messages"] | |
| 132 | + last_message = messages[-1] | |
| 133 | + | |
| 134 | + # If LLM made tool calls, continue to tools | |
| 135 | + if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| 136 | + return "tools" | |
| 137 | + # Otherwise, end (agent has final response) | |
| 138 | + return END | |
| 139 | + | |
| 140 | + # Build graph | |
| 141 | + workflow = StateGraph(AgentState) | |
| 142 | + | |
| 143 | + workflow.add_node("agent", agent_node) | |
| 144 | + workflow.add_node("tools", tool_node) | |
| 145 | + | |
| 146 | + workflow.add_edge(START, "agent") | |
| 147 | + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) | |
| 148 | + workflow.add_edge("tools", "agent") | |
| 149 | + | |
| 150 | + # Compile with memory | |
| 151 | + checkpointer = MemorySaver() | |
| 152 | + return workflow.compile(checkpointer=checkpointer) | |
| 153 | + | |
| 154 | + def chat(self, query: str, image_path: Optional[str] = None) -> dict: | |
| 155 | + """Process user query with the agent | |
| 156 | + | |
| 157 | + Args: | |
| 158 | + query: User's text query | |
| 159 | + image_path: Optional path to uploaded image | |
| 160 | + | |
| 161 | + Returns: | |
| 162 | + Dict with response and metadata | |
| 163 | + """ | |
| 164 | + try: | |
| 165 | + logger.info( | |
| 166 | + f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})" | |
| 167 | + ) | |
| 168 | + | |
| 169 | + # Validate image | |
| 170 | + if image_path and not Path(image_path).exists(): | |
| 171 | + return { | |
| 172 | + "response": f"Error: Image file not found at '{image_path}'", | |
| 173 | + "error": True, | |
| 174 | + } | |
| 175 | + | |
| 176 | + # Build input message | |
| 177 | + message_content = query | |
| 178 | + if image_path: | |
| 179 | + message_content = f"{query}\n[User uploaded image: {image_path}]" | |
| 180 | + | |
| 181 | + # Invoke agent | |
| 182 | + config = {"configurable": {"thread_id": self.session_id}} | |
| 183 | + input_state = { | |
| 184 | + "messages": [HumanMessage(content=message_content)], | |
| 185 | + "current_image_path": image_path, | |
| 186 | + } | |
| 187 | + | |
| 188 | + # Track tool calls | |
| 189 | + tool_calls = [] | |
| 190 | + | |
| 191 | + # Stream events to capture tool calls | |
| 192 | + for event in self.graph.stream(input_state, config=config): | |
| 193 | + logger.info(f"Event: {event}") | |
| 194 | + | |
| 195 | + # Check for agent node (tool calls) | |
| 196 | + if "agent" in event: | |
| 197 | + agent_output = event["agent"] | |
| 198 | + if "messages" in agent_output: | |
| 199 | + for msg in agent_output["messages"]: | |
| 200 | + if hasattr(msg, "tool_calls") and msg.tool_calls: | |
| 201 | + for tc in msg.tool_calls: | |
| 202 | + tool_calls.append({ | |
| 203 | + "name": tc["name"], | |
| 204 | + "args": tc.get("args", {}), | |
| 205 | + }) | |
| 206 | + | |
| 207 | + # Check for tool node (tool results) | |
| 208 | + if "tools" in event: | |
| 209 | + tools_output = event["tools"] | |
| 210 | + if "messages" in tools_output: | |
| 211 | + for i, msg in enumerate(tools_output["messages"]): | |
| 212 | + if i < len(tool_calls): | |
| 213 | + tool_calls[i]["result"] = str(msg.content)[:200] + "..." | |
| 214 | + | |
| 215 | + # Get final state | |
| 216 | + final_state = self.graph.get_state(config) | |
| 217 | + final_message = final_state.values["messages"][-1] | |
| 218 | + response_text = _extract_message_text(final_message) | |
| 219 | + | |
| 220 | + logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls") | |
| 221 | + | |
| 222 | + return { | |
| 223 | + "response": response_text, | |
| 224 | + "tool_calls": tool_calls, | |
| 225 | + "error": False, | |
| 226 | + } | |
| 227 | + | |
| 228 | + except Exception as e: | |
| 229 | + logger.error(f"Error in agent chat: {e}", exc_info=True) | |
| 230 | + return { | |
| 231 | + "response": f"I apologize, I encountered an error: {str(e)}", | |
| 232 | + "error": True, | |
| 233 | + } | |
| 234 | + | |
| 235 | + def get_conversation_history(self) -> list: | |
| 236 | + """Get conversation history for this session""" | |
| 237 | + try: | |
| 238 | + config = {"configurable": {"thread_id": self.session_id}} | |
| 239 | + state = self.graph.get_state(config) | |
| 240 | + | |
| 241 | + if not state or not state.values.get("messages"): | |
| 242 | + return [] | |
| 243 | + | |
| 244 | + messages = state.values["messages"] | |
| 245 | + result = [] | |
| 246 | + | |
| 247 | + for msg in messages: | |
| 248 | + # Skip system messages and tool messages | |
| 249 | + if isinstance(msg, SystemMessage): | |
| 250 | + continue | |
| 251 | + if hasattr(msg, "type") and msg.type in ["system", "tool"]: | |
| 252 | + continue | |
| 253 | + | |
| 254 | + role = "user" if msg.type == "human" else "assistant" | |
| 255 | + result.append({"role": role, "content": _extract_message_text(msg)}) | |
| 256 | + | |
| 257 | + return result | |
| 258 | + | |
| 259 | + except Exception as e: | |
| 260 | + logger.error(f"Error getting history: {e}") | |
| 261 | + return [] | |
| 262 | + | |
| 263 | + def clear_history(self): | |
| 264 | + """Clear conversation history for this session""" | |
| 265 | + # With MemorySaver, we can't easily clear, but we can log | |
| 266 | + logger.info(f"[{self.session_id}] History clear requested") | |
| 267 | + # In production, implement proper clearing or use new thread_id | |
| 268 | + | |
| 269 | + | |
| 270 | +def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: | |
| 271 | + """Factory function to create a shopping agent""" | |
| 272 | + return ShoppingAgent(session_id=session_id) | ... | ... |
| 1 | +++ a/app/config.py | |
| ... | ... | @@ -0,0 +1,86 @@ |
| 1 | +""" | |
| 2 | +Configuration management for OmniShopAgent | |
| 3 | +Loads environment variables and provides configuration objects | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import os | |
| 7 | + | |
| 8 | +from pydantic_settings import BaseSettings | |
| 9 | + | |
| 10 | + | |
| 11 | +class Settings(BaseSettings): | |
| 12 | + """Application settings loaded from environment variables | |
| 13 | + | |
| 14 | + All settings can be configured via .env file or environment variables. | |
| 15 | + """ | |
| 16 | + | |
| 17 | + # OpenAI Configuration | |
| 18 | + openai_api_key: str | |
| 19 | + openai_model: str = "gpt-4o-mini" | |
| 20 | + openai_embedding_model: str = "text-embedding-3-small" | |
| 21 | + openai_temperature: float = 0.7 | |
| 22 | + openai_max_tokens: int = 1000 | |
| 23 | + | |
| 24 | + # CLIP Server Configuration | |
| 25 | + clip_server_url: str = "grpc://localhost:51000" | |
| 26 | + | |
| 27 | + # Milvus Configuration | |
| 28 | + milvus_uri: str = "http://localhost:19530" | |
| 29 | + milvus_host: str = "localhost" | |
| 30 | + milvus_port: int = 19530 | |
| 31 | + text_collection_name: str = "text_embeddings" | |
| 32 | + image_collection_name: str = "image_embeddings" | |
| 33 | + text_dim: int = 1536 | |
| 34 | + image_dim: int = 512 | |
| 35 | + | |
| 36 | + @property | |
| 37 | + def milvus_uri_absolute(self) -> str: | |
| 38 | + """Get absolute path for Milvus URI | |
| 39 | + | |
| 40 | + Returns: | |
| 41 | + - For http/https URIs: returns as-is (Milvus Standalone) | |
| 42 | + - For file paths starting with ./: converts to absolute path (Milvus Lite) | |
| 43 | + - For other paths: returns as-is | |
| 44 | + """ | |
| 45 | + import os | |
| 46 | + | |
| 47 | + # If it's a network URI, return as-is (Milvus Standalone) | |
| 48 | + if self.milvus_uri.startswith(("http://", "https://")): | |
| 49 | + return self.milvus_uri | |
| 50 | + # If it's a relative path, convert to absolute (Milvus Lite) | |
| 51 | + if self.milvus_uri.startswith("./"): | |
| 52 | + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| 53 | + return os.path.join(base_dir, self.milvus_uri[2:]) | |
| 54 | + # Otherwise return as-is | |
| 55 | + return self.milvus_uri | |
| 56 | + | |
| 57 | + # Search Configuration | |
| 58 | + top_k_results: int = 10 | |
| 59 | + similarity_threshold: float = 0.6 | |
| 60 | + | |
| 61 | + # Application Configuration | |
| 62 | + app_host: str = "0.0.0.0" | |
| 63 | + app_port: int = 8000 | |
| 64 | + debug: bool = True | |
| 65 | + log_level: str = "INFO" | |
| 66 | + | |
| 67 | + # Data Paths | |
| 68 | + raw_data_path: str = "./data/raw" | |
| 69 | + processed_data_path: str = "./data/processed" | |
| 70 | + image_data_path: str = "./data/images" | |
| 71 | + | |
| 72 | + class Config: | |
| 73 | + env_file = ".env" | |
| 74 | + env_file_encoding = "utf-8" | |
| 75 | + case_sensitive = False | |
| 76 | + | |
| 77 | + | |
| 78 | +# Global settings instance | |
| 79 | +settings = Settings() | |
| 80 | + | |
| 81 | + | |
| 82 | +# Helper function to get absolute paths | |
| 83 | +def get_absolute_path(relative_path: str) -> str: | |
| 84 | + """Convert relative path to absolute path""" | |
| 85 | + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| 86 | + return os.path.join(base_dir, relative_path) | ... | ... |
| 1 | +++ a/app/services/__init__.py | |
| ... | ... | @@ -0,0 +1,14 @@ |
| 1 | +""" | |
| 2 | +Services Module | |
| 3 | +Provides database and embedding services for the application | |
| 4 | +""" | |
| 5 | + | |
| 6 | +from app.services.embedding_service import EmbeddingService, get_embedding_service | |
| 7 | +from app.services.milvus_service import MilvusService, get_milvus_service | |
| 8 | + | |
| 9 | +__all__ = [ | |
| 10 | + "EmbeddingService", | |
| 11 | + "get_embedding_service", | |
| 12 | + "MilvusService", | |
| 13 | + "get_milvus_service", | |
| 14 | +] | ... | ... |
| 1 | +++ a/app/services/embedding_service.py | |
| ... | ... | @@ -0,0 +1,293 @@ |
| 1 | +""" | |
| 2 | +Embedding Service for Text and Image Embeddings | |
| 3 | +Supports OpenAI text embeddings and CLIP image embeddings | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import logging | |
| 7 | +from pathlib import Path | |
| 8 | +from typing import List, Optional, Union | |
| 9 | + | |
| 10 | +import numpy as np | |
| 11 | +from clip_client import Client as ClipClient | |
| 12 | +from openai import OpenAI | |
| 13 | + | |
| 14 | +from app.config import settings | |
| 15 | + | |
| 16 | +logger = logging.getLogger(__name__) | |
| 17 | + | |
| 18 | + | |
| 19 | +class EmbeddingService: | |
| 20 | + """Service for generating text and image embeddings""" | |
| 21 | + | |
| 22 | + def __init__( | |
| 23 | + self, | |
| 24 | + openai_api_key: Optional[str] = None, | |
| 25 | + clip_server_url: Optional[str] = None, | |
| 26 | + ): | |
| 27 | + """Initialize embedding service | |
| 28 | + | |
| 29 | + Args: | |
| 30 | + openai_api_key: OpenAI API key. If None, uses settings.openai_api_key | |
| 31 | + clip_server_url: CLIP server URL. If None, uses settings.clip_server_url | |
| 32 | + """ | |
| 33 | + # Initialize OpenAI client for text embeddings | |
| 34 | + self.openai_api_key = openai_api_key or settings.openai_api_key | |
| 35 | + self.openai_client = OpenAI(api_key=self.openai_api_key) | |
| 36 | + self.text_embedding_model = settings.openai_embedding_model | |
| 37 | + | |
| 38 | + # Initialize CLIP client for image embeddings | |
| 39 | + self.clip_server_url = clip_server_url or settings.clip_server_url | |
| 40 | + self.clip_client: Optional[ClipClient] = None | |
| 41 | + | |
| 42 | + logger.info("Embedding service initialized") | |
| 43 | + | |
| 44 | + def connect_clip(self) -> None: | |
| 45 | + """Connect to CLIP server""" | |
| 46 | + try: | |
| 47 | + self.clip_client = ClipClient(server=self.clip_server_url) | |
| 48 | + logger.info(f"Connected to CLIP server at {self.clip_server_url}") | |
| 49 | + except Exception as e: | |
| 50 | + logger.error(f"Failed to connect to CLIP server: {e}") | |
| 51 | + raise | |
| 52 | + | |
| 53 | + def disconnect_clip(self) -> None: | |
| 54 | + """Disconnect from CLIP server""" | |
| 55 | + if self.clip_client: | |
| 56 | + # Note: clip_client doesn't have explicit close method | |
| 57 | + self.clip_client = None | |
| 58 | + logger.info("Disconnected from CLIP server") | |
| 59 | + | |
| 60 | + def get_text_embedding(self, text: str) -> List[float]: | |
| 61 | + """Get embedding for a single text | |
| 62 | + | |
| 63 | + Args: | |
| 64 | + text: Input text | |
| 65 | + | |
| 66 | + Returns: | |
| 67 | + Embedding vector as list of floats | |
| 68 | + """ | |
| 69 | + try: | |
| 70 | + response = self.openai_client.embeddings.create( | |
| 71 | + input=text, model=self.text_embedding_model | |
| 72 | + ) | |
| 73 | + embedding = response.data[0].embedding | |
| 74 | + logger.debug(f"Generated text embedding for: {text[:50]}...") | |
| 75 | + return embedding | |
| 76 | + except Exception as e: | |
| 77 | + logger.error(f"Failed to generate text embedding: {e}") | |
| 78 | + raise | |
| 79 | + | |
| 80 | + def get_text_embeddings_batch( | |
| 81 | + self, texts: List[str], batch_size: int = 100 | |
| 82 | + ) -> List[List[float]]: | |
| 83 | + """Get embeddings for multiple texts in batches | |
| 84 | + | |
| 85 | + Args: | |
| 86 | + texts: List of input texts | |
| 87 | + batch_size: Number of texts to process at once | |
| 88 | + | |
| 89 | + Returns: | |
| 90 | + List of embedding vectors | |
| 91 | + """ | |
| 92 | + all_embeddings = [] | |
| 93 | + | |
| 94 | + for i in range(0, len(texts), batch_size): | |
| 95 | + batch = texts[i : i + batch_size] | |
| 96 | + | |
| 97 | + try: | |
| 98 | + response = self.openai_client.embeddings.create( | |
| 99 | + input=batch, model=self.text_embedding_model | |
| 100 | + ) | |
| 101 | + | |
| 102 | + # Extract embeddings in the correct order | |
| 103 | + embeddings = [item.embedding for item in response.data] | |
| 104 | + all_embeddings.extend(embeddings) | |
| 105 | + | |
| 106 | + logger.info( | |
| 107 | + f"Generated text embeddings for batch {i // batch_size + 1}: {len(embeddings)} embeddings" | |
| 108 | + ) | |
| 109 | + | |
| 110 | + except Exception as e: | |
| 111 | + logger.error( | |
| 112 | + f"Failed to generate text embeddings for batch {i // batch_size + 1}: {e}" | |
| 113 | + ) | |
| 114 | + raise | |
| 115 | + | |
| 116 | + return all_embeddings | |
| 117 | + | |
| 118 | + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: | |
| 119 | + """Get CLIP embedding for a single image | |
| 120 | + | |
| 121 | + Args: | |
| 122 | + image_path: Path to image file | |
| 123 | + | |
| 124 | + Returns: | |
| 125 | + Embedding vector as list of floats | |
| 126 | + """ | |
| 127 | + if not self.clip_client: | |
| 128 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | |
| 129 | + | |
| 130 | + image_path = Path(image_path) | |
| 131 | + if not image_path.exists(): | |
| 132 | + raise FileNotFoundError(f"Image not found: {image_path}") | |
| 133 | + | |
| 134 | + try: | |
| 135 | + # Get embedding from CLIP server using image path (as string) | |
| 136 | + result = self.clip_client.encode([str(image_path)]) | |
| 137 | + | |
| 138 | + # Extract embedding - result is numpy array | |
| 139 | + import numpy as np | |
| 140 | + | |
| 141 | + if isinstance(result, np.ndarray): | |
| 142 | + # If result is numpy array, use first element | |
| 143 | + embedding = ( | |
| 144 | + result[0].tolist() if len(result.shape) > 1 else result.tolist() | |
| 145 | + ) | |
| 146 | + else: | |
| 147 | + # If result is DocumentArray | |
| 148 | + embedding = result[0].embedding.tolist() | |
| 149 | + | |
| 150 | + logger.debug(f"Generated image embedding for: {image_path.name}") | |
| 151 | + return embedding | |
| 152 | + | |
| 153 | + except Exception as e: | |
| 154 | + logger.error(f"Failed to generate image embedding for {image_path}: {e}") | |
| 155 | + raise | |
| 156 | + | |
| 157 | + def get_image_embeddings_batch( | |
| 158 | + self, image_paths: List[Union[str, Path]], batch_size: int = 32 | |
| 159 | + ) -> List[Optional[List[float]]]: | |
| 160 | + """Get CLIP embeddings for multiple images in batches | |
| 161 | + | |
| 162 | + Args: | |
| 163 | + image_paths: List of paths to image files | |
| 164 | + batch_size: Number of images to process at once | |
| 165 | + | |
| 166 | + Returns: | |
| 167 | + List of embedding vectors (None for failed images) | |
| 168 | + """ | |
| 169 | + if not self.clip_client: | |
| 170 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | |
| 171 | + | |
| 172 | + all_embeddings = [] | |
| 173 | + | |
| 174 | + for i in range(0, len(image_paths), batch_size): | |
| 175 | + batch_paths = image_paths[i : i + batch_size] | |
| 176 | + valid_paths = [] | |
| 177 | + valid_indices = [] | |
| 178 | + | |
| 179 | + # Check which images exist | |
| 180 | + for idx, path in enumerate(batch_paths): | |
| 181 | + path = Path(path) | |
| 182 | + if path.exists(): | |
| 183 | + valid_paths.append(str(path)) | |
| 184 | + valid_indices.append(idx) | |
| 185 | + else: | |
| 186 | + logger.warning(f"Image not found: {path}") | |
| 187 | + | |
| 188 | + # Get embeddings for valid images | |
| 189 | + if valid_paths: | |
| 190 | + try: | |
| 191 | + # Send paths as strings to CLIP server | |
| 192 | + result = self.clip_client.encode(valid_paths) | |
| 193 | + | |
| 194 | + # Create embeddings list with None for missing images | |
| 195 | + batch_embeddings = [None] * len(batch_paths) | |
| 196 | + | |
| 197 | + # Handle result format - could be numpy array or DocumentArray | |
| 198 | + import numpy as np | |
| 199 | + | |
| 200 | + if isinstance(result, np.ndarray): | |
| 201 | + # Result is numpy array - shape (n_images, embedding_dim) | |
| 202 | + for idx in range(len(result)): | |
| 203 | + original_idx = valid_indices[idx] | |
| 204 | + batch_embeddings[original_idx] = result[idx].tolist() | |
| 205 | + else: | |
| 206 | + # Result is DocumentArray | |
| 207 | + for idx, doc in enumerate(result): | |
| 208 | + original_idx = valid_indices[idx] | |
| 209 | + batch_embeddings[original_idx] = doc.embedding.tolist() | |
| 210 | + | |
| 211 | + all_embeddings.extend(batch_embeddings) | |
| 212 | + | |
| 213 | + logger.info( | |
| 214 | + f"Generated image embeddings for batch {i // batch_size + 1}: " | |
| 215 | + f"{len(valid_paths)}/{len(batch_paths)} successful" | |
| 216 | + ) | |
| 217 | + | |
| 218 | + except Exception as e: | |
| 219 | + logger.error( | |
| 220 | + f"Failed to generate image embeddings for batch {i // batch_size + 1}: {e}" | |
| 221 | + ) | |
| 222 | + # Add None for all images in failed batch | |
| 223 | + all_embeddings.extend([None] * len(batch_paths)) | |
| 224 | + else: | |
| 225 | + # All images in batch failed to load | |
| 226 | + all_embeddings.extend([None] * len(batch_paths)) | |
| 227 | + | |
| 228 | + return all_embeddings | |
| 229 | + | |
| 230 | + def get_text_embedding_from_image( | |
| 231 | + self, image_path: Union[str, Path] | |
| 232 | + ) -> List[float]: | |
| 233 | + """Get text-based embedding by describing the image | |
| 234 | + This is useful for cross-modal search | |
| 235 | + | |
| 236 | + Note: This is a placeholder for future implementation | |
| 237 | + that could use vision models to generate text descriptions | |
| 238 | + | |
| 239 | + Args: | |
| 240 | + image_path: Path to image file | |
| 241 | + | |
| 242 | + Returns: | |
| 243 | + Text embedding vector | |
| 244 | + """ | |
| 245 | + # For now, we just return the image embedding | |
| 246 | + # In the future, this could use a vision-language model to generate | |
| 247 | + # a text description and then embed that | |
| 248 | + raise NotImplementedError("Text embedding from image not yet implemented") | |
| 249 | + | |
| 250 | + def cosine_similarity( | |
| 251 | + self, embedding1: List[float], embedding2: List[float] | |
| 252 | + ) -> float: | |
| 253 | + """Calculate cosine similarity between two embeddings | |
| 254 | + | |
| 255 | + Args: | |
| 256 | + embedding1: First embedding vector | |
| 257 | + embedding2: Second embedding vector | |
| 258 | + | |
| 259 | + Returns: | |
| 260 | + Cosine similarity score (0-1) | |
| 261 | + """ | |
| 262 | + vec1 = np.array(embedding1) | |
| 263 | + vec2 = np.array(embedding2) | |
| 264 | + | |
| 265 | + # Normalize vectors | |
| 266 | + vec1_norm = vec1 / np.linalg.norm(vec1) | |
| 267 | + vec2_norm = vec2 / np.linalg.norm(vec2) | |
| 268 | + | |
| 269 | + # Calculate cosine similarity | |
| 270 | + similarity = np.dot(vec1_norm, vec2_norm) | |
| 271 | + | |
| 272 | + return float(similarity) | |
| 273 | + | |
| 274 | + def get_embedding_dimensions(self) -> dict: | |
| 275 | + """Get the dimensions of text and image embeddings | |
| 276 | + | |
| 277 | + Returns: | |
| 278 | + Dictionary with text_dim and image_dim | |
| 279 | + """ | |
| 280 | + return {"text_dim": settings.text_dim, "image_dim": settings.image_dim} | |
| 281 | + | |
| 282 | + | |
| 283 | +# Global instance | |
| 284 | +_embedding_service: Optional[EmbeddingService] = None | |
| 285 | + | |
| 286 | + | |
| 287 | +def get_embedding_service() -> EmbeddingService: | |
| 288 | + """Get or create the global embedding service instance""" | |
| 289 | + global _embedding_service | |
| 290 | + if _embedding_service is None: | |
| 291 | + _embedding_service = EmbeddingService() | |
| 292 | + _embedding_service.connect_clip() | |
| 293 | + return _embedding_service | ... | ... |
| 1 | +++ a/app/services/milvus_service.py | |
| ... | ... | @@ -0,0 +1,480 @@ |
| 1 | +""" | |
| 2 | +Milvus Service for Vector Storage and Similarity Search | |
| 3 | +Manages text and image embeddings in separate collections | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import logging | |
| 7 | +from typing import Any, Dict, List, Optional | |
| 8 | + | |
| 9 | +from pymilvus import ( | |
| 10 | + DataType, | |
| 11 | + MilvusClient, | |
| 12 | +) | |
| 13 | + | |
| 14 | +from app.config import settings | |
| 15 | + | |
| 16 | +logger = logging.getLogger(__name__) | |
| 17 | + | |
| 18 | + | |
| 19 | +class MilvusService: | |
| 20 | + """Service for managing vector embeddings in Milvus""" | |
| 21 | + | |
| 22 | + def __init__(self, uri: Optional[str] = None): | |
| 23 | + """Initialize Milvus service | |
| 24 | + | |
| 25 | + Args: | |
| 26 | + uri: Milvus connection URI. If None, uses settings.milvus_uri | |
| 27 | + """ | |
| 28 | + if uri: | |
| 29 | + self.uri = uri | |
| 30 | + else: | |
| 31 | + # Use absolute path for Milvus Lite | |
| 32 | + self.uri = settings.milvus_uri_absolute | |
| 33 | + self.text_collection_name = settings.text_collection_name | |
| 34 | + self.image_collection_name = settings.image_collection_name | |
| 35 | + self.text_dim = settings.text_dim | |
| 36 | + self.image_dim = settings.image_dim | |
| 37 | + | |
| 38 | + # Use MilvusClient for simplified operations | |
| 39 | + self._client: Optional[MilvusClient] = None | |
| 40 | + | |
| 41 | + logger.info(f"Initializing Milvus service with URI: {self.uri}") | |
| 42 | + | |
| 43 | + def is_connected(self) -> bool: | |
| 44 | + """Check if connected to Milvus""" | |
| 45 | + return self._client is not None | |
| 46 | + | |
| 47 | + def connect(self) -> None: | |
| 48 | + """Connect to Milvus""" | |
| 49 | + if self.is_connected(): | |
| 50 | + return | |
| 51 | + try: | |
| 52 | + self._client = MilvusClient(uri=self.uri) | |
| 53 | + logger.info(f"Connected to Milvus at {self.uri}") | |
| 54 | + except Exception as e: | |
| 55 | + logger.error(f"Failed to connect to Milvus: {e}") | |
| 56 | + raise | |
| 57 | + | |
| 58 | + def disconnect(self) -> None: | |
| 59 | + """Disconnect from Milvus""" | |
| 60 | + if self._client: | |
| 61 | + self._client.close() | |
| 62 | + self._client = None | |
| 63 | + logger.info("Disconnected from Milvus") | |
| 64 | + | |
| 65 | + @property | |
| 66 | + def client(self) -> MilvusClient: | |
| 67 | + """Get the Milvus client""" | |
| 68 | + if not self._client: | |
| 69 | + raise RuntimeError("Milvus not connected. Call connect() first.") | |
| 70 | + return self._client | |
| 71 | + | |
| 72 | + def create_text_collection(self, recreate: bool = False) -> None: | |
| 73 | + """Create collection for text embeddings with product metadata | |
| 74 | + | |
| 75 | + Args: | |
| 76 | + recreate: If True, drop existing collection and recreate | |
| 77 | + """ | |
| 78 | + if recreate and self.client.has_collection(self.text_collection_name): | |
| 79 | + self.client.drop_collection(self.text_collection_name) | |
| 80 | + logger.info(f"Dropped existing collection: {self.text_collection_name}") | |
| 81 | + | |
| 82 | + if self.client.has_collection(self.text_collection_name): | |
| 83 | + logger.info(f"Text collection already exists: {self.text_collection_name}") | |
| 84 | + return | |
| 85 | + | |
| 86 | + # Create collection with schema (includes metadata fields) | |
| 87 | + schema = MilvusClient.create_schema( | |
| 88 | + auto_id=False, | |
| 89 | + enable_dynamic_field=True, # Allow additional metadata fields | |
| 90 | + ) | |
| 91 | + | |
| 92 | + # Core fields | |
| 93 | + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | |
| 94 | + schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) | |
| 95 | + schema.add_field( | |
| 96 | + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim | |
| 97 | + ) | |
| 98 | + | |
| 99 | + # Product metadata fields | |
| 100 | + schema.add_field( | |
| 101 | + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 | |
| 102 | + ) | |
| 103 | + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | |
| 104 | + schema.add_field( | |
| 105 | + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 | |
| 106 | + ) | |
| 107 | + schema.add_field( | |
| 108 | + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 | |
| 109 | + ) | |
| 110 | + schema.add_field( | |
| 111 | + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 | |
| 112 | + ) | |
| 113 | + schema.add_field( | |
| 114 | + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 | |
| 115 | + ) | |
| 116 | + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) | |
| 117 | + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) | |
| 118 | + | |
| 119 | + # Create index parameters | |
| 120 | + index_params = self.client.prepare_index_params() | |
| 121 | + index_params.add_index( | |
| 122 | + field_name="embedding", | |
| 123 | + index_type="AUTOINDEX", | |
| 124 | + metric_type="COSINE", | |
| 125 | + ) | |
| 126 | + | |
| 127 | + # Create collection | |
| 128 | + self.client.create_collection( | |
| 129 | + collection_name=self.text_collection_name, | |
| 130 | + schema=schema, | |
| 131 | + index_params=index_params, | |
| 132 | + ) | |
| 133 | + | |
| 134 | + logger.info( | |
| 135 | + f"Created text collection with metadata: {self.text_collection_name}" | |
| 136 | + ) | |
| 137 | + | |
| 138 | + def create_image_collection(self, recreate: bool = False) -> None: | |
| 139 | + """Create collection for image embeddings with product metadata | |
| 140 | + | |
| 141 | + Args: | |
| 142 | + recreate: If True, drop existing collection and recreate | |
| 143 | + """ | |
| 144 | + if recreate and self.client.has_collection(self.image_collection_name): | |
| 145 | + self.client.drop_collection(self.image_collection_name) | |
| 146 | + logger.info(f"Dropped existing collection: {self.image_collection_name}") | |
| 147 | + | |
| 148 | + if self.client.has_collection(self.image_collection_name): | |
| 149 | + logger.info( | |
| 150 | + f"Image collection already exists: {self.image_collection_name}" | |
| 151 | + ) | |
| 152 | + return | |
| 153 | + | |
| 154 | + # Create collection with schema (includes metadata fields) | |
| 155 | + schema = MilvusClient.create_schema( | |
| 156 | + auto_id=False, | |
| 157 | + enable_dynamic_field=True, # Allow additional metadata fields | |
| 158 | + ) | |
| 159 | + | |
| 160 | + # Core fields | |
| 161 | + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | |
| 162 | + schema.add_field( | |
| 163 | + field_name="image_path", datatype=DataType.VARCHAR, max_length=500 | |
| 164 | + ) | |
| 165 | + schema.add_field( | |
| 166 | + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim | |
| 167 | + ) | |
| 168 | + | |
| 169 | + # Product metadata fields | |
| 170 | + schema.add_field( | |
| 171 | + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 | |
| 172 | + ) | |
| 173 | + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | |
| 174 | + schema.add_field( | |
| 175 | + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 | |
| 176 | + ) | |
| 177 | + schema.add_field( | |
| 178 | + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 | |
| 179 | + ) | |
| 180 | + schema.add_field( | |
| 181 | + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 | |
| 182 | + ) | |
| 183 | + schema.add_field( | |
| 184 | + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 | |
| 185 | + ) | |
| 186 | + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) | |
| 187 | + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) | |
| 188 | + | |
| 189 | + # Create index parameters | |
| 190 | + index_params = self.client.prepare_index_params() | |
| 191 | + index_params.add_index( | |
| 192 | + field_name="embedding", | |
| 193 | + index_type="AUTOINDEX", | |
| 194 | + metric_type="COSINE", | |
| 195 | + ) | |
| 196 | + | |
| 197 | + # Create collection | |
| 198 | + self.client.create_collection( | |
| 199 | + collection_name=self.image_collection_name, | |
| 200 | + schema=schema, | |
| 201 | + index_params=index_params, | |
| 202 | + ) | |
| 203 | + | |
| 204 | + logger.info( | |
| 205 | + f"Created image collection with metadata: {self.image_collection_name}" | |
| 206 | + ) | |
| 207 | + | |
| 208 | + def insert_text_embeddings( | |
| 209 | + self, | |
| 210 | + embeddings: List[Dict[str, Any]], | |
| 211 | + ) -> int: | |
| 212 | + """Insert text embeddings with metadata into collection | |
| 213 | + | |
| 214 | + Args: | |
| 215 | + embeddings: List of dictionaries with keys: | |
| 216 | + - id: unique ID (product ID) | |
| 217 | + - text: the text that was embedded | |
| 218 | + - embedding: the embedding vector | |
| 219 | + - productDisplayName, gender, masterCategory, etc. (metadata) | |
| 220 | + | |
| 221 | + Returns: | |
| 222 | + Number of inserted embeddings | |
| 223 | + """ | |
| 224 | + if not embeddings: | |
| 225 | + return 0 | |
| 226 | + | |
| 227 | + try: | |
| 228 | + # Insert data directly (all fields including metadata) | |
| 229 | + # Milvus will accept all fields defined in schema + dynamic fields | |
| 230 | + data = embeddings | |
| 231 | + | |
| 232 | + # Insert data | |
| 233 | + result = self.client.insert( | |
| 234 | + collection_name=self.text_collection_name, | |
| 235 | + data=data, | |
| 236 | + ) | |
| 237 | + | |
| 238 | + logger.info(f"Inserted {len(data)} text embeddings") | |
| 239 | + return len(data) | |
| 240 | + | |
| 241 | + except Exception as e: | |
| 242 | + logger.error(f"Failed to insert text embeddings: {e}") | |
| 243 | + raise | |
| 244 | + | |
| 245 | + def insert_image_embeddings( | |
| 246 | + self, | |
| 247 | + embeddings: List[Dict[str, Any]], | |
| 248 | + ) -> int: | |
| 249 | + """Insert image embeddings with metadata into collection | |
| 250 | + | |
| 251 | + Args: | |
| 252 | + embeddings: List of dictionaries with keys: | |
| 253 | + - id: unique ID (product ID) | |
| 254 | + - image_path: path to the image file | |
| 255 | + - embedding: the embedding vector | |
| 256 | + - productDisplayName, gender, masterCategory, etc. (metadata) | |
| 257 | + | |
| 258 | + Returns: | |
| 259 | + Number of inserted embeddings | |
| 260 | + """ | |
| 261 | + if not embeddings: | |
| 262 | + return 0 | |
| 263 | + | |
| 264 | + try: | |
| 265 | + # Insert data directly (all fields including metadata) | |
| 266 | + # Milvus will accept all fields defined in schema + dynamic fields | |
| 267 | + data = embeddings | |
| 268 | + | |
| 269 | + # Insert data | |
| 270 | + result = self.client.insert( | |
| 271 | + collection_name=self.image_collection_name, | |
| 272 | + data=data, | |
| 273 | + ) | |
| 274 | + | |
| 275 | + logger.info(f"Inserted {len(data)} image embeddings") | |
| 276 | + return len(data) | |
| 277 | + | |
| 278 | + except Exception as e: | |
| 279 | + logger.error(f"Failed to insert image embeddings: {e}") | |
| 280 | + raise | |
| 281 | + | |
| 282 | + def search_similar_text( | |
| 283 | + self, | |
| 284 | + query_embedding: List[float], | |
| 285 | + limit: int = 10, | |
| 286 | + filters: Optional[str] = None, | |
| 287 | + output_fields: Optional[List[str]] = None, | |
| 288 | + ) -> List[Dict[str, Any]]: | |
| 289 | + """Search for similar text embeddings | |
| 290 | + | |
| 291 | + Args: | |
| 292 | + query_embedding: Query embedding vector | |
| 293 | + limit: Maximum number of results | |
| 294 | + filters: Filter expression (e.g., "product_id in [1, 2, 3]") | |
| 295 | + output_fields: List of fields to return | |
| 296 | + | |
| 297 | + Returns: | |
| 298 | + List of search results with fields: | |
| 299 | + - id: embedding ID | |
| 300 | + - distance: similarity distance | |
| 301 | + - entity: the matched entity with requested fields | |
| 302 | + """ | |
| 303 | + try: | |
| 304 | + if output_fields is None: | |
| 305 | + output_fields = [ | |
| 306 | + "id", | |
| 307 | + "text", | |
| 308 | + "productDisplayName", | |
| 309 | + "gender", | |
| 310 | + "masterCategory", | |
| 311 | + "subCategory", | |
| 312 | + "articleType", | |
| 313 | + "baseColour", | |
| 314 | + ] | |
| 315 | + | |
| 316 | + search_params = {} | |
| 317 | + if filters: | |
| 318 | + search_params["expr"] = filters | |
| 319 | + | |
| 320 | + results = self.client.search( | |
| 321 | + collection_name=self.text_collection_name, | |
| 322 | + data=[query_embedding], | |
| 323 | + limit=limit, | |
| 324 | + output_fields=output_fields, | |
| 325 | + search_params=search_params, | |
| 326 | + ) | |
| 327 | + | |
| 328 | + # Format results | |
| 329 | + formatted_results = [] | |
| 330 | + if results and len(results) > 0: | |
| 331 | + for hit in results[0]: | |
| 332 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | |
| 333 | + # Extract fields from entity | |
| 334 | + entity = hit.get("entity", {}) | |
| 335 | + for field in output_fields: | |
| 336 | + if field in entity: | |
| 337 | + result[field] = entity.get(field) | |
| 338 | + formatted_results.append(result) | |
| 339 | + | |
| 340 | + logger.debug(f"Found {len(formatted_results)} similar text embeddings") | |
| 341 | + return formatted_results | |
| 342 | + | |
| 343 | + except Exception as e: | |
| 344 | + logger.error(f"Failed to search similar text: {e}") | |
| 345 | + raise | |
| 346 | + | |
| 347 | + def search_similar_images( | |
| 348 | + self, | |
| 349 | + query_embedding: List[float], | |
| 350 | + limit: int = 10, | |
| 351 | + filters: Optional[str] = None, | |
| 352 | + output_fields: Optional[List[str]] = None, | |
| 353 | + ) -> List[Dict[str, Any]]: | |
| 354 | + """Search for similar image embeddings | |
| 355 | + | |
| 356 | + Args: | |
| 357 | + query_embedding: Query embedding vector | |
| 358 | + limit: Maximum number of results | |
| 359 | + filters: Filter expression (e.g., "product_id in [1, 2, 3]") | |
| 360 | + output_fields: List of fields to return | |
| 361 | + | |
| 362 | + Returns: | |
| 363 | + List of search results with fields: | |
| 364 | + - id: embedding ID | |
| 365 | + - distance: similarity distance | |
| 366 | + - entity: the matched entity with requested fields | |
| 367 | + """ | |
| 368 | + try: | |
| 369 | + if output_fields is None: | |
| 370 | + output_fields = [ | |
| 371 | + "id", | |
| 372 | + "image_path", | |
| 373 | + "productDisplayName", | |
| 374 | + "gender", | |
| 375 | + "masterCategory", | |
| 376 | + "subCategory", | |
| 377 | + "articleType", | |
| 378 | + "baseColour", | |
| 379 | + ] | |
| 380 | + | |
| 381 | + search_params = {} | |
| 382 | + if filters: | |
| 383 | + search_params["expr"] = filters | |
| 384 | + | |
| 385 | + results = self.client.search( | |
| 386 | + collection_name=self.image_collection_name, | |
| 387 | + data=[query_embedding], | |
| 388 | + limit=limit, | |
| 389 | + output_fields=output_fields, | |
| 390 | + search_params=search_params, | |
| 391 | + ) | |
| 392 | + | |
| 393 | + # Format results | |
| 394 | + formatted_results = [] | |
| 395 | + if results and len(results) > 0: | |
| 396 | + for hit in results[0]: | |
| 397 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | |
| 398 | + # Extract fields from entity | |
| 399 | + entity = hit.get("entity", {}) | |
| 400 | + for field in output_fields: | |
| 401 | + if field in entity: | |
| 402 | + result[field] = entity.get(field) | |
| 403 | + formatted_results.append(result) | |
| 404 | + | |
| 405 | + logger.debug(f"Found {len(formatted_results)} similar image embeddings") | |
| 406 | + return formatted_results | |
| 407 | + | |
| 408 | + except Exception as e: | |
| 409 | + logger.error(f"Failed to search similar images: {e}") | |
| 410 | + raise | |
| 411 | + | |
| 412 | + def get_collection_stats(self, collection_name: str) -> Dict[str, Any]: | |
| 413 | + """Get statistics for a collection | |
| 414 | + | |
| 415 | + Args: | |
| 416 | + collection_name: Name of the collection | |
| 417 | + | |
| 418 | + Returns: | |
| 419 | + Dictionary with collection statistics | |
| 420 | + """ | |
| 421 | + try: | |
| 422 | + stats = self.client.get_collection_stats(collection_name) | |
| 423 | + return { | |
| 424 | + "collection_name": collection_name, | |
| 425 | + "row_count": stats.get("row_count", 0), | |
| 426 | + } | |
| 427 | + except Exception as e: | |
| 428 | + logger.error(f"Failed to get collection stats: {e}") | |
| 429 | + return {"collection_name": collection_name, "row_count": 0} | |
| 430 | + | |
| 431 | + def delete_by_ids(self, collection_name: str, ids: List[int]) -> int: | |
| 432 | + """Delete embeddings by IDs | |
| 433 | + | |
| 434 | + Args: | |
| 435 | + collection_name: Name of the collection | |
| 436 | + ids: List of IDs to delete | |
| 437 | + | |
| 438 | + Returns: | |
| 439 | + Number of deleted embeddings | |
| 440 | + """ | |
| 441 | + if not ids: | |
| 442 | + return 0 | |
| 443 | + | |
| 444 | + try: | |
| 445 | + self.client.delete( | |
| 446 | + collection_name=collection_name, | |
| 447 | + ids=ids, | |
| 448 | + ) | |
| 449 | + logger.info(f"Deleted {len(ids)} embeddings from {collection_name}") | |
| 450 | + return len(ids) | |
| 451 | + except Exception as e: | |
| 452 | + logger.error(f"Failed to delete embeddings: {e}") | |
| 453 | + raise | |
| 454 | + | |
| 455 | + def clear_collection(self, collection_name: str) -> None: | |
| 456 | + """Clear all data from a collection | |
| 457 | + | |
| 458 | + Args: | |
| 459 | + collection_name: Name of the collection | |
| 460 | + """ | |
| 461 | + try: | |
| 462 | + if self.client.has_collection(collection_name): | |
| 463 | + self.client.drop_collection(collection_name) | |
| 464 | + logger.info(f"Dropped collection: {collection_name}") | |
| 465 | + except Exception as e: | |
| 466 | + logger.error(f"Failed to clear collection: {e}") | |
| 467 | + raise | |
| 468 | + | |
| 469 | + | |
| 470 | +# Global instance | |
| 471 | +_milvus_service: Optional[MilvusService] = None | |
| 472 | + | |
| 473 | + | |
| 474 | +def get_milvus_service() -> MilvusService: | |
| 475 | + """Get or create the global Milvus service instance""" | |
| 476 | + global _milvus_service | |
| 477 | + if _milvus_service is None: | |
| 478 | + _milvus_service = MilvusService() | |
| 479 | + _milvus_service.connect() | |
| 480 | + return _milvus_service | ... | ... |
| 1 | +++ a/app/tools/__init__.py | |
| ... | ... | @@ -0,0 +1,17 @@ |
| 1 | +""" | |
| 2 | +LangChain Tools for Product Search and Discovery | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from app.tools.search_tools import ( | |
| 6 | + analyze_image_style, | |
| 7 | + get_all_tools, | |
| 8 | + search_by_image, | |
| 9 | + search_products, | |
| 10 | +) | |
| 11 | + | |
| 12 | +__all__ = [ | |
| 13 | + "search_products", | |
| 14 | + "search_by_image", | |
| 15 | + "analyze_image_style", | |
| 16 | + "get_all_tools", | |
| 17 | +] | ... | ... |
| 1 | +++ a/app/tools/search_tools.py | |
| ... | ... | @@ -0,0 +1,294 @@ |
| 1 | +""" | |
| 2 | +Search Tools for Product Discovery | |
| 3 | +Provides text-based, image-based, and VLM reasoning capabilities | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import base64 | |
| 7 | +import logging | |
| 8 | +from pathlib import Path | |
| 9 | +from typing import Optional | |
| 10 | + | |
| 11 | +from langchain_core.tools import tool | |
| 12 | +from openai import OpenAI | |
| 13 | + | |
| 14 | +from app.config import settings | |
| 15 | +from app.services.embedding_service import EmbeddingService | |
| 16 | +from app.services.milvus_service import MilvusService | |
| 17 | + | |
| 18 | +logger = logging.getLogger(__name__) | |
| 19 | + | |
| 20 | +# Initialize services as singletons | |
| 21 | +_embedding_service: Optional[EmbeddingService] = None | |
| 22 | +_milvus_service: Optional[MilvusService] = None | |
| 23 | +_openai_client: Optional[OpenAI] = None | |
| 24 | + | |
| 25 | + | |
| 26 | +def get_embedding_service() -> EmbeddingService: | |
| 27 | + global _embedding_service | |
| 28 | + if _embedding_service is None: | |
| 29 | + _embedding_service = EmbeddingService() | |
| 30 | + return _embedding_service | |
| 31 | + | |
| 32 | + | |
| 33 | +def get_milvus_service() -> MilvusService: | |
| 34 | + global _milvus_service | |
| 35 | + if _milvus_service is None: | |
| 36 | + _milvus_service = MilvusService() | |
| 37 | + _milvus_service.connect() | |
| 38 | + return _milvus_service | |
| 39 | + | |
| 40 | + | |
| 41 | +def get_openai_client() -> OpenAI: | |
| 42 | + global _openai_client | |
| 43 | + if _openai_client is None: | |
| 44 | + _openai_client = OpenAI(api_key=settings.openai_api_key) | |
| 45 | + return _openai_client | |
| 46 | + | |
| 47 | + | |
| 48 | +@tool | |
| 49 | +def search_products(query: str, limit: int = 5) -> str: | |
| 50 | + """Search for fashion products using natural language descriptions. | |
| 51 | + | |
| 52 | + Use when users describe what they want: | |
| 53 | + - "Find me red summer dresses" | |
| 54 | + - "Show me blue running shoes" | |
| 55 | + - "I want casual shirts for men" | |
| 56 | + | |
| 57 | + Args: | |
| 58 | + query: Natural language product description | |
| 59 | + limit: Maximum number of results (1-20) | |
| 60 | + | |
| 61 | + Returns: | |
| 62 | + Formatted string with product information | |
| 63 | + """ | |
| 64 | + try: | |
| 65 | + logger.info(f"Searching products: '{query}', limit: {limit}") | |
| 66 | + | |
| 67 | + embedding_service = get_embedding_service() | |
| 68 | + milvus_service = get_milvus_service() | |
| 69 | + | |
| 70 | + if not milvus_service.is_connected(): | |
| 71 | + milvus_service.connect() | |
| 72 | + | |
| 73 | + query_embedding = embedding_service.get_text_embedding(query) | |
| 74 | + | |
| 75 | + results = milvus_service.search_similar_text( | |
| 76 | + query_embedding=query_embedding, | |
| 77 | + limit=min(limit, 20), | |
| 78 | + filters=None, | |
| 79 | + output_fields=[ | |
| 80 | + "id", | |
| 81 | + "productDisplayName", | |
| 82 | + "gender", | |
| 83 | + "masterCategory", | |
| 84 | + "subCategory", | |
| 85 | + "articleType", | |
| 86 | + "baseColour", | |
| 87 | + "season", | |
| 88 | + "usage", | |
| 89 | + ], | |
| 90 | + ) | |
| 91 | + | |
| 92 | + if not results: | |
| 93 | + return "No products found matching your search." | |
| 94 | + | |
| 95 | + output = f"Found {len(results)} product(s):\n\n" | |
| 96 | + | |
| 97 | + for idx, product in enumerate(results, 1): | |
| 98 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | |
| 99 | + output += f" ID: {product.get('id', 'N/A')}\n" | |
| 100 | + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" | |
| 101 | + output += f" Color: {product.get('baseColour', 'N/A')}\n" | |
| 102 | + output += f" Gender: {product.get('gender', 'N/A')}\n" | |
| 103 | + | |
| 104 | + if product.get("season"): | |
| 105 | + output += f" Season: {product.get('season')}\n" | |
| 106 | + if product.get("usage"): | |
| 107 | + output += f" Usage: {product.get('usage')}\n" | |
| 108 | + | |
| 109 | + if "distance" in product: | |
| 110 | + similarity = 1 - product["distance"] | |
| 111 | + output += f" Relevance: {similarity:.2%}\n" | |
| 112 | + | |
| 113 | + output += "\n" | |
| 114 | + | |
| 115 | + return output.strip() | |
| 116 | + | |
| 117 | + except Exception as e: | |
| 118 | + logger.error(f"Error searching products: {e}", exc_info=True) | |
| 119 | + return f"Error searching products: {str(e)}" | |
| 120 | + | |
| 121 | + | |
| 122 | +@tool | |
| 123 | +def search_by_image(image_path: str, limit: int = 5) -> str: | |
| 124 | + """Find similar fashion products using an image. | |
| 125 | + | |
| 126 | + Use when users want visually similar items: | |
| 127 | + - User uploads an image and asks "find similar items" | |
| 128 | + - "Show me products that look like this" | |
| 129 | + | |
| 130 | + Args: | |
| 131 | + image_path: Path to the image file | |
| 132 | + limit: Maximum number of results (1-20) | |
| 133 | + | |
| 134 | + Returns: | |
| 135 | + Formatted string with similar products | |
| 136 | + """ | |
| 137 | + try: | |
| 138 | + logger.info(f"Image search: '{image_path}', limit: {limit}") | |
| 139 | + | |
| 140 | + img_path = Path(image_path) | |
| 141 | + if not img_path.exists(): | |
| 142 | + return f"Error: Image file not found at '{image_path}'" | |
| 143 | + | |
| 144 | + embedding_service = get_embedding_service() | |
| 145 | + milvus_service = get_milvus_service() | |
| 146 | + | |
| 147 | + if not milvus_service.is_connected(): | |
| 148 | + milvus_service.connect() | |
| 149 | + | |
| 150 | + if ( | |
| 151 | + not hasattr(embedding_service, "clip_client") | |
| 152 | + or embedding_service.clip_client is None | |
| 153 | + ): | |
| 154 | + embedding_service.connect_clip() | |
| 155 | + | |
| 156 | + image_embedding = embedding_service.get_image_embedding(image_path) | |
| 157 | + | |
| 158 | + if image_embedding is None: | |
| 159 | + return "Error: Failed to generate embedding for image" | |
| 160 | + | |
| 161 | + results = milvus_service.search_similar_images( | |
| 162 | + query_embedding=image_embedding, | |
| 163 | + limit=min(limit + 1, 21), | |
| 164 | + filters=None, | |
| 165 | + output_fields=[ | |
| 166 | + "id", | |
| 167 | + "image_path", | |
| 168 | + "productDisplayName", | |
| 169 | + "gender", | |
| 170 | + "masterCategory", | |
| 171 | + "subCategory", | |
| 172 | + "articleType", | |
| 173 | + "baseColour", | |
| 174 | + "season", | |
| 175 | + "usage", | |
| 176 | + ], | |
| 177 | + ) | |
| 178 | + | |
| 179 | + if not results: | |
| 180 | + return "No similar products found." | |
| 181 | + | |
| 182 | + # Filter out the query image itself | |
| 183 | + query_id = img_path.stem | |
| 184 | + filtered_results = [] | |
| 185 | + for result in results: | |
| 186 | + result_path = result.get("image_path", "") | |
| 187 | + if Path(result_path).stem != query_id: | |
| 188 | + filtered_results.append(result) | |
| 189 | + if len(filtered_results) >= limit: | |
| 190 | + break | |
| 191 | + | |
| 192 | + if not filtered_results: | |
| 193 | + return "No similar products found." | |
| 194 | + | |
| 195 | + output = f"Found {len(filtered_results)} visually similar product(s):\n\n" | |
| 196 | + | |
| 197 | + for idx, product in enumerate(filtered_results, 1): | |
| 198 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | |
| 199 | + output += f" ID: {product.get('id', 'N/A')}\n" | |
| 200 | + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" | |
| 201 | + output += f" Color: {product.get('baseColour', 'N/A')}\n" | |
| 202 | + output += f" Gender: {product.get('gender', 'N/A')}\n" | |
| 203 | + | |
| 204 | + if product.get("season"): | |
| 205 | + output += f" Season: {product.get('season')}\n" | |
| 206 | + if product.get("usage"): | |
| 207 | + output += f" Usage: {product.get('usage')}\n" | |
| 208 | + | |
| 209 | + if "distance" in product: | |
| 210 | + similarity = 1 - product["distance"] | |
| 211 | + output += f" Visual Similarity: {similarity:.2%}\n" | |
| 212 | + | |
| 213 | + output += "\n" | |
| 214 | + | |
| 215 | + return output.strip() | |
| 216 | + | |
| 217 | + except Exception as e: | |
| 218 | + logger.error(f"Error in image search: {e}", exc_info=True) | |
| 219 | + return f"Error searching by image: {str(e)}" | |
| 220 | + | |
| 221 | + | |
| 222 | +@tool | |
| 223 | +def analyze_image_style(image_path: str) -> str: | |
| 224 | + """Analyze a fashion product image using AI vision to extract detailed style information. | |
| 225 | + | |
| 226 | + Use when you need to understand style/attributes from an image: | |
| 227 | + - Understand the style, color, pattern of a product | |
| 228 | + - Extract attributes like "casual", "formal", "vintage" | |
| 229 | + - Get detailed descriptions for subsequent searches | |
| 230 | + | |
| 231 | + Args: | |
| 232 | + image_path: Path to the image file | |
| 233 | + | |
| 234 | + Returns: | |
| 235 | + Detailed text description of the product's visual attributes | |
| 236 | + """ | |
| 237 | + try: | |
| 238 | + logger.info(f"Analyzing image with VLM: '{image_path}'") | |
| 239 | + | |
| 240 | + img_path = Path(image_path) | |
| 241 | + if not img_path.exists(): | |
| 242 | + return f"Error: Image file not found at '{image_path}'" | |
| 243 | + | |
| 244 | + with open(img_path, "rb") as image_file: | |
| 245 | + image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| 246 | + | |
| 247 | + prompt = """Analyze this fashion product image and provide a detailed description. | |
| 248 | + | |
| 249 | +Include: | |
| 250 | +- Product type (e.g., shirt, dress, shoes, pants, bag) | |
| 251 | +- Primary colors | |
| 252 | +- Style/design (e.g., casual, formal, sporty, vintage, modern) | |
| 253 | +- Pattern or texture (e.g., plain, striped, checked, floral) | |
| 254 | +- Key features (e.g., collar type, sleeve length, fit) | |
| 255 | +- Material appearance (if obvious, e.g., denim, cotton, leather) | |
| 256 | +- Suitable occasion (e.g., office wear, party, casual, sports) | |
| 257 | + | |
| 258 | +Provide a comprehensive yet concise description (3-4 sentences).""" | |
| 259 | + | |
| 260 | + client = get_openai_client() | |
| 261 | + response = client.chat.completions.create( | |
| 262 | + model="gpt-4o-mini", | |
| 263 | + messages=[ | |
| 264 | + { | |
| 265 | + "role": "user", | |
| 266 | + "content": [ | |
| 267 | + {"type": "text", "text": prompt}, | |
| 268 | + { | |
| 269 | + "type": "image_url", | |
| 270 | + "image_url": { | |
| 271 | + "url": f"data:image/jpeg;base64,{image_data}", | |
| 272 | + "detail": "high", | |
| 273 | + }, | |
| 274 | + }, | |
| 275 | + ], | |
| 276 | + } | |
| 277 | + ], | |
| 278 | + max_tokens=500, | |
| 279 | + temperature=0.3, | |
| 280 | + ) | |
| 281 | + | |
| 282 | + analysis = response.choices[0].message.content.strip() | |
| 283 | + logger.info("VLM analysis completed") | |
| 284 | + | |
| 285 | + return analysis | |
| 286 | + | |
| 287 | + except Exception as e: | |
| 288 | + logger.error(f"Error analyzing image: {e}", exc_info=True) | |
| 289 | + return f"Error analyzing image: {str(e)}" | |
| 290 | + | |
| 291 | + | |
| 292 | +def get_all_tools(): | |
| 293 | + """Get all available tools for the agent""" | |
| 294 | + return [search_products, search_by_image, analyze_image_style] | ... | ... |
No preview for this file type
| 1 | +++ a/docker-compose.yml | |
| ... | ... | @@ -0,0 +1,76 @@ |
| 1 | +version: '3.5' | |
| 2 | + | |
| 3 | +services: | |
| 4 | + etcd: | |
| 5 | + container_name: milvus-etcd | |
| 6 | + image: quay.io/coreos/etcd:v3.5.5 | |
| 7 | + environment: | |
| 8 | + - ETCD_AUTO_COMPACTION_MODE=revision | |
| 9 | + - ETCD_AUTO_COMPACTION_RETENTION=1000 | |
| 10 | + - ETCD_QUOTA_BACKEND_BYTES=4294967296 | |
| 11 | + - ETCD_SNAPSHOT_COUNT=50000 | |
| 12 | + volumes: | |
| 13 | + - ./volumes/etcd:/etcd | |
| 14 | + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd | |
| 15 | + healthcheck: | |
| 16 | + test: ["CMD", "etcdctl", "endpoint", "health"] | |
| 17 | + interval: 30s | |
| 18 | + timeout: 20s | |
| 19 | + retries: 3 | |
| 20 | + | |
| 21 | + minio: | |
| 22 | + container_name: milvus-minio | |
| 23 | + image: minio/minio:RELEASE.2023-03-20T20-16-18Z | |
| 24 | + environment: | |
| 25 | + MINIO_ACCESS_KEY: minioadmin | |
| 26 | + MINIO_SECRET_KEY: minioadmin | |
| 27 | + ports: | |
| 28 | + - "9001:9001" | |
| 29 | + - "9000:9000" | |
| 30 | + volumes: | |
| 31 | + - ./volumes/minio:/minio_data | |
| 32 | + command: minio server /minio_data --console-address ":9001" | |
| 33 | + healthcheck: | |
| 34 | + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] | |
| 35 | + interval: 30s | |
| 36 | + timeout: 20s | |
| 37 | + retries: 3 | |
| 38 | + | |
| 39 | + standalone: | |
| 40 | + container_name: milvus-standalone | |
| 41 | + image: milvusdb/milvus:v2.4.0 | |
| 42 | + command: ["milvus", "run", "standalone"] | |
| 43 | + security_opt: | |
| 44 | + - seccomp:unconfined | |
| 45 | + environment: | |
| 46 | + ETCD_ENDPOINTS: etcd:2379 | |
| 47 | + MINIO_ADDRESS: minio:9000 | |
| 48 | + volumes: | |
| 49 | + - ./volumes/milvus:/var/lib/milvus | |
| 50 | + healthcheck: | |
| 51 | + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] | |
| 52 | + interval: 30s | |
| 53 | + start_period: 90s | |
| 54 | + timeout: 20s | |
| 55 | + retries: 3 | |
| 56 | + ports: | |
| 57 | + - "19530:19530" | |
| 58 | + - "9091:9091" | |
| 59 | + depends_on: | |
| 60 | + - "etcd" | |
| 61 | + - "minio" | |
| 62 | + | |
| 63 | + attu: | |
| 64 | + container_name: milvus-attu | |
| 65 | + image: zilliz/attu:v2.4 | |
| 66 | + environment: | |
| 67 | + MILVUS_URL: milvus-standalone:19530 | |
| 68 | + ports: | |
| 69 | + - "8000:3000" | |
| 70 | + depends_on: | |
| 71 | + - "standalone" | |
| 72 | + | |
| 73 | +networks: | |
| 74 | + default: | |
| 75 | + name: milvus | |
| 76 | + | ... | ... |
| 1 | +++ a/docs/DEPLOY_CENTOS8.md | |
| ... | ... | @@ -0,0 +1,216 @@ |
| 1 | +# OmniShopAgent centOS 8 部署指南 | |
| 2 | + | |
| 3 | +## 一、环境要求 | |
| 4 | + | |
| 5 | +| 组件 | 要求 | | |
| 6 | +|------|------| | |
| 7 | +| 操作系统 | CentOS 8.x | | |
| 8 | +| Python | 3.12+(LangChain 1.x 要求 3.10+) | | |
| 9 | +| 内存 | 建议 8GB+(Milvus + CLIP 较占内存) | | |
| 10 | +| 磁盘 | 建议 20GB+(含数据集) | | |
| 11 | + | |
| 12 | +## 二、快速部署步骤 | |
| 13 | + | |
| 14 | +### 2.1 一键环境准备(推荐) | |
| 15 | + | |
| 16 | +```bash | |
| 17 | +cd /path/to/shop_agent | |
| 18 | +chmod +x scripts/*.sh | |
| 19 | +./scripts/setup_env_centos8.sh | |
| 20 | +``` | |
| 21 | + | |
| 22 | +该脚本会: | |
| 23 | +- 安装系统依赖(gcc、openssl-devel 等) | |
| 24 | +- 安装 Docker(用于 Milvus) | |
| 25 | +- 安装 Python 3.12(conda 或源码编译) | |
| 26 | +- 创建虚拟环境并安装 requirements.txt | |
| 27 | + | |
| 28 | +### 2.2 手动部署(分步执行) | |
| 29 | + | |
| 30 | +#### 步骤 1:安装系统依赖 | |
| 31 | + | |
| 32 | +```bash | |
| 33 | +sudo dnf install -y gcc gcc-c++ make openssl-devel bzip2-devel \ | |
| 34 | + libffi-devel sqlite-devel xz-devel zlib-devel curl wget git | |
| 35 | +``` | |
| 36 | + | |
| 37 | +#### 步骤 2:安装 Python 3.12 | |
| 38 | + | |
| 39 | +**方式 A:Miniconda(推荐)** | |
| 40 | + | |
| 41 | +```bash | |
| 42 | +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh | |
| 43 | +bash Miniconda3-latest-Linux-x86_64.sh | |
| 44 | +# 按提示安装后 | |
| 45 | +conda create -n shop_agent python=3.12 | |
| 46 | +conda activate shop_agent | |
| 47 | +``` | |
| 48 | + | |
| 49 | +**方式 B:从源码编译** | |
| 50 | + | |
| 51 | +```bash | |
| 52 | +sudo dnf groupinstall -y 'Development Tools' | |
| 53 | +cd /tmp | |
| 54 | +wget https://www.python.org/ftp/python/3.12.0/Python-3.12.0.tgz | |
| 55 | +tar xzf Python-3.12.0.tgz | |
| 56 | +cd Python-3.12.0 | |
| 57 | +./configure --enable-optimizations --prefix=/usr/local | |
| 58 | +make -j $(nproc) | |
| 59 | +sudo make altinstall | |
| 60 | +``` | |
| 61 | + | |
| 62 | +#### 步骤 3:安装 Docker | |
| 63 | + | |
| 64 | +```bash | |
| 65 | +sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo | |
| 66 | +sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin | |
| 67 | +sudo systemctl enable docker && sudo systemctl start docker | |
| 68 | +sudo usermod -aG docker $USER | |
| 69 | +# 执行 newgrp docker 或重新登录 | |
| 70 | +``` | |
| 71 | + | |
| 72 | +#### 步骤 4:创建虚拟环境并安装依赖 | |
| 73 | + | |
| 74 | +```bash | |
| 75 | +cd /path/to/shop_agent | |
| 76 | +python3.12 -m venv venv | |
| 77 | +source venv/bin/activate | |
| 78 | +pip install -U pip | |
| 79 | +pip install -r requirements.txt | |
| 80 | +``` | |
| 81 | + | |
| 82 | +#### 步骤 5:配置环境变量 | |
| 83 | + | |
| 84 | +```bash | |
| 85 | +cp .env.example .env | |
| 86 | +# 编辑 .env,至少配置: | |
| 87 | +# OPENAI_API_KEY=sk-xxx | |
| 88 | +# MILVUS_HOST=localhost | |
| 89 | +# MILVUS_PORT=19530 | |
| 90 | +# CLIP_SERVER_URL=grpc://localhost:51000 | |
| 91 | +``` | |
| 92 | + | |
| 93 | +## 三、数据准备 | |
| 94 | + | |
| 95 | +### 3.1 下载数据集 | |
| 96 | + | |
| 97 | +```bash | |
| 98 | +# 需先配置 Kaggle API:~/.kaggle/kaggle.json | |
| 99 | +python scripts/download_dataset.py | |
| 100 | +``` | |
| 101 | + | |
| 102 | +### 3.2 启动 Milvus 并索引数据 | |
| 103 | + | |
| 104 | +```bash | |
| 105 | +# 启动 Milvus | |
| 106 | +./scripts/run_milvus.sh | |
| 107 | + | |
| 108 | +# 等待就绪后,创建索引 | |
| 109 | +python scripts/index_data.py | |
| 110 | +``` | |
| 111 | + | |
| 112 | +## 四、启动服务 | |
| 113 | + | |
| 114 | +### 4.1 启动脚本说明 | |
| 115 | + | |
| 116 | +| 脚本 | 用途 | | |
| 117 | +|------|------| | |
| 118 | +| `start.sh` | 主启动脚本:启动 Milvus + Streamlit | | |
| 119 | +| `stop.sh` | 停止所有服务 | | |
| 120 | +| `run_milvus.sh` | 仅启动 Milvus | | |
| 121 | +| `run_clip.sh` | 仅启动 CLIP(图像搜索需此服务) | | |
| 122 | +| `check_services.sh` | 健康检查 | | |
| 123 | + | |
| 124 | +### 4.2 启动应用 | |
| 125 | + | |
| 126 | +```bash | |
| 127 | +# 方式 1:使用 start.sh(推荐) | |
| 128 | +./scripts/start.sh | |
| 129 | + | |
| 130 | +# 方式 2:分步启动 | |
| 131 | +# 终端 1:Milvus | |
| 132 | +./scripts/run_milvus.sh | |
| 133 | + | |
| 134 | +# 终端 2:CLIP(图像搜索需要) | |
| 135 | +./scripts/run_clip.sh | |
| 136 | + | |
| 137 | +# 终端 3:Streamlit | |
| 138 | +source venv/bin/activate | |
| 139 | +streamlit run app.py --server.port=8501 --server.address=0.0.0.0 | |
| 140 | +``` | |
| 141 | + | |
| 142 | +### 4.3 访问地址 | |
| 143 | + | |
| 144 | +- **Streamlit 应用**:http://服务器IP:8501 | |
| 145 | +- **Milvus Attu 管理界面**:http://服务器IP:8000 | |
| 146 | + | |
| 147 | +## 五、生产部署建议 | |
| 148 | + | |
| 149 | +### 5.1 使用 systemd 管理 Streamlit | |
| 150 | + | |
| 151 | +创建 `/etc/systemd/system/omishop-agent.service`: | |
| 152 | + | |
| 153 | +```ini | |
| 154 | +[Unit] | |
| 155 | +Description=OmniShopAgent Streamlit App | |
| 156 | +After=network.target docker.service | |
| 157 | + | |
| 158 | +[Service] | |
| 159 | +Type=simple | |
| 160 | +User=your_user | |
| 161 | +WorkingDirectory=/path/to/shop_agent | |
| 162 | +Environment="PATH=/path/to/shop_agent/venv/bin" | |
| 163 | +ExecStart=/path/to/shop_agent/venv/bin/streamlit run app.py --server.port=8501 --server.address=0.0.0.0 | |
| 164 | +Restart=on-failure | |
| 165 | + | |
| 166 | +[Install] | |
| 167 | +WantedBy=multi-user.target | |
| 168 | +``` | |
| 169 | + | |
| 170 | +```bash | |
| 171 | +sudo systemctl daemon-reload | |
| 172 | +sudo systemctl enable omishop-agent | |
| 173 | +sudo systemctl start omishop-agent | |
| 174 | +``` | |
| 175 | + | |
| 176 | +### 5.2 使用 Nginx 反向代理(可选) | |
| 177 | + | |
| 178 | +```nginx | |
| 179 | +server { | |
| 180 | + listen 80; | |
| 181 | + server_name your-domain.com; | |
| 182 | + location / { | |
| 183 | + proxy_pass http://127.0.0.1:8501; | |
| 184 | + proxy_http_version 1.1; | |
| 185 | + proxy_set_header Upgrade $http_upgrade; | |
| 186 | + proxy_set_header Connection "upgrade"; | |
| 187 | + proxy_set_header Host $host; | |
| 188 | + proxy_set_header X-Real-IP $remote_addr; | |
| 189 | + } | |
| 190 | +} | |
| 191 | +``` | |
| 192 | + | |
| 193 | +### 5.3 防火墙 | |
| 194 | + | |
| 195 | +```bash | |
| 196 | +sudo firewall-cmd --permanent --add-port=8501/tcp | |
| 197 | +sudo firewall-cmd --permanent --add-port=19530/tcp | |
| 198 | +sudo firewall-cmd --reload | |
| 199 | +``` | |
| 200 | + | |
| 201 | +## 六、常见问题 | |
| 202 | + | |
| 203 | +### Q: Python 3.12 编译失败? | |
| 204 | +A: 确保已安装 `openssl-devel`、`libffi-devel`,或直接使用 Miniconda。 | |
| 205 | + | |
| 206 | +### Q: Docker 权限不足? | |
| 207 | +A: 执行 `sudo usermod -aG docker $USER` 后重新登录。 | |
| 208 | + | |
| 209 | +### Q: Milvus 启动超时? | |
| 210 | +A: 首次启动需拉取镜像,可能较慢。可检查 `docker compose logs -f standalone`。 | |
| 211 | + | |
| 212 | +### Q: 图像搜索不可用? | |
| 213 | +A: 需单独启动 CLIP 服务:`./scripts/run_clip.sh`。 | |
| 214 | + | |
| 215 | +### Q: 健康检查? | |
| 216 | +A: 执行 `./scripts/check_services.sh` 查看各组件状态。 | ... | ... |
| 1 | +++ a/docs/LANGCHAIN_1.0_MIGRATION.md | |
| ... | ... | @@ -0,0 +1,77 @@ |
| 1 | +# LangChain 1.0 升级说明 | |
| 2 | + | |
| 3 | +## 一、升级概览 | |
| 4 | + | |
| 5 | +本项目已完成从 LangChain 0.3 到 LangChain 1.x 的升级,并同步升级 LangGraph 至 1.x。升级后兼容 Python 3.12。 | |
| 6 | + | |
| 7 | +## 二、依赖变更 | |
| 8 | + | |
| 9 | +| 包 | 升级前 | 升级后 | | |
| 10 | +|----|--------|--------| | |
| 11 | +| langchain | >=0.3.0 | >=1.0.0 | | |
| 12 | +| langchain-core | (间接依赖) | >=0.3.0 | | |
| 13 | +| langchain-openai | >=0.2.0 | >=0.2.0 | | |
| 14 | +| langgraph | >=0.2.74 | >=1.0.0 | | |
| 15 | +| langchain-community | >=0.4.0 | **已移除**(项目未使用) | | |
| 16 | + | |
| 17 | +## 三、代码改造说明 | |
| 18 | + | |
| 19 | +### 3.1 保持不变的部分 | |
| 20 | + | |
| 21 | +项目采用 **自定义 StateGraph** 架构(非 `create_react_agent`),以下导入在 LangChain/LangGraph 1.x 中保持兼容: | |
| 22 | + | |
| 23 | +```python | |
| 24 | +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage | |
| 25 | +from langchain_openai import ChatOpenAI | |
| 26 | +from langgraph.checkpoint.memory import MemorySaver | |
| 27 | +from langgraph.graph import END, START, StateGraph | |
| 28 | +from langgraph.graph.message import add_messages | |
| 29 | +from langgraph.prebuilt import ToolNode | |
| 30 | +from langchain_core.tools import tool | |
| 31 | +``` | |
| 32 | + | |
| 33 | +### 3.2 已适配的变更 | |
| 34 | + | |
| 35 | +1. **消息内容提取**:LangChain 1.0 引入 `content_blocks`,`content` 可能为字符串或 multimodal 列表。新增 `_extract_message_text()` 辅助函数,统一处理两种格式。 | |
| 36 | +2. **依赖精简**:移除未使用的 `langchain-community`,减少依赖冲突。 | |
| 37 | + | |
| 38 | +### 3.3 LangChain 1.0 主要变更(参考) | |
| 39 | + | |
| 40 | +- **包命名空间精简**:核心功能移至 `langchain-core`,`langchain` 主包聚焦 Agent 构建 | |
| 41 | +- **create_agent**:若未来迁移到 `langchain.agents.create_agent`,可参考 `docs/Skills实现方案-LangChain1.0.md` | |
| 42 | +- **langchain-classic**:Legacy chains、Retrievers 等已迁移至 `langchain-classic`,本项目未使用 | |
| 43 | + | |
| 44 | +## 四、环境要求 | |
| 45 | + | |
| 46 | +- **Python**:3.12+(**LangChain 1.x 要求 Python 3.10+**,不支持 3.9 及以下) | |
| 47 | +- 若系统默认 Python 版本过低,需使用虚拟环境: | |
| 48 | + | |
| 49 | +```bash | |
| 50 | +# 方式 1:使用 conda(推荐,项目根目录 scripts/setup_conda_env.sh) | |
| 51 | +conda create -n shop_agent python=3.12 | |
| 52 | +conda activate shop_agent | |
| 53 | +pip install -r requirements.txt | |
| 54 | + | |
| 55 | +# 方式 2:使用 pyenv | |
| 56 | +pyenv install 3.12 | |
| 57 | +pyenv local 3.12 | |
| 58 | +pip install -r requirements.txt | |
| 59 | + | |
| 60 | +# 方式 3:使用 venv(需系统已安装 python3.12) | |
| 61 | +python3.12 -m venv venv | |
| 62 | +source venv/bin/activate # Linux/Mac | |
| 63 | +pip install -r requirements.txt | |
| 64 | +``` | |
| 65 | + | |
| 66 | +## 五、验证 | |
| 67 | + | |
| 68 | +```bash | |
| 69 | +# 验证导入 | |
| 70 | +python -c " | |
| 71 | +from langchain_core.messages import HumanMessage | |
| 72 | +from langchain_openai import ChatOpenAI | |
| 73 | +from langgraph.graph import StateGraph | |
| 74 | +from langgraph.prebuilt import ToolNode | |
| 75 | +print('LangChain 1.x 依赖加载成功') | |
| 76 | +" | |
| 77 | +``` | ... | ... |
| 1 | +++ a/docs/Skills实现方案-LangChain1.0.md | |
| ... | ... | @@ -0,0 +1,318 @@ |
| 1 | +# Skills 渐进式展开实现方案(LangChain 1.0+) | |
| 2 | + | |
| 3 | +## 一、需求概述 | |
| 4 | + | |
| 5 | +用 **Skills** 替代零散的工具调用,实现**渐进式展开**(Progressive Disclosure): | |
| 6 | +Agent 在 system prompt 中只看到技能摘要,按需加载详细技能内容,减少 token 消耗、提升扩展性。 | |
| 7 | + | |
| 8 | +| 技能 | 英文标识 | 职责 | | |
| 9 | +|------|----------|------| | |
| 10 | +| 查找相关商品 | lookup_related | 基于文本/图片查找相似或相关商品 | | |
| 11 | +| 搜索商品 | search_products | 按自然语言描述搜索商品 | | |
| 12 | +| 检验商品 | check_product | 检验商品是否符合用户要求 | | |
| 13 | +| 结果包装 | result_packaging | 格式化、排序、筛选并呈现结果 | | |
| 14 | +| 售后相关 | after_sales | 退换货、物流、保修等售后问题 | | |
| 15 | + | |
| 16 | +--- | |
| 17 | + | |
| 18 | +## 二、LangChain 1.0 中的 Skills 实现方式 | |
| 19 | + | |
| 20 | +### 2.1 两种实现路线 | |
| 21 | + | |
| 22 | +| 方式 | 适用场景 | 依赖 | | |
| 23 | +|------|----------|------| | |
| 24 | +| **方式 A:create_agent + 自定义 Skill 中间件** | 购物导购等业务 Agent | `langchain>=1.0`、`langgraph>=1.0` | | |
| 25 | +| **方式 B:Deep Agents + SKILL.md** | 依赖文件系统、多技能目录 | `deepagents` | | |
| 26 | + | |
| 27 | +购物导购场景推荐**方式 A**,更易与现有 Milvus、CLIP 等服务集成。 | |
| 28 | + | |
| 29 | +### 2.2 核心思路:Progressive Disclosure | |
| 30 | + | |
| 31 | +``` | |
| 32 | +用户请求 → Agent 看轻量描述 → 判断需要的技能 → load_skill → 拿到完整说明 → 执行工具 → 回复 | |
| 33 | +``` | |
| 34 | + | |
| 35 | +- **启动时**:只注入技能名称 + 简短描述(1–2 句) | |
| 36 | +- **按需加载**:Agent 调用 `load_skill(skill_name)` 获取完整指令 | |
| 37 | +- **执行**:按技能说明调用对应工具 | |
| 38 | + | |
| 39 | +--- | |
| 40 | + | |
| 41 | +## 三、实现架构 | |
| 42 | + | |
| 43 | +### 3.1 技能定义结构 | |
| 44 | + | |
| 45 | +```python | |
| 46 | +from typing import TypedDict | |
| 47 | + | |
| 48 | +class Skill(TypedDict): | |
| 49 | + """可渐进式展开的技能""" | |
| 50 | + name: str # 唯一标识 | |
| 51 | + description: str # 1-2 句,展示在 system prompt | |
| 52 | + content: str # 完整指令,仅在 load_skill 时返回 | |
| 53 | +``` | |
| 54 | + | |
| 55 | +### 3.2 五个技能定义示例 | |
| 56 | + | |
| 57 | +```python | |
| 58 | +SKILLS: list[Skill] = [ | |
| 59 | + { | |
| 60 | + "name": "lookup_related", | |
| 61 | + "description": "查找与某商品相关的其他商品,支持以图搜图、文本相似、同品类推荐。", | |
| 62 | + "content": """# 查找相关商品 | |
| 63 | + | |
| 64 | +## 适用场景 | |
| 65 | +- 用户上传图片要求「找类似的」 | |
| 66 | +- 用户说「和这个差不多」「搭配的裤子」 | |
| 67 | +- 用户已有一件商品,想找相关款 | |
| 68 | + | |
| 69 | +## 操作步骤 | |
| 70 | +1. **有图片**:先调用 `analyze_image_style` 理解风格,再调用 `search_by_image` 或 `search_products` | |
| 71 | +2. **无图片**:用 `search_products` 描述品类+风格+颜色 | |
| 72 | +3. 可结合上下文中的商品 ID、品类做同品类推荐 | |
| 73 | + | |
| 74 | +## 可用工具 | |
| 75 | +- `search_by_image(image_path, limit)`:以图搜图 | |
| 76 | +- `search_products(query, limit)`:文本搜索 | |
| 77 | +- `analyze_image_style(image_path)`:分析图片风格""", | |
| 78 | + }, | |
| 79 | + { | |
| 80 | + "name": "search_products", | |
| 81 | + "description": "按自然语言描述搜索商品,如「红色连衣裙」「运动鞋」等。", | |
| 82 | + "content": """# 搜索商品 | |
| 83 | + | |
| 84 | +## 适用场景 | |
| 85 | +- 用户用文字描述想要什么 | |
| 86 | +- 如「冬天穿的外套」「正装衬衫」「跑步鞋」 | |
| 87 | + | |
| 88 | +## 操作步骤 | |
| 89 | +1. 将用户描述整理成结构化 query(品类+颜色+风格+场景) | |
| 90 | +2. 调用 `search_products(query, limit)`,limit 默认 5–10 | |
| 91 | +3. 如有图片,可先 `analyze_image_style` 提炼关键词再搜索 | |
| 92 | + | |
| 93 | +## 可用工具 | |
| 94 | +- `search_products(query, limit)`:自然语言搜索""", | |
| 95 | + }, | |
| 96 | + { | |
| 97 | + "name": "check_product", | |
| 98 | + "description": "检验商品是否符合用户要求,如尺寸、材质、场合、价格区间等。", | |
| 99 | + "content": """# 检验商品是否符合要求 | |
| 100 | + | |
| 101 | +## 适用场景 | |
| 102 | +- 用户问「这款适合我吗」「有没有 XX 材质的」 | |
| 103 | +- 用户提出约束:尺寸、价格、场合、材质 | |
| 104 | + | |
| 105 | +## 操作步骤 | |
| 106 | +1. 从对话中提取约束条件(尺寸、材质、场合、价格等) | |
| 107 | +2. 对已召回商品做筛选或二次搜索 | |
| 108 | +3. 调用 `search_products` 时在 query 中带上约束 | |
| 109 | +4. 回复时明确说明哪些符合、哪些不符合 | |
| 110 | + | |
| 111 | +## 注意 | |
| 112 | +- 无专门工具时,用 search_products 的 query 表达约束 | |
| 113 | +- 可结合商品元数据(baseColour, season, usage 等)做简单筛选""", | |
| 114 | + }, | |
| 115 | + { | |
| 116 | + "name": "result_packaging", | |
| 117 | + "description": "对搜索结果进行格式化、排序、筛选并呈现给用户。", | |
| 118 | + "content": """# 结果包装 | |
| 119 | + | |
| 120 | +## 适用场景 | |
| 121 | +- 工具返回多条商品后需要整理呈现 | |
| 122 | +- 用户要求「按价格排序」「只要前 3 个」 | |
| 123 | + | |
| 124 | +## 操作步骤 | |
| 125 | +1. 按相关性/相似度排序 | |
| 126 | +2. 限制展示数量(通常 3–5 个) | |
| 127 | +3. **必须使用以下格式**呈现每个商品: | |
| 128 | + | |
| 129 | +``` | |
| 130 | +1. [Product Name] | |
| 131 | + ID: [Product ID Number] | |
| 132 | + Category: [Category] | |
| 133 | + Color: [Color] | |
| 134 | + Gender: [Gender] | |
| 135 | + Season: [Season] | |
| 136 | + Usage: [Usage] | |
| 137 | + Relevance: [XX%] | |
| 138 | +``` | |
| 139 | + | |
| 140 | +4. ID 字段不可省略,用于前端展示图片""", | |
| 141 | + }, | |
| 142 | + { | |
| 143 | + "name": "after_sales", | |
| 144 | + "description": "处理退换货、物流、保修、尺码建议等售后问题。", | |
| 145 | + "content": """# 售后相关 | |
| 146 | + | |
| 147 | +## 适用场景 | |
| 148 | +- 退换货政策、运费、签收时间 | |
| 149 | +- 尺码建议、洗涤说明 | |
| 150 | +- 保修、客服联系方式 | |
| 151 | + | |
| 152 | +## 操作步骤 | |
| 153 | +1. 此类问题无需调用商品搜索工具 | |
| 154 | +2. 按平台统一售后政策回答 | |
| 155 | +3. 涉及具体商品时,可结合商品 ID 查询详情后再回答 | |
| 156 | +4. 复杂问题引导用户联系客服""", | |
| 157 | + }, | |
| 158 | +] | |
| 159 | +``` | |
| 160 | + | |
| 161 | +--- | |
| 162 | + | |
| 163 | +## 四、核心代码实现 | |
| 164 | + | |
| 165 | +### 4.1 load_skill 工具 | |
| 166 | + | |
| 167 | +```python | |
| 168 | +from langchain.tools import tool | |
| 169 | + | |
| 170 | +@tool | |
| 171 | +def load_skill(skill_name: str) -> str: | |
| 172 | + """加载技能的完整内容到 Agent 上下文中。 | |
| 173 | + | |
| 174 | + 当需要处理特定类型请求时,调用此工具获取该技能的详细说明和操作步骤。 | |
| 175 | + | |
| 176 | + Args: | |
| 177 | + skill_name: 技能名称,可选值:lookup_related, search_products, check_product, result_packaging, after_sales | |
| 178 | + """ | |
| 179 | + for skill in SKILLS: | |
| 180 | + if skill["name"] == skill_name: | |
| 181 | + return f"Loaded skill: {skill_name}\n\n{skill['content']}" | |
| 182 | + | |
| 183 | + available = ", ".join(s["name"] for s in SKILLS) | |
| 184 | + return f"Skill '{skill_name}' not found. Available: {available}" | |
| 185 | +``` | |
| 186 | + | |
| 187 | +### 4.2 SkillMiddleware(注入技能描述) | |
| 188 | + | |
| 189 | +```python | |
| 190 | +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse | |
| 191 | +from langchain.messages import SystemMessage | |
| 192 | +from typing import Callable | |
| 193 | + | |
| 194 | +class ShoppingSkillMiddleware(AgentMiddleware): | |
| 195 | + """将技能描述注入 system prompt,使 Agent 能发现并按需加载技能""" | |
| 196 | + | |
| 197 | + tools = [load_skill] | |
| 198 | + | |
| 199 | + def __init__(self): | |
| 200 | + skills_list = [] | |
| 201 | + for skill in SKILLS: | |
| 202 | + skills_list.append(f"- **{skill['name']}**: {skill['description']}") | |
| 203 | + self.skills_prompt = "\n".join(skills_list) | |
| 204 | + | |
| 205 | + def wrap_model_call( | |
| 206 | + self, | |
| 207 | + request: ModelRequest, | |
| 208 | + handler: Callable[[ModelRequest], ModelResponse], | |
| 209 | + ) -> ModelResponse: | |
| 210 | + skills_addendum = ( | |
| 211 | + f"\n\n## 可用技能(按需加载)\n\n{self.skills_prompt}\n\n" | |
| 212 | + "在需要详细说明时,使用 load_skill 工具加载对应技能。" | |
| 213 | + ) | |
| 214 | + new_content = list(request.system_message.content_blocks) + [ | |
| 215 | + {"type": "text", "text": skills_addendum} | |
| 216 | + ] | |
| 217 | + new_system_message = SystemMessage(content=new_content) | |
| 218 | + modified_request = request.override(system_message=new_system_message) | |
| 219 | + return handler(modified_request) | |
| 220 | +``` | |
| 221 | + | |
| 222 | +### 4.3 创建带 Skills 的 Agent | |
| 223 | + | |
| 224 | +```python | |
| 225 | +from langchain.agents import create_agent | |
| 226 | +from langgraph.checkpoint.memory import MemorySaver | |
| 227 | + | |
| 228 | +# 基础工具(搜索、以图搜图、风格分析等) | |
| 229 | +from app.tools.search_tools import search_products, search_by_image, analyze_image_style | |
| 230 | + | |
| 231 | +agent = create_agent( | |
| 232 | + model="gpt-4o-mini", | |
| 233 | + tools=[ | |
| 234 | + load_skill, # 技能加载 | |
| 235 | + search_products, | |
| 236 | + search_by_image, | |
| 237 | + analyze_image_style, | |
| 238 | + ], | |
| 239 | + system_prompt="""你是智能时尚购物助手。根据用户需求,先判断使用哪个技能,必要时用 load_skill 加载技能详情。 | |
| 240 | + | |
| 241 | +处理商品结果时,必须遵守 result_packaging 技能中的格式要求。""", | |
| 242 | + middleware=[ShoppingSkillMiddleware()], | |
| 243 | + checkpointer=MemorySaver(), | |
| 244 | +) | |
| 245 | +``` | |
| 246 | + | |
| 247 | +--- | |
| 248 | + | |
| 249 | +## 五、与工具的关系 | |
| 250 | + | |
| 251 | +| 能力 | 技能 | 工具 | | |
| 252 | +|------|------|------| | |
| 253 | +| 查找相关 | lookup_related | search_by_image, search_products, analyze_image_style | | |
| 254 | +| 搜索商品 | search_products | search_products | | |
| 255 | +| 检验商品 | check_product | search_products(用 query 表达约束) | | |
| 256 | +| 结果包装 | result_packaging | 无(纯 prompt 约束) | | |
| 257 | +| 售后 | after_sales | 无(或对接客服 API) | | |
| 258 | + | |
| 259 | +- **技能**:提供「何时用、怎么用」的说明,支持渐进式加载。 | |
| 260 | +- **工具**:实际执行搜索、分析等操作。 | |
| 261 | + | |
| 262 | +--- | |
| 263 | + | |
| 264 | +## 六、可选:技能约束(进阶) | |
| 265 | + | |
| 266 | +若希望「先加载技能再使用工具」,可结合 `ToolRuntime` 和 state 做约束: | |
| 267 | + | |
| 268 | +```python | |
| 269 | +from langchain.tools import tool, ToolRuntime | |
| 270 | +from langgraph.types import Command | |
| 271 | +from langchain.messages import ToolMessage | |
| 272 | +from typing_extensions import NotRequired | |
| 273 | + | |
| 274 | +class CustomState(AgentState): | |
| 275 | + skills_loaded: NotRequired[list[str]] | |
| 276 | + | |
| 277 | +@tool | |
| 278 | +def load_skill(skill_name: str, runtime: ToolRuntime) -> Command: | |
| 279 | + """...""" | |
| 280 | + for skill in SKILLS: | |
| 281 | + if skill["name"] == skill_name: | |
| 282 | + content = f"Loaded skill: {skill_name}\n\n{skill['content']}" | |
| 283 | + return Command(update={ | |
| 284 | + "messages": [ToolMessage(content=content, tool_call_id=runtime.tool_call_id)], | |
| 285 | + "skills_loaded": [skill_name], | |
| 286 | + }) | |
| 287 | + # ... | |
| 288 | + | |
| 289 | +# 在 check_product 等工具中检查 skills_loaded | |
| 290 | +``` | |
| 291 | + | |
| 292 | +--- | |
| 293 | + | |
| 294 | +## 七、依赖与版本 | |
| 295 | + | |
| 296 | +```text | |
| 297 | +# requirements.txt | |
| 298 | +langchain>=1.0.0 | |
| 299 | +langchain-openai>=0.2.0 | |
| 300 | +langchain-core>=0.3.0 | |
| 301 | +langgraph>=1.0.0 | |
| 302 | +``` | |
| 303 | + | |
| 304 | +- Python 3.10+ | |
| 305 | +- 若使用 Deep Agents 的 SKILL.md,需额外安装 `deepagents` | |
| 306 | + | |
| 307 | +--- | |
| 308 | + | |
| 309 | +## 八、总结 | |
| 310 | + | |
| 311 | +| 项目 | 说明 | | |
| 312 | +|------|------| | |
| 313 | +| **效果** | 系统 prompt 只放简短技能描述,按需加载完整内容,减少 token、便于扩展 | | |
| 314 | +| **流程** | 轻量描述 → load_skill → 完整说明 → 调用工具 → 回复 | | |
| 315 | +| **实现** | `SkillMiddleware` + `load_skill` + `create_agent` | | |
| 316 | +| **技能** | lookup_related, search_products, check_product, result_packaging, after_sales | | |
| 317 | + | |
| 318 | +完整示例可参考官方教程:[Build a SQL assistant with on-demand skills](https://docs.langchain.com/oss/python/langchain/multi-agent/skills-sql-assistant)。 | ... | ... |
| 1 | +++ a/requirements.txt | |
| ... | ... | @@ -0,0 +1,40 @@ |
| 1 | +# Core Framework | |
| 2 | +fastapi>=0.109.0 | |
| 3 | +uvicorn[standard]>=0.27.0 | |
| 4 | +pydantic>=2.6.0 | |
| 5 | +pydantic-settings>=2.1.0 | |
| 6 | +streamlit>=1.50.0 | |
| 7 | + | |
| 8 | +# LLM & LangChain (Python 3.12, LangChain 1.x) | |
| 9 | +langchain>=1.0.0 | |
| 10 | +langchain-core>=0.3.0 | |
| 11 | +langchain-openai>=0.2.0 | |
| 12 | +langgraph>=1.0.0 | |
| 13 | +openai>=1.12.0 | |
| 14 | + | |
| 15 | +# Embeddings & Vision | |
| 16 | +clip-client>=3.5.0 # CLIP-as-Service client | |
| 17 | +Pillow>=10.2.0 # Image processing | |
| 18 | + | |
| 19 | +# Vector Database | |
| 20 | +pymilvus>=2.3.6 | |
| 21 | + | |
| 22 | +# Databases | |
| 23 | +pymongo>=4.6.1 | |
| 24 | + | |
| 25 | +# Utilities | |
| 26 | +python-dotenv>=1.0.1 | |
| 27 | +python-multipart>=0.0.9 | |
| 28 | +aiofiles>=23.2.1 | |
| 29 | +requests>=2.31.0 | |
| 30 | + | |
| 31 | +# Data Processing | |
| 32 | +pandas>=2.2.3 | |
| 33 | +numpy>=1.26.4 | |
| 34 | +tqdm>=4.66.1 | |
| 35 | + | |
| 36 | +# Development & Testing | |
| 37 | +pytest>=8.0.0 | |
| 38 | +pytest-asyncio>=0.23.4 | |
| 39 | +httpx>=0.26.0 | |
| 40 | +black>=24.1.1 | ... | ... |
| 1 | +++ a/scripts/check_services.sh | |
| ... | ... | @@ -0,0 +1,93 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - 服务健康检查脚本 | |
| 4 | +# 检查 Milvus、CLIP、Streamlit 等依赖服务状态 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | +RED='\033[0;31m' | |
| 11 | +GREEN='\033[0;32m' | |
| 12 | +YELLOW='\033[1;33m' | |
| 13 | +NC='\033[0m' | |
| 14 | + | |
| 15 | +echo "==========================================" | |
| 16 | +echo "OmniShopAgent 服务健康检查" | |
| 17 | +echo "==========================================" | |
| 18 | + | |
| 19 | +# 1. Python 环境 | |
| 20 | +echo -n "[Python] " | |
| 21 | +if command -v python3 &>/dev/null; then | |
| 22 | + VER=$(python3 -c 'import sys; v=sys.version_info; print(f"{v.major}.{v.minor}.{v.micro}")' 2>/dev/null) | |
| 23 | + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then | |
| 24 | + echo -e "${GREEN}OK${NC} $VER" | |
| 25 | + else | |
| 26 | + echo -e "${YELLOW}WARN${NC} $VER (建议 3.12+)" | |
| 27 | + fi | |
| 28 | +else | |
| 29 | + echo -e "${RED}FAIL${NC} 未找到" | |
| 30 | +fi | |
| 31 | + | |
| 32 | +# 2. 虚拟环境 | |
| 33 | +echo -n "[Virtualenv] " | |
| 34 | +if [ -d "$PROJECT_ROOT/venv" ]; then | |
| 35 | + echo -e "${GREEN}OK${NC} $PROJECT_ROOT/venv" | |
| 36 | +else | |
| 37 | + echo -e "${YELLOW}WARN${NC} 未找到 venv" | |
| 38 | +fi | |
| 39 | + | |
| 40 | +# 3. .env 配置 | |
| 41 | +echo -n "[.env] " | |
| 42 | +if [ -f "$PROJECT_ROOT/.env" ]; then | |
| 43 | + if grep -q "OPENAI_API_KEY=sk-" "$PROJECT_ROOT/.env" 2>/dev/null; then | |
| 44 | + echo -e "${GREEN}OK${NC} 已配置" | |
| 45 | + else | |
| 46 | + echo -e "${YELLOW}WARN${NC} 请配置 OPENAI_API_KEY" | |
| 47 | + fi | |
| 48 | +else | |
| 49 | + echo -e "${RED}FAIL${NC} 未找到" | |
| 50 | +fi | |
| 51 | + | |
| 52 | +# 4. Milvus | |
| 53 | +echo -n "[Milvus] " | |
| 54 | +if command -v docker &>/dev/null; then | |
| 55 | + if docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then | |
| 56 | + if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then | |
| 57 | + echo -e "${GREEN}OK${NC} localhost:19530" | |
| 58 | + else | |
| 59 | + echo -e "${YELLOW}WARN${NC} 容器运行中,健康检查未响应" | |
| 60 | + fi | |
| 61 | + else | |
| 62 | + echo -e "${YELLOW}WARN${NC} 未运行 (docker compose up -d)" | |
| 63 | + fi | |
| 64 | +else | |
| 65 | + echo -e "${YELLOW}SKIP${NC} Docker 未安装" | |
| 66 | +fi | |
| 67 | + | |
| 68 | +# 5. CLIP 服务(可选) | |
| 69 | +echo -n "[CLIP] " | |
| 70 | +if timeout 2 bash -c 'echo >/dev/tcp/localhost/51000' 2>/dev/null; then | |
| 71 | + echo -e "${GREEN}OK${NC} localhost:51000" | |
| 72 | +else | |
| 73 | + echo -e "${YELLOW}WARN${NC} 未运行 (图像搜索需启动: python -m clip_server launch)" | |
| 74 | +fi | |
| 75 | + | |
| 76 | +# 6. 数据目录 | |
| 77 | +echo -n "[数据] " | |
| 78 | +if [ -d "$PROJECT_ROOT/data/images" ] && [ -f "$PROJECT_ROOT/data/styles.csv" ]; then | |
| 79 | + IMG_COUNT=$(find "$PROJECT_ROOT/data/images" -name "*.jpg" 2>/dev/null | wc -l) | |
| 80 | + echo -e "${GREEN}OK${NC} $IMG_COUNT 张图片" | |
| 81 | +else | |
| 82 | + echo -e "${YELLOW}WARN${NC} 未找到 data/images 或 data/styles.csv (运行 download_dataset.py)" | |
| 83 | +fi | |
| 84 | + | |
| 85 | +# 7. Streamlit | |
| 86 | +echo -n "[Streamlit] " | |
| 87 | +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then | |
| 88 | + echo -e "${GREEN}OK${NC} 运行中" | |
| 89 | +else | |
| 90 | + echo -e "${YELLOW}WARN${NC} 未运行 (./scripts/start.sh)" | |
| 91 | +fi | |
| 92 | + | |
| 93 | +echo "==========================================" | ... | ... |
| 1 | +++ a/scripts/download_dataset.py | |
| ... | ... | @@ -0,0 +1,95 @@ |
| 1 | +""" | |
| 2 | +Script to download the Fashion Product Images Dataset from Kaggle | |
| 3 | + | |
| 4 | +Requirements: | |
| 5 | +1. Install Kaggle CLI: pip install kaggle | |
| 6 | +2. Setup Kaggle API credentials: | |
| 7 | + - Go to https://www.kaggle.com/settings/account | |
| 8 | + - Click "Create New API Token" | |
| 9 | + - Save kaggle.json to ~/.kaggle/kaggle.json | |
| 10 | + - chmod 600 ~/.kaggle/kaggle.json | |
| 11 | + | |
| 12 | +Usage: | |
| 13 | + python scripts/download_dataset.py | |
| 14 | +""" | |
| 15 | + | |
| 16 | +import subprocess | |
| 17 | +import zipfile | |
| 18 | +from pathlib import Path | |
| 19 | + | |
| 20 | + | |
| 21 | +def download_dataset(): | |
| 22 | + """Download and extract the Fashion Product Images Dataset""" | |
| 23 | + | |
| 24 | + # Get project root | |
| 25 | + project_root = Path(__file__).parent.parent | |
| 26 | + raw_data_path = project_root / "data" / "raw" | |
| 27 | + | |
| 28 | + # Check if data already exists | |
| 29 | + if (raw_data_path / "styles.csv").exists(): | |
| 30 | + print("Dataset already exists in data/raw/") | |
| 31 | + response = input("Do you want to re-download? (y/n): ") | |
| 32 | + if response.lower() != "y": | |
| 33 | + print("Skipping download.") | |
| 34 | + return | |
| 35 | + | |
| 36 | + # Check Kaggle credentials | |
| 37 | + kaggle_json = Path.home() / ".kaggle" / "kaggle.json" | |
| 38 | + if not kaggle_json.exists(): | |
| 39 | + print(" Kaggle API credentials not found!") | |
| 40 | + return | |
| 41 | + | |
| 42 | + print("Downloading dataset from Kaggle...") | |
| 43 | + | |
| 44 | + try: | |
| 45 | + # Download using Kaggle API | |
| 46 | + subprocess.run( | |
| 47 | + [ | |
| 48 | + "kaggle", | |
| 49 | + "datasets", | |
| 50 | + "download", | |
| 51 | + "-d", | |
| 52 | + "paramaggarwal/fashion-product-images-dataset", | |
| 53 | + "-p", | |
| 54 | + str(raw_data_path), | |
| 55 | + ], | |
| 56 | + check=True, | |
| 57 | + ) | |
| 58 | + | |
| 59 | + print("Download complete!") | |
| 60 | + | |
| 61 | + # Extract zip file | |
| 62 | + zip_path = raw_data_path / "fashion-product-images-dataset.zip" | |
| 63 | + if zip_path.exists(): | |
| 64 | + print("Extracting files...") | |
| 65 | + with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| 66 | + zip_ref.extractall(raw_data_path) | |
| 67 | + | |
| 68 | + print("Extraction complete!") | |
| 69 | + | |
| 70 | + # Clean up zip file | |
| 71 | + zip_path.unlink() | |
| 72 | + print("Cleaned up zip file") | |
| 73 | + | |
| 74 | + # Verify files | |
| 75 | + styles_csv = raw_data_path / "styles.csv" | |
| 76 | + images_dir = raw_data_path / "images" | |
| 77 | + | |
| 78 | + if styles_csv.exists() and images_dir.exists(): | |
| 79 | + print("\Dataset ready!") | |
| 80 | + | |
| 81 | + # Count images | |
| 82 | + image_count = len(list(images_dir.glob("*.jpg"))) | |
| 83 | + print(f"- Total images: {image_count:,}") | |
| 84 | + else: | |
| 85 | + print("Warning: Expected files not found") | |
| 86 | + | |
| 87 | + except subprocess.CalledProcessError: | |
| 88 | + print("Download failed!") | |
| 89 | + | |
| 90 | + except Exception as e: | |
| 91 | + print(f"Error: {e}") | |
| 92 | + | |
| 93 | + | |
| 94 | +if __name__ == "__main__": | |
| 95 | + download_dataset() | ... | ... |
| 1 | +++ a/scripts/index_data.py | |
| ... | ... | @@ -0,0 +1,467 @@ |
| 1 | +""" | |
| 2 | +Data Indexing Script | |
| 3 | +Generates embeddings for products and stores them in Milvus | |
| 4 | +""" | |
| 5 | + | |
| 6 | +import csv | |
| 7 | +import logging | |
| 8 | +import os | |
| 9 | +import sys | |
| 10 | +from pathlib import Path | |
| 11 | +from typing import Any, Dict, Optional | |
| 12 | + | |
| 13 | +from tqdm import tqdm | |
| 14 | + | |
| 15 | +# Add parent directory to path | |
| 16 | +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| 17 | + | |
| 18 | +# Import config and settings first | |
| 19 | +# Direct imports from files to avoid __init__.py circular issues | |
| 20 | +import importlib.util | |
| 21 | + | |
| 22 | +from app.config import get_absolute_path, settings | |
| 23 | + | |
| 24 | + | |
| 25 | +def load_service_module(module_name, file_name): | |
| 26 | + """Load a service module directly from file""" | |
| 27 | + spec = importlib.util.spec_from_file_location( | |
| 28 | + module_name, | |
| 29 | + os.path.join( | |
| 30 | + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | |
| 31 | + f"app/services/{file_name}", | |
| 32 | + ), | |
| 33 | + ) | |
| 34 | + module = importlib.util.module_from_spec(spec) | |
| 35 | + spec.loader.exec_module(module) | |
| 36 | + return module | |
| 37 | + | |
| 38 | + | |
| 39 | +embedding_module = load_service_module("embedding_service", "embedding_service.py") | |
| 40 | +milvus_module = load_service_module("milvus_service", "milvus_service.py") | |
| 41 | + | |
| 42 | +EmbeddingService = embedding_module.EmbeddingService | |
| 43 | +MilvusService = milvus_module.MilvusService | |
| 44 | + | |
| 45 | +# Configure logging | |
| 46 | +logging.basicConfig( | |
| 47 | + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| 48 | +) | |
| 49 | +logger = logging.getLogger(__name__) | |
| 50 | + | |
| 51 | + | |
| 52 | +class DataIndexer: | |
| 53 | + """Index product data by generating and storing embeddings""" | |
| 54 | + | |
| 55 | + def __init__(self): | |
| 56 | + """Initialize services""" | |
| 57 | + self.embedding_service = EmbeddingService() | |
| 58 | + self.milvus_service = MilvusService() | |
| 59 | + | |
| 60 | + self.image_dir = Path(get_absolute_path(settings.image_data_path)) | |
| 61 | + self.styles_csv = get_absolute_path("./data/styles.csv") | |
| 62 | + self.images_csv = get_absolute_path("./data/images.csv") | |
| 63 | + | |
| 64 | + # Load product data from CSV | |
| 65 | + self.products = self._load_products_from_csv() | |
| 66 | + | |
| 67 | + def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: | |
| 68 | + """Load products from CSV files""" | |
| 69 | + products = {} | |
| 70 | + | |
| 71 | + # Load images mapping | |
| 72 | + images_dict = {} | |
| 73 | + with open(self.images_csv, "r", encoding="utf-8") as f: | |
| 74 | + reader = csv.DictReader(f) | |
| 75 | + for row in reader: | |
| 76 | + product_id = int(row["filename"].split(".")[0]) | |
| 77 | + images_dict[product_id] = row["link"] | |
| 78 | + | |
| 79 | + # Load styles/products | |
| 80 | + with open(self.styles_csv, "r", encoding="utf-8") as f: | |
| 81 | + reader = csv.DictReader(f) | |
| 82 | + for row in reader: | |
| 83 | + try: | |
| 84 | + product_id = int(row["id"]) | |
| 85 | + products[product_id] = { | |
| 86 | + "id": product_id, | |
| 87 | + "gender": row.get("gender", ""), | |
| 88 | + "masterCategory": row.get("masterCategory", ""), | |
| 89 | + "subCategory": row.get("subCategory", ""), | |
| 90 | + "articleType": row.get("articleType", ""), | |
| 91 | + "baseColour": row.get("baseColour", ""), | |
| 92 | + "season": row.get("season", ""), | |
| 93 | + "year": int(row["year"]) if row.get("year") else 0, | |
| 94 | + "usage": row.get("usage", ""), | |
| 95 | + "productDisplayName": row.get("productDisplayName", ""), | |
| 96 | + "imageUrl": images_dict.get(product_id, ""), | |
| 97 | + "imagePath": f"{product_id}.jpg", | |
| 98 | + } | |
| 99 | + except (ValueError, KeyError) as e: | |
| 100 | + logger.warning(f"Error loading product {row.get('id')}: {e}") | |
| 101 | + continue | |
| 102 | + | |
| 103 | + logger.info(f"Loaded {len(products)} products from CSV") | |
| 104 | + return products | |
| 105 | + | |
| 106 | + def setup(self) -> None: | |
| 107 | + """Setup connections and collections""" | |
| 108 | + logger.info("Setting up services...") | |
| 109 | + | |
| 110 | + # Connect to CLIP server | |
| 111 | + self.embedding_service.connect_clip() | |
| 112 | + logger.info("✓ CLIP server connected") | |
| 113 | + | |
| 114 | + # Connect to Milvus | |
| 115 | + self.milvus_service.connect() | |
| 116 | + logger.info("✓ Milvus connected") | |
| 117 | + | |
| 118 | + # Create Milvus collections | |
| 119 | + self.milvus_service.create_text_collection(recreate=False) | |
| 120 | + self.milvus_service.create_image_collection(recreate=False) | |
| 121 | + logger.info("✓ Milvus collections ready") | |
| 122 | + | |
| 123 | + def teardown(self) -> None: | |
| 124 | + """Close all connections""" | |
| 125 | + logger.info("Closing connections...") | |
| 126 | + self.embedding_service.disconnect_clip() | |
| 127 | + self.milvus_service.disconnect() | |
| 128 | + logger.info("✓ All connections closed") | |
| 129 | + | |
| 130 | + def index_text_embeddings( | |
| 131 | + self, batch_size: int = 100, skip: int = 0, limit: Optional[int] = None | |
| 132 | + ) -> Dict[str, int]: | |
| 133 | + """Generate and store text embeddings for products | |
| 134 | + | |
| 135 | + Args: | |
| 136 | + batch_size: Number of products to process at once | |
| 137 | + skip: Number of products to skip | |
| 138 | + limit: Maximum number of products to process (None for all) | |
| 139 | + | |
| 140 | + Returns: | |
| 141 | + Dictionary with indexing statistics | |
| 142 | + """ | |
| 143 | + logger.info("Starting text embedding indexing...") | |
| 144 | + | |
| 145 | + # Get products list | |
| 146 | + product_ids = list(self.products.keys())[skip:] | |
| 147 | + if limit: | |
| 148 | + product_ids = product_ids[:limit] | |
| 149 | + | |
| 150 | + total_products = len(product_ids) | |
| 151 | + processed = 0 | |
| 152 | + inserted = 0 | |
| 153 | + errors = 0 | |
| 154 | + | |
| 155 | + with tqdm(total=total_products, desc="Indexing text embeddings") as pbar: | |
| 156 | + while processed < total_products: | |
| 157 | + # Get batch of products | |
| 158 | + current_batch_size = min(batch_size, total_products - processed) | |
| 159 | + batch_ids = product_ids[processed : processed + current_batch_size] | |
| 160 | + products = [self.products[pid] for pid in batch_ids] | |
| 161 | + | |
| 162 | + if not products: | |
| 163 | + break | |
| 164 | + | |
| 165 | + try: | |
| 166 | + # Prepare texts for embedding | |
| 167 | + texts = [] | |
| 168 | + text_mappings = [] | |
| 169 | + | |
| 170 | + for product in products: | |
| 171 | + # Create text representation of product | |
| 172 | + text = self._create_product_text(product) | |
| 173 | + texts.append(text) | |
| 174 | + text_mappings.append( | |
| 175 | + {"product_id": product["id"], "text": text} | |
| 176 | + ) | |
| 177 | + | |
| 178 | + # Generate embeddings | |
| 179 | + embeddings = self.embedding_service.get_text_embeddings_batch( | |
| 180 | + texts, batch_size=50 # OpenAI batch size | |
| 181 | + ) | |
| 182 | + | |
| 183 | + # Prepare data for Milvus (with metadata) | |
| 184 | + milvus_data = [] | |
| 185 | + for idx, (mapping, embedding) in enumerate( | |
| 186 | + zip(text_mappings, embeddings) | |
| 187 | + ): | |
| 188 | + product_id = mapping["product_id"] | |
| 189 | + product = self.products[product_id] | |
| 190 | + | |
| 191 | + milvus_data.append( | |
| 192 | + { | |
| 193 | + "id": product_id, | |
| 194 | + "text": mapping["text"][ | |
| 195 | + :2000 | |
| 196 | + ], # Truncate to max length | |
| 197 | + "embedding": embedding, | |
| 198 | + # Product metadata | |
| 199 | + "productDisplayName": product["productDisplayName"][ | |
| 200 | + :500 | |
| 201 | + ], | |
| 202 | + "gender": product["gender"][:50], | |
| 203 | + "masterCategory": product["masterCategory"][:100], | |
| 204 | + "subCategory": product["subCategory"][:100], | |
| 205 | + "articleType": product["articleType"][:100], | |
| 206 | + "baseColour": product["baseColour"][:50], | |
| 207 | + "season": product["season"][:50], | |
| 208 | + "usage": product["usage"][:50], | |
| 209 | + "year": product["year"], | |
| 210 | + "imageUrl": product["imageUrl"], | |
| 211 | + "imagePath": product["imagePath"], | |
| 212 | + } | |
| 213 | + ) | |
| 214 | + | |
| 215 | + # Insert into Milvus | |
| 216 | + count = self.milvus_service.insert_text_embeddings(milvus_data) | |
| 217 | + inserted += count | |
| 218 | + | |
| 219 | + except Exception as e: | |
| 220 | + logger.error( | |
| 221 | + f"Error processing text batch at offset {processed}: {e}" | |
| 222 | + ) | |
| 223 | + errors += len(products) | |
| 224 | + | |
| 225 | + processed += len(products) | |
| 226 | + pbar.update(len(products)) | |
| 227 | + | |
| 228 | + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} | |
| 229 | + | |
| 230 | + logger.info(f"Text embedding indexing completed: {stats}") | |
| 231 | + return stats | |
| 232 | + | |
| 233 | + def index_image_embeddings( | |
| 234 | + self, batch_size: int = 32, skip: int = 0, limit: Optional[int] = None | |
| 235 | + ) -> Dict[str, int]: | |
| 236 | + """Generate and store image embeddings for products | |
| 237 | + | |
| 238 | + Args: | |
| 239 | + batch_size: Number of images to process at once | |
| 240 | + skip: Number of products to skip | |
| 241 | + limit: Maximum number of products to process (None for all) | |
| 242 | + | |
| 243 | + Returns: | |
| 244 | + Dictionary with indexing statistics | |
| 245 | + """ | |
| 246 | + logger.info("Starting image embedding indexing...") | |
| 247 | + | |
| 248 | + # Get products list | |
| 249 | + product_ids = list(self.products.keys())[skip:] | |
| 250 | + if limit: | |
| 251 | + product_ids = product_ids[:limit] | |
| 252 | + | |
| 253 | + total_products = len(product_ids) | |
| 254 | + processed = 0 | |
| 255 | + inserted = 0 | |
| 256 | + errors = 0 | |
| 257 | + | |
| 258 | + with tqdm(total=total_products, desc="Indexing image embeddings") as pbar: | |
| 259 | + while processed < total_products: | |
| 260 | + # Get batch of products | |
| 261 | + current_batch_size = min(batch_size, total_products - processed) | |
| 262 | + batch_ids = product_ids[processed : processed + current_batch_size] | |
| 263 | + products = [self.products[pid] for pid in batch_ids] | |
| 264 | + | |
| 265 | + if not products: | |
| 266 | + break | |
| 267 | + | |
| 268 | + try: | |
| 269 | + # Prepare image paths | |
| 270 | + image_paths = [] | |
| 271 | + image_mappings = [] | |
| 272 | + | |
| 273 | + for product in products: | |
| 274 | + image_path = self.image_dir / product["imagePath"] | |
| 275 | + image_paths.append(image_path) | |
| 276 | + image_mappings.append( | |
| 277 | + { | |
| 278 | + "product_id": product["id"], | |
| 279 | + "image_path": product["imagePath"], | |
| 280 | + } | |
| 281 | + ) | |
| 282 | + | |
| 283 | + # Generate embeddings | |
| 284 | + embeddings = self.embedding_service.get_image_embeddings_batch( | |
| 285 | + image_paths, batch_size=batch_size | |
| 286 | + ) | |
| 287 | + | |
| 288 | + # Prepare data for Milvus (with metadata) | |
| 289 | + milvus_data = [] | |
| 290 | + for idx, (mapping, embedding) in enumerate( | |
| 291 | + zip(image_mappings, embeddings) | |
| 292 | + ): | |
| 293 | + if embedding is not None: | |
| 294 | + product_id = mapping["product_id"] | |
| 295 | + product = self.products[product_id] | |
| 296 | + | |
| 297 | + milvus_data.append( | |
| 298 | + { | |
| 299 | + "id": product_id, | |
| 300 | + "image_path": mapping["image_path"], | |
| 301 | + "embedding": embedding, | |
| 302 | + # Product metadata | |
| 303 | + "productDisplayName": product["productDisplayName"][ | |
| 304 | + :500 | |
| 305 | + ], | |
| 306 | + "gender": product["gender"][:50], | |
| 307 | + "masterCategory": product["masterCategory"][:100], | |
| 308 | + "subCategory": product["subCategory"][:100], | |
| 309 | + "articleType": product["articleType"][:100], | |
| 310 | + "baseColour": product["baseColour"][:50], | |
| 311 | + "season": product["season"][:50], | |
| 312 | + "usage": product["usage"][:50], | |
| 313 | + "year": product["year"], | |
| 314 | + "imageUrl": product["imageUrl"], | |
| 315 | + } | |
| 316 | + ) | |
| 317 | + else: | |
| 318 | + errors += 1 | |
| 319 | + | |
| 320 | + # Insert into Milvus | |
| 321 | + if milvus_data: | |
| 322 | + count = self.milvus_service.insert_image_embeddings(milvus_data) | |
| 323 | + inserted += count | |
| 324 | + | |
| 325 | + except Exception as e: | |
| 326 | + logger.error( | |
| 327 | + f"Error processing image batch at offset {processed}: {e}" | |
| 328 | + ) | |
| 329 | + errors += len(products) | |
| 330 | + | |
| 331 | + processed += len(products) | |
| 332 | + pbar.update(len(products)) | |
| 333 | + | |
| 334 | + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} | |
| 335 | + | |
| 336 | + logger.info(f"Image embedding indexing completed: {stats}") | |
| 337 | + return stats | |
| 338 | + | |
| 339 | + def _create_product_text(self, product: Dict[str, Any]) -> str: | |
| 340 | + """Create text representation of product for embedding | |
| 341 | + | |
| 342 | + Args: | |
| 343 | + product: Product document | |
| 344 | + | |
| 345 | + Returns: | |
| 346 | + Text representation | |
| 347 | + """ | |
| 348 | + # Create a natural language description | |
| 349 | + parts = [ | |
| 350 | + product.get("productDisplayName", ""), | |
| 351 | + f"Gender: {product.get('gender', '')}", | |
| 352 | + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", | |
| 353 | + f"Type: {product.get('articleType', '')}", | |
| 354 | + f"Color: {product.get('baseColour', '')}", | |
| 355 | + f"Season: {product.get('season', '')}", | |
| 356 | + f"Usage: {product.get('usage', '')}", | |
| 357 | + ] | |
| 358 | + | |
| 359 | + text = " | ".join( | |
| 360 | + [p for p in parts if p and p != "Gender: " and p != "Color: "] | |
| 361 | + ) | |
| 362 | + return text | |
| 363 | + | |
| 364 | + def get_stats(self) -> Dict[str, Any]: | |
| 365 | + """Get indexing statistics | |
| 366 | + | |
| 367 | + Returns: | |
| 368 | + Dictionary with statistics | |
| 369 | + """ | |
| 370 | + text_stats = self.milvus_service.get_collection_stats( | |
| 371 | + self.milvus_service.text_collection_name | |
| 372 | + ) | |
| 373 | + image_stats = self.milvus_service.get_collection_stats( | |
| 374 | + self.milvus_service.image_collection_name | |
| 375 | + ) | |
| 376 | + | |
| 377 | + return { | |
| 378 | + "total_products": len(self.products), | |
| 379 | + "milvus_text": text_stats, | |
| 380 | + "milvus_image": image_stats, | |
| 381 | + } | |
| 382 | + | |
| 383 | + | |
| 384 | +def main(): | |
| 385 | + """Main function""" | |
| 386 | + import argparse | |
| 387 | + | |
| 388 | + parser = argparse.ArgumentParser(description="Index product data for search") | |
| 389 | + parser.add_argument( | |
| 390 | + "--mode", | |
| 391 | + choices=["text", "image", "both"], | |
| 392 | + default="both", | |
| 393 | + help="Which embeddings to index", | |
| 394 | + ) | |
| 395 | + parser.add_argument( | |
| 396 | + "--batch-size", type=int, default=100, help="Batch size for processing" | |
| 397 | + ) | |
| 398 | + parser.add_argument( | |
| 399 | + "--skip", type=int, default=0, help="Number of products to skip" | |
| 400 | + ) | |
| 401 | + parser.add_argument( | |
| 402 | + "--limit", type=int, default=None, help="Maximum number of products to process" | |
| 403 | + ) | |
| 404 | + parser.add_argument("--stats", action="store_true", help="Show statistics only") | |
| 405 | + | |
| 406 | + args = parser.parse_args() | |
| 407 | + | |
| 408 | + # Create indexer | |
| 409 | + indexer = DataIndexer() | |
| 410 | + | |
| 411 | + try: | |
| 412 | + # Setup services | |
| 413 | + indexer.setup() | |
| 414 | + | |
| 415 | + if args.stats: | |
| 416 | + # Show statistics | |
| 417 | + stats = indexer.get_stats() | |
| 418 | + print("\n=== Indexing Statistics ===") | |
| 419 | + print(f"\nTotal Products in CSV: {stats['total_products']}") | |
| 420 | + | |
| 421 | + print("\nMilvus Text Embeddings:") | |
| 422 | + print(f" Collection: {stats['milvus_text']['collection_name']}") | |
| 423 | + print(f" Total embeddings: {stats['milvus_text']['row_count']}") | |
| 424 | + | |
| 425 | + print("\nMilvus Image Embeddings:") | |
| 426 | + print(f" Collection: {stats['milvus_image']['collection_name']}") | |
| 427 | + print(f" Total embeddings: {stats['milvus_image']['row_count']}") | |
| 428 | + | |
| 429 | + print( | |
| 430 | + f"\nCoverage: {stats['milvus_image']['row_count'] / stats['total_products'] * 100:.1f}%" | |
| 431 | + ) | |
| 432 | + else: | |
| 433 | + # Index data | |
| 434 | + if args.mode in ["text", "both"]: | |
| 435 | + logger.info("=== Indexing Text Embeddings ===") | |
| 436 | + text_stats = indexer.index_text_embeddings( | |
| 437 | + batch_size=args.batch_size, skip=args.skip, limit=args.limit | |
| 438 | + ) | |
| 439 | + print(f"\nText Indexing Results: {text_stats}") | |
| 440 | + | |
| 441 | + if args.mode in ["image", "both"]: | |
| 442 | + logger.info("=== Indexing Image Embeddings ===") | |
| 443 | + image_stats = indexer.index_image_embeddings( | |
| 444 | + batch_size=min(args.batch_size, 32), # Smaller batch for images | |
| 445 | + skip=args.skip, | |
| 446 | + limit=args.limit, | |
| 447 | + ) | |
| 448 | + print(f"\nImage Indexing Results: {image_stats}") | |
| 449 | + | |
| 450 | + # Show final statistics | |
| 451 | + logger.info("\n=== Final Statistics ===") | |
| 452 | + stats = indexer.get_stats() | |
| 453 | + print(f"Total products: {stats['total_products']}") | |
| 454 | + print(f"Text embeddings: {stats['milvus_text']['row_count']}") | |
| 455 | + print(f"Image embeddings: {stats['milvus_image']['row_count']}") | |
| 456 | + | |
| 457 | + except KeyboardInterrupt: | |
| 458 | + logger.info("\nIndexing interrupted by user") | |
| 459 | + except Exception as e: | |
| 460 | + logger.error(f"Error during indexing: {e}", exc_info=True) | |
| 461 | + sys.exit(1) | |
| 462 | + finally: | |
| 463 | + indexer.teardown() | |
| 464 | + | |
| 465 | + | |
| 466 | +if __name__ == "__main__": | |
| 467 | + main() | ... | ... |
| 1 | +++ a/scripts/run_clip.sh | |
| ... | ... | @@ -0,0 +1,22 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - 启动 CLIP 图像向量服务 | |
| 4 | +# 图像搜索、以图搜图功能依赖此服务 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | |
| 11 | + | |
| 12 | +cd "$PROJECT_ROOT" | |
| 13 | + | |
| 14 | +if [ -d "$VENV_DIR" ]; then | |
| 15 | + set +u | |
| 16 | + source "$VENV_DIR/bin/activate" | |
| 17 | + set -u | |
| 18 | +fi | |
| 19 | + | |
| 20 | +echo "启动 CLIP 服务 (端口 51000)..." | |
| 21 | +echo "按 Ctrl+C 停止" | |
| 22 | +exec python -m clip_server launch | ... | ... |
| 1 | +++ a/scripts/run_milvus.sh | |
| ... | ... | @@ -0,0 +1,31 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - 启动 Milvus 向量数据库 | |
| 4 | +# 使用 Docker Compose 启动 Milvus 及相关依赖 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | + | |
| 11 | +cd "$PROJECT_ROOT" | |
| 12 | + | |
| 13 | +if ! command -v docker &>/dev/null; then | |
| 14 | + echo "错误: 未安装 Docker。请先运行 setup_env_centos8.sh" | |
| 15 | + exit 1 | |
| 16 | +fi | |
| 17 | + | |
| 18 | +echo "启动 Milvus..." | |
| 19 | +docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { | |
| 20 | + echo "错误: 无法执行 docker compose。请确保已安装 Docker Compose" | |
| 21 | + exit 1 | |
| 22 | +} | |
| 23 | + | |
| 24 | +echo "等待 Milvus 就绪 (约 60 秒)..." | |
| 25 | +sleep 60 | |
| 26 | + | |
| 27 | +if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then | |
| 28 | + echo "Milvus 已就绪: localhost:19530" | |
| 29 | +else | |
| 30 | + echo "提示: Milvus 可能仍在启动,请稍后执行 check_services.sh 检查" | |
| 31 | +fi | ... | ... |
| 1 | +++ a/scripts/setup_env_centos8.sh | |
| ... | ... | @@ -0,0 +1,152 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - CentOS 8 环境准备脚本 | |
| 4 | +# 准备 Python 3.12、Docker、依赖及虚拟环境 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | |
| 11 | +PYTHON_VERSION="${PYTHON_VERSION:-3.12}" | |
| 12 | + | |
| 13 | +echo "==========================================" | |
| 14 | +echo "OmniShopAgent - CentOS 8 环境准备" | |
| 15 | +echo "==========================================" | |
| 16 | +echo "项目目录: $PROJECT_ROOT" | |
| 17 | +echo "虚拟环境: $VENV_DIR" | |
| 18 | +echo "Python 版本: $PYTHON_VERSION" | |
| 19 | +echo "==========================================" | |
| 20 | + | |
| 21 | +# ----------------------------------------------------------------------------- | |
| 22 | +# 1. 安装系统依赖 | |
| 23 | +# ----------------------------------------------------------------------------- | |
| 24 | +echo "[1/4] 安装系统依赖..." | |
| 25 | +sudo dnf install -y \ | |
| 26 | + gcc \ | |
| 27 | + gcc-c++ \ | |
| 28 | + make \ | |
| 29 | + openssl-devel \ | |
| 30 | + bzip2-devel \ | |
| 31 | + libffi-devel \ | |
| 32 | + sqlite-devel \ | |
| 33 | + xz-devel \ | |
| 34 | + zlib-devel \ | |
| 35 | + readline-devel \ | |
| 36 | + tk-devel \ | |
| 37 | + libuuid-devel \ | |
| 38 | + curl \ | |
| 39 | + wget \ | |
| 40 | + git \ | |
| 41 | + tar | |
| 42 | + | |
| 43 | +# ----------------------------------------------------------------------------- | |
| 44 | +# 2. 安装 Docker(用于 Milvus) | |
| 45 | +# ----------------------------------------------------------------------------- | |
| 46 | +echo "[2/4] 检查/安装 Docker..." | |
| 47 | +if ! command -v docker &>/dev/null; then | |
| 48 | + echo " 安装 Docker..." | |
| 49 | + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 2>/dev/null || { | |
| 50 | + sudo dnf install -y dnf-plugins-core | |
| 51 | + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo | |
| 52 | + } | |
| 53 | + sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin | |
| 54 | + sudo systemctl enable docker | |
| 55 | + sudo systemctl start docker | |
| 56 | + sudo usermod -aG docker "$USER" 2>/dev/null || true | |
| 57 | + echo " Docker 已安装。请执行 'newgrp docker' 或重新登录以使用 docker 命令。" | |
| 58 | +else | |
| 59 | + echo " Docker 已安装: $(docker --version)" | |
| 60 | +fi | |
| 61 | + | |
| 62 | +# ----------------------------------------------------------------------------- | |
| 63 | +# 3. 安装 Python 3.12 | |
| 64 | +# ----------------------------------------------------------------------------- | |
| 65 | +echo "[3/4] 安装 Python $PYTHON_VERSION..." | |
| 66 | + | |
| 67 | +USE_CONDA=false | |
| 68 | +if command -v python3.12 &>/dev/null; then | |
| 69 | + echo " Python 3.12 已安装" | |
| 70 | +elif command -v conda &>/dev/null; then | |
| 71 | + echo " 使用 conda 创建 Python $PYTHON_VERSION 环境..." | |
| 72 | + conda create -n shop_agent "python=$PYTHON_VERSION" -y | |
| 73 | + USE_CONDA=true | |
| 74 | + echo " Conda 环境已创建。请执行: conda activate shop_agent" | |
| 75 | + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" | |
| 76 | + echo " 跳过 venv 创建..." | |
| 77 | +else | |
| 78 | + echo " 从源码编译 Python $PYTHON_VERSION..." | |
| 79 | + sudo dnf groupinstall -y 'Development Tools' | |
| 80 | + cd /tmp | |
| 81 | + PY_URL="https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.0.tgz" | |
| 82 | + PY_TGZ="Python-${PYTHON_VERSION}.0.tgz" | |
| 83 | + [ -f "$PY_TGZ" ] || wget -q "$PY_URL" -O "$PY_TGZ" | |
| 84 | + tar xzf "$PY_TGZ" | |
| 85 | + cd "Python-${PYTHON_VERSION}.0" | |
| 86 | + ./configure --enable-optimizations --prefix=/usr/local | |
| 87 | + make -j "$(nproc)" | |
| 88 | + sudo make altinstall | |
| 89 | + cd /tmp | |
| 90 | + rm -rf "Python-${PYTHON_VERSION}.0" "$PY_TGZ" | |
| 91 | +fi | |
| 92 | + | |
| 93 | +# ----------------------------------------------------------------------------- | |
| 94 | +# 4. 创建虚拟环境并安装依赖(非 conda 时) | |
| 95 | +# ----------------------------------------------------------------------------- | |
| 96 | +if [ "$USE_CONDA" = true ]; then | |
| 97 | + echo "[4/4] 已使用 conda,跳过 venv 创建" | |
| 98 | +else | |
| 99 | + echo "[4/4] 创建虚拟环境与安装 Python 依赖..." | |
| 100 | + | |
| 101 | + PYTHON_BIN="" | |
| 102 | + for p in python3.12 python3.11 python3; do | |
| 103 | + if command -v "$p" &>/dev/null; then | |
| 104 | + VER=$("$p" -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || echo "0") | |
| 105 | + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then | |
| 106 | + PYTHON_BIN="$p" | |
| 107 | + break | |
| 108 | + fi | |
| 109 | + fi | |
| 110 | + done | |
| 111 | + | |
| 112 | + if [ -z "$PYTHON_BIN" ]; then | |
| 113 | + echo " 错误: 未找到 Python 3.10+。若使用 conda,请先执行: conda activate shop_agent" | |
| 114 | + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" | |
| 115 | + exit 1 | |
| 116 | + fi | |
| 117 | + | |
| 118 | + if [ ! -d "$VENV_DIR" ]; then | |
| 119 | + echo " 创建虚拟环境: $VENV_DIR" | |
| 120 | + "$PYTHON_BIN" -m venv "$VENV_DIR" | |
| 121 | + fi | |
| 122 | + | |
| 123 | + echo " 激活虚拟环境并安装依赖..." | |
| 124 | + set +u | |
| 125 | + source "$VENV_DIR/bin/activate" | |
| 126 | + set -u | |
| 127 | + pip install -U pip | |
| 128 | + pip install -r "$PROJECT_ROOT/requirements.txt" | |
| 129 | + echo " Python 依赖安装完成。" | |
| 130 | +fi | |
| 131 | + | |
| 132 | +# 配置 .env | |
| 133 | +if [ ! -f "$PROJECT_ROOT/.env" ]; then | |
| 134 | + echo "" | |
| 135 | + echo " 创建 .env 配置文件..." | |
| 136 | + cp "$PROJECT_ROOT/.env.example" "$PROJECT_ROOT/.env" | |
| 137 | + echo " 请编辑 $PROJECT_ROOT/.env 配置 OPENAI_API_KEY 等参数。" | |
| 138 | +fi | |
| 139 | + | |
| 140 | +echo "" | |
| 141 | +echo "==========================================" | |
| 142 | +echo "环境准备完成!" | |
| 143 | +echo "==========================================" | |
| 144 | +echo "下一步:" | |
| 145 | +echo " 1. 编辑 .env 配置 OPENAI_API_KEY" | |
| 146 | +echo " 2. 下载数据: python scripts/download_dataset.py" | |
| 147 | +echo " 3. 启动 Milvus: ./scripts/run_milvus.sh" | |
| 148 | +echo " 4. 索引数据: python scripts/index_data.py" | |
| 149 | +echo " 5. 启动应用: ./scripts/start.sh" | |
| 150 | +echo "" | |
| 151 | +echo "激活虚拟环境: source $VENV_DIR/bin/activate" | |
| 152 | +echo "==========================================" | ... | ... |
| 1 | +++ a/scripts/start.sh | |
| ... | ... | @@ -0,0 +1,63 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - 启动脚本 | |
| 4 | +# 启动 Milvus、CLIP(可选)、Streamlit 应用 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" | |
| 11 | +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" | |
| 12 | +STREAMLIT_HOST="${STREAMLIT_HOST:-0.0.0.0}" | |
| 13 | + | |
| 14 | +cd "$PROJECT_ROOT" | |
| 15 | + | |
| 16 | +# 激活虚拟环境 | |
| 17 | +if [ -d "$VENV_DIR" ]; then | |
| 18 | + echo "激活虚拟环境: $VENV_DIR" | |
| 19 | + set +u | |
| 20 | + source "$VENV_DIR/bin/activate" | |
| 21 | + set -u | |
| 22 | +else | |
| 23 | + echo "警告: 未找到虚拟环境 $VENV_DIR,使用当前 Python" | |
| 24 | +fi | |
| 25 | + | |
| 26 | +echo "==========================================" | |
| 27 | +echo "OmniShopAgent 启动" | |
| 28 | +echo "==========================================" | |
| 29 | + | |
| 30 | +# 1. 启动 Milvus(Docker) | |
| 31 | +if command -v docker &>/dev/null; then | |
| 32 | + echo "[1/3] 检查 Milvus..." | |
| 33 | + if ! docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then | |
| 34 | + echo " 启动 Milvus (docker compose)..." | |
| 35 | + docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { | |
| 36 | + echo " 警告: 无法启动 Milvus,请手动执行: docker compose up -d" | |
| 37 | + } | |
| 38 | + echo " 等待 Milvus 就绪 (30s)..." | |
| 39 | + sleep 30 | |
| 40 | + else | |
| 41 | + echo " Milvus 已运行" | |
| 42 | + fi | |
| 43 | +else | |
| 44 | + echo "[1/3] 跳过 Milvus: 未安装 Docker" | |
| 45 | +fi | |
| 46 | + | |
| 47 | +# 2. 检查 CLIP(可选,图像搜索需要) | |
| 48 | +echo "[2/3] 检查 CLIP 服务..." | |
| 49 | +echo " 提示: 图像搜索需 CLIP。若未启动,请另开终端执行: python -m clip_server launch" | |
| 50 | +echo " 文本搜索可无需 CLIP。" | |
| 51 | + | |
| 52 | +# 3. 启动 Streamlit | |
| 53 | +echo "[3/3] 启动 Streamlit (端口 $STREAMLIT_PORT)..." | |
| 54 | +echo "" | |
| 55 | +echo " 访问: http://$STREAMLIT_HOST:$STREAMLIT_PORT" | |
| 56 | +echo " 按 Ctrl+C 停止" | |
| 57 | +echo "==========================================" | |
| 58 | + | |
| 59 | +exec streamlit run app.py \ | |
| 60 | + --server.port="$STREAMLIT_PORT" \ | |
| 61 | + --server.address="$STREAMLIT_HOST" \ | |
| 62 | + --server.headless=true \ | |
| 63 | + --browser.gatherUsageStats=false | ... | ... |
| 1 | +++ a/scripts/stop.sh | |
| ... | ... | @@ -0,0 +1,46 @@ |
| 1 | +#!/usr/bin/env bash | |
| 2 | +# ============================================================================= | |
| 3 | +# OmniShopAgent - 停止脚本 | |
| 4 | +# 停止 Streamlit 进程及 Milvus 容器 | |
| 5 | +# ============================================================================= | |
| 6 | +set -euo pipefail | |
| 7 | + | |
| 8 | +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | |
| 9 | +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" | |
| 10 | +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" | |
| 11 | + | |
| 12 | +echo "==========================================" | |
| 13 | +echo "OmniShopAgent 停止" | |
| 14 | +echo "==========================================" | |
| 15 | + | |
| 16 | +# 1. 停止 Streamlit 进程 | |
| 17 | +echo "[1/2] 停止 Streamlit..." | |
| 18 | +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then | |
| 19 | + pkill -f "streamlit run app.py" 2>/dev/null || true | |
| 20 | + echo " Streamlit 已停止" | |
| 21 | +else | |
| 22 | + echo " Streamlit 未在运行" | |
| 23 | +fi | |
| 24 | + | |
| 25 | +# 按端口查找并终止 | |
| 26 | +if command -v lsof &>/dev/null; then | |
| 27 | + PID=$(lsof -ti:$STREAMLIT_PORT 2>/dev/null || true) | |
| 28 | + if [ -n "$PID" ]; then | |
| 29 | + kill $PID 2>/dev/null || true | |
| 30 | + echo " 已终止端口 $STREAMLIT_PORT 上的进程" | |
| 31 | + fi | |
| 32 | +fi | |
| 33 | + | |
| 34 | +# 2. 可选:停止 Milvus 容器 | |
| 35 | +echo "[2/2] 停止 Milvus..." | |
| 36 | +if command -v docker &>/dev/null; then | |
| 37 | + cd "$PROJECT_ROOT" | |
| 38 | + docker compose down 2>/dev/null || docker-compose down 2>/dev/null || true | |
| 39 | + echo " Milvus 已停止" | |
| 40 | +else | |
| 41 | + echo " Docker 未安装,跳过" | |
| 42 | +fi | |
| 43 | + | |
| 44 | +echo "==========================================" | |
| 45 | +echo "OmniShopAgent 已停止" | |
| 46 | +echo "==========================================" | ... | ... |
| 1 | +++ a/技术实现报告.md | |
| ... | ... | @@ -0,0 +1,624 @@ |
| 1 | +# OmniShopAgent 项目技术实现报告 | |
| 2 | + | |
| 3 | +## 一、项目概述 | |
| 4 | + | |
| 5 | +OmniShopAgent 是一个基于 **LangGraph** 和 **ReAct 模式** 的自主多模态时尚购物智能体。系统能够自主决定调用哪些工具、维护对话状态、判断何时回复,实现智能化的商品发现与推荐。 | |
| 6 | + | |
| 7 | +### 核心特性 | |
| 8 | + | |
| 9 | +- **自主工具选择与执行**:Agent 根据用户意图自主选择并调用工具 | |
| 10 | +- **多模态搜索**:支持文本搜索 + 图像搜索 | |
| 11 | +- **对话上下文感知**:多轮对话中保持上下文记忆 | |
| 12 | +- **实时视觉分析**:基于 VLM 的图片风格分析 | |
| 13 | + | |
| 14 | +--- | |
| 15 | + | |
| 16 | +## 二、技术栈 | |
| 17 | + | |
| 18 | +| 组件 | 技术选型 | | |
| 19 | +|------|----------| | |
| 20 | +| 运行环境 | Python 3.12 | | |
| 21 | +| Agent 框架 | LangGraph 1.x | | |
| 22 | +| LLM 框架 | LangChain 1.x(支持任意 LLM,默认 gpt-4o-mini) | | |
| 23 | +| 文本向量 | text-embedding-3-small | | |
| 24 | +| 图像向量 | CLIP ViT-B/32 | | |
| 25 | +| 向量数据库 | Milvus | | |
| 26 | +| 前端 | Streamlit | | |
| 27 | +| 数据集 | Kaggle Fashion Products | | |
| 28 | + | |
| 29 | +--- | |
| 30 | + | |
| 31 | +## 三、系统架构 | |
| 32 | + | |
| 33 | +### 3.1 整体架构图 | |
| 34 | + | |
| 35 | +``` | |
| 36 | +┌─────────────────────────────────────────────────────────────────┐ | |
| 37 | +│ Streamlit 前端 (app.py) │ | |
| 38 | +└─────────────────────────────────────────────────────────────────┘ | |
| 39 | + │ | |
| 40 | + ▼ | |
| 41 | +┌─────────────────────────────────────────────────────────────────┐ | |
| 42 | +│ ShoppingAgent (shopping_agent.py) │ | |
| 43 | +│ ┌───────────────────────────────────────────────────────────┐ │ | |
| 44 | +│ │ LangGraph StateGraph + ReAct Pattern │ │ | |
| 45 | +│ │ START → Agent → [Has tool_calls?] → Tools → Agent → END │ │ | |
| 46 | +│ └───────────────────────────────────────────────────────────┘ │ | |
| 47 | +└─────────────────────────────────────────────────────────────────┘ | |
| 48 | + │ │ │ | |
| 49 | + ▼ ▼ ▼ | |
| 50 | +┌──────────────┐ ┌──────────────────┐ ┌─────────────────────┐ | |
| 51 | +│ search_ │ │ search_by_image │ │ analyze_image_style │ | |
| 52 | +│ products │ │ │ │ (OpenAI Vision) │ | |
| 53 | +└──────┬───────┘ └────────┬─────────┘ └──────────┬───────────┘ | |
| 54 | + │ │ │ | |
| 55 | + ▼ ▼ ▼ | |
| 56 | +┌─────────────────────────────────────────────────────────────────┐ | |
| 57 | +│ EmbeddingService (embedding_service.py) │ | |
| 58 | +│ OpenAI API (文本) │ CLIP Server (图像) │ | |
| 59 | +└─────────────────────────────────────────────────────────────────┘ | |
| 60 | + │ | |
| 61 | + ▼ | |
| 62 | +┌─────────────────────────────────────────────────────────────────┐ | |
| 63 | +│ MilvusService (milvus_service.py) │ | |
| 64 | +│ text_embeddings 集合 │ image_embeddings 集合 │ | |
| 65 | +└─────────────────────────────────────────────────────────────────┘ | |
| 66 | +``` | |
| 67 | + | |
| 68 | +### 3.2 Agent 流程图(LangGraph) | |
| 69 | + | |
| 70 | +```mermaid | |
| 71 | +graph LR | |
| 72 | + START --> Agent | |
| 73 | + Agent -->|Has tool_calls| Tools | |
| 74 | + Agent -->|No tool_calls| END | |
| 75 | + Tools --> Agent | |
| 76 | +``` | |
| 77 | + | |
| 78 | +--- | |
| 79 | + | |
| 80 | +## 四、关键代码实现 | |
| 81 | + | |
| 82 | +### 4.1 Agent 核心实现(shopping_agent.py) | |
| 83 | + | |
| 84 | +#### 4.1.1 状态定义 | |
| 85 | + | |
| 86 | +```python | |
| 87 | +from typing_extensions import Annotated, TypedDict | |
| 88 | +from langgraph.graph.message import add_messages | |
| 89 | + | |
| 90 | +class AgentState(TypedDict): | |
| 91 | + """State for the shopping agent with message accumulation""" | |
| 92 | + messages: Annotated[Sequence[BaseMessage], add_messages] | |
| 93 | + current_image_path: Optional[str] # Track uploaded image | |
| 94 | +``` | |
| 95 | + | |
| 96 | +- `messages` 使用 `add_messages` 实现消息累加,支持多轮对话 | |
| 97 | +- `current_image_path` 存储当前上传的图片路径供工具使用 | |
| 98 | + | |
| 99 | +#### 4.1.2 LangGraph 图构建 | |
| 100 | + | |
| 101 | +```python | |
| 102 | +def _build_graph(self): | |
| 103 | + """Build the LangGraph StateGraph""" | |
| 104 | + | |
| 105 | + def agent_node(state: AgentState): | |
| 106 | + """Agent decision node - decides which tools to call or when to respond""" | |
| 107 | + messages = state["messages"] | |
| 108 | + if not any(isinstance(m, SystemMessage) for m in messages): | |
| 109 | + messages = [SystemMessage(content=system_prompt)] + list(messages) | |
| 110 | + response = self.llm_with_tools.invoke(messages) | |
| 111 | + return {"messages": [response]} | |
| 112 | + | |
| 113 | + tool_node = ToolNode(self.tools) | |
| 114 | + | |
| 115 | + def should_continue(state: AgentState): | |
| 116 | + """Determine if agent should continue or end""" | |
| 117 | + last_message = state["messages"][-1] | |
| 118 | + if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| 119 | + return "tools" | |
| 120 | + return END | |
| 121 | + | |
| 122 | + workflow = StateGraph(AgentState) | |
| 123 | + workflow.add_node("agent", agent_node) | |
| 124 | + workflow.add_node("tools", tool_node) | |
| 125 | + workflow.add_edge(START, "agent") | |
| 126 | + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) | |
| 127 | + workflow.add_edge("tools", "agent") | |
| 128 | + | |
| 129 | + checkpointer = MemorySaver() | |
| 130 | + return workflow.compile(checkpointer=checkpointer) | |
| 131 | +``` | |
| 132 | + | |
| 133 | +关键点: | |
| 134 | +- **agent_node**:将消息传入 LLM,由 LLM 决定是否调用工具 | |
| 135 | +- **should_continue**:若有 `tool_calls` 则进入工具节点,否则结束 | |
| 136 | +- **MemorySaver**:按 `thread_id` 持久化对话状态 | |
| 137 | + | |
| 138 | +#### 4.1.3 System Prompt 设计 | |
| 139 | + | |
| 140 | +```python | |
| 141 | +system_prompt = """You are an intelligent fashion shopping assistant. You can: | |
| 142 | +1. Search for products by text description (use search_products) | |
| 143 | +2. Find visually similar products from images (use search_by_image) | |
| 144 | +3. Analyze image style and attributes (use analyze_image_style) | |
| 145 | + | |
| 146 | +When a user asks about products: | |
| 147 | +- For text queries: use search_products directly | |
| 148 | +- For image uploads: decide if you need to analyze_image_style first, then search | |
| 149 | +- You can call multiple tools in sequence if needed | |
| 150 | +- Always provide helpful, friendly responses | |
| 151 | + | |
| 152 | +CRITICAL FORMATTING RULES: | |
| 153 | +When presenting product results, you MUST use this EXACT format for EACH product: | |
| 154 | +1. [Product Name] | |
| 155 | + ID: [Product ID Number] | |
| 156 | + Category: [Category] | |
| 157 | + Color: [Color] | |
| 158 | + Gender: [Gender] | |
| 159 | + (Include Season, Usage, Relevance if available) | |
| 160 | +...""" | |
| 161 | +``` | |
| 162 | + | |
| 163 | +通过 system prompt 约束工具使用和输出格式,保证前端可正确解析产品信息。 | |
| 164 | + | |
| 165 | +#### 4.1.4 对话入口与流式处理 | |
| 166 | + | |
| 167 | +```python | |
| 168 | +def chat(self, query: str, image_path: Optional[str] = None) -> dict: | |
| 169 | + # Build input message | |
| 170 | + message_content = query | |
| 171 | + if image_path: | |
| 172 | + message_content = f"{query}\n[User uploaded image: {image_path}]" | |
| 173 | + | |
| 174 | + config = {"configurable": {"thread_id": self.session_id}} | |
| 175 | + input_state = { | |
| 176 | + "messages": [HumanMessage(content=message_content)], | |
| 177 | + "current_image_path": image_path, | |
| 178 | + } | |
| 179 | + | |
| 180 | + tool_calls = [] | |
| 181 | + for event in self.graph.stream(input_state, config=config): | |
| 182 | + if "agent" in event: | |
| 183 | + for msg in event["agent"].get("messages", []): | |
| 184 | + if hasattr(msg, "tool_calls") and msg.tool_calls: | |
| 185 | + for tc in msg.tool_calls: | |
| 186 | + tool_calls.append({"name": tc["name"], "args": tc.get("args", {})}) | |
| 187 | + if "tools" in event: | |
| 188 | + # 记录工具执行结果 | |
| 189 | + ... | |
| 190 | + | |
| 191 | + final_state = self.graph.get_state(config) | |
| 192 | + response_text = final_state.values["messages"][-1].content | |
| 193 | + | |
| 194 | + return {"response": response_text, "tool_calls": tool_calls, "error": False} | |
| 195 | +``` | |
| 196 | + | |
| 197 | +--- | |
| 198 | + | |
| 199 | +### 4.2 搜索工具实现(search_tools.py) | |
| 200 | + | |
| 201 | +#### 4.2.1 文本语义搜索 | |
| 202 | + | |
| 203 | +```python | |
| 204 | +@tool | |
| 205 | +def search_products(query: str, limit: int = 5) -> str: | |
| 206 | + """Search for fashion products using natural language descriptions.""" | |
| 207 | + try: | |
| 208 | + embedding_service = get_embedding_service() | |
| 209 | + milvus_service = get_milvus_service() | |
| 210 | + | |
| 211 | + query_embedding = embedding_service.get_text_embedding(query) | |
| 212 | + | |
| 213 | + results = milvus_service.search_similar_text( | |
| 214 | + query_embedding=query_embedding, | |
| 215 | + limit=min(limit, 20), | |
| 216 | + filters=None, | |
| 217 | + output_fields=[ | |
| 218 | + "id", "productDisplayName", "gender", "masterCategory", | |
| 219 | + "subCategory", "articleType", "baseColour", "season", "usage", | |
| 220 | + ], | |
| 221 | + ) | |
| 222 | + | |
| 223 | + if not results: | |
| 224 | + return "No products found matching your search." | |
| 225 | + | |
| 226 | + output = f"Found {len(results)} product(s):\n\n" | |
| 227 | + for idx, product in enumerate(results, 1): | |
| 228 | + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" | |
| 229 | + output += f" ID: {product.get('id', 'N/A')}\n" | |
| 230 | + output += f" Category: {product.get('masterCategory')} > {product.get('subCategory')} > {product.get('articleType')}\n" | |
| 231 | + output += f" Color: {product.get('baseColour')}\n" | |
| 232 | + output += f" Gender: {product.get('gender')}\n" | |
| 233 | + if "distance" in product: | |
| 234 | + similarity = 1 - product["distance"] | |
| 235 | + output += f" Relevance: {similarity:.2%}\n" | |
| 236 | + output += "\n" | |
| 237 | + | |
| 238 | + return output.strip() | |
| 239 | + except Exception as e: | |
| 240 | + return f"Error searching products: {str(e)}" | |
| 241 | +``` | |
| 242 | + | |
| 243 | +#### 4.2.2 图像相似度搜索 | |
| 244 | + | |
| 245 | +```python | |
| 246 | +@tool | |
| 247 | +def search_by_image(image_path: str, limit: int = 5) -> str: | |
| 248 | + """Find similar fashion products using an image.""" | |
| 249 | + if not Path(image_path).exists(): | |
| 250 | + return f"Error: Image file not found at '{image_path}'" | |
| 251 | + | |
| 252 | + embedding_service = get_embedding_service() | |
| 253 | + milvus_service = get_milvus_service() | |
| 254 | + | |
| 255 | + if not embedding_service.clip_client: | |
| 256 | + embedding_service.connect_clip() | |
| 257 | + | |
| 258 | + image_embedding = embedding_service.get_image_embedding(image_path) | |
| 259 | + | |
| 260 | + results = milvus_service.search_similar_images( | |
| 261 | + query_embedding=image_embedding, | |
| 262 | + limit=min(limit + 1, 21), | |
| 263 | + output_fields=[...], | |
| 264 | + ) | |
| 265 | + | |
| 266 | + # 过滤掉查询图像本身(如上传的是商品库中的图) | |
| 267 | + query_id = Path(image_path).stem | |
| 268 | + filtered_results = [r for r in results if Path(r.get("image_path", "")).stem != query_id] | |
| 269 | + filtered_results = filtered_results[:limit] | |
| 270 | + | |
| 271 | + | |
| 272 | +``` | |
| 273 | + | |
| 274 | +#### 4.2.3 视觉分析(VLM) | |
| 275 | + | |
| 276 | +```python | |
| 277 | +@tool | |
| 278 | +def analyze_image_style(image_path: str) -> str: | |
| 279 | + """Analyze a fashion product image using AI vision to extract detailed style information.""" | |
| 280 | + with open(img_path, "rb") as image_file: | |
| 281 | + image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| 282 | + | |
| 283 | + prompt = """Analyze this fashion product image and provide a detailed description. | |
| 284 | +Include: | |
| 285 | +- Product type (e.g., shirt, dress, shoes, pants, bag) | |
| 286 | +- Primary colors | |
| 287 | +- Style/design (e.g., casual, formal, sporty, vintage, modern) | |
| 288 | +- Pattern or texture (e.g., plain, striped, checked, floral) | |
| 289 | +- Key features (e.g., collar type, sleeve length, fit) | |
| 290 | +- Material appearance (if obvious, e.g., denim, cotton, leather) | |
| 291 | +- Suitable occasion (e.g., office wear, party, casual, sports) | |
| 292 | +Provide a comprehensive yet concise description (3-4 sentences).""" | |
| 293 | + | |
| 294 | + client = get_openai_client() | |
| 295 | + response = client.chat.completions.create( | |
| 296 | + model="gpt-4o-mini", | |
| 297 | + messages=[{ | |
| 298 | + "role": "user", | |
| 299 | + "content": [ | |
| 300 | + {"type": "text", "text": prompt}, | |
| 301 | + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}", "detail": "high"}}, | |
| 302 | + ], | |
| 303 | + }], | |
| 304 | + max_tokens=500, | |
| 305 | + temperature=0.3, | |
| 306 | + ) | |
| 307 | + | |
| 308 | + return response.choices[0].message.content.strip() | |
| 309 | +``` | |
| 310 | + | |
| 311 | +--- | |
| 312 | + | |
| 313 | +### 4.3 向量服务实现 | |
| 314 | + | |
| 315 | +#### 4.3.1 EmbeddingService(embedding_service.py) | |
| 316 | + | |
| 317 | +```python | |
| 318 | +class EmbeddingService: | |
| 319 | + def get_text_embedding(self, text: str) -> List[float]: | |
| 320 | + """OpenAI text-embedding-3-small""" | |
| 321 | + response = self.openai_client.embeddings.create( | |
| 322 | + input=text, model=self.text_embedding_model | |
| 323 | + ) | |
| 324 | + return response.data[0].embedding | |
| 325 | + | |
| 326 | + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: | |
| 327 | + """CLIP 图像向量""" | |
| 328 | + if not self.clip_client: | |
| 329 | + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") | |
| 330 | + result = self.clip_client.encode([str(image_path)]) | |
| 331 | + if isinstance(result, np.ndarray): | |
| 332 | + embedding = result[0].tolist() if len(result.shape) > 1 else result.tolist() | |
| 333 | + else: | |
| 334 | + embedding = result[0].embedding.tolist() | |
| 335 | + return embedding | |
| 336 | + | |
| 337 | + def get_text_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]: | |
| 338 | + """批量文本嵌入,用于索引""" | |
| 339 | + for i in range(0, len(texts), batch_size): | |
| 340 | + batch = texts[i : i + batch_size] | |
| 341 | + response = self.openai_client.embeddings.create(input=batch, ...) | |
| 342 | + embeddings = [item.embedding for item in response.data] | |
| 343 | + all_embeddings.extend(embeddings) | |
| 344 | + return all_embeddings | |
| 345 | +``` | |
| 346 | + | |
| 347 | +#### 4.3.2 MilvusService(milvus_service.py) | |
| 348 | + | |
| 349 | +**文本集合 Schema:** | |
| 350 | + | |
| 351 | +```python | |
| 352 | +schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True) | |
| 353 | +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | |
| 354 | +schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) | |
| 355 | +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim) # 1536 | |
| 356 | +schema.add_field(field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500) | |
| 357 | +schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) | |
| 358 | +schema.add_field(field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100) | |
| 359 | +# ... 更多元数据字段 | |
| 360 | +``` | |
| 361 | + | |
| 362 | +**图像集合 Schema:** | |
| 363 | + | |
| 364 | +```python | |
| 365 | +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) | |
| 366 | +schema.add_field(field_name="image_path", datatype=DataType.VARCHAR, max_length=500) | |
| 367 | +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim) # 512 | |
| 368 | +# ... 产品元数据 | |
| 369 | +``` | |
| 370 | + | |
| 371 | +**相似度搜索:** | |
| 372 | + | |
| 373 | +```python | |
| 374 | +def search_similar_text(self, query_embedding, limit=10, output_fields=None): | |
| 375 | + results = self.client.search( | |
| 376 | + collection_name=self.text_collection_name, | |
| 377 | + data=[query_embedding], | |
| 378 | + limit=limit, | |
| 379 | + output_fields=output_fields, | |
| 380 | + ) | |
| 381 | + formatted_results = [] | |
| 382 | + for hit in results[0]: | |
| 383 | + result = {"id": hit.get("id"), "distance": hit.get("distance")} | |
| 384 | + entity = hit.get("entity", {}) | |
| 385 | + for field in output_fields: | |
| 386 | + if field in entity: | |
| 387 | + result[field] = entity.get(field) | |
| 388 | + formatted_results.append(result) | |
| 389 | + return formatted_results | |
| 390 | +``` | |
| 391 | + | |
| 392 | +--- | |
| 393 | + | |
| 394 | +### 4.4 数据索引脚本(index_data.py) | |
| 395 | + | |
| 396 | +#### 4.4.1 产品数据加载 | |
| 397 | + | |
| 398 | +```python | |
| 399 | +def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: | |
| 400 | + products = {} | |
| 401 | + # 加载 images.csv 映射 | |
| 402 | + with open(self.images_csv, "r") as f: | |
| 403 | + images_dict = {int(row["filename"].split(".")[0]): row["link"] for row in csv.DictReader(f)} | |
| 404 | + | |
| 405 | + # 加载 styles.csv | |
| 406 | + with open(self.styles_csv, "r") as f: | |
| 407 | + for row in csv.DictReader(f): | |
| 408 | + product_id = int(row["id"]) | |
| 409 | + products[product_id] = { | |
| 410 | + "id": product_id, | |
| 411 | + "gender": row.get("gender", ""), | |
| 412 | + "masterCategory": row.get("masterCategory", ""), | |
| 413 | + "subCategory": row.get("subCategory", ""), | |
| 414 | + "articleType": row.get("articleType", ""), | |
| 415 | + "baseColour": row.get("baseColour", ""), | |
| 416 | + "season": row.get("season", ""), | |
| 417 | + "usage": row.get("usage", ""), | |
| 418 | + "productDisplayName": row.get("productDisplayName", ""), | |
| 419 | + "imagePath": f"{product_id}.jpg", | |
| 420 | + } | |
| 421 | + return products | |
| 422 | +``` | |
| 423 | + | |
| 424 | +#### 4.4.2 文本索引 | |
| 425 | + | |
| 426 | +```python | |
| 427 | +def _create_product_text(self, product: Dict[str, Any]) -> str: | |
| 428 | + """构造产品文本用于 embedding""" | |
| 429 | + parts = [ | |
| 430 | + product.get("productDisplayName", ""), | |
| 431 | + f"Gender: {product.get('gender', '')}", | |
| 432 | + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", | |
| 433 | + f"Type: {product.get('articleType', '')}", | |
| 434 | + f"Color: {product.get('baseColour', '')}", | |
| 435 | + f"Season: {product.get('season', '')}", | |
| 436 | + f"Usage: {product.get('usage', '')}", | |
| 437 | + ] | |
| 438 | + return " | ".join([p for p in parts if p and p != "Gender: " and p != "Color: "]) | |
| 439 | +``` | |
| 440 | + | |
| 441 | +#### 4.4.3 批量索引流程 | |
| 442 | + | |
| 443 | +```python | |
| 444 | +# 文本索引 | |
| 445 | +texts = [self._create_product_text(p) for p in products] | |
| 446 | +embeddings = self.embedding_service.get_text_embeddings_batch(texts, batch_size=50) | |
| 447 | +milvus_data = [{ | |
| 448 | + "id": product_id, | |
| 449 | + "text": text[:2000], | |
| 450 | + "embedding": embedding, | |
| 451 | + "productDisplayName": product["productDisplayName"][:500], | |
| 452 | + "gender": product["gender"][:50], | |
| 453 | + # ... 其他元数据 | |
| 454 | +} for product_id, text, embedding in zip(...)] | |
| 455 | +self.milvus_service.insert_text_embeddings(milvus_data) | |
| 456 | + | |
| 457 | +# 图像索引 | |
| 458 | +image_paths = [self.image_dir / p["imagePath"] for p in products] | |
| 459 | +embeddings = self.embedding_service.get_image_embeddings_batch(image_paths, batch_size=32) | |
| 460 | +# 类似插入 image_embeddings 集合 | |
| 461 | +``` | |
| 462 | + | |
| 463 | +--- | |
| 464 | + | |
| 465 | +### 4.5 Streamlit 前端(app.py) | |
| 466 | + | |
| 467 | +#### 4.5.1 会话与 Agent 初始化 | |
| 468 | + | |
| 469 | +```python | |
| 470 | +def initialize_session(): | |
| 471 | + if "session_id" not in st.session_state: | |
| 472 | + st.session_state.session_id = str(uuid.uuid4()) | |
| 473 | + if "shopping_agent" not in st.session_state: | |
| 474 | + st.session_state.shopping_agent = ShoppingAgent(session_id=st.session_state.session_id) | |
| 475 | + if "messages" not in st.session_state: | |
| 476 | + st.session_state.messages = [] | |
| 477 | + if "uploaded_image" not in st.session_state: | |
| 478 | + st.session_state.uploaded_image = None | |
| 479 | +``` | |
| 480 | + | |
| 481 | +#### 4.5.2 产品信息解析 | |
| 482 | + | |
| 483 | +```python | |
| 484 | +def extract_products_from_response(response: str) -> list: | |
| 485 | + """从 Agent 回复中解析产品信息""" | |
| 486 | + products = [] | |
| 487 | + for line in response.split("\n"): | |
| 488 | + if re.match(r"^\*?\*?\d+\.\s+", line): | |
| 489 | + if current_product: | |
| 490 | + products.append(current_product) | |
| 491 | + current_product = {"name": re.sub(r"^\*?\*?\d+\.\s+", "", line).replace("**", "").strip()} | |
| 492 | + elif "ID:" in line: | |
| 493 | + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) | |
| 494 | + if id_match: | |
| 495 | + current_product["id"] = id_match.group(1) | |
| 496 | + elif "Category:" in line: | |
| 497 | + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) | |
| 498 | + if cat_match: | |
| 499 | + current_product["category"] = cat_match.group(1).strip() | |
| 500 | + # ... Color, Gender, Season, Usage, Similarity/Relevance | |
| 501 | + return products | |
| 502 | +``` | |
| 503 | + | |
| 504 | +#### 4.5.3 多轮对话中的图片引用 | |
| 505 | + | |
| 506 | +```python | |
| 507 | +# 用户输入 "make them formal" 时,若上一条消息有图片,则引用该图片 | |
| 508 | +if any(ref in query_lower for ref in ["this", "that", "the image", "it"]): | |
| 509 | + for msg in reversed(st.session_state.messages): | |
| 510 | + if msg.get("role") == "user" and msg.get("image_path"): | |
| 511 | + image_path = msg["image_path"] | |
| 512 | + break | |
| 513 | +``` | |
| 514 | + | |
| 515 | +--- | |
| 516 | + | |
| 517 | +### 4.6 配置管理(config.py) | |
| 518 | + | |
| 519 | +```python | |
| 520 | +class Settings(BaseSettings): | |
| 521 | + openai_api_key: str | |
| 522 | + openai_model: str = "gpt-4o-mini" | |
| 523 | + openai_embedding_model: str = "text-embedding-3-small" | |
| 524 | + clip_server_url: str = "grpc://localhost:51000" | |
| 525 | + milvus_uri: str = "http://localhost:19530" | |
| 526 | + text_collection_name: str = "text_embeddings" | |
| 527 | + image_collection_name: str = "image_embeddings" | |
| 528 | + text_dim: int = 1536 | |
| 529 | + image_dim: int = 512 | |
| 530 | + | |
| 531 | + @property | |
| 532 | + def milvus_uri_absolute(self) -> str: | |
| 533 | + """支持 Milvus Standalone 和 Milvus Lite""" | |
| 534 | + if self.milvus_uri.startswith(("http://", "https://")): | |
| 535 | + return self.milvus_uri | |
| 536 | + if self.milvus_uri.startswith("./"): | |
| 537 | + return os.path.join(base_dir, self.milvus_uri[2:]) | |
| 538 | + return self.milvus_uri | |
| 539 | + | |
| 540 | + class Config: | |
| 541 | + env_file = ".env" | |
| 542 | +``` | |
| 543 | + | |
| 544 | +--- | |
| 545 | + | |
| 546 | +## 五、部署与运行 | |
| 547 | + | |
| 548 | +### 5.1 依赖服务 | |
| 549 | + | |
| 550 | +```yaml | |
| 551 | +# docker-compose.yml 提供 | |
| 552 | +- etcd: 元数据存储 | |
| 553 | +- minio: 对象存储 | |
| 554 | +- milvus-standalone: 向量数据库 | |
| 555 | +- attu: Milvus 管理界面 | |
| 556 | +``` | |
| 557 | + | |
| 558 | +### 5.2 启动流程 | |
| 559 | + | |
| 560 | +```bash | |
| 561 | +# 1. 环境 | |
| 562 | +pip install -r requirements.txt | |
| 563 | +cp .env.example .env # 配置 OPENAI_API_KEY | |
| 564 | + | |
| 565 | +# 2. 下载数据 | |
| 566 | +python scripts/download_dataset.py # Kaggle Fashion Product Images Dataset | |
| 567 | + | |
| 568 | +# 3. 启动 CLIP 服务(需单独运行) | |
| 569 | +python -m clip_server | |
| 570 | + | |
| 571 | +# 4. 启动 Milvus | |
| 572 | +docker-compose up | |
| 573 | + | |
| 574 | +# 5. 索引数据 | |
| 575 | +python scripts/index_data.py | |
| 576 | + | |
| 577 | +# 6. 启动应用 | |
| 578 | +streamlit run app.py | |
| 579 | +``` | |
| 580 | + | |
| 581 | +--- | |
| 582 | + | |
| 583 | +## 六、典型交互流程 | |
| 584 | + | |
| 585 | +| 场景 | 用户输入 | Agent 行为 | 工具调用 | | |
| 586 | +|------|----------|------------|----------| | |
| 587 | +| 文本搜索 | "winter coats for women" | 直接文本搜索 | `search_products("winter coats women")` | | |
| 588 | +| 图像搜索 | [上传图片] "find similar" | 图像相似度搜索 | `search_by_image(path)` | | |
| 589 | +| 风格分析+搜索 | [上传复古夹克] "what style? find matching pants" | 先分析风格再搜索 | `analyze_image_style(path)` → `search_products("vintage pants casual")` | | |
| 590 | +| 多轮上下文 | [第1轮] "show me red dresses"<br>[第2轮] "make them formal" | 结合上下文 | `search_products("red formal dresses")` | | |
| 591 | + | |
| 592 | +--- | |
| 593 | + | |
| 594 | +## 七、设计要点总结 | |
| 595 | + | |
| 596 | +1. **ReAct 模式**:Agent 自主决定何时调用工具、调用哪些工具、是否继续调用。 | |
| 597 | +2. **LangGraph 状态图**:`START → Agent → [条件] → Tools → Agent → END`,支持多轮工具调用。 | |
| 598 | +3. **多模态**:文本 + 图像 + VLM 分析,覆盖文本搜索、以图搜图、风格理解。 | |
| 599 | +4. **双向量集合**:Milvus 中 text_embeddings / image_embeddings 分别存储,支持不同模态的检索。 | |
| 600 | +5. **会话持久化**:`MemorySaver` + `thread_id` 实现多轮对话记忆。 | |
| 601 | +6. **格式约束**:System prompt 严格限制产品输出格式,便于前端解析和展示。 | |
| 602 | + | |
| 603 | +--- | |
| 604 | + | |
| 605 | +## 八、附录:项目结构 | |
| 606 | + | |
| 607 | +``` | |
| 608 | +OmniShopAgent/ | |
| 609 | +├── app/ | |
| 610 | +│ ├── agents/ | |
| 611 | +│ │ └── shopping_agent.py | |
| 612 | +│ ├── config.py | |
| 613 | +│ ├── services/ | |
| 614 | +│ │ ├── embedding_service.py | |
| 615 | +│ │ └── milvus_service.py | |
| 616 | +│ └── tools/ | |
| 617 | +│ └── search_tools.py | |
| 618 | +├── scripts/ | |
| 619 | +│ ├── download_dataset.py | |
| 620 | +│ └── index_data.py | |
| 621 | +├── app.py | |
| 622 | +├── docker-compose.yml | |
| 623 | +└── requirements.txt | |
| 624 | +``` | ... | ... |