-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_puzzles.py
87 lines (81 loc) · 1.64 KB
/
generate_puzzles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import shutil
from pathlib import Path
from gradio_client import Client
from tqdm import tqdm
themes = [
"advancedPawn",
"advantage",
"anastasiaMate",
"arabianMate",
"attackingF2F7",
"attraction",
"backRankMate",
"bishopEndgame",
"bodenMate",
"castling",
"capturingDefender",
"crushing",
"doubleBishopMate",
"dovetailMate",
"equality",
"kingsideAttack",
"clearance",
"defensiveMove",
"deflection",
"discoveredAttack",
"doubleCheck",
"endgame",
"exposedKing",
"fork",
"hangingPiece",
"hookMate",
"interference",
"intermezzo",
"knightEndgame",
"long",
"master",
"masterVsMaster",
"mate",
"mateIn1",
"mateIn2",
"mateIn3",
"mateIn4",
"mateIn5",
"middlegame",
"oneMove",
"opening",
"pawnEndgame",
"pin",
"promotion",
"queenEndgame",
"queenRookEndgame",
"queensideAttack",
"quietMove",
"rookEndgame",
"sacrifice",
"short",
"skewer",
"smotheredMate",
"superGM",
"trappedPiece",
"underPromotion",
"veryLong",
"xRayAttack",
"zugzwang",
"healthyMix",
"playerGames",
]
client = Client("http://localhost:7860/")
for theme in tqdm(themes, desc="Generating puzzles...", total=len(themes)):
csv_data, csv_file = client.predict(
themes=[theme],
popularity_range=[80, 100],
rating_range=[0, 4000],
nb_plays_range=[0, 1007625],
opening_tags=None,
max=1000,
api_name="/get_puzzles_from_db",
)
src = csv_file["value"]
dest = f"puzzles/{theme}.csv"
shutil.copy(Path(src), dest)