LH-Tech-AI commited on
Commit
e699431
·
verified ·
1 Parent(s): 2259c4b

Create benchmark.py

Browse files

Do benchmark with model over 500 images.

Files changed (1) hide show
  1. benchmark.py +113 -0
benchmark.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print(f"[*] Setting up...")
2
+
3
+ import torch
4
+ import requests
5
+ import random
6
+ import numpy as np
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from transformers import ResNetForImageClassification
11
+ from collections import Counter
12
+
13
+ # --- 1. CONFIGURATION & SETUP ---
14
+ ANGLES = [0, 90, 180, 270]
15
+ NUM_IMAGES = 500
16
+ MODEL_NAME = "LH-Tech-AI/GyroScope"
17
+ IMG_SOURCE_URL = "https://loremflickr.com/400/400/all"
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"[*] Using device: {device}")
21
+
22
+ # Modell laden
23
+ print(f"[*] Loading model {MODEL_NAME}...")
24
+ model = ResNetForImageClassification.from_pretrained(MODEL_NAME)
25
+ model.eval()
26
+ model.to(device)
27
+
28
+ # Vorverarbeitung
29
+ preprocess = transforms.Compose([
30
+ transforms.Resize(256),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
34
+ ])
35
+
36
+ results = []
37
+
38
+ # --- 2. EVALUATIONS-LOOP ---
39
+ print(f"[*] Starting download and evaluation of {NUM_IMAGES} images (In-Memory)...")
40
+
41
+ for i in range(1, NUM_IMAGES + 1):
42
+ try:
43
+ # Load image into RAM
44
+ response = requests.get(f"{IMG_SOURCE_URL}?random={i}", timeout=10)
45
+ img = Image.open(BytesIO(response.content)).convert("RGB")
46
+
47
+ # Apply random rotation
48
+ true_angle = random.choice(ANGLES)
49
+ label_idx = ANGLES.index(true_angle)
50
+
51
+ # Rotate image
52
+ rotated_img = img.rotate(true_angle, expand=True)
53
+
54
+ # Prediction
55
+ tensor = preprocess(rotated_img).unsqueeze(0).to(device)
56
+ with torch.no_grad():
57
+ logits = model(pixel_values=tensor).logits
58
+ pred_idx = logits.argmax().item()
59
+
60
+ is_correct = (pred_idx == label_idx)
61
+ results.append({
62
+ "true": true_angle,
63
+ "pred": ANGLES[pred_idx],
64
+ "correct": is_correct
65
+ })
66
+
67
+ status = "✓" if is_correct else "✗"
68
+ percent = (i / NUM_IMAGES) * 100
69
+ bar_length = 20
70
+ filled_length = int(bar_length * i // NUM_IMAGES)
71
+ bar = '#' * filled_length + ' ' * (bar_length - filled_length)
72
+
73
+ status = "✓" if is_correct else "✗"
74
+ print(f"\rProgress: [{bar}] {percent:.1f}% ({i}/{NUM_IMAGES}) | Last result: {status}", end="")
75
+
76
+ except Exception as e:
77
+ print(f"\n[!] Error processing image {i}: {e}")
78
+
79
+ # --- 3. RESULTS ---
80
+ print("\n\n" + "="*15)
81
+ print(" RESULTS")
82
+ print("="*15)
83
+
84
+ total_correct = sum(1 for r in results if r['correct'])
85
+ accuracy = (total_correct / len(results)) * 100
86
+
87
+ print(f"Overall result: {total_correct}/{len(results)} correct")
88
+ print(f"Hit rate: {accuracy:.2f} %")
89
+ print("-" * 30)
90
+
91
+ print("Details per rotation class:")
92
+ for angle in ANGLES:
93
+ class_results = [r for r in results if r['true'] == angle]
94
+ if class_results:
95
+ correct_in_class = sum(1 for r in class_results if r['correct'])
96
+ class_acc = (correct_in_class / len(class_results)) * 100
97
+ print(f" {angle:>3}° : {correct_in_class:>2}/{len(class_results):>2} correct ({class_acc:>6.2f}%)")
98
+
99
+ print("="*30)
100
+
101
+ # Result of our benchmark:
102
+ # ===============
103
+ # RESULTS
104
+ # ===============
105
+ # Overall result: 411/500 correct
106
+ # Hit rate: 82.20 %
107
+ # ------------------------------
108
+ # Details per rotation class:
109
+ # 0° : 96/124 correct ( 77.42%)
110
+ # 90° : 103/119 correct ( 86.55%)
111
+ # 180° : 112/129 correct ( 86.82%)
112
+ # 270° : 100/128 correct ( 78.12%)
113
+ # ==============================