From e7f2b2409cd4db20be799b9ce9995e939c0d4807 Mon Sep 17 00:00:00 2001 From: tangwang Date: Thu, 12 Feb 2026 17:25:05 +0800 Subject: [PATCH] first commit --- .env.example | 47 +++++++++++++++++++++++++++++++++++++++++++++++ .gitignore | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ .python-version | 1 + LICENSE | 21 +++++++++++++++++++++ README.md | 161 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app.py | 732 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/__init__.py | 5 +++++ app/agents/__init__.py | 10 ++++++++++ app/agents/shopping_agent.py | 272 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/config.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/services/__init__.py | 14 ++++++++++++++ app/services/embedding_service.py | 293 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/services/milvus_service.py | 480 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/tools/__init__.py | 17 +++++++++++++++++ app/tools/search_tools.py | 294 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ demo.pdf | Bin 0 -> 452576 bytes docker-compose.yml | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ docs/DEPLOY_CENTOS8.md | 216 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ docs/LANGCHAIN_1.0_MIGRATION.md | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ docs/Skills实现方案-LangChain1.0.md | 318 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 40 ++++++++++++++++++++++++++++++++++++++++ scripts/check_services.sh | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/download_dataset.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/index_data.py | 467 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/run_clip.sh | 22 ++++++++++++++++++++++ scripts/run_milvus.sh | 31 +++++++++++++++++++++++++++++++ scripts/setup_env_centos8.sh | 152 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/start.sh | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ scripts/stop.sh | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 技术实现报告.md | 624 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 30 files changed, 4836 insertions(+), 0 deletions(-) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 .python-version create mode 100644 LICENSE create mode 100644 README.md create mode 100644 app.py create mode 100644 app/__init__.py create mode 100644 app/agents/__init__.py create mode 100644 app/agents/shopping_agent.py create mode 100644 app/config.py create mode 100644 app/services/__init__.py create mode 100644 app/services/embedding_service.py create mode 100644 app/services/milvus_service.py create mode 100644 app/tools/__init__.py create mode 100644 app/tools/search_tools.py create mode 100644 demo.pdf create mode 100644 docker-compose.yml create mode 100644 docs/DEPLOY_CENTOS8.md create mode 100644 docs/LANGCHAIN_1.0_MIGRATION.md create mode 100644 docs/Skills实现方案-LangChain1.0.md create mode 100644 requirements.txt create mode 100755 scripts/check_services.sh create mode 100644 scripts/download_dataset.py create mode 100644 scripts/index_data.py create mode 100755 scripts/run_clip.sh create mode 100755 scripts/run_milvus.sh create mode 100755 scripts/setup_env_centos8.sh create mode 100755 scripts/start.sh create mode 100755 scripts/stop.sh create mode 100644 技术实现报告.md diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..1996389 --- /dev/null +++ b/.env.example @@ -0,0 +1,47 @@ +# ==================== +# OpenAI Configuration +# ==================== +OPENAI_API_KEY= +OPENAI_MODEL=gpt-4o-mini +OPENAI_EMBEDDING_MODEL=text-embedding-3-small +OPENAI_TEMPERATURE=1 +OPENAI_MAX_TOKENS=1000 + +# ==================== +# CLIP Server Configuration +# ==================== +CLIP_SERVER_URL=grpc://localhost:51000 + +# ==================== +# Milvus Configuration +# ==================== +MILVUS_HOST=localhost +MILVUS_PORT=19530 + +# Collection settings +TEXT_COLLECTION_NAME=text_embeddings +IMAGE_COLLECTION_NAME=image_embeddings +TEXT_DIM=1536 +IMAGE_DIM=512 + +# ==================== +# Search Configuration +# ==================== +TOP_K_RESULTS=30 +SIMILARITY_THRESHOLD=0.6 + +# ==================== +# Application Configuration +# ==================== +APP_HOST=0.0.0.0 +APP_PORT=8000 +DEBUG=true +LOG_LEVEL=INFO + +# ==================== +# Data Paths +# ==================== +RAW_DATA_PATH=./data/raw +PROCESSED_DATA_PATH=./data/processed +IMAGE_DATA_PATH=./data/images + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9908c6b --- /dev/null +++ b/.gitignore @@ -0,0 +1,83 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual Environment +venv/ +env/ +ENV/ +.venv + +# Environment Variables +.env +*.env +!.env.example + +# IDEs +.vscode/ +.idea/ +.cursor/ +*.swp +*.swo +*~ +.DS_Store + +# Data Files - ignore everything in data/ except .gitkeep files +data/** +!data/ +!data/raw/ +!data/processed/ +!data/images/ +!data/**/.gitkeep + +# Database +*.db +*.sqlite +*.sqlite3 +data/milvus_lite.db + +# Docker volumes +volumes/ + +# Logs +*.log +logs/ +nohup.out + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ + +# Jupyter +.ipynb_checkpoints/ +*.ipynb + +# Model caches +.cache/ +models/ + +# Temporary files +tmp/ +temp/ +*.tmp diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..487d37d --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 zhangruotian + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c060d0a --- /dev/null +++ b/README.md @@ -0,0 +1,161 @@ +# OmniShopAgent + +An autonomous multi-modal fashion shopping agent powered by **LangGraph** and **ReAct pattern**. + +## Demo + +📄 **[demo.pdf](./demo.pdf)** + +## Overview + +OmniShopAgent autonomously decides which tools to call, maintains conversation state, and determines when to respond. Built with **LangGraph**, it uses agentic patterns for intelligent product discovery. + +**Key Features:** +- Autonomous tool selection and execution +- Multi-modal search (text + image) +- Conversational context awareness +- Real-time visual analysis + +## Tech Stack + +| Component | Technology | +|-----------|-----------| +| **Agent Framework** | LangGraph | +| **LLM** | any LLM supported by LangChain | +| **Text Embedding** | text-embedding-3-small | +| **Image Embedding** | CLIP ViT-B/32 | +| **Vector Database** | Milvus | +| **Frontend** | Streamlit | +| **Dataset** | Kaggle Fashion Products | + +## Architecture + +**Agent Flow:** + +```mermaid +graph LR + START --> Agent + Agent -->|Has tool_calls| Tools + Agent -->|No tool_calls| END + Tools --> Agent + + subgraph "Agent Node" + A[Receive Messages] --> B[LLM Reasoning] + B --> C{Need Tools?} + C -->|Yes| D[Generate tool_calls] + C -->|No| E[Generate Response] + end + + subgraph "Tool Node" + F[Execute Tools] --> G[Return ToolMessage] + end +``` + +**Available Tools:** +- `search_products(query)` - Text-based semantic search +- `search_by_image(image_path)` - Visual similarity search +- `analyze_image_style(image_path)` - VLM style analysis + + + +## Examples + +**Text Search:** +``` +User: "winter coats for women" +Agent: search_products("winter coats women") → Returns 5 products +``` + +**Image Upload:** +``` +User: [uploads sneaker photo] "find similar" +Agent: search_by_image(path) → Returns visually similar shoes +``` + +**Style Analysis + Search:** +``` +User: [uploads vintage jacket] "what style is this? find matching pants" +Agent: analyze_image_style(path) → "Vintage denim bomber..." + search_products("vintage pants casual") → Returns matching items +``` + +**Multi-turn Context:** +``` +Turn 1: "show me red dresses" +Agent: search_products("red dresses") → Results + +Turn 2: "make them formal" +Agent: [remembers context] → search_products("red formal dresses") → Results +``` + +**Complex Reasoning:** +``` +User: [uploads office outfit] "I like the shirt but need something more casual" +Agent: analyze_image_style(path) → Extracts shirt details + search_products("casual shirt [color] [style]") → Returns casual alternatives +``` + +## Installation + +**Prerequisites:** +- Python 3.12+ (LangChain 1.x 要求 Python 3.10+) +- OpenAI API Key +- Docker & Docker Compose + +### 1. Setup Environment +```bash +# Clone and install dependencies +git clone +cd OmniShopAgent +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate +pip install -r requirements.txt + +# Configure environment variables +cp .env.example .env +# Edit .env and add your OPENAI_API_KEY +``` + +### 2. Download Dataset +Download the [Fashion Product Images Dataset](https://www.kaggle.com/datasets/paramaggarwal/fashion-product-images-dataset) from Kaggle and extract to `./data/`: + +```python +python scripts/download_dataset.py +``` + +Expected structure: +``` +data/ +├── images/ # ~44k product images +├── styles.csv # Product metadata +└── images.csv # Image filenames +``` + +### 3. Start Services + +```bash +docker-compose up +python -m clip_server +``` + + +### 4. Index Data + +```bash +python scripts/index_data.py +``` + +This generates and stores text/image embeddings for all 44k products in Milvus. + +### 5. Launch Application +```bash +# 使用启动脚本(推荐) +./scripts/start.sh + +# 或直接运行 +streamlit run app.py +``` +Opens at `http://localhost:8501` + +### CentOS 8 部署 +详见 [docs/DEPLOY_CENTOS8.md](docs/DEPLOY_CENTOS8.md) diff --git a/app.py b/app.py new file mode 100644 index 0000000..e9eb877 --- /dev/null +++ b/app.py @@ -0,0 +1,732 @@ +""" +OmniShopAgent - Streamlit UI +Multi-modal fashion shopping assistant with conversational AI +""" + +import logging +import re +import uuid +from pathlib import Path +from typing import Optional + +import streamlit as st +from PIL import Image, ImageOps + +from app.agents.shopping_agent import ShoppingAgent + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# Page config +st.set_page_config( + page_title="OmniShopAgent", + page_icon="👗", + layout="centered", + initial_sidebar_state="collapsed", +) + +# Custom CSS - ChatGPT-like style +st.markdown( + """ + + """, + unsafe_allow_html=True, +) + + +# Initialize session state +def initialize_session(): + """Initialize session state variables""" + if "session_id" not in st.session_state: + st.session_state.session_id = str(uuid.uuid4()) + + if "shopping_agent" not in st.session_state: + st.session_state.shopping_agent = ShoppingAgent( + session_id=st.session_state.session_id + ) + + if "messages" not in st.session_state: + st.session_state.messages = [] + + if "uploaded_image" not in st.session_state: + st.session_state.uploaded_image = None + + if "show_image_upload" not in st.session_state: + st.session_state.show_image_upload = False + + +def save_uploaded_image(uploaded_file) -> Optional[str]: + """Save uploaded image to temp directory""" + if uploaded_file is None: + return None + + try: + temp_dir = Path("temp_uploads") + temp_dir.mkdir(exist_ok=True) + + image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" + with open(image_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + + logger.info(f"Saved uploaded image to {image_path}") + return str(image_path) + + except Exception as e: + logger.error(f"Error saving uploaded image: {e}") + st.error(f"Failed to save image: {str(e)}") + return None + + +def extract_products_from_response(response: str) -> list: + """Extract product information from agent response + + Returns list of dicts with product info + """ + products = [] + + # Pattern to match product blocks in the response + # Looking for ID, name, and other details + lines = response.split("\n") + current_product = {} + + for line in lines: + line = line.strip() + + # Match product number (e.g., "1. Product Name" or "**1. Product Name**") + if re.match(r"^\*?\*?\d+\.\s+", line): + if current_product: + products.append(current_product) + current_product = {} + # Extract product name + name = re.sub(r"^\*?\*?\d+\.\s+", "", line) + name = name.replace("**", "").strip() + current_product["name"] = name + + # Match ID + elif "ID:" in line or "id:" in line: + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) + if id_match: + current_product["id"] = id_match.group(1) + + # Match Category + elif "Category:" in line: + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) + if cat_match: + current_product["category"] = cat_match.group(1).strip() + + # Match Color + elif "Color:" in line: + color_match = re.search(r"Color:\s*(\w+)", line) + if color_match: + current_product["color"] = color_match.group(1) + + # Match Gender + elif "Gender:" in line: + gender_match = re.search(r"Gender:\s*(\w+)", line) + if gender_match: + current_product["gender"] = gender_match.group(1) + + # Match Season + elif "Season:" in line: + season_match = re.search(r"Season:\s*(\w+)", line) + if season_match: + current_product["season"] = season_match.group(1) + + # Match Usage + elif "Usage:" in line: + usage_match = re.search(r"Usage:\s*(\w+)", line) + if usage_match: + current_product["usage"] = usage_match.group(1) + + # Match Similarity/Relevance score + elif "Similarity:" in line or "Relevance:" in line: + score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) + if score_match: + current_product["score"] = score_match.group(1) + + # Add last product + if current_product: + products.append(current_product) + + return products + + +def display_product_card(product: dict): + """Display a product card with image and name""" + product_id = product.get("id", "") + name = product.get("name", "Unknown Product") + + # Debug: log what we got + logger.info(f"Displaying product: ID={product_id}, Name={name}") + + # Try to load image from data/images directory + if product_id: + image_path = Path(f"data/images/{product_id}.jpg") + + if image_path.exists(): + try: + img = Image.open(image_path) + # Fixed size for all images + target_size = (200, 200) + try: + # Try new Pillow API + img_processed = ImageOps.fit( + img, target_size, method=Image.Resampling.LANCZOS + ) + except AttributeError: + # Fallback for older Pillow versions + img_processed = ImageOps.fit( + img, target_size, method=Image.LANCZOS + ) + + # Display image with fixed width + st.image(img_processed, use_container_width=False, width=200) + st.markdown(f"**{name}**") + st.caption(f"ID: {product_id}") + return + except Exception as e: + logger.warning(f"Failed to load image {image_path}: {e}") + else: + logger.warning(f"Image not found: {image_path}") + + # Fallback: no image + st.markdown(f"**📷 {name}**") + if product_id: + st.caption(f"ID: {product_id}") + else: + st.caption("ID not available") + + +def display_message(message: dict): + """Display a chat message""" + role = message["role"] + content = message["content"] + image_path = message.get("image_path") + tool_calls = message.get("tool_calls", []) + + if role == "user": + st.markdown('
', unsafe_allow_html=True) + + if image_path and Path(image_path).exists(): + try: + img = Image.open(image_path) + st.image(img, width=200) + except Exception: + logger.warning(f"Failed to load user uploaded image: {image_path}") + + st.markdown(content) + st.markdown("
", unsafe_allow_html=True) + + else: # assistant + # Display tool calls horizontally - only tool names + if tool_calls: + tool_names = [tc['name'] for tc in tool_calls] + st.caption(" → ".join(tool_names)) + st.markdown("") + + # Extract and display products if any + products = extract_products_from_response(content) + + # Debug logging + logger.info(f"Extracted {len(products)} products from response") + for p in products: + logger.info(f"Product: {p}") + + if products: + def parse_score(product: dict) -> float: + score = product.get("score") + if score is None: + return 0.0 + try: + return float(score) + except (TypeError, ValueError): + return 0.0 + + # Sort by score and limit to 3 + products = sorted(products, key=parse_score, reverse=True)[:3] + + logger.info(f"Displaying top {len(products)} products") + + # Display the text response first (without product details) + text_lines = [] + for line in content.split("\n"): + # Skip product detail lines + if not any( + keyword in line + for keyword in [ + "ID:", + "Category:", + "Color:", + "Gender:", + "Season:", + "Usage:", + "Similarity:", + "Relevance:", + ] + ): + if not re.match(r"^\*?\*?\d+\.\s+", line): + text_lines.append(line) + + intro_text = "\n".join(text_lines).strip() + if intro_text: + st.markdown(intro_text) + + # Display product cards in grid + st.markdown("
", unsafe_allow_html=True) + + # Create exactly 3 columns with equal width + cols = st.columns(3) + for j, product in enumerate(products[:3]): # Ensure max 3 + with cols[j]: + display_product_card(product) + else: + # No products found, display full content + st.markdown(content) + + st.markdown("", unsafe_allow_html=True) + + +def display_welcome(): + """Display welcome screen""" + + col1, col2, col3, col4 = st.columns(4) + + with col1: + st.markdown( + """ +
+
💬
+
Text Search
+
Describe what you want
+
+ """, + unsafe_allow_html=True, + ) + + with col2: + st.markdown( + """ +
+
📸
+
Image Search
+
Upload product photos
+
+ """, + unsafe_allow_html=True, + ) + + with col3: + st.markdown( + """ +
+
🔍
+
Visual Analysis
+
AI analyzes prodcut style
+
+ """, + unsafe_allow_html=True, + ) + + with col4: + st.markdown( + """ +
+
💭
+
Conversational
+
Remembers context
+
+ """, + unsafe_allow_html=True, + ) + + st.markdown("

", unsafe_allow_html=True) + + +def main(): + """Main Streamlit app""" + initialize_session() + + # Header + st.markdown( + """ +
+
👗 OmniShopAgent
+
AI Fashion Shopping Assistant
+
+ """, + unsafe_allow_html=True, + ) + + # Sidebar (collapsed by default, but accessible) + with st.sidebar: + st.markdown("### ⚙️ Settings") + + if st.button("🗑️ Clear Chat", use_container_width=True): + if "shopping_agent" in st.session_state: + st.session_state.shopping_agent.clear_history() + st.session_state.messages = [] + st.session_state.uploaded_image = None + st.rerun() + + st.markdown("---") + st.caption(f"Session: `{st.session_state.session_id[:8]}...`") + + # Chat messages container + messages_container = st.container() + + with messages_container: + if not st.session_state.messages: + display_welcome() + else: + for message in st.session_state.messages: + display_message(message) + + # Fixed input area at bottom (using container to simulate fixed position) + st.markdown('
', unsafe_allow_html=True) + + input_container = st.container() + + with input_container: + # Image upload area (shown when + is clicked) + if st.session_state.show_image_upload: + uploaded_file = st.file_uploader( + "Choose an image", + type=["jpg", "jpeg", "png"], + key="file_uploader", + ) + + if uploaded_file: + st.session_state.uploaded_image = uploaded_file + # Show preview + col1, col2 = st.columns([1, 4]) + with col1: + img = Image.open(uploaded_file) + st.image(img, width=100) + with col2: + if st.button("❌ Remove"): + st.session_state.uploaded_image = None + st.session_state.show_image_upload = False + st.rerun() + + # Input row + col1, col2 = st.columns([1, 12]) + + with col1: + # Image upload toggle button + if st.button("➕", help="Add image", use_container_width=True): + st.session_state.show_image_upload = ( + not st.session_state.show_image_upload + ) + st.rerun() + + with col2: + # Text input + user_query = st.chat_input( + "Ask about fashion products...", + key="chat_input", + ) + + st.markdown("
", unsafe_allow_html=True) + + # Process user input + if user_query: + # Ensure shopping agent is initialized + if "shopping_agent" not in st.session_state: + st.error("Session not initialized. Please refresh the page.") + st.stop() + + # Save uploaded image if present, or get from recent history + image_path = None + if st.session_state.uploaded_image: + # User explicitly uploaded an image for this query + image_path = save_uploaded_image(st.session_state.uploaded_image) + else: + # Check if query refers to a previous image + query_lower = user_query.lower() + if any( + ref in query_lower + for ref in [ + "this", + "that", + "the image", + "the shirt", + "the product", + "it", + ] + ): + # Find the most recent message with an image + for msg in reversed(st.session_state.messages): + if msg.get("role") == "user" and msg.get("image_path"): + image_path = msg["image_path"] + logger.info(f"Using image from previous message: {image_path}") + break + + # Add user message + st.session_state.messages.append( + { + "role": "user", + "content": user_query, + "image_path": image_path, + } + ) + + # Display user message immediately + with messages_container: + display_message(st.session_state.messages[-1]) + + # Process with shopping agent + try: + shopping_agent = st.session_state.shopping_agent + + # Handle greetings + query_lower = user_query.lower().strip() + if query_lower in ["hi", "hello", "hey"]: + response = """Hello! 👋 I'm your fashion shopping assistant. + +I can help you: +- Search for products by description +- Find items similar to images you upload +- Analyze product styles + +What are you looking for today?""" + tool_calls = [] + else: + # Process with agent + result = shopping_agent.chat( + query=user_query, + image_path=image_path, + ) + response = result["response"] + tool_calls = result.get("tool_calls", []) + + # Add assistant message + st.session_state.messages.append( + { + "role": "assistant", + "content": response, + "tool_calls": tool_calls, + } + ) + + # Clear uploaded image and hide upload area after sending + st.session_state.uploaded_image = None + st.session_state.show_image_upload = False + + # Auto-scroll to bottom with JavaScript + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + + except Exception as e: + logger.error(f"Error processing query: {e}", exc_info=True) + error_msg = f"I apologize, I encountered an error: {str(e)}" + + st.session_state.messages.append( + { + "role": "assistant", + "content": error_msg, + } + ) + + # Rerun to update UI + st.rerun() + + +if __name__ == "__main__": + main() diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..88cb718 --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,5 @@ +""" +OmniShopAgent - Multi-modal E-commerce Search Agent +""" + +__version__ = "0.1.0" diff --git a/app/agents/__init__.py b/app/agents/__init__.py new file mode 100644 index 0000000..c504615 --- /dev/null +++ b/app/agents/__init__.py @@ -0,0 +1,10 @@ +""" +Agent Layer - Autonomous Shopping Agent +""" + +from app.agents.shopping_agent import ShoppingAgent, create_shopping_agent + +__all__ = [ + "ShoppingAgent", + "create_shopping_agent", +] diff --git a/app/agents/shopping_agent.py b/app/agents/shopping_agent.py new file mode 100644 index 0000000..2bb1533 --- /dev/null +++ b/app/agents/shopping_agent.py @@ -0,0 +1,272 @@ +""" +Conversational Shopping Agent with LangGraph +True ReAct agent with autonomous tool calling and message accumulation +""" + +import logging +from pathlib import Path +from typing import Optional, Sequence + +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode +from typing_extensions import Annotated, TypedDict + +from app.config import settings +from app.tools.search_tools import get_all_tools + +logger = logging.getLogger(__name__) + + +def _extract_message_text(msg) -> str: + """Extract text from message content. + LangChain 1.0: content may be str or content_blocks (list) for multimodal.""" + content = getattr(msg, "content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, dict): + parts.append(block.get("text", block.get("content", ""))) + else: + parts.append(str(block)) + return "".join(str(p) for p in parts) + return str(content) if content else "" + + +class AgentState(TypedDict): + """State for the shopping agent with message accumulation""" + + messages: Annotated[Sequence[BaseMessage], add_messages] + current_image_path: Optional[str] # Track uploaded image + + +class ShoppingAgent: + """True ReAct agent with autonomous decision making""" + + def __init__(self, session_id: Optional[str] = None): + self.session_id = session_id or "default" + + # Initialize LLM + self.llm = ChatOpenAI( + model=settings.openai_model, + temperature=settings.openai_temperature, + api_key=settings.openai_api_key, + ) + + # Get tools and bind to model + self.tools = get_all_tools() + self.llm_with_tools = self.llm.bind_tools(self.tools) + + # Build graph + self.graph = self._build_graph() + + logger.info(f"Shopping agent initialized for session: {self.session_id}") + + def _build_graph(self): + """Build the LangGraph StateGraph""" + + # System prompt for the agent + system_prompt = """You are an intelligent fashion shopping assistant. You can: +1. Search for products by text description (use search_products) +2. Find visually similar products from images (use search_by_image) +3. Analyze image style and attributes (use analyze_image_style) + +When a user asks about products: +- For text queries: use search_products directly +- For image uploads: decide if you need to analyze_image_style first, then search +- You can call multiple tools in sequence if needed +- Always provide helpful, friendly responses + +CRITICAL FORMATTING RULES: +When presenting product results, you MUST use this EXACT format for EACH product: + +1. [Product Name] + ID: [Product ID Number] + Category: [Category] + Color: [Color] + Gender: [Gender] + (Include Season, Usage, Relevance if available) + +Example: +1. Puma Men White 3/4 Length Pants + ID: 12345 + Category: Apparel > Bottomwear > Track Pants + Color: White + Gender: Men + Season: Summer + Usage: Sports + Relevance: 95.2% + +DO NOT skip the ID field! It is essential for displaying product images. +Be conversational in your introduction, but preserve the exact product format.""" + + def agent_node(state: AgentState): + """Agent decision node - decides which tools to call or when to respond""" + messages = state["messages"] + + # Add system prompt if first message + if not any(isinstance(m, SystemMessage) for m in messages): + messages = [SystemMessage(content=system_prompt)] + list(messages) + + # Handle image context + if state.get("current_image_path"): + # Inject image path context for tool calls + # The agent can reference this in its reasoning + pass + + # Invoke LLM with tools + response = self.llm_with_tools.invoke(messages) + return {"messages": [response]} + + # Create tool node + tool_node = ToolNode(self.tools) + + def should_continue(state: AgentState): + """Determine if agent should continue or end""" + messages = state["messages"] + last_message = messages[-1] + + # If LLM made tool calls, continue to tools + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" + # Otherwise, end (agent has final response) + return END + + # Build graph + workflow = StateGraph(AgentState) + + workflow.add_node("agent", agent_node) + workflow.add_node("tools", tool_node) + + workflow.add_edge(START, "agent") + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) + workflow.add_edge("tools", "agent") + + # Compile with memory + checkpointer = MemorySaver() + return workflow.compile(checkpointer=checkpointer) + + def chat(self, query: str, image_path: Optional[str] = None) -> dict: + """Process user query with the agent + + Args: + query: User's text query + image_path: Optional path to uploaded image + + Returns: + Dict with response and metadata + """ + try: + logger.info( + f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})" + ) + + # Validate image + if image_path and not Path(image_path).exists(): + return { + "response": f"Error: Image file not found at '{image_path}'", + "error": True, + } + + # Build input message + message_content = query + if image_path: + message_content = f"{query}\n[User uploaded image: {image_path}]" + + # Invoke agent + config = {"configurable": {"thread_id": self.session_id}} + input_state = { + "messages": [HumanMessage(content=message_content)], + "current_image_path": image_path, + } + + # Track tool calls + tool_calls = [] + + # Stream events to capture tool calls + for event in self.graph.stream(input_state, config=config): + logger.info(f"Event: {event}") + + # Check for agent node (tool calls) + if "agent" in event: + agent_output = event["agent"] + if "messages" in agent_output: + for msg in agent_output["messages"]: + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tool_calls.append({ + "name": tc["name"], + "args": tc.get("args", {}), + }) + + # Check for tool node (tool results) + if "tools" in event: + tools_output = event["tools"] + if "messages" in tools_output: + for i, msg in enumerate(tools_output["messages"]): + if i < len(tool_calls): + tool_calls[i]["result"] = str(msg.content)[:200] + "..." + + # Get final state + final_state = self.graph.get_state(config) + final_message = final_state.values["messages"][-1] + response_text = _extract_message_text(final_message) + + logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls") + + return { + "response": response_text, + "tool_calls": tool_calls, + "error": False, + } + + except Exception as e: + logger.error(f"Error in agent chat: {e}", exc_info=True) + return { + "response": f"I apologize, I encountered an error: {str(e)}", + "error": True, + } + + def get_conversation_history(self) -> list: + """Get conversation history for this session""" + try: + config = {"configurable": {"thread_id": self.session_id}} + state = self.graph.get_state(config) + + if not state or not state.values.get("messages"): + return [] + + messages = state.values["messages"] + result = [] + + for msg in messages: + # Skip system messages and tool messages + if isinstance(msg, SystemMessage): + continue + if hasattr(msg, "type") and msg.type in ["system", "tool"]: + continue + + role = "user" if msg.type == "human" else "assistant" + result.append({"role": role, "content": _extract_message_text(msg)}) + + return result + + except Exception as e: + logger.error(f"Error getting history: {e}") + return [] + + def clear_history(self): + """Clear conversation history for this session""" + # With MemorySaver, we can't easily clear, but we can log + logger.info(f"[{self.session_id}] History clear requested") + # In production, implement proper clearing or use new thread_id + + +def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: + """Factory function to create a shopping agent""" + return ShoppingAgent(session_id=session_id) diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..618552e --- /dev/null +++ b/app/config.py @@ -0,0 +1,86 @@ +""" +Configuration management for OmniShopAgent +Loads environment variables and provides configuration objects +""" + +import os + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application settings loaded from environment variables + + All settings can be configured via .env file or environment variables. + """ + + # OpenAI Configuration + openai_api_key: str + openai_model: str = "gpt-4o-mini" + openai_embedding_model: str = "text-embedding-3-small" + openai_temperature: float = 0.7 + openai_max_tokens: int = 1000 + + # CLIP Server Configuration + clip_server_url: str = "grpc://localhost:51000" + + # Milvus Configuration + milvus_uri: str = "http://localhost:19530" + milvus_host: str = "localhost" + milvus_port: int = 19530 + text_collection_name: str = "text_embeddings" + image_collection_name: str = "image_embeddings" + text_dim: int = 1536 + image_dim: int = 512 + + @property + def milvus_uri_absolute(self) -> str: + """Get absolute path for Milvus URI + + Returns: + - For http/https URIs: returns as-is (Milvus Standalone) + - For file paths starting with ./: converts to absolute path (Milvus Lite) + - For other paths: returns as-is + """ + import os + + # If it's a network URI, return as-is (Milvus Standalone) + if self.milvus_uri.startswith(("http://", "https://")): + return self.milvus_uri + # If it's a relative path, convert to absolute (Milvus Lite) + if self.milvus_uri.startswith("./"): + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_dir, self.milvus_uri[2:]) + # Otherwise return as-is + return self.milvus_uri + + # Search Configuration + top_k_results: int = 10 + similarity_threshold: float = 0.6 + + # Application Configuration + app_host: str = "0.0.0.0" + app_port: int = 8000 + debug: bool = True + log_level: str = "INFO" + + # Data Paths + raw_data_path: str = "./data/raw" + processed_data_path: str = "./data/processed" + image_data_path: str = "./data/images" + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + case_sensitive = False + + +# Global settings instance +settings = Settings() + + +# Helper function to get absolute paths +def get_absolute_path(relative_path: str) -> str: + """Convert relative path to absolute path""" + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(base_dir, relative_path) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..ea964c3 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,14 @@ +""" +Services Module +Provides database and embedding services for the application +""" + +from app.services.embedding_service import EmbeddingService, get_embedding_service +from app.services.milvus_service import MilvusService, get_milvus_service + +__all__ = [ + "EmbeddingService", + "get_embedding_service", + "MilvusService", + "get_milvus_service", +] diff --git a/app/services/embedding_service.py b/app/services/embedding_service.py new file mode 100644 index 0000000..d26edf6 --- /dev/null +++ b/app/services/embedding_service.py @@ -0,0 +1,293 @@ +""" +Embedding Service for Text and Image Embeddings +Supports OpenAI text embeddings and CLIP image embeddings +""" + +import logging +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +from clip_client import Client as ClipClient +from openai import OpenAI + +from app.config import settings + +logger = logging.getLogger(__name__) + + +class EmbeddingService: + """Service for generating text and image embeddings""" + + def __init__( + self, + openai_api_key: Optional[str] = None, + clip_server_url: Optional[str] = None, + ): + """Initialize embedding service + + Args: + openai_api_key: OpenAI API key. If None, uses settings.openai_api_key + clip_server_url: CLIP server URL. If None, uses settings.clip_server_url + """ + # Initialize OpenAI client for text embeddings + self.openai_api_key = openai_api_key or settings.openai_api_key + self.openai_client = OpenAI(api_key=self.openai_api_key) + self.text_embedding_model = settings.openai_embedding_model + + # Initialize CLIP client for image embeddings + self.clip_server_url = clip_server_url or settings.clip_server_url + self.clip_client: Optional[ClipClient] = None + + logger.info("Embedding service initialized") + + def connect_clip(self) -> None: + """Connect to CLIP server""" + try: + self.clip_client = ClipClient(server=self.clip_server_url) + logger.info(f"Connected to CLIP server at {self.clip_server_url}") + except Exception as e: + logger.error(f"Failed to connect to CLIP server: {e}") + raise + + def disconnect_clip(self) -> None: + """Disconnect from CLIP server""" + if self.clip_client: + # Note: clip_client doesn't have explicit close method + self.clip_client = None + logger.info("Disconnected from CLIP server") + + def get_text_embedding(self, text: str) -> List[float]: + """Get embedding for a single text + + Args: + text: Input text + + Returns: + Embedding vector as list of floats + """ + try: + response = self.openai_client.embeddings.create( + input=text, model=self.text_embedding_model + ) + embedding = response.data[0].embedding + logger.debug(f"Generated text embedding for: {text[:50]}...") + return embedding + except Exception as e: + logger.error(f"Failed to generate text embedding: {e}") + raise + + def get_text_embeddings_batch( + self, texts: List[str], batch_size: int = 100 + ) -> List[List[float]]: + """Get embeddings for multiple texts in batches + + Args: + texts: List of input texts + batch_size: Number of texts to process at once + + Returns: + List of embedding vectors + """ + all_embeddings = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + + try: + response = self.openai_client.embeddings.create( + input=batch, model=self.text_embedding_model + ) + + # Extract embeddings in the correct order + embeddings = [item.embedding for item in response.data] + all_embeddings.extend(embeddings) + + logger.info( + f"Generated text embeddings for batch {i // batch_size + 1}: {len(embeddings)} embeddings" + ) + + except Exception as e: + logger.error( + f"Failed to generate text embeddings for batch {i // batch_size + 1}: {e}" + ) + raise + + return all_embeddings + + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: + """Get CLIP embedding for a single image + + Args: + image_path: Path to image file + + Returns: + Embedding vector as list of floats + """ + if not self.clip_client: + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") + + image_path = Path(image_path) + if not image_path.exists(): + raise FileNotFoundError(f"Image not found: {image_path}") + + try: + # Get embedding from CLIP server using image path (as string) + result = self.clip_client.encode([str(image_path)]) + + # Extract embedding - result is numpy array + import numpy as np + + if isinstance(result, np.ndarray): + # If result is numpy array, use first element + embedding = ( + result[0].tolist() if len(result.shape) > 1 else result.tolist() + ) + else: + # If result is DocumentArray + embedding = result[0].embedding.tolist() + + logger.debug(f"Generated image embedding for: {image_path.name}") + return embedding + + except Exception as e: + logger.error(f"Failed to generate image embedding for {image_path}: {e}") + raise + + def get_image_embeddings_batch( + self, image_paths: List[Union[str, Path]], batch_size: int = 32 + ) -> List[Optional[List[float]]]: + """Get CLIP embeddings for multiple images in batches + + Args: + image_paths: List of paths to image files + batch_size: Number of images to process at once + + Returns: + List of embedding vectors (None for failed images) + """ + if not self.clip_client: + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") + + all_embeddings = [] + + for i in range(0, len(image_paths), batch_size): + batch_paths = image_paths[i : i + batch_size] + valid_paths = [] + valid_indices = [] + + # Check which images exist + for idx, path in enumerate(batch_paths): + path = Path(path) + if path.exists(): + valid_paths.append(str(path)) + valid_indices.append(idx) + else: + logger.warning(f"Image not found: {path}") + + # Get embeddings for valid images + if valid_paths: + try: + # Send paths as strings to CLIP server + result = self.clip_client.encode(valid_paths) + + # Create embeddings list with None for missing images + batch_embeddings = [None] * len(batch_paths) + + # Handle result format - could be numpy array or DocumentArray + import numpy as np + + if isinstance(result, np.ndarray): + # Result is numpy array - shape (n_images, embedding_dim) + for idx in range(len(result)): + original_idx = valid_indices[idx] + batch_embeddings[original_idx] = result[idx].tolist() + else: + # Result is DocumentArray + for idx, doc in enumerate(result): + original_idx = valid_indices[idx] + batch_embeddings[original_idx] = doc.embedding.tolist() + + all_embeddings.extend(batch_embeddings) + + logger.info( + f"Generated image embeddings for batch {i // batch_size + 1}: " + f"{len(valid_paths)}/{len(batch_paths)} successful" + ) + + except Exception as e: + logger.error( + f"Failed to generate image embeddings for batch {i // batch_size + 1}: {e}" + ) + # Add None for all images in failed batch + all_embeddings.extend([None] * len(batch_paths)) + else: + # All images in batch failed to load + all_embeddings.extend([None] * len(batch_paths)) + + return all_embeddings + + def get_text_embedding_from_image( + self, image_path: Union[str, Path] + ) -> List[float]: + """Get text-based embedding by describing the image + This is useful for cross-modal search + + Note: This is a placeholder for future implementation + that could use vision models to generate text descriptions + + Args: + image_path: Path to image file + + Returns: + Text embedding vector + """ + # For now, we just return the image embedding + # In the future, this could use a vision-language model to generate + # a text description and then embed that + raise NotImplementedError("Text embedding from image not yet implemented") + + def cosine_similarity( + self, embedding1: List[float], embedding2: List[float] + ) -> float: + """Calculate cosine similarity between two embeddings + + Args: + embedding1: First embedding vector + embedding2: Second embedding vector + + Returns: + Cosine similarity score (0-1) + """ + vec1 = np.array(embedding1) + vec2 = np.array(embedding2) + + # Normalize vectors + vec1_norm = vec1 / np.linalg.norm(vec1) + vec2_norm = vec2 / np.linalg.norm(vec2) + + # Calculate cosine similarity + similarity = np.dot(vec1_norm, vec2_norm) + + return float(similarity) + + def get_embedding_dimensions(self) -> dict: + """Get the dimensions of text and image embeddings + + Returns: + Dictionary with text_dim and image_dim + """ + return {"text_dim": settings.text_dim, "image_dim": settings.image_dim} + + +# Global instance +_embedding_service: Optional[EmbeddingService] = None + + +def get_embedding_service() -> EmbeddingService: + """Get or create the global embedding service instance""" + global _embedding_service + if _embedding_service is None: + _embedding_service = EmbeddingService() + _embedding_service.connect_clip() + return _embedding_service diff --git a/app/services/milvus_service.py b/app/services/milvus_service.py new file mode 100644 index 0000000..bdce812 --- /dev/null +++ b/app/services/milvus_service.py @@ -0,0 +1,480 @@ +""" +Milvus Service for Vector Storage and Similarity Search +Manages text and image embeddings in separate collections +""" + +import logging +from typing import Any, Dict, List, Optional + +from pymilvus import ( + DataType, + MilvusClient, +) + +from app.config import settings + +logger = logging.getLogger(__name__) + + +class MilvusService: + """Service for managing vector embeddings in Milvus""" + + def __init__(self, uri: Optional[str] = None): + """Initialize Milvus service + + Args: + uri: Milvus connection URI. If None, uses settings.milvus_uri + """ + if uri: + self.uri = uri + else: + # Use absolute path for Milvus Lite + self.uri = settings.milvus_uri_absolute + self.text_collection_name = settings.text_collection_name + self.image_collection_name = settings.image_collection_name + self.text_dim = settings.text_dim + self.image_dim = settings.image_dim + + # Use MilvusClient for simplified operations + self._client: Optional[MilvusClient] = None + + logger.info(f"Initializing Milvus service with URI: {self.uri}") + + def is_connected(self) -> bool: + """Check if connected to Milvus""" + return self._client is not None + + def connect(self) -> None: + """Connect to Milvus""" + if self.is_connected(): + return + try: + self._client = MilvusClient(uri=self.uri) + logger.info(f"Connected to Milvus at {self.uri}") + except Exception as e: + logger.error(f"Failed to connect to Milvus: {e}") + raise + + def disconnect(self) -> None: + """Disconnect from Milvus""" + if self._client: + self._client.close() + self._client = None + logger.info("Disconnected from Milvus") + + @property + def client(self) -> MilvusClient: + """Get the Milvus client""" + if not self._client: + raise RuntimeError("Milvus not connected. Call connect() first.") + return self._client + + def create_text_collection(self, recreate: bool = False) -> None: + """Create collection for text embeddings with product metadata + + Args: + recreate: If True, drop existing collection and recreate + """ + if recreate and self.client.has_collection(self.text_collection_name): + self.client.drop_collection(self.text_collection_name) + logger.info(f"Dropped existing collection: {self.text_collection_name}") + + if self.client.has_collection(self.text_collection_name): + logger.info(f"Text collection already exists: {self.text_collection_name}") + return + + # Create collection with schema (includes metadata fields) + schema = MilvusClient.create_schema( + auto_id=False, + enable_dynamic_field=True, # Allow additional metadata fields + ) + + # Core fields + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) + schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) + schema.add_field( + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim + ) + + # Product metadata fields + schema.add_field( + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 + ) + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) + schema.add_field( + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 + ) + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) + + # Create index parameters + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="embedding", + index_type="AUTOINDEX", + metric_type="COSINE", + ) + + # Create collection + self.client.create_collection( + collection_name=self.text_collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + f"Created text collection with metadata: {self.text_collection_name}" + ) + + def create_image_collection(self, recreate: bool = False) -> None: + """Create collection for image embeddings with product metadata + + Args: + recreate: If True, drop existing collection and recreate + """ + if recreate and self.client.has_collection(self.image_collection_name): + self.client.drop_collection(self.image_collection_name) + logger.info(f"Dropped existing collection: {self.image_collection_name}") + + if self.client.has_collection(self.image_collection_name): + logger.info( + f"Image collection already exists: {self.image_collection_name}" + ) + return + + # Create collection with schema (includes metadata fields) + schema = MilvusClient.create_schema( + auto_id=False, + enable_dynamic_field=True, # Allow additional metadata fields + ) + + # Core fields + schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) + schema.add_field( + field_name="image_path", datatype=DataType.VARCHAR, max_length=500 + ) + schema.add_field( + field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim + ) + + # Product metadata fields + schema.add_field( + field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500 + ) + schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) + schema.add_field( + field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="subCategory", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="articleType", datatype=DataType.VARCHAR, max_length=100 + ) + schema.add_field( + field_name="baseColour", datatype=DataType.VARCHAR, max_length=50 + ) + schema.add_field(field_name="season", datatype=DataType.VARCHAR, max_length=50) + schema.add_field(field_name="usage", datatype=DataType.VARCHAR, max_length=50) + + # Create index parameters + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="embedding", + index_type="AUTOINDEX", + metric_type="COSINE", + ) + + # Create collection + self.client.create_collection( + collection_name=self.image_collection_name, + schema=schema, + index_params=index_params, + ) + + logger.info( + f"Created image collection with metadata: {self.image_collection_name}" + ) + + def insert_text_embeddings( + self, + embeddings: List[Dict[str, Any]], + ) -> int: + """Insert text embeddings with metadata into collection + + Args: + embeddings: List of dictionaries with keys: + - id: unique ID (product ID) + - text: the text that was embedded + - embedding: the embedding vector + - productDisplayName, gender, masterCategory, etc. (metadata) + + Returns: + Number of inserted embeddings + """ + if not embeddings: + return 0 + + try: + # Insert data directly (all fields including metadata) + # Milvus will accept all fields defined in schema + dynamic fields + data = embeddings + + # Insert data + result = self.client.insert( + collection_name=self.text_collection_name, + data=data, + ) + + logger.info(f"Inserted {len(data)} text embeddings") + return len(data) + + except Exception as e: + logger.error(f"Failed to insert text embeddings: {e}") + raise + + def insert_image_embeddings( + self, + embeddings: List[Dict[str, Any]], + ) -> int: + """Insert image embeddings with metadata into collection + + Args: + embeddings: List of dictionaries with keys: + - id: unique ID (product ID) + - image_path: path to the image file + - embedding: the embedding vector + - productDisplayName, gender, masterCategory, etc. (metadata) + + Returns: + Number of inserted embeddings + """ + if not embeddings: + return 0 + + try: + # Insert data directly (all fields including metadata) + # Milvus will accept all fields defined in schema + dynamic fields + data = embeddings + + # Insert data + result = self.client.insert( + collection_name=self.image_collection_name, + data=data, + ) + + logger.info(f"Inserted {len(data)} image embeddings") + return len(data) + + except Exception as e: + logger.error(f"Failed to insert image embeddings: {e}") + raise + + def search_similar_text( + self, + query_embedding: List[float], + limit: int = 10, + filters: Optional[str] = None, + output_fields: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Search for similar text embeddings + + Args: + query_embedding: Query embedding vector + limit: Maximum number of results + filters: Filter expression (e.g., "product_id in [1, 2, 3]") + output_fields: List of fields to return + + Returns: + List of search results with fields: + - id: embedding ID + - distance: similarity distance + - entity: the matched entity with requested fields + """ + try: + if output_fields is None: + output_fields = [ + "id", + "text", + "productDisplayName", + "gender", + "masterCategory", + "subCategory", + "articleType", + "baseColour", + ] + + search_params = {} + if filters: + search_params["expr"] = filters + + results = self.client.search( + collection_name=self.text_collection_name, + data=[query_embedding], + limit=limit, + output_fields=output_fields, + search_params=search_params, + ) + + # Format results + formatted_results = [] + if results and len(results) > 0: + for hit in results[0]: + result = {"id": hit.get("id"), "distance": hit.get("distance")} + # Extract fields from entity + entity = hit.get("entity", {}) + for field in output_fields: + if field in entity: + result[field] = entity.get(field) + formatted_results.append(result) + + logger.debug(f"Found {len(formatted_results)} similar text embeddings") + return formatted_results + + except Exception as e: + logger.error(f"Failed to search similar text: {e}") + raise + + def search_similar_images( + self, + query_embedding: List[float], + limit: int = 10, + filters: Optional[str] = None, + output_fields: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: + """Search for similar image embeddings + + Args: + query_embedding: Query embedding vector + limit: Maximum number of results + filters: Filter expression (e.g., "product_id in [1, 2, 3]") + output_fields: List of fields to return + + Returns: + List of search results with fields: + - id: embedding ID + - distance: similarity distance + - entity: the matched entity with requested fields + """ + try: + if output_fields is None: + output_fields = [ + "id", + "image_path", + "productDisplayName", + "gender", + "masterCategory", + "subCategory", + "articleType", + "baseColour", + ] + + search_params = {} + if filters: + search_params["expr"] = filters + + results = self.client.search( + collection_name=self.image_collection_name, + data=[query_embedding], + limit=limit, + output_fields=output_fields, + search_params=search_params, + ) + + # Format results + formatted_results = [] + if results and len(results) > 0: + for hit in results[0]: + result = {"id": hit.get("id"), "distance": hit.get("distance")} + # Extract fields from entity + entity = hit.get("entity", {}) + for field in output_fields: + if field in entity: + result[field] = entity.get(field) + formatted_results.append(result) + + logger.debug(f"Found {len(formatted_results)} similar image embeddings") + return formatted_results + + except Exception as e: + logger.error(f"Failed to search similar images: {e}") + raise + + def get_collection_stats(self, collection_name: str) -> Dict[str, Any]: + """Get statistics for a collection + + Args: + collection_name: Name of the collection + + Returns: + Dictionary with collection statistics + """ + try: + stats = self.client.get_collection_stats(collection_name) + return { + "collection_name": collection_name, + "row_count": stats.get("row_count", 0), + } + except Exception as e: + logger.error(f"Failed to get collection stats: {e}") + return {"collection_name": collection_name, "row_count": 0} + + def delete_by_ids(self, collection_name: str, ids: List[int]) -> int: + """Delete embeddings by IDs + + Args: + collection_name: Name of the collection + ids: List of IDs to delete + + Returns: + Number of deleted embeddings + """ + if not ids: + return 0 + + try: + self.client.delete( + collection_name=collection_name, + ids=ids, + ) + logger.info(f"Deleted {len(ids)} embeddings from {collection_name}") + return len(ids) + except Exception as e: + logger.error(f"Failed to delete embeddings: {e}") + raise + + def clear_collection(self, collection_name: str) -> None: + """Clear all data from a collection + + Args: + collection_name: Name of the collection + """ + try: + if self.client.has_collection(collection_name): + self.client.drop_collection(collection_name) + logger.info(f"Dropped collection: {collection_name}") + except Exception as e: + logger.error(f"Failed to clear collection: {e}") + raise + + +# Global instance +_milvus_service: Optional[MilvusService] = None + + +def get_milvus_service() -> MilvusService: + """Get or create the global Milvus service instance""" + global _milvus_service + if _milvus_service is None: + _milvus_service = MilvusService() + _milvus_service.connect() + return _milvus_service diff --git a/app/tools/__init__.py b/app/tools/__init__.py new file mode 100644 index 0000000..f8082e3 --- /dev/null +++ b/app/tools/__init__.py @@ -0,0 +1,17 @@ +""" +LangChain Tools for Product Search and Discovery +""" + +from app.tools.search_tools import ( + analyze_image_style, + get_all_tools, + search_by_image, + search_products, +) + +__all__ = [ + "search_products", + "search_by_image", + "analyze_image_style", + "get_all_tools", +] diff --git a/app/tools/search_tools.py b/app/tools/search_tools.py new file mode 100644 index 0000000..0a32a19 --- /dev/null +++ b/app/tools/search_tools.py @@ -0,0 +1,294 @@ +""" +Search Tools for Product Discovery +Provides text-based, image-based, and VLM reasoning capabilities +""" + +import base64 +import logging +from pathlib import Path +from typing import Optional + +from langchain_core.tools import tool +from openai import OpenAI + +from app.config import settings +from app.services.embedding_service import EmbeddingService +from app.services.milvus_service import MilvusService + +logger = logging.getLogger(__name__) + +# Initialize services as singletons +_embedding_service: Optional[EmbeddingService] = None +_milvus_service: Optional[MilvusService] = None +_openai_client: Optional[OpenAI] = None + + +def get_embedding_service() -> EmbeddingService: + global _embedding_service + if _embedding_service is None: + _embedding_service = EmbeddingService() + return _embedding_service + + +def get_milvus_service() -> MilvusService: + global _milvus_service + if _milvus_service is None: + _milvus_service = MilvusService() + _milvus_service.connect() + return _milvus_service + + +def get_openai_client() -> OpenAI: + global _openai_client + if _openai_client is None: + _openai_client = OpenAI(api_key=settings.openai_api_key) + return _openai_client + + +@tool +def search_products(query: str, limit: int = 5) -> str: + """Search for fashion products using natural language descriptions. + + Use when users describe what they want: + - "Find me red summer dresses" + - "Show me blue running shoes" + - "I want casual shirts for men" + + Args: + query: Natural language product description + limit: Maximum number of results (1-20) + + Returns: + Formatted string with product information + """ + try: + logger.info(f"Searching products: '{query}', limit: {limit}") + + embedding_service = get_embedding_service() + milvus_service = get_milvus_service() + + if not milvus_service.is_connected(): + milvus_service.connect() + + query_embedding = embedding_service.get_text_embedding(query) + + results = milvus_service.search_similar_text( + query_embedding=query_embedding, + limit=min(limit, 20), + filters=None, + output_fields=[ + "id", + "productDisplayName", + "gender", + "masterCategory", + "subCategory", + "articleType", + "baseColour", + "season", + "usage", + ], + ) + + if not results: + return "No products found matching your search." + + output = f"Found {len(results)} product(s):\n\n" + + for idx, product in enumerate(results, 1): + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" + output += f" ID: {product.get('id', 'N/A')}\n" + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" + output += f" Color: {product.get('baseColour', 'N/A')}\n" + output += f" Gender: {product.get('gender', 'N/A')}\n" + + if product.get("season"): + output += f" Season: {product.get('season')}\n" + if product.get("usage"): + output += f" Usage: {product.get('usage')}\n" + + if "distance" in product: + similarity = 1 - product["distance"] + output += f" Relevance: {similarity:.2%}\n" + + output += "\n" + + return output.strip() + + except Exception as e: + logger.error(f"Error searching products: {e}", exc_info=True) + return f"Error searching products: {str(e)}" + + +@tool +def search_by_image(image_path: str, limit: int = 5) -> str: + """Find similar fashion products using an image. + + Use when users want visually similar items: + - User uploads an image and asks "find similar items" + - "Show me products that look like this" + + Args: + image_path: Path to the image file + limit: Maximum number of results (1-20) + + Returns: + Formatted string with similar products + """ + try: + logger.info(f"Image search: '{image_path}', limit: {limit}") + + img_path = Path(image_path) + if not img_path.exists(): + return f"Error: Image file not found at '{image_path}'" + + embedding_service = get_embedding_service() + milvus_service = get_milvus_service() + + if not milvus_service.is_connected(): + milvus_service.connect() + + if ( + not hasattr(embedding_service, "clip_client") + or embedding_service.clip_client is None + ): + embedding_service.connect_clip() + + image_embedding = embedding_service.get_image_embedding(image_path) + + if image_embedding is None: + return "Error: Failed to generate embedding for image" + + results = milvus_service.search_similar_images( + query_embedding=image_embedding, + limit=min(limit + 1, 21), + filters=None, + output_fields=[ + "id", + "image_path", + "productDisplayName", + "gender", + "masterCategory", + "subCategory", + "articleType", + "baseColour", + "season", + "usage", + ], + ) + + if not results: + return "No similar products found." + + # Filter out the query image itself + query_id = img_path.stem + filtered_results = [] + for result in results: + result_path = result.get("image_path", "") + if Path(result_path).stem != query_id: + filtered_results.append(result) + if len(filtered_results) >= limit: + break + + if not filtered_results: + return "No similar products found." + + output = f"Found {len(filtered_results)} visually similar product(s):\n\n" + + for idx, product in enumerate(filtered_results, 1): + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" + output += f" ID: {product.get('id', 'N/A')}\n" + output += f" Category: {product.get('masterCategory', 'N/A')} > {product.get('subCategory', 'N/A')} > {product.get('articleType', 'N/A')}\n" + output += f" Color: {product.get('baseColour', 'N/A')}\n" + output += f" Gender: {product.get('gender', 'N/A')}\n" + + if product.get("season"): + output += f" Season: {product.get('season')}\n" + if product.get("usage"): + output += f" Usage: {product.get('usage')}\n" + + if "distance" in product: + similarity = 1 - product["distance"] + output += f" Visual Similarity: {similarity:.2%}\n" + + output += "\n" + + return output.strip() + + except Exception as e: + logger.error(f"Error in image search: {e}", exc_info=True) + return f"Error searching by image: {str(e)}" + + +@tool +def analyze_image_style(image_path: str) -> str: + """Analyze a fashion product image using AI vision to extract detailed style information. + + Use when you need to understand style/attributes from an image: + - Understand the style, color, pattern of a product + - Extract attributes like "casual", "formal", "vintage" + - Get detailed descriptions for subsequent searches + + Args: + image_path: Path to the image file + + Returns: + Detailed text description of the product's visual attributes + """ + try: + logger.info(f"Analyzing image with VLM: '{image_path}'") + + img_path = Path(image_path) + if not img_path.exists(): + return f"Error: Image file not found at '{image_path}'" + + with open(img_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode("utf-8") + + prompt = """Analyze this fashion product image and provide a detailed description. + +Include: +- Product type (e.g., shirt, dress, shoes, pants, bag) +- Primary colors +- Style/design (e.g., casual, formal, sporty, vintage, modern) +- Pattern or texture (e.g., plain, striped, checked, floral) +- Key features (e.g., collar type, sleeve length, fit) +- Material appearance (if obvious, e.g., denim, cotton, leather) +- Suitable occasion (e.g., office wear, party, casual, sports) + +Provide a comprehensive yet concise description (3-4 sentences).""" + + client = get_openai_client() + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}", + "detail": "high", + }, + }, + ], + } + ], + max_tokens=500, + temperature=0.3, + ) + + analysis = response.choices[0].message.content.strip() + logger.info("VLM analysis completed") + + return analysis + + except Exception as e: + logger.error(f"Error analyzing image: {e}", exc_info=True) + return f"Error analyzing image: {str(e)}" + + +def get_all_tools(): + """Get all available tools for the agent""" + return [search_products, search_by_image, analyze_image_style] diff --git a/demo.pdf b/demo.pdf new file mode 100644 index 0000000..f66b729 Binary files /dev/null and b/demo.pdf differ diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..30fc64d --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,76 @@ +version: '3.5' + +services: + etcd: + container_name: milvus-etcd + image: quay.io/coreos/etcd:v3.5.5 + environment: + - ETCD_AUTO_COMPACTION_MODE=revision + - ETCD_AUTO_COMPACTION_RETENTION=1000 + - ETCD_QUOTA_BACKEND_BYTES=4294967296 + - ETCD_SNAPSHOT_COUNT=50000 + volumes: + - ./volumes/etcd:/etcd + command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd + healthcheck: + test: ["CMD", "etcdctl", "endpoint", "health"] + interval: 30s + timeout: 20s + retries: 3 + + minio: + container_name: milvus-minio + image: minio/minio:RELEASE.2023-03-20T20-16-18Z + environment: + MINIO_ACCESS_KEY: minioadmin + MINIO_SECRET_KEY: minioadmin + ports: + - "9001:9001" + - "9000:9000" + volumes: + - ./volumes/minio:/minio_data + command: minio server /minio_data --console-address ":9001" + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 30s + timeout: 20s + retries: 3 + + standalone: + container_name: milvus-standalone + image: milvusdb/milvus:v2.4.0 + command: ["milvus", "run", "standalone"] + security_opt: + - seccomp:unconfined + environment: + ETCD_ENDPOINTS: etcd:2379 + MINIO_ADDRESS: minio:9000 + volumes: + - ./volumes/milvus:/var/lib/milvus + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9091/healthz"] + interval: 30s + start_period: 90s + timeout: 20s + retries: 3 + ports: + - "19530:19530" + - "9091:9091" + depends_on: + - "etcd" + - "minio" + + attu: + container_name: milvus-attu + image: zilliz/attu:v2.4 + environment: + MILVUS_URL: milvus-standalone:19530 + ports: + - "8000:3000" + depends_on: + - "standalone" + +networks: + default: + name: milvus + diff --git a/docs/DEPLOY_CENTOS8.md b/docs/DEPLOY_CENTOS8.md new file mode 100644 index 0000000..fda647c --- /dev/null +++ b/docs/DEPLOY_CENTOS8.md @@ -0,0 +1,216 @@ +# OmniShopAgent centOS 8 部署指南 + +## 一、环境要求 + +| 组件 | 要求 | +|------|------| +| 操作系统 | CentOS 8.x | +| Python | 3.12+(LangChain 1.x 要求 3.10+) | +| 内存 | 建议 8GB+(Milvus + CLIP 较占内存) | +| 磁盘 | 建议 20GB+(含数据集) | + +## 二、快速部署步骤 + +### 2.1 一键环境准备(推荐) + +```bash +cd /path/to/shop_agent +chmod +x scripts/*.sh +./scripts/setup_env_centos8.sh +``` + +该脚本会: +- 安装系统依赖(gcc、openssl-devel 等) +- 安装 Docker(用于 Milvus) +- 安装 Python 3.12(conda 或源码编译) +- 创建虚拟环境并安装 requirements.txt + +### 2.2 手动部署(分步执行) + +#### 步骤 1:安装系统依赖 + +```bash +sudo dnf install -y gcc gcc-c++ make openssl-devel bzip2-devel \ + libffi-devel sqlite-devel xz-devel zlib-devel curl wget git +``` + +#### 步骤 2:安装 Python 3.12 + +**方式 A:Miniconda(推荐)** + +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +bash Miniconda3-latest-Linux-x86_64.sh +# 按提示安装后 +conda create -n shop_agent python=3.12 +conda activate shop_agent +``` + +**方式 B:从源码编译** + +```bash +sudo dnf groupinstall -y 'Development Tools' +cd /tmp +wget https://www.python.org/ftp/python/3.12.0/Python-3.12.0.tgz +tar xzf Python-3.12.0.tgz +cd Python-3.12.0 +./configure --enable-optimizations --prefix=/usr/local +make -j $(nproc) +sudo make altinstall +``` + +#### 步骤 3:安装 Docker + +```bash +sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo +sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin +sudo systemctl enable docker && sudo systemctl start docker +sudo usermod -aG docker $USER +# 执行 newgrp docker 或重新登录 +``` + +#### 步骤 4:创建虚拟环境并安装依赖 + +```bash +cd /path/to/shop_agent +python3.12 -m venv venv +source venv/bin/activate +pip install -U pip +pip install -r requirements.txt +``` + +#### 步骤 5:配置环境变量 + +```bash +cp .env.example .env +# 编辑 .env,至少配置: +# OPENAI_API_KEY=sk-xxx +# MILVUS_HOST=localhost +# MILVUS_PORT=19530 +# CLIP_SERVER_URL=grpc://localhost:51000 +``` + +## 三、数据准备 + +### 3.1 下载数据集 + +```bash +# 需先配置 Kaggle API:~/.kaggle/kaggle.json +python scripts/download_dataset.py +``` + +### 3.2 启动 Milvus 并索引数据 + +```bash +# 启动 Milvus +./scripts/run_milvus.sh + +# 等待就绪后,创建索引 +python scripts/index_data.py +``` + +## 四、启动服务 + +### 4.1 启动脚本说明 + +| 脚本 | 用途 | +|------|------| +| `start.sh` | 主启动脚本:启动 Milvus + Streamlit | +| `stop.sh` | 停止所有服务 | +| `run_milvus.sh` | 仅启动 Milvus | +| `run_clip.sh` | 仅启动 CLIP(图像搜索需此服务) | +| `check_services.sh` | 健康检查 | + +### 4.2 启动应用 + +```bash +# 方式 1:使用 start.sh(推荐) +./scripts/start.sh + +# 方式 2:分步启动 +# 终端 1:Milvus +./scripts/run_milvus.sh + +# 终端 2:CLIP(图像搜索需要) +./scripts/run_clip.sh + +# 终端 3:Streamlit +source venv/bin/activate +streamlit run app.py --server.port=8501 --server.address=0.0.0.0 +``` + +### 4.3 访问地址 + +- **Streamlit 应用**:http://服务器IP:8501 +- **Milvus Attu 管理界面**:http://服务器IP:8000 + +## 五、生产部署建议 + +### 5.1 使用 systemd 管理 Streamlit + +创建 `/etc/systemd/system/omishop-agent.service`: + +```ini +[Unit] +Description=OmniShopAgent Streamlit App +After=network.target docker.service + +[Service] +Type=simple +User=your_user +WorkingDirectory=/path/to/shop_agent +Environment="PATH=/path/to/shop_agent/venv/bin" +ExecStart=/path/to/shop_agent/venv/bin/streamlit run app.py --server.port=8501 --server.address=0.0.0.0 +Restart=on-failure + +[Install] +WantedBy=multi-user.target +``` + +```bash +sudo systemctl daemon-reload +sudo systemctl enable omishop-agent +sudo systemctl start omishop-agent +``` + +### 5.2 使用 Nginx 反向代理(可选) + +```nginx +server { + listen 80; + server_name your-domain.com; + location / { + proxy_pass http://127.0.0.1:8501; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } +} +``` + +### 5.3 防火墙 + +```bash +sudo firewall-cmd --permanent --add-port=8501/tcp +sudo firewall-cmd --permanent --add-port=19530/tcp +sudo firewall-cmd --reload +``` + +## 六、常见问题 + +### Q: Python 3.12 编译失败? +A: 确保已安装 `openssl-devel`、`libffi-devel`,或直接使用 Miniconda。 + +### Q: Docker 权限不足? +A: 执行 `sudo usermod -aG docker $USER` 后重新登录。 + +### Q: Milvus 启动超时? +A: 首次启动需拉取镜像,可能较慢。可检查 `docker compose logs -f standalone`。 + +### Q: 图像搜索不可用? +A: 需单独启动 CLIP 服务:`./scripts/run_clip.sh`。 + +### Q: 健康检查? +A: 执行 `./scripts/check_services.sh` 查看各组件状态。 diff --git a/docs/LANGCHAIN_1.0_MIGRATION.md b/docs/LANGCHAIN_1.0_MIGRATION.md new file mode 100644 index 0000000..5297017 --- /dev/null +++ b/docs/LANGCHAIN_1.0_MIGRATION.md @@ -0,0 +1,77 @@ +# LangChain 1.0 升级说明 + +## 一、升级概览 + +本项目已完成从 LangChain 0.3 到 LangChain 1.x 的升级,并同步升级 LangGraph 至 1.x。升级后兼容 Python 3.12。 + +## 二、依赖变更 + +| 包 | 升级前 | 升级后 | +|----|--------|--------| +| langchain | >=0.3.0 | >=1.0.0 | +| langchain-core | (间接依赖) | >=0.3.0 | +| langchain-openai | >=0.2.0 | >=0.2.0 | +| langgraph | >=0.2.74 | >=1.0.0 | +| langchain-community | >=0.4.0 | **已移除**(项目未使用) | + +## 三、代码改造说明 + +### 3.1 保持不变的部分 + +项目采用 **自定义 StateGraph** 架构(非 `create_react_agent`),以下导入在 LangChain/LangGraph 1.x 中保持兼容: + +```python +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt import ToolNode +from langchain_core.tools import tool +``` + +### 3.2 已适配的变更 + +1. **消息内容提取**:LangChain 1.0 引入 `content_blocks`,`content` 可能为字符串或 multimodal 列表。新增 `_extract_message_text()` 辅助函数,统一处理两种格式。 +2. **依赖精简**:移除未使用的 `langchain-community`,减少依赖冲突。 + +### 3.3 LangChain 1.0 主要变更(参考) + +- **包命名空间精简**:核心功能移至 `langchain-core`,`langchain` 主包聚焦 Agent 构建 +- **create_agent**:若未来迁移到 `langchain.agents.create_agent`,可参考 `docs/Skills实现方案-LangChain1.0.md` +- **langchain-classic**:Legacy chains、Retrievers 等已迁移至 `langchain-classic`,本项目未使用 + +## 四、环境要求 + +- **Python**:3.12+(**LangChain 1.x 要求 Python 3.10+**,不支持 3.9 及以下) +- 若系统默认 Python 版本过低,需使用虚拟环境: + +```bash +# 方式 1:使用 conda(推荐,项目根目录 scripts/setup_conda_env.sh) +conda create -n shop_agent python=3.12 +conda activate shop_agent +pip install -r requirements.txt + +# 方式 2:使用 pyenv +pyenv install 3.12 +pyenv local 3.12 +pip install -r requirements.txt + +# 方式 3:使用 venv(需系统已安装 python3.12) +python3.12 -m venv venv +source venv/bin/activate # Linux/Mac +pip install -r requirements.txt +``` + +## 五、验证 + +```bash +# 验证导入 +python -c " +from langchain_core.messages import HumanMessage +from langchain_openai import ChatOpenAI +from langgraph.graph import StateGraph +from langgraph.prebuilt import ToolNode +print('LangChain 1.x 依赖加载成功') +" +``` diff --git a/docs/Skills实现方案-LangChain1.0.md b/docs/Skills实现方案-LangChain1.0.md new file mode 100644 index 0000000..6b34832 --- /dev/null +++ b/docs/Skills实现方案-LangChain1.0.md @@ -0,0 +1,318 @@ +# Skills 渐进式展开实现方案(LangChain 1.0+) + +## 一、需求概述 + +用 **Skills** 替代零散的工具调用,实现**渐进式展开**(Progressive Disclosure): +Agent 在 system prompt 中只看到技能摘要,按需加载详细技能内容,减少 token 消耗、提升扩展性。 + +| 技能 | 英文标识 | 职责 | +|------|----------|------| +| 查找相关商品 | lookup_related | 基于文本/图片查找相似或相关商品 | +| 搜索商品 | search_products | 按自然语言描述搜索商品 | +| 检验商品 | check_product | 检验商品是否符合用户要求 | +| 结果包装 | result_packaging | 格式化、排序、筛选并呈现结果 | +| 售后相关 | after_sales | 退换货、物流、保修等售后问题 | + +--- + +## 二、LangChain 1.0 中的 Skills 实现方式 + +### 2.1 两种实现路线 + +| 方式 | 适用场景 | 依赖 | +|------|----------|------| +| **方式 A:create_agent + 自定义 Skill 中间件** | 购物导购等业务 Agent | `langchain>=1.0`、`langgraph>=1.0` | +| **方式 B:Deep Agents + SKILL.md** | 依赖文件系统、多技能目录 | `deepagents` | + +购物导购场景推荐**方式 A**,更易与现有 Milvus、CLIP 等服务集成。 + +### 2.2 核心思路:Progressive Disclosure + +``` +用户请求 → Agent 看轻量描述 → 判断需要的技能 → load_skill → 拿到完整说明 → 执行工具 → 回复 +``` + +- **启动时**:只注入技能名称 + 简短描述(1–2 句) +- **按需加载**:Agent 调用 `load_skill(skill_name)` 获取完整指令 +- **执行**:按技能说明调用对应工具 + +--- + +## 三、实现架构 + +### 3.1 技能定义结构 + +```python +from typing import TypedDict + +class Skill(TypedDict): + """可渐进式展开的技能""" + name: str # 唯一标识 + description: str # 1-2 句,展示在 system prompt + content: str # 完整指令,仅在 load_skill 时返回 +``` + +### 3.2 五个技能定义示例 + +```python +SKILLS: list[Skill] = [ + { + "name": "lookup_related", + "description": "查找与某商品相关的其他商品,支持以图搜图、文本相似、同品类推荐。", + "content": """# 查找相关商品 + +## 适用场景 +- 用户上传图片要求「找类似的」 +- 用户说「和这个差不多」「搭配的裤子」 +- 用户已有一件商品,想找相关款 + +## 操作步骤 +1. **有图片**:先调用 `analyze_image_style` 理解风格,再调用 `search_by_image` 或 `search_products` +2. **无图片**:用 `search_products` 描述品类+风格+颜色 +3. 可结合上下文中的商品 ID、品类做同品类推荐 + +## 可用工具 +- `search_by_image(image_path, limit)`:以图搜图 +- `search_products(query, limit)`:文本搜索 +- `analyze_image_style(image_path)`:分析图片风格""", + }, + { + "name": "search_products", + "description": "按自然语言描述搜索商品,如「红色连衣裙」「运动鞋」等。", + "content": """# 搜索商品 + +## 适用场景 +- 用户用文字描述想要什么 +- 如「冬天穿的外套」「正装衬衫」「跑步鞋」 + +## 操作步骤 +1. 将用户描述整理成结构化 query(品类+颜色+风格+场景) +2. 调用 `search_products(query, limit)`,limit 默认 5–10 +3. 如有图片,可先 `analyze_image_style` 提炼关键词再搜索 + +## 可用工具 +- `search_products(query, limit)`:自然语言搜索""", + }, + { + "name": "check_product", + "description": "检验商品是否符合用户要求,如尺寸、材质、场合、价格区间等。", + "content": """# 检验商品是否符合要求 + +## 适用场景 +- 用户问「这款适合我吗」「有没有 XX 材质的」 +- 用户提出约束:尺寸、价格、场合、材质 + +## 操作步骤 +1. 从对话中提取约束条件(尺寸、材质、场合、价格等) +2. 对已召回商品做筛选或二次搜索 +3. 调用 `search_products` 时在 query 中带上约束 +4. 回复时明确说明哪些符合、哪些不符合 + +## 注意 +- 无专门工具时,用 search_products 的 query 表达约束 +- 可结合商品元数据(baseColour, season, usage 等)做简单筛选""", + }, + { + "name": "result_packaging", + "description": "对搜索结果进行格式化、排序、筛选并呈现给用户。", + "content": """# 结果包装 + +## 适用场景 +- 工具返回多条商品后需要整理呈现 +- 用户要求「按价格排序」「只要前 3 个」 + +## 操作步骤 +1. 按相关性/相似度排序 +2. 限制展示数量(通常 3–5 个) +3. **必须使用以下格式**呈现每个商品: + +``` +1. [Product Name] + ID: [Product ID Number] + Category: [Category] + Color: [Color] + Gender: [Gender] + Season: [Season] + Usage: [Usage] + Relevance: [XX%] +``` + +4. ID 字段不可省略,用于前端展示图片""", + }, + { + "name": "after_sales", + "description": "处理退换货、物流、保修、尺码建议等售后问题。", + "content": """# 售后相关 + +## 适用场景 +- 退换货政策、运费、签收时间 +- 尺码建议、洗涤说明 +- 保修、客服联系方式 + +## 操作步骤 +1. 此类问题无需调用商品搜索工具 +2. 按平台统一售后政策回答 +3. 涉及具体商品时,可结合商品 ID 查询详情后再回答 +4. 复杂问题引导用户联系客服""", + }, +] +``` + +--- + +## 四、核心代码实现 + +### 4.1 load_skill 工具 + +```python +from langchain.tools import tool + +@tool +def load_skill(skill_name: str) -> str: + """加载技能的完整内容到 Agent 上下文中。 + + 当需要处理特定类型请求时,调用此工具获取该技能的详细说明和操作步骤。 + + Args: + skill_name: 技能名称,可选值:lookup_related, search_products, check_product, result_packaging, after_sales + """ + for skill in SKILLS: + if skill["name"] == skill_name: + return f"Loaded skill: {skill_name}\n\n{skill['content']}" + + available = ", ".join(s["name"] for s in SKILLS) + return f"Skill '{skill_name}' not found. Available: {available}" +``` + +### 4.2 SkillMiddleware(注入技能描述) + +```python +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse +from langchain.messages import SystemMessage +from typing import Callable + +class ShoppingSkillMiddleware(AgentMiddleware): + """将技能描述注入 system prompt,使 Agent 能发现并按需加载技能""" + + tools = [load_skill] + + def __init__(self): + skills_list = [] + for skill in SKILLS: + skills_list.append(f"- **{skill['name']}**: {skill['description']}") + self.skills_prompt = "\n".join(skills_list) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + skills_addendum = ( + f"\n\n## 可用技能(按需加载)\n\n{self.skills_prompt}\n\n" + "在需要详细说明时,使用 load_skill 工具加载对应技能。" + ) + new_content = list(request.system_message.content_blocks) + [ + {"type": "text", "text": skills_addendum} + ] + new_system_message = SystemMessage(content=new_content) + modified_request = request.override(system_message=new_system_message) + return handler(modified_request) +``` + +### 4.3 创建带 Skills 的 Agent + +```python +from langchain.agents import create_agent +from langgraph.checkpoint.memory import MemorySaver + +# 基础工具(搜索、以图搜图、风格分析等) +from app.tools.search_tools import search_products, search_by_image, analyze_image_style + +agent = create_agent( + model="gpt-4o-mini", + tools=[ + load_skill, # 技能加载 + search_products, + search_by_image, + analyze_image_style, + ], + system_prompt="""你是智能时尚购物助手。根据用户需求,先判断使用哪个技能,必要时用 load_skill 加载技能详情。 + +处理商品结果时,必须遵守 result_packaging 技能中的格式要求。""", + middleware=[ShoppingSkillMiddleware()], + checkpointer=MemorySaver(), +) +``` + +--- + +## 五、与工具的关系 + +| 能力 | 技能 | 工具 | +|------|------|------| +| 查找相关 | lookup_related | search_by_image, search_products, analyze_image_style | +| 搜索商品 | search_products | search_products | +| 检验商品 | check_product | search_products(用 query 表达约束) | +| 结果包装 | result_packaging | 无(纯 prompt 约束) | +| 售后 | after_sales | 无(或对接客服 API) | + +- **技能**:提供「何时用、怎么用」的说明,支持渐进式加载。 +- **工具**:实际执行搜索、分析等操作。 + +--- + +## 六、可选:技能约束(进阶) + +若希望「先加载技能再使用工具」,可结合 `ToolRuntime` 和 state 做约束: + +```python +from langchain.tools import tool, ToolRuntime +from langgraph.types import Command +from langchain.messages import ToolMessage +from typing_extensions import NotRequired + +class CustomState(AgentState): + skills_loaded: NotRequired[list[str]] + +@tool +def load_skill(skill_name: str, runtime: ToolRuntime) -> Command: + """...""" + for skill in SKILLS: + if skill["name"] == skill_name: + content = f"Loaded skill: {skill_name}\n\n{skill['content']}" + return Command(update={ + "messages": [ToolMessage(content=content, tool_call_id=runtime.tool_call_id)], + "skills_loaded": [skill_name], + }) + # ... + +# 在 check_product 等工具中检查 skills_loaded +``` + +--- + +## 七、依赖与版本 + +```text +# requirements.txt +langchain>=1.0.0 +langchain-openai>=0.2.0 +langchain-core>=0.3.0 +langgraph>=1.0.0 +``` + +- Python 3.10+ +- 若使用 Deep Agents 的 SKILL.md,需额外安装 `deepagents` + +--- + +## 八、总结 + +| 项目 | 说明 | +|------|------| +| **效果** | 系统 prompt 只放简短技能描述,按需加载完整内容,减少 token、便于扩展 | +| **流程** | 轻量描述 → load_skill → 完整说明 → 调用工具 → 回复 | +| **实现** | `SkillMiddleware` + `load_skill` + `create_agent` | +| **技能** | lookup_related, search_products, check_product, result_packaging, after_sales | + +完整示例可参考官方教程:[Build a SQL assistant with on-demand skills](https://docs.langchain.com/oss/python/langchain/multi-agent/skills-sql-assistant)。 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..66c76f7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,40 @@ +# Core Framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +pydantic>=2.6.0 +pydantic-settings>=2.1.0 +streamlit>=1.50.0 + +# LLM & LangChain (Python 3.12, LangChain 1.x) +langchain>=1.0.0 +langchain-core>=0.3.0 +langchain-openai>=0.2.0 +langgraph>=1.0.0 +openai>=1.12.0 + +# Embeddings & Vision +clip-client>=3.5.0 # CLIP-as-Service client +Pillow>=10.2.0 # Image processing + +# Vector Database +pymilvus>=2.3.6 + +# Databases +pymongo>=4.6.1 + +# Utilities +python-dotenv>=1.0.1 +python-multipart>=0.0.9 +aiofiles>=23.2.1 +requests>=2.31.0 + +# Data Processing +pandas>=2.2.3 +numpy>=1.26.4 +tqdm>=4.66.1 + +# Development & Testing +pytest>=8.0.0 +pytest-asyncio>=0.23.4 +httpx>=0.26.0 +black>=24.1.1 diff --git a/scripts/check_services.sh b/scripts/check_services.sh new file mode 100755 index 0000000..ad207f3 --- /dev/null +++ b/scripts/check_services.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - 服务健康检查脚本 +# 检查 Milvus、CLIP、Streamlit 等依赖服务状态 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +echo "==========================================" +echo "OmniShopAgent 服务健康检查" +echo "==========================================" + +# 1. Python 环境 +echo -n "[Python] " +if command -v python3 &>/dev/null; then + VER=$(python3 -c 'import sys; v=sys.version_info; print(f"{v.major}.{v.minor}.{v.micro}")' 2>/dev/null) + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then + echo -e "${GREEN}OK${NC} $VER" + else + echo -e "${YELLOW}WARN${NC} $VER (建议 3.12+)" + fi +else + echo -e "${RED}FAIL${NC} 未找到" +fi + +# 2. 虚拟环境 +echo -n "[Virtualenv] " +if [ -d "$PROJECT_ROOT/venv" ]; then + echo -e "${GREEN}OK${NC} $PROJECT_ROOT/venv" +else + echo -e "${YELLOW}WARN${NC} 未找到 venv" +fi + +# 3. .env 配置 +echo -n "[.env] " +if [ -f "$PROJECT_ROOT/.env" ]; then + if grep -q "OPENAI_API_KEY=sk-" "$PROJECT_ROOT/.env" 2>/dev/null; then + echo -e "${GREEN}OK${NC} 已配置" + else + echo -e "${YELLOW}WARN${NC} 请配置 OPENAI_API_KEY" + fi +else + echo -e "${RED}FAIL${NC} 未找到" +fi + +# 4. Milvus +echo -n "[Milvus] " +if command -v docker &>/dev/null; then + if docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then + if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then + echo -e "${GREEN}OK${NC} localhost:19530" + else + echo -e "${YELLOW}WARN${NC} 容器运行中,健康检查未响应" + fi + else + echo -e "${YELLOW}WARN${NC} 未运行 (docker compose up -d)" + fi +else + echo -e "${YELLOW}SKIP${NC} Docker 未安装" +fi + +# 5. CLIP 服务(可选) +echo -n "[CLIP] " +if timeout 2 bash -c 'echo >/dev/tcp/localhost/51000' 2>/dev/null; then + echo -e "${GREEN}OK${NC} localhost:51000" +else + echo -e "${YELLOW}WARN${NC} 未运行 (图像搜索需启动: python -m clip_server launch)" +fi + +# 6. 数据目录 +echo -n "[数据] " +if [ -d "$PROJECT_ROOT/data/images" ] && [ -f "$PROJECT_ROOT/data/styles.csv" ]; then + IMG_COUNT=$(find "$PROJECT_ROOT/data/images" -name "*.jpg" 2>/dev/null | wc -l) + echo -e "${GREEN}OK${NC} $IMG_COUNT 张图片" +else + echo -e "${YELLOW}WARN${NC} 未找到 data/images 或 data/styles.csv (运行 download_dataset.py)" +fi + +# 7. Streamlit +echo -n "[Streamlit] " +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then + echo -e "${GREEN}OK${NC} 运行中" +else + echo -e "${YELLOW}WARN${NC} 未运行 (./scripts/start.sh)" +fi + +echo "==========================================" diff --git a/scripts/download_dataset.py b/scripts/download_dataset.py new file mode 100644 index 0000000..8d93252 --- /dev/null +++ b/scripts/download_dataset.py @@ -0,0 +1,95 @@ +""" +Script to download the Fashion Product Images Dataset from Kaggle + +Requirements: +1. Install Kaggle CLI: pip install kaggle +2. Setup Kaggle API credentials: + - Go to https://www.kaggle.com/settings/account + - Click "Create New API Token" + - Save kaggle.json to ~/.kaggle/kaggle.json + - chmod 600 ~/.kaggle/kaggle.json + +Usage: + python scripts/download_dataset.py +""" + +import subprocess +import zipfile +from pathlib import Path + + +def download_dataset(): + """Download and extract the Fashion Product Images Dataset""" + + # Get project root + project_root = Path(__file__).parent.parent + raw_data_path = project_root / "data" / "raw" + + # Check if data already exists + if (raw_data_path / "styles.csv").exists(): + print("Dataset already exists in data/raw/") + response = input("Do you want to re-download? (y/n): ") + if response.lower() != "y": + print("Skipping download.") + return + + # Check Kaggle credentials + kaggle_json = Path.home() / ".kaggle" / "kaggle.json" + if not kaggle_json.exists(): + print(" Kaggle API credentials not found!") + return + + print("Downloading dataset from Kaggle...") + + try: + # Download using Kaggle API + subprocess.run( + [ + "kaggle", + "datasets", + "download", + "-d", + "paramaggarwal/fashion-product-images-dataset", + "-p", + str(raw_data_path), + ], + check=True, + ) + + print("Download complete!") + + # Extract zip file + zip_path = raw_data_path / "fashion-product-images-dataset.zip" + if zip_path.exists(): + print("Extracting files...") + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(raw_data_path) + + print("Extraction complete!") + + # Clean up zip file + zip_path.unlink() + print("Cleaned up zip file") + + # Verify files + styles_csv = raw_data_path / "styles.csv" + images_dir = raw_data_path / "images" + + if styles_csv.exists() and images_dir.exists(): + print("\Dataset ready!") + + # Count images + image_count = len(list(images_dir.glob("*.jpg"))) + print(f"- Total images: {image_count:,}") + else: + print("Warning: Expected files not found") + + except subprocess.CalledProcessError: + print("Download failed!") + + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + download_dataset() diff --git a/scripts/index_data.py b/scripts/index_data.py new file mode 100644 index 0000000..c495200 --- /dev/null +++ b/scripts/index_data.py @@ -0,0 +1,467 @@ +""" +Data Indexing Script +Generates embeddings for products and stores them in Milvus +""" + +import csv +import logging +import os +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from tqdm import tqdm + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Import config and settings first +# Direct imports from files to avoid __init__.py circular issues +import importlib.util + +from app.config import get_absolute_path, settings + + +def load_service_module(module_name, file_name): + """Load a service module directly from file""" + spec = importlib.util.spec_from_file_location( + module_name, + os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + f"app/services/{file_name}", + ), + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +embedding_module = load_service_module("embedding_service", "embedding_service.py") +milvus_module = load_service_module("milvus_service", "milvus_service.py") + +EmbeddingService = embedding_module.EmbeddingService +MilvusService = milvus_module.MilvusService + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class DataIndexer: + """Index product data by generating and storing embeddings""" + + def __init__(self): + """Initialize services""" + self.embedding_service = EmbeddingService() + self.milvus_service = MilvusService() + + self.image_dir = Path(get_absolute_path(settings.image_data_path)) + self.styles_csv = get_absolute_path("./data/styles.csv") + self.images_csv = get_absolute_path("./data/images.csv") + + # Load product data from CSV + self.products = self._load_products_from_csv() + + def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: + """Load products from CSV files""" + products = {} + + # Load images mapping + images_dict = {} + with open(self.images_csv, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + product_id = int(row["filename"].split(".")[0]) + images_dict[product_id] = row["link"] + + # Load styles/products + with open(self.styles_csv, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + try: + product_id = int(row["id"]) + products[product_id] = { + "id": product_id, + "gender": row.get("gender", ""), + "masterCategory": row.get("masterCategory", ""), + "subCategory": row.get("subCategory", ""), + "articleType": row.get("articleType", ""), + "baseColour": row.get("baseColour", ""), + "season": row.get("season", ""), + "year": int(row["year"]) if row.get("year") else 0, + "usage": row.get("usage", ""), + "productDisplayName": row.get("productDisplayName", ""), + "imageUrl": images_dict.get(product_id, ""), + "imagePath": f"{product_id}.jpg", + } + except (ValueError, KeyError) as e: + logger.warning(f"Error loading product {row.get('id')}: {e}") + continue + + logger.info(f"Loaded {len(products)} products from CSV") + return products + + def setup(self) -> None: + """Setup connections and collections""" + logger.info("Setting up services...") + + # Connect to CLIP server + self.embedding_service.connect_clip() + logger.info("✓ CLIP server connected") + + # Connect to Milvus + self.milvus_service.connect() + logger.info("✓ Milvus connected") + + # Create Milvus collections + self.milvus_service.create_text_collection(recreate=False) + self.milvus_service.create_image_collection(recreate=False) + logger.info("✓ Milvus collections ready") + + def teardown(self) -> None: + """Close all connections""" + logger.info("Closing connections...") + self.embedding_service.disconnect_clip() + self.milvus_service.disconnect() + logger.info("✓ All connections closed") + + def index_text_embeddings( + self, batch_size: int = 100, skip: int = 0, limit: Optional[int] = None + ) -> Dict[str, int]: + """Generate and store text embeddings for products + + Args: + batch_size: Number of products to process at once + skip: Number of products to skip + limit: Maximum number of products to process (None for all) + + Returns: + Dictionary with indexing statistics + """ + logger.info("Starting text embedding indexing...") + + # Get products list + product_ids = list(self.products.keys())[skip:] + if limit: + product_ids = product_ids[:limit] + + total_products = len(product_ids) + processed = 0 + inserted = 0 + errors = 0 + + with tqdm(total=total_products, desc="Indexing text embeddings") as pbar: + while processed < total_products: + # Get batch of products + current_batch_size = min(batch_size, total_products - processed) + batch_ids = product_ids[processed : processed + current_batch_size] + products = [self.products[pid] for pid in batch_ids] + + if not products: + break + + try: + # Prepare texts for embedding + texts = [] + text_mappings = [] + + for product in products: + # Create text representation of product + text = self._create_product_text(product) + texts.append(text) + text_mappings.append( + {"product_id": product["id"], "text": text} + ) + + # Generate embeddings + embeddings = self.embedding_service.get_text_embeddings_batch( + texts, batch_size=50 # OpenAI batch size + ) + + # Prepare data for Milvus (with metadata) + milvus_data = [] + for idx, (mapping, embedding) in enumerate( + zip(text_mappings, embeddings) + ): + product_id = mapping["product_id"] + product = self.products[product_id] + + milvus_data.append( + { + "id": product_id, + "text": mapping["text"][ + :2000 + ], # Truncate to max length + "embedding": embedding, + # Product metadata + "productDisplayName": product["productDisplayName"][ + :500 + ], + "gender": product["gender"][:50], + "masterCategory": product["masterCategory"][:100], + "subCategory": product["subCategory"][:100], + "articleType": product["articleType"][:100], + "baseColour": product["baseColour"][:50], + "season": product["season"][:50], + "usage": product["usage"][:50], + "year": product["year"], + "imageUrl": product["imageUrl"], + "imagePath": product["imagePath"], + } + ) + + # Insert into Milvus + count = self.milvus_service.insert_text_embeddings(milvus_data) + inserted += count + + except Exception as e: + logger.error( + f"Error processing text batch at offset {processed}: {e}" + ) + errors += len(products) + + processed += len(products) + pbar.update(len(products)) + + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} + + logger.info(f"Text embedding indexing completed: {stats}") + return stats + + def index_image_embeddings( + self, batch_size: int = 32, skip: int = 0, limit: Optional[int] = None + ) -> Dict[str, int]: + """Generate and store image embeddings for products + + Args: + batch_size: Number of images to process at once + skip: Number of products to skip + limit: Maximum number of products to process (None for all) + + Returns: + Dictionary with indexing statistics + """ + logger.info("Starting image embedding indexing...") + + # Get products list + product_ids = list(self.products.keys())[skip:] + if limit: + product_ids = product_ids[:limit] + + total_products = len(product_ids) + processed = 0 + inserted = 0 + errors = 0 + + with tqdm(total=total_products, desc="Indexing image embeddings") as pbar: + while processed < total_products: + # Get batch of products + current_batch_size = min(batch_size, total_products - processed) + batch_ids = product_ids[processed : processed + current_batch_size] + products = [self.products[pid] for pid in batch_ids] + + if not products: + break + + try: + # Prepare image paths + image_paths = [] + image_mappings = [] + + for product in products: + image_path = self.image_dir / product["imagePath"] + image_paths.append(image_path) + image_mappings.append( + { + "product_id": product["id"], + "image_path": product["imagePath"], + } + ) + + # Generate embeddings + embeddings = self.embedding_service.get_image_embeddings_batch( + image_paths, batch_size=batch_size + ) + + # Prepare data for Milvus (with metadata) + milvus_data = [] + for idx, (mapping, embedding) in enumerate( + zip(image_mappings, embeddings) + ): + if embedding is not None: + product_id = mapping["product_id"] + product = self.products[product_id] + + milvus_data.append( + { + "id": product_id, + "image_path": mapping["image_path"], + "embedding": embedding, + # Product metadata + "productDisplayName": product["productDisplayName"][ + :500 + ], + "gender": product["gender"][:50], + "masterCategory": product["masterCategory"][:100], + "subCategory": product["subCategory"][:100], + "articleType": product["articleType"][:100], + "baseColour": product["baseColour"][:50], + "season": product["season"][:50], + "usage": product["usage"][:50], + "year": product["year"], + "imageUrl": product["imageUrl"], + } + ) + else: + errors += 1 + + # Insert into Milvus + if milvus_data: + count = self.milvus_service.insert_image_embeddings(milvus_data) + inserted += count + + except Exception as e: + logger.error( + f"Error processing image batch at offset {processed}: {e}" + ) + errors += len(products) + + processed += len(products) + pbar.update(len(products)) + + stats = {"total_processed": processed, "inserted": inserted, "errors": errors} + + logger.info(f"Image embedding indexing completed: {stats}") + return stats + + def _create_product_text(self, product: Dict[str, Any]) -> str: + """Create text representation of product for embedding + + Args: + product: Product document + + Returns: + Text representation + """ + # Create a natural language description + parts = [ + product.get("productDisplayName", ""), + f"Gender: {product.get('gender', '')}", + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", + f"Type: {product.get('articleType', '')}", + f"Color: {product.get('baseColour', '')}", + f"Season: {product.get('season', '')}", + f"Usage: {product.get('usage', '')}", + ] + + text = " | ".join( + [p for p in parts if p and p != "Gender: " and p != "Color: "] + ) + return text + + def get_stats(self) -> Dict[str, Any]: + """Get indexing statistics + + Returns: + Dictionary with statistics + """ + text_stats = self.milvus_service.get_collection_stats( + self.milvus_service.text_collection_name + ) + image_stats = self.milvus_service.get_collection_stats( + self.milvus_service.image_collection_name + ) + + return { + "total_products": len(self.products), + "milvus_text": text_stats, + "milvus_image": image_stats, + } + + +def main(): + """Main function""" + import argparse + + parser = argparse.ArgumentParser(description="Index product data for search") + parser.add_argument( + "--mode", + choices=["text", "image", "both"], + default="both", + help="Which embeddings to index", + ) + parser.add_argument( + "--batch-size", type=int, default=100, help="Batch size for processing" + ) + parser.add_argument( + "--skip", type=int, default=0, help="Number of products to skip" + ) + parser.add_argument( + "--limit", type=int, default=None, help="Maximum number of products to process" + ) + parser.add_argument("--stats", action="store_true", help="Show statistics only") + + args = parser.parse_args() + + # Create indexer + indexer = DataIndexer() + + try: + # Setup services + indexer.setup() + + if args.stats: + # Show statistics + stats = indexer.get_stats() + print("\n=== Indexing Statistics ===") + print(f"\nTotal Products in CSV: {stats['total_products']}") + + print("\nMilvus Text Embeddings:") + print(f" Collection: {stats['milvus_text']['collection_name']}") + print(f" Total embeddings: {stats['milvus_text']['row_count']}") + + print("\nMilvus Image Embeddings:") + print(f" Collection: {stats['milvus_image']['collection_name']}") + print(f" Total embeddings: {stats['milvus_image']['row_count']}") + + print( + f"\nCoverage: {stats['milvus_image']['row_count'] / stats['total_products'] * 100:.1f}%" + ) + else: + # Index data + if args.mode in ["text", "both"]: + logger.info("=== Indexing Text Embeddings ===") + text_stats = indexer.index_text_embeddings( + batch_size=args.batch_size, skip=args.skip, limit=args.limit + ) + print(f"\nText Indexing Results: {text_stats}") + + if args.mode in ["image", "both"]: + logger.info("=== Indexing Image Embeddings ===") + image_stats = indexer.index_image_embeddings( + batch_size=min(args.batch_size, 32), # Smaller batch for images + skip=args.skip, + limit=args.limit, + ) + print(f"\nImage Indexing Results: {image_stats}") + + # Show final statistics + logger.info("\n=== Final Statistics ===") + stats = indexer.get_stats() + print(f"Total products: {stats['total_products']}") + print(f"Text embeddings: {stats['milvus_text']['row_count']}") + print(f"Image embeddings: {stats['milvus_image']['row_count']}") + + except KeyboardInterrupt: + logger.info("\nIndexing interrupted by user") + except Exception as e: + logger.error(f"Error during indexing: {e}", exc_info=True) + sys.exit(1) + finally: + indexer.teardown() + + +if __name__ == "__main__": + main() diff --git a/scripts/run_clip.sh b/scripts/run_clip.sh new file mode 100755 index 0000000..f095625 --- /dev/null +++ b/scripts/run_clip.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - 启动 CLIP 图像向量服务 +# 图像搜索、以图搜图功能依赖此服务 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" + +cd "$PROJECT_ROOT" + +if [ -d "$VENV_DIR" ]; then + set +u + source "$VENV_DIR/bin/activate" + set -u +fi + +echo "启动 CLIP 服务 (端口 51000)..." +echo "按 Ctrl+C 停止" +exec python -m clip_server launch diff --git a/scripts/run_milvus.sh b/scripts/run_milvus.sh new file mode 100755 index 0000000..190a816 --- /dev/null +++ b/scripts/run_milvus.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - 启动 Milvus 向量数据库 +# 使用 Docker Compose 启动 Milvus 及相关依赖 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" + +cd "$PROJECT_ROOT" + +if ! command -v docker &>/dev/null; then + echo "错误: 未安装 Docker。请先运行 setup_env_centos8.sh" + exit 1 +fi + +echo "启动 Milvus..." +docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { + echo "错误: 无法执行 docker compose。请确保已安装 Docker Compose" + exit 1 +} + +echo "等待 Milvus 就绪 (约 60 秒)..." +sleep 60 + +if curl -s -o /dev/null -w "%{http_code}" http://localhost:9091/healthz 2>/dev/null | grep -q 200; then + echo "Milvus 已就绪: localhost:19530" +else + echo "提示: Milvus 可能仍在启动,请稍后执行 check_services.sh 检查" +fi diff --git a/scripts/setup_env_centos8.sh b/scripts/setup_env_centos8.sh new file mode 100755 index 0000000..d1571b4 --- /dev/null +++ b/scripts/setup_env_centos8.sh @@ -0,0 +1,152 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - CentOS 8 环境准备脚本 +# 准备 Python 3.12、Docker、依赖及虚拟环境 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" +PYTHON_VERSION="${PYTHON_VERSION:-3.12}" + +echo "==========================================" +echo "OmniShopAgent - CentOS 8 环境准备" +echo "==========================================" +echo "项目目录: $PROJECT_ROOT" +echo "虚拟环境: $VENV_DIR" +echo "Python 版本: $PYTHON_VERSION" +echo "==========================================" + +# ----------------------------------------------------------------------------- +# 1. 安装系统依赖 +# ----------------------------------------------------------------------------- +echo "[1/4] 安装系统依赖..." +sudo dnf install -y \ + gcc \ + gcc-c++ \ + make \ + openssl-devel \ + bzip2-devel \ + libffi-devel \ + sqlite-devel \ + xz-devel \ + zlib-devel \ + readline-devel \ + tk-devel \ + libuuid-devel \ + curl \ + wget \ + git \ + tar + +# ----------------------------------------------------------------------------- +# 2. 安装 Docker(用于 Milvus) +# ----------------------------------------------------------------------------- +echo "[2/4] 检查/安装 Docker..." +if ! command -v docker &>/dev/null; then + echo " 安装 Docker..." + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 2>/dev/null || { + sudo dnf install -y dnf-plugins-core + sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo + } + sudo dnf install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin + sudo systemctl enable docker + sudo systemctl start docker + sudo usermod -aG docker "$USER" 2>/dev/null || true + echo " Docker 已安装。请执行 'newgrp docker' 或重新登录以使用 docker 命令。" +else + echo " Docker 已安装: $(docker --version)" +fi + +# ----------------------------------------------------------------------------- +# 3. 安装 Python 3.12 +# ----------------------------------------------------------------------------- +echo "[3/4] 安装 Python $PYTHON_VERSION..." + +USE_CONDA=false +if command -v python3.12 &>/dev/null; then + echo " Python 3.12 已安装" +elif command -v conda &>/dev/null; then + echo " 使用 conda 创建 Python $PYTHON_VERSION 环境..." + conda create -n shop_agent "python=$PYTHON_VERSION" -y + USE_CONDA=true + echo " Conda 环境已创建。请执行: conda activate shop_agent" + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" + echo " 跳过 venv 创建..." +else + echo " 从源码编译 Python $PYTHON_VERSION..." + sudo dnf groupinstall -y 'Development Tools' + cd /tmp + PY_URL="https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.0.tgz" + PY_TGZ="Python-${PYTHON_VERSION}.0.tgz" + [ -f "$PY_TGZ" ] || wget -q "$PY_URL" -O "$PY_TGZ" + tar xzf "$PY_TGZ" + cd "Python-${PYTHON_VERSION}.0" + ./configure --enable-optimizations --prefix=/usr/local + make -j "$(nproc)" + sudo make altinstall + cd /tmp + rm -rf "Python-${PYTHON_VERSION}.0" "$PY_TGZ" +fi + +# ----------------------------------------------------------------------------- +# 4. 创建虚拟环境并安装依赖(非 conda 时) +# ----------------------------------------------------------------------------- +if [ "$USE_CONDA" = true ]; then + echo "[4/4] 已使用 conda,跳过 venv 创建" +else + echo "[4/4] 创建虚拟环境与安装 Python 依赖..." + + PYTHON_BIN="" + for p in python3.12 python3.11 python3; do + if command -v "$p" &>/dev/null; then + VER=$("$p" -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null || echo "0") + if [[ "$VER" == "3.1"* ]] || [[ "$VER" == "3.12"* ]]; then + PYTHON_BIN="$p" + break + fi + fi + done + + if [ -z "$PYTHON_BIN" ]; then + echo " 错误: 未找到 Python 3.10+。若使用 conda,请先执行: conda activate shop_agent" + echo " 然后手动执行: pip install -r $PROJECT_ROOT/requirements.txt" + exit 1 + fi + + if [ ! -d "$VENV_DIR" ]; then + echo " 创建虚拟环境: $VENV_DIR" + "$PYTHON_BIN" -m venv "$VENV_DIR" + fi + + echo " 激活虚拟环境并安装依赖..." + set +u + source "$VENV_DIR/bin/activate" + set -u + pip install -U pip + pip install -r "$PROJECT_ROOT/requirements.txt" + echo " Python 依赖安装完成。" +fi + +# 配置 .env +if [ ! -f "$PROJECT_ROOT/.env" ]; then + echo "" + echo " 创建 .env 配置文件..." + cp "$PROJECT_ROOT/.env.example" "$PROJECT_ROOT/.env" + echo " 请编辑 $PROJECT_ROOT/.env 配置 OPENAI_API_KEY 等参数。" +fi + +echo "" +echo "==========================================" +echo "环境准备完成!" +echo "==========================================" +echo "下一步:" +echo " 1. 编辑 .env 配置 OPENAI_API_KEY" +echo " 2. 下载数据: python scripts/download_dataset.py" +echo " 3. 启动 Milvus: ./scripts/run_milvus.sh" +echo " 4. 索引数据: python scripts/index_data.py" +echo " 5. 启动应用: ./scripts/start.sh" +echo "" +echo "激活虚拟环境: source $VENV_DIR/bin/activate" +echo "==========================================" diff --git a/scripts/start.sh b/scripts/start.sh new file mode 100755 index 0000000..fe42510 --- /dev/null +++ b/scripts/start.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - 启动脚本 +# 启动 Milvus、CLIP(可选)、Streamlit 应用 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +VENV_DIR="${VENV_DIR:-$PROJECT_ROOT/venv}" +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" +STREAMLIT_HOST="${STREAMLIT_HOST:-0.0.0.0}" + +cd "$PROJECT_ROOT" + +# 激活虚拟环境 +if [ -d "$VENV_DIR" ]; then + echo "激活虚拟环境: $VENV_DIR" + set +u + source "$VENV_DIR/bin/activate" + set -u +else + echo "警告: 未找到虚拟环境 $VENV_DIR,使用当前 Python" +fi + +echo "==========================================" +echo "OmniShopAgent 启动" +echo "==========================================" + +# 1. 启动 Milvus(Docker) +if command -v docker &>/dev/null; then + echo "[1/3] 检查 Milvus..." + if ! docker ps --format '{{.Names}}' 2>/dev/null | grep -q milvus-standalone; then + echo " 启动 Milvus (docker compose)..." + docker compose up -d 2>/dev/null || docker-compose up -d 2>/dev/null || { + echo " 警告: 无法启动 Milvus,请手动执行: docker compose up -d" + } + echo " 等待 Milvus 就绪 (30s)..." + sleep 30 + else + echo " Milvus 已运行" + fi +else + echo "[1/3] 跳过 Milvus: 未安装 Docker" +fi + +# 2. 检查 CLIP(可选,图像搜索需要) +echo "[2/3] 检查 CLIP 服务..." +echo " 提示: 图像搜索需 CLIP。若未启动,请另开终端执行: python -m clip_server launch" +echo " 文本搜索可无需 CLIP。" + +# 3. 启动 Streamlit +echo "[3/3] 启动 Streamlit (端口 $STREAMLIT_PORT)..." +echo "" +echo " 访问: http://$STREAMLIT_HOST:$STREAMLIT_PORT" +echo " 按 Ctrl+C 停止" +echo "==========================================" + +exec streamlit run app.py \ + --server.port="$STREAMLIT_PORT" \ + --server.address="$STREAMLIT_HOST" \ + --server.headless=true \ + --browser.gatherUsageStats=false diff --git a/scripts/stop.sh b/scripts/stop.sh new file mode 100755 index 0000000..ba2f64d --- /dev/null +++ b/scripts/stop.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# ============================================================================= +# OmniShopAgent - 停止脚本 +# 停止 Streamlit 进程及 Milvus 容器 +# ============================================================================= +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)" +STREAMLIT_PORT="${STREAMLIT_PORT:-8501}" + +echo "==========================================" +echo "OmniShopAgent 停止" +echo "==========================================" + +# 1. 停止 Streamlit 进程 +echo "[1/2] 停止 Streamlit..." +if pgrep -f "streamlit run app.py" >/dev/null 2>&1; then + pkill -f "streamlit run app.py" 2>/dev/null || true + echo " Streamlit 已停止" +else + echo " Streamlit 未在运行" +fi + +# 按端口查找并终止 +if command -v lsof &>/dev/null; then + PID=$(lsof -ti:$STREAMLIT_PORT 2>/dev/null || true) + if [ -n "$PID" ]; then + kill $PID 2>/dev/null || true + echo " 已终止端口 $STREAMLIT_PORT 上的进程" + fi +fi + +# 2. 可选:停止 Milvus 容器 +echo "[2/2] 停止 Milvus..." +if command -v docker &>/dev/null; then + cd "$PROJECT_ROOT" + docker compose down 2>/dev/null || docker-compose down 2>/dev/null || true + echo " Milvus 已停止" +else + echo " Docker 未安装,跳过" +fi + +echo "==========================================" +echo "OmniShopAgent 已停止" +echo "==========================================" diff --git a/技术实现报告.md b/技术实现报告.md new file mode 100644 index 0000000..d6f94d3 --- /dev/null +++ b/技术实现报告.md @@ -0,0 +1,624 @@ +# OmniShopAgent 项目技术实现报告 + +## 一、项目概述 + +OmniShopAgent 是一个基于 **LangGraph** 和 **ReAct 模式** 的自主多模态时尚购物智能体。系统能够自主决定调用哪些工具、维护对话状态、判断何时回复,实现智能化的商品发现与推荐。 + +### 核心特性 + +- **自主工具选择与执行**:Agent 根据用户意图自主选择并调用工具 +- **多模态搜索**:支持文本搜索 + 图像搜索 +- **对话上下文感知**:多轮对话中保持上下文记忆 +- **实时视觉分析**:基于 VLM 的图片风格分析 + +--- + +## 二、技术栈 + +| 组件 | 技术选型 | +|------|----------| +| 运行环境 | Python 3.12 | +| Agent 框架 | LangGraph 1.x | +| LLM 框架 | LangChain 1.x(支持任意 LLM,默认 gpt-4o-mini) | +| 文本向量 | text-embedding-3-small | +| 图像向量 | CLIP ViT-B/32 | +| 向量数据库 | Milvus | +| 前端 | Streamlit | +| 数据集 | Kaggle Fashion Products | + +--- + +## 三、系统架构 + +### 3.1 整体架构图 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Streamlit 前端 (app.py) │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ ShoppingAgent (shopping_agent.py) │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ LangGraph StateGraph + ReAct Pattern │ │ +│ │ START → Agent → [Has tool_calls?] → Tools → Agent → END │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌──────────────┐ ┌──────────────────┐ ┌─────────────────────┐ +│ search_ │ │ search_by_image │ │ analyze_image_style │ +│ products │ │ │ │ (OpenAI Vision) │ +└──────┬───────┘ └────────┬─────────┘ └──────────┬───────────┘ + │ │ │ + ▼ ▼ ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ EmbeddingService (embedding_service.py) │ +│ OpenAI API (文本) │ CLIP Server (图像) │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ MilvusService (milvus_service.py) │ +│ text_embeddings 集合 │ image_embeddings 集合 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### 3.2 Agent 流程图(LangGraph) + +```mermaid +graph LR + START --> Agent + Agent -->|Has tool_calls| Tools + Agent -->|No tool_calls| END + Tools --> Agent +``` + +--- + +## 四、关键代码实现 + +### 4.1 Agent 核心实现(shopping_agent.py) + +#### 4.1.1 状态定义 + +```python +from typing_extensions import Annotated, TypedDict +from langgraph.graph.message import add_messages + +class AgentState(TypedDict): + """State for the shopping agent with message accumulation""" + messages: Annotated[Sequence[BaseMessage], add_messages] + current_image_path: Optional[str] # Track uploaded image +``` + +- `messages` 使用 `add_messages` 实现消息累加,支持多轮对话 +- `current_image_path` 存储当前上传的图片路径供工具使用 + +#### 4.1.2 LangGraph 图构建 + +```python +def _build_graph(self): + """Build the LangGraph StateGraph""" + + def agent_node(state: AgentState): + """Agent decision node - decides which tools to call or when to respond""" + messages = state["messages"] + if not any(isinstance(m, SystemMessage) for m in messages): + messages = [SystemMessage(content=system_prompt)] + list(messages) + response = self.llm_with_tools.invoke(messages) + return {"messages": [response]} + + tool_node = ToolNode(self.tools) + + def should_continue(state: AgentState): + """Determine if agent should continue or end""" + last_message = state["messages"][-1] + if hasattr(last_message, "tool_calls") and last_message.tool_calls: + return "tools" + return END + + workflow = StateGraph(AgentState) + workflow.add_node("agent", agent_node) + workflow.add_node("tools", tool_node) + workflow.add_edge(START, "agent") + workflow.add_conditional_edges("agent", should_continue, ["tools", END]) + workflow.add_edge("tools", "agent") + + checkpointer = MemorySaver() + return workflow.compile(checkpointer=checkpointer) +``` + +关键点: +- **agent_node**:将消息传入 LLM,由 LLM 决定是否调用工具 +- **should_continue**:若有 `tool_calls` 则进入工具节点,否则结束 +- **MemorySaver**:按 `thread_id` 持久化对话状态 + +#### 4.1.3 System Prompt 设计 + +```python +system_prompt = """You are an intelligent fashion shopping assistant. You can: +1. Search for products by text description (use search_products) +2. Find visually similar products from images (use search_by_image) +3. Analyze image style and attributes (use analyze_image_style) + +When a user asks about products: +- For text queries: use search_products directly +- For image uploads: decide if you need to analyze_image_style first, then search +- You can call multiple tools in sequence if needed +- Always provide helpful, friendly responses + +CRITICAL FORMATTING RULES: +When presenting product results, you MUST use this EXACT format for EACH product: +1. [Product Name] + ID: [Product ID Number] + Category: [Category] + Color: [Color] + Gender: [Gender] + (Include Season, Usage, Relevance if available) +...""" +``` + +通过 system prompt 约束工具使用和输出格式,保证前端可正确解析产品信息。 + +#### 4.1.4 对话入口与流式处理 + +```python +def chat(self, query: str, image_path: Optional[str] = None) -> dict: + # Build input message + message_content = query + if image_path: + message_content = f"{query}\n[User uploaded image: {image_path}]" + + config = {"configurable": {"thread_id": self.session_id}} + input_state = { + "messages": [HumanMessage(content=message_content)], + "current_image_path": image_path, + } + + tool_calls = [] + for event in self.graph.stream(input_state, config=config): + if "agent" in event: + for msg in event["agent"].get("messages", []): + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tool_calls.append({"name": tc["name"], "args": tc.get("args", {})}) + if "tools" in event: + # 记录工具执行结果 + ... + + final_state = self.graph.get_state(config) + response_text = final_state.values["messages"][-1].content + + return {"response": response_text, "tool_calls": tool_calls, "error": False} +``` + +--- + +### 4.2 搜索工具实现(search_tools.py) + +#### 4.2.1 文本语义搜索 + +```python +@tool +def search_products(query: str, limit: int = 5) -> str: + """Search for fashion products using natural language descriptions.""" + try: + embedding_service = get_embedding_service() + milvus_service = get_milvus_service() + + query_embedding = embedding_service.get_text_embedding(query) + + results = milvus_service.search_similar_text( + query_embedding=query_embedding, + limit=min(limit, 20), + filters=None, + output_fields=[ + "id", "productDisplayName", "gender", "masterCategory", + "subCategory", "articleType", "baseColour", "season", "usage", + ], + ) + + if not results: + return "No products found matching your search." + + output = f"Found {len(results)} product(s):\n\n" + for idx, product in enumerate(results, 1): + output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n" + output += f" ID: {product.get('id', 'N/A')}\n" + output += f" Category: {product.get('masterCategory')} > {product.get('subCategory')} > {product.get('articleType')}\n" + output += f" Color: {product.get('baseColour')}\n" + output += f" Gender: {product.get('gender')}\n" + if "distance" in product: + similarity = 1 - product["distance"] + output += f" Relevance: {similarity:.2%}\n" + output += "\n" + + return output.strip() + except Exception as e: + return f"Error searching products: {str(e)}" +``` + +#### 4.2.2 图像相似度搜索 + +```python +@tool +def search_by_image(image_path: str, limit: int = 5) -> str: + """Find similar fashion products using an image.""" + if not Path(image_path).exists(): + return f"Error: Image file not found at '{image_path}'" + + embedding_service = get_embedding_service() + milvus_service = get_milvus_service() + + if not embedding_service.clip_client: + embedding_service.connect_clip() + + image_embedding = embedding_service.get_image_embedding(image_path) + + results = milvus_service.search_similar_images( + query_embedding=image_embedding, + limit=min(limit + 1, 21), + output_fields=[...], + ) + + # 过滤掉查询图像本身(如上传的是商品库中的图) + query_id = Path(image_path).stem + filtered_results = [r for r in results if Path(r.get("image_path", "")).stem != query_id] + filtered_results = filtered_results[:limit] + + +``` + +#### 4.2.3 视觉分析(VLM) + +```python +@tool +def analyze_image_style(image_path: str) -> str: + """Analyze a fashion product image using AI vision to extract detailed style information.""" + with open(img_path, "rb") as image_file: + image_data = base64.b64encode(image_file.read()).decode("utf-8") + + prompt = """Analyze this fashion product image and provide a detailed description. +Include: +- Product type (e.g., shirt, dress, shoes, pants, bag) +- Primary colors +- Style/design (e.g., casual, formal, sporty, vintage, modern) +- Pattern or texture (e.g., plain, striped, checked, floral) +- Key features (e.g., collar type, sleeve length, fit) +- Material appearance (if obvious, e.g., denim, cotton, leather) +- Suitable occasion (e.g., office wear, party, casual, sports) +Provide a comprehensive yet concise description (3-4 sentences).""" + + client = get_openai_client() + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}", "detail": "high"}}, + ], + }], + max_tokens=500, + temperature=0.3, + ) + + return response.choices[0].message.content.strip() +``` + +--- + +### 4.3 向量服务实现 + +#### 4.3.1 EmbeddingService(embedding_service.py) + +```python +class EmbeddingService: + def get_text_embedding(self, text: str) -> List[float]: + """OpenAI text-embedding-3-small""" + response = self.openai_client.embeddings.create( + input=text, model=self.text_embedding_model + ) + return response.data[0].embedding + + def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]: + """CLIP 图像向量""" + if not self.clip_client: + raise RuntimeError("CLIP client not connected. Call connect_clip() first.") + result = self.clip_client.encode([str(image_path)]) + if isinstance(result, np.ndarray): + embedding = result[0].tolist() if len(result.shape) > 1 else result.tolist() + else: + embedding = result[0].embedding.tolist() + return embedding + + def get_text_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]: + """批量文本嵌入,用于索引""" + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + response = self.openai_client.embeddings.create(input=batch, ...) + embeddings = [item.embedding for item in response.data] + all_embeddings.extend(embeddings) + return all_embeddings +``` + +#### 4.3.2 MilvusService(milvus_service.py) + +**文本集合 Schema:** + +```python +schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True) +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) +schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000) +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim) # 1536 +schema.add_field(field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500) +schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50) +schema.add_field(field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100) +# ... 更多元数据字段 +``` + +**图像集合 Schema:** + +```python +schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) +schema.add_field(field_name="image_path", datatype=DataType.VARCHAR, max_length=500) +schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim) # 512 +# ... 产品元数据 +``` + +**相似度搜索:** + +```python +def search_similar_text(self, query_embedding, limit=10, output_fields=None): + results = self.client.search( + collection_name=self.text_collection_name, + data=[query_embedding], + limit=limit, + output_fields=output_fields, + ) + formatted_results = [] + for hit in results[0]: + result = {"id": hit.get("id"), "distance": hit.get("distance")} + entity = hit.get("entity", {}) + for field in output_fields: + if field in entity: + result[field] = entity.get(field) + formatted_results.append(result) + return formatted_results +``` + +--- + +### 4.4 数据索引脚本(index_data.py) + +#### 4.4.1 产品数据加载 + +```python +def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]: + products = {} + # 加载 images.csv 映射 + with open(self.images_csv, "r") as f: + images_dict = {int(row["filename"].split(".")[0]): row["link"] for row in csv.DictReader(f)} + + # 加载 styles.csv + with open(self.styles_csv, "r") as f: + for row in csv.DictReader(f): + product_id = int(row["id"]) + products[product_id] = { + "id": product_id, + "gender": row.get("gender", ""), + "masterCategory": row.get("masterCategory", ""), + "subCategory": row.get("subCategory", ""), + "articleType": row.get("articleType", ""), + "baseColour": row.get("baseColour", ""), + "season": row.get("season", ""), + "usage": row.get("usage", ""), + "productDisplayName": row.get("productDisplayName", ""), + "imagePath": f"{product_id}.jpg", + } + return products +``` + +#### 4.4.2 文本索引 + +```python +def _create_product_text(self, product: Dict[str, Any]) -> str: + """构造产品文本用于 embedding""" + parts = [ + product.get("productDisplayName", ""), + f"Gender: {product.get('gender', '')}", + f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}", + f"Type: {product.get('articleType', '')}", + f"Color: {product.get('baseColour', '')}", + f"Season: {product.get('season', '')}", + f"Usage: {product.get('usage', '')}", + ] + return " | ".join([p for p in parts if p and p != "Gender: " and p != "Color: "]) +``` + +#### 4.4.3 批量索引流程 + +```python +# 文本索引 +texts = [self._create_product_text(p) for p in products] +embeddings = self.embedding_service.get_text_embeddings_batch(texts, batch_size=50) +milvus_data = [{ + "id": product_id, + "text": text[:2000], + "embedding": embedding, + "productDisplayName": product["productDisplayName"][:500], + "gender": product["gender"][:50], + # ... 其他元数据 +} for product_id, text, embedding in zip(...)] +self.milvus_service.insert_text_embeddings(milvus_data) + +# 图像索引 +image_paths = [self.image_dir / p["imagePath"] for p in products] +embeddings = self.embedding_service.get_image_embeddings_batch(image_paths, batch_size=32) +# 类似插入 image_embeddings 集合 +``` + +--- + +### 4.5 Streamlit 前端(app.py) + +#### 4.5.1 会话与 Agent 初始化 + +```python +def initialize_session(): + if "session_id" not in st.session_state: + st.session_state.session_id = str(uuid.uuid4()) + if "shopping_agent" not in st.session_state: + st.session_state.shopping_agent = ShoppingAgent(session_id=st.session_state.session_id) + if "messages" not in st.session_state: + st.session_state.messages = [] + if "uploaded_image" not in st.session_state: + st.session_state.uploaded_image = None +``` + +#### 4.5.2 产品信息解析 + +```python +def extract_products_from_response(response: str) -> list: + """从 Agent 回复中解析产品信息""" + products = [] + for line in response.split("\n"): + if re.match(r"^\*?\*?\d+\.\s+", line): + if current_product: + products.append(current_product) + current_product = {"name": re.sub(r"^\*?\*?\d+\.\s+", "", line).replace("**", "").strip()} + elif "ID:" in line: + id_match = re.search(r"(?:ID|id):\s*(\d+)", line) + if id_match: + current_product["id"] = id_match.group(1) + elif "Category:" in line: + cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) + if cat_match: + current_product["category"] = cat_match.group(1).strip() + # ... Color, Gender, Season, Usage, Similarity/Relevance + return products +``` + +#### 4.5.3 多轮对话中的图片引用 + +```python +# 用户输入 "make them formal" 时,若上一条消息有图片,则引用该图片 +if any(ref in query_lower for ref in ["this", "that", "the image", "it"]): + for msg in reversed(st.session_state.messages): + if msg.get("role") == "user" and msg.get("image_path"): + image_path = msg["image_path"] + break +``` + +--- + +### 4.6 配置管理(config.py) + +```python +class Settings(BaseSettings): + openai_api_key: str + openai_model: str = "gpt-4o-mini" + openai_embedding_model: str = "text-embedding-3-small" + clip_server_url: str = "grpc://localhost:51000" + milvus_uri: str = "http://localhost:19530" + text_collection_name: str = "text_embeddings" + image_collection_name: str = "image_embeddings" + text_dim: int = 1536 + image_dim: int = 512 + + @property + def milvus_uri_absolute(self) -> str: + """支持 Milvus Standalone 和 Milvus Lite""" + if self.milvus_uri.startswith(("http://", "https://")): + return self.milvus_uri + if self.milvus_uri.startswith("./"): + return os.path.join(base_dir, self.milvus_uri[2:]) + return self.milvus_uri + + class Config: + env_file = ".env" +``` + +--- + +## 五、部署与运行 + +### 5.1 依赖服务 + +```yaml +# docker-compose.yml 提供 +- etcd: 元数据存储 +- minio: 对象存储 +- milvus-standalone: 向量数据库 +- attu: Milvus 管理界面 +``` + +### 5.2 启动流程 + +```bash +# 1. 环境 +pip install -r requirements.txt +cp .env.example .env # 配置 OPENAI_API_KEY + +# 2. 下载数据 +python scripts/download_dataset.py # Kaggle Fashion Product Images Dataset + +# 3. 启动 CLIP 服务(需单独运行) +python -m clip_server + +# 4. 启动 Milvus +docker-compose up + +# 5. 索引数据 +python scripts/index_data.py + +# 6. 启动应用 +streamlit run app.py +``` + +--- + +## 六、典型交互流程 + +| 场景 | 用户输入 | Agent 行为 | 工具调用 | +|------|----------|------------|----------| +| 文本搜索 | "winter coats for women" | 直接文本搜索 | `search_products("winter coats women")` | +| 图像搜索 | [上传图片] "find similar" | 图像相似度搜索 | `search_by_image(path)` | +| 风格分析+搜索 | [上传复古夹克] "what style? find matching pants" | 先分析风格再搜索 | `analyze_image_style(path)` → `search_products("vintage pants casual")` | +| 多轮上下文 | [第1轮] "show me red dresses"
[第2轮] "make them formal" | 结合上下文 | `search_products("red formal dresses")` | + +--- + +## 七、设计要点总结 + +1. **ReAct 模式**:Agent 自主决定何时调用工具、调用哪些工具、是否继续调用。 +2. **LangGraph 状态图**:`START → Agent → [条件] → Tools → Agent → END`,支持多轮工具调用。 +3. **多模态**:文本 + 图像 + VLM 分析,覆盖文本搜索、以图搜图、风格理解。 +4. **双向量集合**:Milvus 中 text_embeddings / image_embeddings 分别存储,支持不同模态的检索。 +5. **会话持久化**:`MemorySaver` + `thread_id` 实现多轮对话记忆。 +6. **格式约束**:System prompt 严格限制产品输出格式,便于前端解析和展示。 + +--- + +## 八、附录:项目结构 + +``` +OmniShopAgent/ +├── app/ +│ ├── agents/ +│ │ └── shopping_agent.py +│ ├── config.py +│ ├── services/ +│ │ ├── embedding_service.py +│ │ └── milvus_service.py +│ └── tools/ +│ └── search_tools.py +├── scripts/ +│ ├── download_dataset.py +│ └── index_data.py +├── app.py +├── docker-compose.yml +└── requirements.txt +``` -- libgit2 0.21.2