Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 192 additions & 4 deletions stockfish/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
:copyright: (c) 2016-2021 by Ilya Zhelyabuzhsky.
:license: MIT, see LICENSE for more details.
"""

from __future__ import annotations
import subprocess
from typing import Any, List, Optional
from typing import Any, List, Optional, Generator
import copy
from os import path
from dataclasses import dataclass
Expand All @@ -25,7 +25,10 @@ class Stockfish:
# Used in test_models: will count how many times the del function is called.

def __init__(
self, path: str = "stockfish", depth: int = 15, parameters: dict = None
self,
path: str = "stockfish",
depth: int = 15,
parameters: Optional[dict] = None,
) -> None:
self._DEFAULT_STOCKFISH_PARAMS = {
"Debug Log File": "",
Expand Down Expand Up @@ -324,7 +327,9 @@ def set_elo_rating(self, elo_rating: int = 1350) -> None:
{"UCI_LimitStrength": "true", "UCI_Elo": elo_rating}
)

def get_best_move(self, wtime: int = None, btime: int = None) -> Optional[str]:
def get_best_move(
self, wtime: Optional[int] = None, btime: Optional[int] = None
) -> Optional[str]:
"""Returns best move with current position on the board.
wtime and btime arguments influence the search only if provided.

Expand Down Expand Up @@ -583,6 +588,189 @@ def get_top_moves(self, num_top_moves: int = 5) -> List[dict]:
self._parameters.update({"MultiPV": old_MultiPV_value})
return top_moves

class TopMove:
def __init__(self, line: str) -> None:
splits = line.split(" ")
pv_index = splits.index("pv")
self.move = splits[pv_index + 1]
self.line = splits[pv_index + 1 :]
self.depth = int(splits[splits.index("depth") + 1])
self.seldepth = int(splits[splits.index("seldepth") + 1])

self.cp = None
self.mate = None

try:
self.cp = int(splits[splits.index("cp") + 1])
except ValueError:
self.mate = int(splits[splits.index("mate") + 1])

def dict(self) -> dict:
return {
"move": self.move,
"depth": self.depth,
"seldepth": self.seldepth,
"line": self.line,
"cp": self.cp,
"mate": self.mate,
}

# compare if this move is better than the other move
def __gt__(self, other: Stockfish.TopMove) -> bool:

if other.mate is None:
# this move is mate and the other is not
if self.mate is not None:
# a negative mate value is a losing move
return self.mate < 0

# both moves has no mate, compare the depth first than centipawn
if self.depth == other.depth:
if self.cp == other.cp:
if self.seldepth is None or other.seldepth is None:
raise RuntimeError("None value when it should be an int.")
return self.seldepth > other.seldepth
else:
if self.cp is None or other.cp is None:
raise RuntimeError("None value when it should be an int.")
return self.cp > other.cp
else:
return self.depth > other.depth

else:
# both this move and other move is mate
if self.mate is not None:
# both losing move, which takes more moves is better
# both winning move, which takes less move is better
if (
self.mate < 0
and other.mate < 0
or self.mate > 0
and other.mate > 0
):
return self.mate < other.mate
else:
# comparing a losing move with a winning move, positive mate score is winning
return self.mate > other.mate
else:
return other.mate < 0

# the oposite of __gt__
def __lt__(self, other: Stockfish.TopMove) -> bool:
return not self.__gt__(other)

# equal move, by "move", not by score/evaluation
def __eq__(self, other: object) -> bool:
if not isinstance(other, Stockfish.TopMove):
return False
return self.move == other.move

def generate_top_moves(
self, num_top_moves: int = 5
) -> Generator[List[TopMove], None, None]:
"""Returns a generator that yields top moves in the position at each depth

Args:
num_top_moves:
The number of moves to return info on, assuming there are at least
those many legal moves.

Returns:
A generator that yields top moves in the position at each depth.

The evaluation could be stopped early by calling Generator.close();
this however will take some time for stockfish to stop.

Unlike `get_top_moves` - which returns a list of dict, this will yield
a list of `Stockfish.TopMove` instead, and the score (cp/mate) is relative
to which side is playing instead of absolute like `get_top_moves`.

The score is either `cp` or `mate`; a higher `cp` is better, a positive `mate`
is winning and vice versa.

If there are no moves in the position, an empty list is returned.
"""

if num_top_moves <= 0:
raise ValueError("num_top_moves is not a positive number.")

old_MultiPV_value = self._parameters["MultiPV"]
if num_top_moves != self._parameters["MultiPV"]:
self._set_option("MultiPV", num_top_moves)
self._parameters.update({"MultiPV": num_top_moves})

foundBestMove = False

try:
self._go()

top_moves: List[Stockfish.TopMove] = []
current_depth = 1

while True:
line = self._read_line()

if "multipv" in line and "depth" in line:
move = Stockfish.TopMove(line)

# try to find the move in the list, if it exists then update it, else append to the list
try:
idx = top_moves.index(move)

# don't update if the new move has a smaller depth than the one in the list
if move.depth >= top_moves[idx].depth:
top_moves[idx] = move

except ValueError:
top_moves.append(move)

# yield the top moves once the current depth changed, the current depth might be smaller than the old depth
if move.depth != current_depth:
current_depth = move.depth
top_moves.sort(reverse=True)
yield top_moves[:num_top_moves]

elif line.startswith("bestmove"):
foundBestMove = True
best_move = line.split(" ")[1]

# no more moves, the game is ended
if best_move == "(none)":
yield []
else:
# sort the list once again
top_moves.sort(reverse=True)

# if the move at index 0 is not the best move returned by stockfish
if best_move != top_moves[0].move:
for move in top_moves:
if best_move == move.move:
top_moves.remove(move)
top_moves.insert(0, move)
break
else:
raise ValueError(
f"Stockfish returned the best move: {best_move}, but it's not in the list"
)

yield top_moves[:num_top_moves]

break

except BaseException as e:
raise e from e

finally:
# stockfish has not returned the best move, but the generator was signaled to close
if not foundBestMove:
self._put("stop")
while not self._read_line().startswith("bestmove"):
pass

if old_MultiPV_value != self._parameters["MultiPV"]:
self._set_option("MultiPV", old_MultiPV_value)
self._parameters.update({"MultiPV": old_MultiPV_value})

@dataclass
class BenchmarkParameters:
ttSize: int = 16
Expand Down