data_transformer.py 10.6 KB
"""
Data transformer for converting source data to ES documents.

Handles field mapping, type conversion, and embedding generation.
"""

import pandas as pd
import numpy as np
import datetime
from typing import Dict, Any, List, Optional
from config import SearchConfig, FieldConfig, FieldType
from embeddings import BgeEncoder, CLIPImageEncoder
from utils.cache import EmbeddingCache


class DataTransformer:
    """Transform source data into ES-ready documents."""

    def __init__(
        self,
        config: SearchConfig,
        text_encoder: Optional[BgeEncoder] = None,
        image_encoder: Optional[CLIPImageEncoder] = None,
        use_cache: bool = True
    ):
        """
        Initialize data transformer.

        Args:
            config: Search configuration
            text_encoder: Text embedding encoder (lazy loaded if not provided)
            image_encoder: Image embedding encoder (lazy loaded if not provided)
            use_cache: Whether to use embedding cache
        """
        self.config = config
        self._text_encoder = text_encoder
        self._image_encoder = image_encoder
        self.use_cache = use_cache

        if use_cache:
            self.text_cache = EmbeddingCache(".cache/text_embeddings")
            self.image_cache = EmbeddingCache(".cache/image_embeddings")
        else:
            self.text_cache = None
            self.image_cache = None

    @property
    def text_encoder(self) -> BgeEncoder:
        """Lazy load text encoder."""
        if self._text_encoder is None:
            print("[DataTransformer] Initializing text encoder...")
            self._text_encoder = BgeEncoder()
        return self._text_encoder

    @property
    def image_encoder(self) -> CLIPImageEncoder:
        """Lazy load image encoder."""
        if self._image_encoder is None:
            print("[DataTransformer] Initializing image encoder...")
            self._image_encoder = CLIPImageEncoder()
        return self._image_encoder

    def transform_batch(
        self,
        df: pd.DataFrame,
        batch_size: int = 32
    ) -> List[Dict[str, Any]]:
        """
        Transform a batch of source data into ES documents.

        Args:
            df: DataFrame with source data
            batch_size: Batch size for embedding generation

        Returns:
            List of ES documents
        """
        documents = []

        # First pass: generate all embeddings in batch
        embedding_data = self._generate_embeddings_batch(df, batch_size)

        # Second pass: build documents
        for idx, row in df.iterrows():
            doc = self._transform_row(row, embedding_data.get(idx, {}))
            if doc:
                documents.append(doc)

        return documents

    def _generate_embeddings_batch(
        self,
        df: pd.DataFrame,
        batch_size: int
    ) -> Dict[int, Dict[str, Any]]:
        """
        Generate all embeddings in batch for efficiency.

        Args:
            df: Source dataframe
            batch_size: Batch size

        Returns:
            Dictionary mapping row index to embedding data
        """
        result = {}

        # Collect all text embedding fields
        text_embedding_fields = [
            field for field in self.config.fields
            if field.field_type == FieldType.TEXT_EMBEDDING
        ]

        # Collect all image embedding fields
        image_embedding_fields = [
            field for field in self.config.fields
            if field.field_type == FieldType.IMAGE_EMBEDDING
        ]

        # Process text embeddings
        for field in text_embedding_fields:
            source_col = field.source_column
            if source_col not in df.columns:
                continue

            print(f"[DataTransformer] Generating text embeddings for field: {field.name}")

            # Get texts and check cache
            texts_to_encode = []
            text_indices = []

            for idx, row in df.iterrows():
                text = row[source_col]
                if pd.isna(text) or text == '':
                    continue

                text_str = str(text)

                # Check cache
                if self.use_cache and self.text_cache.exists(text_str):
                    cached_emb = self.text_cache.get(text_str)
                    if idx not in result:
                        result[idx] = {}
                    result[idx][field.name] = cached_emb
                else:
                    texts_to_encode.append(text_str)
                    text_indices.append(idx)

            # Encode batch
            if texts_to_encode:
                embeddings = self.text_encoder.encode_batch(
                    texts_to_encode,
                    batch_size=batch_size
                )

                # Store results
                for i, (idx, emb) in enumerate(zip(text_indices, embeddings)):
                    if idx not in result:
                        result[idx] = {}
                    result[idx][field.name] = emb

                    # Cache
                    if self.use_cache:
                        self.text_cache.set(texts_to_encode[i], emb)

        # Process image embeddings
        for field in image_embedding_fields:
            source_col = field.source_column
            if source_col not in df.columns:
                continue

            print(f"[DataTransformer] Generating image embeddings for field: {field.name}")

            # Get URLs and check cache
            urls_to_encode = []
            url_indices = []

            for idx, row in df.iterrows():
                url = row[source_col]
                if pd.isna(url) or url == '':
                    continue

                url_str = str(url)

                # Check cache
                if self.use_cache and self.image_cache.exists(url_str):
                    cached_emb = self.image_cache.get(url_str)
                    if idx not in result:
                        result[idx] = {}
                    result[idx][field.name] = cached_emb
                else:
                    urls_to_encode.append(url_str)
                    url_indices.append(idx)

            # Encode batch (with smaller batch size for images)
            if urls_to_encode:
                embeddings = self.image_encoder.encode_batch(
                    urls_to_encode,
                    batch_size=min(8, batch_size)
                )

                # Store results
                for i, (idx, emb) in enumerate(zip(url_indices, embeddings)):
                    if emb is not None:
                        if idx not in result:
                            result[idx] = {}
                        result[idx][field.name] = emb

                        # Cache
                        if self.use_cache:
                            self.image_cache.set(urls_to_encode[i], emb)

        return result

    def _transform_row(
        self,
        row: pd.Series,
        embedding_data: Dict[str, Any]
    ) -> Optional[Dict[str, Any]]:
        """
        Transform a single row into an ES document.

        Args:
            row: Source data row
            embedding_data: Pre-computed embeddings for this row

        Returns:
            ES document or None if transformation fails
        """
        doc = {}

        for field in self.config.fields:
            field_name = field.name
            source_col = field.source_column

            # Handle embedding fields
            if field.field_type in [FieldType.TEXT_EMBEDDING, FieldType.IMAGE_EMBEDDING]:
                if field_name in embedding_data:
                    emb = embedding_data[field_name]
                    if isinstance(emb, np.ndarray):
                        doc[field_name] = emb.tolist()
                continue

            # Handle regular fields
            if source_col not in row:
                if field.required:
                    print(f"Warning: Required field '{field_name}' missing in row")
                    return None
                continue

            value = row[source_col]

            # Skip null values for non-required fields
            if pd.isna(value):
                if field.required:
                    print(f"Warning: Required field '{field_name}' is null")
                    return None
                continue

            # Type conversion
            converted_value = self._convert_value(value, field)
            if converted_value is not None:
                doc[field_name] = converted_value

        return doc

    def _convert_value(self, value: Any, field: FieldConfig) -> Any:
        """Convert value to appropriate type for ES."""
        if pd.isna(value):
            return None

        field_type = field.field_type

        if field_type == FieldType.TEXT:
            return str(value)

        elif field_type == FieldType.KEYWORD:
            return str(value)

        elif field_type in [FieldType.INT, FieldType.LONG]:
            try:
                return int(value)
            except (ValueError, TypeError):
                return None

        elif field_type in [FieldType.FLOAT, FieldType.DOUBLE]:
            try:
                return float(value)
            except (ValueError, TypeError):
                return None

        elif field_type == FieldType.BOOLEAN:
            if isinstance(value, bool):
                return value
            if isinstance(value, (int, float)):
                return bool(value)
            if isinstance(value, str):
                return value.lower() in ['true', '1', 'yes', 'y']
            return None

        elif field_type == FieldType.DATE:
            # Pandas datetime handling
            if isinstance(value, pd.Timestamp):
                return value.isoformat()
            elif isinstance(value, str):
                # Try to parse string datetime and convert to ISO format
                try:
                    # Handle common datetime formats
                    formats = [
                        '%Y-%m-%d %H:%M:%S',    # 2020-07-07 16:44:09
                        '%Y-%m-%d %H:%M:%S.%f',  # 2020-07-07 16:44:09.123
                        '%Y-%m-%dT%H:%M:%S',    # 2020-07-07T16:44:09
                        '%Y-%m-%d',             # 2020-07-07
                    ]
                    for fmt in formats:
                        try:
                            dt = datetime.datetime.strptime(value.strip(), fmt)
                            return dt.isoformat()
                        except ValueError:
                            continue
                    # If no format matches, return original string
                    return value
                except Exception:
                    return value
            return value

        else:
            return value