Commit 0b73c877f36d43093c7b4fcd4cf662a40a58a95e
1 parent
d35d18eb
fix
Showing
6 changed files
with
24 additions
and
37 deletions
Show diff stats
| ... | ... | @@ -0,0 +1,7 @@ |
| 1 | +from modelscope.pipelines import pipeline | |
| 2 | +from modelscope.utils.constant import Tasks | |
| 3 | +from modelscope.outputs import OutputKeys | |
| 4 | + | |
| 5 | +img_captioning = pipeline(Tasks.image_captioning, model='iic/ofa_image-caption_coco_distilled_en', model_revision='master') | |
| 6 | +result = img_captioning('https://modelscope.oss-cn-beijing.aliyuncs.com/demo/image-captioning/donuts.jpeg') | |
| 7 | +print(result[OutputKeys.CAPTION]) # 'a wooden table topped with different types of donuts' | |
| 0 | 8 | \ No newline at end of file | ... | ... |
offline_tasks/scripts/a.py deleted
| ... | ... | @@ -1,36 +0,0 @@ |
| 1 | -from modelscope import AutoProcessor, Gemma3nForConditionalGeneration | |
| 2 | -from PIL import Image | |
| 3 | -import requests | |
| 4 | -import torch | |
| 5 | -model_id = "google/gemma-3n-e4b-it" | |
| 6 | -model = Gemma3nForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,).eval() | |
| 7 | -processor = AutoProcessor.from_pretrained(model_id) | |
| 8 | -messages = [ | |
| 9 | - { | |
| 10 | - "role": "system", | |
| 11 | - "content": [{"type": "text", "text": "You are a helpful assistant."}] | |
| 12 | - }, | |
| 13 | - { | |
| 14 | - "role": "user", | |
| 15 | - "content": [ | |
| 16 | - {"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, | |
| 17 | - {"type": "text", "text": "Describe this image in detail."} | |
| 18 | - ] | |
| 19 | - } | |
| 20 | -] | |
| 21 | -inputs = processor.apply_chat_template( | |
| 22 | - messages, | |
| 23 | - add_generation_prompt=True, | |
| 24 | - tokenize=True, | |
| 25 | - return_dict=True, | |
| 26 | - return_tensors="pt", | |
| 27 | -).to(model.device) | |
| 28 | -input_len = inputs["input_ids"].shape[-1] | |
| 29 | -with torch.inference_mode(): | |
| 30 | - generation = model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| 31 | - generation = generation[0][input_len:] | |
| 32 | -decoded = processor.decode(generation, skip_special_tokens=True) | |
| 33 | -print(decoded) | |
| 34 | -# **Overall Impression:** The image is a close-up shot of a vibrant garden scene, | |
| 35 | -# focusing on a cluster of pink cosmos flowers and a busy bumblebee. | |
| 36 | -# It has a slightly soft, natural feel, likely captured in daylight. | |
| 37 | 0 | \ No newline at end of file |
offline_tasks/scripts/i2i_deepwalk.py
| ... | ... | @@ -280,6 +280,10 @@ def main(): |
| 280 | 280 | df = pd.read_sql(sql_query, engine) |
| 281 | 281 | logger.info(f"获取到 {len(df)} 条记录") |
| 282 | 282 | |
| 283 | + # 确保ID为整数类型 | |
| 284 | + df['item_id'] = df['item_id'].astype(int) | |
| 285 | + df['user_id'] = df['user_id'].astype(str) | |
| 286 | + | |
| 283 | 287 | # 记录数据信息 |
| 284 | 288 | log_dataframe_info(logger, df, "用户行为数据") |
| 285 | 289 | ... | ... |
offline_tasks/scripts/i2i_item_behavior.py
| ... | ... | @@ -56,6 +56,10 @@ if args.debug: |
| 56 | 56 | # 执行 SQL 查询并将结果加载到 pandas DataFrame |
| 57 | 57 | df = pd.read_sql(sql_query, engine) |
| 58 | 58 | |
| 59 | +# 确保ID为整数类型 | |
| 60 | +df['item_id'] = df['item_id'].astype(int) | |
| 61 | +df['user_id'] = df['user_id'].astype(str) # user_id保持为字符串 | |
| 62 | + | |
| 59 | 63 | if args.debug: |
| 60 | 64 | print(f"[DEBUG] 查询完成,共 {len(df)} 条记录") |
| 61 | 65 | print(f"[DEBUG] 唯一用户数: {df['user_id'].nunique()}") | ... | ... |
offline_tasks/scripts/i2i_session_w2v.py
| ... | ... | @@ -227,6 +227,10 @@ def main(): |
| 227 | 227 | df = pd.read_sql(sql_query, engine) |
| 228 | 228 | logger.info(f"获取到 {len(df)} 条记录") |
| 229 | 229 | |
| 230 | + # 确保ID为整数类型 | |
| 231 | + df['item_id'] = df['item_id'].astype(int) | |
| 232 | + df['user_id'] = df['user_id'].astype(str) | |
| 233 | + | |
| 230 | 234 | # 记录数据信息 |
| 231 | 235 | log_dataframe_info(logger, df, "用户行为数据") |
| 232 | 236 | ... | ... |
offline_tasks/scripts/i2i_swing.py
| ... | ... | @@ -18,7 +18,7 @@ from config.offline_config import ( |
| 18 | 18 | from scripts.debug_utils import ( |
| 19 | 19 | setup_debug_logger, log_dataframe_info, log_dict_stats, |
| 20 | 20 | save_readable_index, load_name_mappings_from_file, log_algorithm_params, |
| 21 | - log_processing_step, clean_item_name | |
| 21 | + log_processing_step | |
| 22 | 22 | ) |
| 23 | 23 | |
| 24 | 24 | |
| ... | ... | @@ -285,6 +285,10 @@ def main(): |
| 285 | 285 | df = pd.read_sql(sql_query, engine) |
| 286 | 286 | logger.info(f"获取到 {len(df)} 条记录") |
| 287 | 287 | |
| 288 | + # 确保ID为整数类型 | |
| 289 | + df['item_id'] = df['item_id'].astype(int) | |
| 290 | + df['user_id'] = df['user_id'].astype(str) | |
| 291 | + | |
| 288 | 292 | # Debug: 显示数据详情 |
| 289 | 293 | if args.debug: |
| 290 | 294 | log_dataframe_info(logger, df, "用户行为数据", sample_size=10) | ... | ... |