Spaces:
Running
Running
| """ | |
| AIFinder Data Loader | |
| Downloads and parses HuggingFace datasets, extracts assistant responses, | |
| and labels them with is_ai, provider, and model. | |
| """ | |
| import re | |
| import time | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from config import ( | |
| DATASET_REGISTRY, | |
| DEEPSEEK_AM_DATASETS, | |
| ) | |
| def _parse_msg(msg): | |
| """Parse a message that may be a dict or a JSON string.""" | |
| if isinstance(msg, dict): | |
| return msg | |
| if isinstance(msg, str): | |
| try: | |
| import json | |
| parsed = json.loads(msg) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| return {} | |
| def _extract_assistant_texts_from_conversations(rows): | |
| """Extract assistant message content from conversation datasets. | |
| These have a 'conversations' or 'messages' column with list of | |
| {role, content} dicts (or JSON strings encoding such dicts). | |
| """ | |
| texts = [] | |
| for row in rows: | |
| convos = row.get("conversations") | |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): | |
| convos = row.get("messages") | |
| if convos is None or (hasattr(convos, "__len__") and len(convos) == 0): | |
| convos = [] | |
| parts = [] | |
| for msg in convos: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "gpt", "model") and content: | |
| parts.append(content) | |
| if parts: | |
| texts.append("\n\n".join(parts)) | |
| return texts | |
| def _extract_from_am_dataset(row): | |
| """Extract assistant text from a-m-team format (messages list with role/content).""" | |
| messages = row.get("messages") or row.get("conversations") or [] | |
| parts = [] | |
| for msg in messages: | |
| role = msg.get("role", "") if isinstance(msg, dict) else "" | |
| content = msg.get("content", "") if isinstance(msg, dict) else "" | |
| if role == "assistant" and content: | |
| parts.append(content) | |
| return "\n\n".join(parts) if parts else "" | |
| def load_teichai_dataset(dataset_id, provider, model_name, kwargs): | |
| """Load a single conversation-format dataset and return (texts, providers, models).""" | |
| max_samples = kwargs.get("max_samples") | |
| load_kwargs = {} | |
| if "name" in kwargs: | |
| load_kwargs["name"] = kwargs["name"] | |
| try: | |
| ds = load_dataset(dataset_id, split="train", **load_kwargs) | |
| rows = list(ds) | |
| except Exception as e: | |
| # Fallback: load from auto-converted parquet via HF API | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| rows = df.to_dict(orient="records") | |
| except Exception as e2: | |
| print(f" [SKIP] {dataset_id}: {e} / parquet fallback: {e2}") | |
| return [], [], [] | |
| if max_samples and len(rows) > max_samples: | |
| import random | |
| random.seed(42) | |
| rows = random.sample(rows, max_samples) | |
| texts = _extract_assistant_texts_from_conversations(rows) | |
| # Filter out empty/too-short texts | |
| filtered = [(t, provider, model_name) for t in texts if len(t) > 50] | |
| if not filtered: | |
| print(f" [SKIP] {dataset_id}: no valid texts extracted") | |
| return [], [], [] | |
| t, p, m = zip(*filtered) | |
| return list(t), list(p), list(m) | |
| def load_am_deepseek_dataset(dataset_id, provider, model_name, kwargs): | |
| """Load a-m-team DeepSeek dataset.""" | |
| max_samples = kwargs.get("max_samples") | |
| load_kwargs = {} | |
| if "name" in kwargs: | |
| load_kwargs["name"] = kwargs["name"] | |
| try: | |
| ds = load_dataset(dataset_id, split="train", **load_kwargs) | |
| except Exception as e1: | |
| # Try without name kwarg as fallback | |
| try: | |
| ds = load_dataset(dataset_id, split="train", streaming=True) | |
| rows = [] | |
| for row in ds: | |
| rows.append(row) | |
| if max_samples and len(rows) >= max_samples: | |
| break | |
| except Exception as e2: | |
| print(f" [SKIP] {dataset_id}: {e2}") | |
| return [], [], [] | |
| else: | |
| rows = list(ds) | |
| if max_samples and len(rows) > max_samples: | |
| rows = rows[:max_samples] | |
| texts = [] | |
| for row in rows: | |
| text = _extract_from_am_dataset(row) | |
| if len(text) > 50: | |
| texts.append(text) | |
| providers = [provider] * len(texts) | |
| models = [model_name] * len(texts) | |
| return texts, providers, models | |
| def load_all_data(): | |
| """Load all datasets and return combined lists. | |
| Returns: | |
| texts: list of str | |
| providers: list of str | |
| models: list of str | |
| is_ai: list of int (1=AI, 0=Human) | |
| """ | |
| all_texts = [] | |
| all_providers = [] | |
| all_models = [] | |
| # TeichAI datasets | |
| print("Loading TeichAI datasets...") | |
| for dataset_id, provider, model_name, kwargs in tqdm( | |
| DATASET_REGISTRY, desc="TeichAI" | |
| ): | |
| t0 = time.time() | |
| texts, providers, models = load_teichai_dataset( | |
| dataset_id, provider, model_name, kwargs | |
| ) | |
| elapsed = time.time() - t0 | |
| all_texts.extend(texts) | |
| all_providers.extend(providers) | |
| all_models.extend(models) | |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") | |
| # DeepSeek a-m-team datasets | |
| print("\nLoading DeepSeek (a-m-team) datasets...") | |
| for dataset_id, provider, model_name, kwargs in tqdm( | |
| DEEPSEEK_AM_DATASETS, desc="DeepSeek-AM" | |
| ): | |
| t0 = time.time() | |
| texts, providers, models = load_am_deepseek_dataset( | |
| dataset_id, provider, model_name, kwargs | |
| ) | |
| elapsed = time.time() - t0 | |
| all_texts.extend(texts) | |
| all_providers.extend(providers) | |
| all_models.extend(models) | |
| print(f" {dataset_id}: {len(texts)} samples ({elapsed:.1f}s)") | |
| # Build is_ai labels (all AI) | |
| is_ai = [1] * len(all_texts) | |
| print(f"\n=== Total: {len(all_texts)} samples ===") | |
| # Print per-provider counts | |
| from collections import Counter | |
| prov_counts = Counter(all_providers) | |
| for p, c in sorted(prov_counts.items(), key=lambda x: -x[1]): | |
| print(f" {p}: {c}") | |
| return all_texts, all_providers, all_models, is_ai | |
| if __name__ == "__main__": | |
| texts, providers, models, is_ai = load_all_data() | |