1 % (c) 2022-2024 Lehrstuhl fuer Softwaretechnik und Programmiersprachen,
2 % Heinrich Heine Universitaet Duesseldorf
3 % This software is licenced under EPL 1.0 (http://www.eclipse.org/org/documents/epl-v10.html)
4
5 % MCTS: Monte-Carlo Tree Search
6 % initial version by Michael Leuschel, May 2022
7
8 % requires:
9 % GAME_OVER
10 % GAME_VALUE
11 % GAME_PLAYER
12 % GAME_MCTS_RUNS
13 % GAME_MCTS_TIMEOUT
14 % GAME_MCTS_CACHE_LAST_TREE
15 % GAME_MCTS_MAX_SIMULATION_LENGTH
16
17 % TODO: allow to use for CSP, ...
18 % improve performance of evaluation, preparing state and avoiding repeated unpacking
19
20 :- module(mcts_game_play, [mcts_auto_play/4, mcts_auto_play_available/0,
21 tcltk_gen_last_mcts_tree/2, mcts_tree_available/0]).
22
23
24 :- use_module(probsrc(module_information)).
25 :- module_info(group,misc).
26 :- module_info(description,'This module provides Monte Carlo Tree Search for games.').
27
28 % use_module(extrasrc(mcts_game_play)), current_state_id(I), mcts_auto_play(I,Action,TID,Dest).
29
30 :- use_module(probsrc(error_manager)).
31 :- use_module(probsrc(debug)).
32 :- use_module(probsrc(state_space),[current_state_id/1, transition/4,
33 visited_expression/2]).
34 :- use_module(probsrc(tcltk_interface),[compute_all_transitions_if_necessary/2]).
35
36 game_move(FromID,Action,TransID,DestID) :-
37 compute_all_transitions_if_necessary(FromID,true),
38 transition(FromID,Action,TransID,DestID).
39
40
41
42 :- use_module(library(random),[random_select/3]).
43 % use if the game does not provide random_game_move
44 % drawback: it computes all transitions first, thereby adding all successors
45 % and then selects a random move amongst them
46 % TODO: investigate whether we can implement a random move using random enumeration
47 my_random_game_move(X,Z) :-
48 winning_move(X,DestID),!,
49 Z=DestID. % Mini-Max style optimisation: do not perform random move; pick the winning move directly
50 my_random_game_move(X,Z) :-
51 findall(Y,game_move(X,_,_,Y),List),
52 random_select(Z,List,_).
53
54 winning_move(X,DestID) :-
55 game_move(X,_A,_,DestID), % as we have computed all successors, we can just as well look for a winning one
56 terminal(DestID,_Val,winning). % TODO: extend this detection also to non-terminal nodes?!
57
58 other_player(min,max).
59 other_player(max,min).
60
61 :- use_module(probsrc(state_space_exploration_modes),
62 [compute_heuristic_function_for_state_id/2]).
63 :- use_module(probsrc(eventhandling),[register_event_listener/3]).
64 :- register_event_listener(reset_specification,reset_mcts,'Reset MCTS information').
65 :- register_event_listener(reset_prob,reset_mcts,'Reset MCTS information').
66
67 :- dynamic terminal_node_cache/3.
68 reset_mcts :- retractall(terminal_node_cache(_,_,_)), retractall(saved_mcts_tree(_,_)).
69
70 :- use_module(probsrc(specfile),[state_corresponds_to_fully_setup_b_machine/2]).
71
72 utility(State,Value) :-
73 eval_animation_expression(State,'GAME_VALUE',Res,mandatory),
74 get_number(Res,Value).
75
76 % check if node is terminal and compute utility if it is:
77 terminal(NodeID,Val,WinForParent) :- terminal_node_cache(NodeID,T,WP),!,
78 T \= non_terminal,
79 Val=T, WinForParent=WP.
80 terminal(NodeID,Val,WinForParent) :-
81 visited_expression(NodeID,State),
82 eval_animation_expression(State,'GAME_OVER',Res,mandatory),
83 (is_true(Res)
84 -> (utility(State,Val) -> true
85 ; add_warning(mcts_game_play,'Could not compute utility for terminal node: ',NodeID),
86 Val=0
87 ),
88 (winning_node_for_parent(State,Val) -> WinForParent=winning ; WinForParent=not_directly_winning),
89 assert(terminal_node_cache(NodeID,Val,WinForParent))
90 ; assert(terminal_node_cache(NodeID,non_terminal,not_directly_winning)),
91 fail).
92
93 is_true(pred_true).
94 is_true(true). % XTL
95 is_true('TRUE').
96
97 get_player_in_state_id(NodeID,R) :-
98 visited_expression(NodeID,State),!,
99 get_player(State,R).
100 get_player_in_state_id(NodeID,R) :-
101 add_internal_error('Illegal state:',get_player_in_state_id(NodeID,R)),fail.
102
103 get_player(State,R) :-
104 eval_animation_expression(State,'GAME_PLAYER',Res,mandatory),
105 is_max_player(Res),
106 !, R=max.
107 get_player(_,min).
108
109 :- use_module(probsrc(b_global_sets),[is_b_global_constant/3]).
110 is_max_player(string(Atom)) :- !, is_max_aux(Atom).
111 is_max_player(fd(Nr,Set)) :- is_b_global_constant(Set,Nr,Cst), !,is_max_aux(Cst).
112 is_max_player(pred_true) :- !.
113 is_max_player(pred_false) :- !, fail.
114 is_max_player(Atom) :- is_max_aux(Atom). % XTL
115
116 is_max_aux(Atom) :- is_max_atom(Atom),!.
117 is_max_aux(Atom) :- is_min_atom(Atom),!,fail.
118 is_max_aux(Atom) :- add_error(mcts_game_play,'Illegal GAME_PLAYER value, must be min or max:',Atom),fail.
119
120 is_max_atom(max).
121 is_max_atom('MAX').
122 is_max_atom('Max').
123 is_min_atom(min).
124 is_min_atom('MIN').
125 is_min_atom('Min').
126
127
128 :- use_module(probsrc(xtl_interface),[xtl_game_info/3]).
129
130 % TODO: also provide version which checks if available in state
131 mcts_auto_play_available :-
132 (b_get_machine_animation_expression('GAME_PLAYER',_) -> true
133 ; b_get_machine_animation_expression('GAME_MCTS_RUNS',_) -> true
134 ; xtl_mode -> xtl_game_info('GAME_MCTS_RUNS',_,_) -> true).
135
136 :- use_module(probsrc(bmachine),[b_get_machine_animation_expression/2,b_get_definition/5]).
137 :- use_module(probsrc(specfile),[csp_mode/0, xtl_mode/0]).
138 eval_animation_expression_in_state_id(NodeID,STR,Val,Mandatory) :-
139 visited_expression(NodeID,State),
140 eval_animation_expression(State,STR,Val,Mandatory).
141
142 eval_animation_expression(State,STR,Val,_) :-
143 b_get_machine_animation_expression(STR,TExpr),!,
144 state_corresponds_to_fully_setup_b_machine(State,BState),
145 b_interpreter:b_compute_expression_nowf(TExpr,[],BState,Val,'MCTS',0).
146 eval_animation_expression(State,STR,Val,Mandatory) :- xtl_mode,!,
147 (xtl_game_info(STR,State,Res) -> Val=Res
148 ; Mandatory=mandatory,
149 add_warning(mcts_game_play,'Add definition for prob_game_info(Key,State,Val) with Key = ',STR),fail).
150 eval_animation_expression(_,STR,_,_) :-
151 b_get_definition(STR,DefType,_Args,_Body,_Deps),!,
152 (DefType = expression
153 -> add_warning(mcts_game_play,'Please rewrite DEFINITION to use no arguments: ',STR)
154 ; add_warning(mcts_game_play,'Please rewrite DEFINITION to return an expression and use no arguments: ',STR)
155 ),
156 fail.
157 eval_animation_expression(_,STR,_,mandatory) :-
158 add_warning(mcts_game_play,'Please add DEFINITION for: ',STR),fail.
159
160 % TODO: get definitions also from VisB JSON files (e.g., for Event-B, ...)
161
162 % ------------------------------
163
164
165 mcts_auto_play(State,Action,TransID,State2) :-
166 get_animation_expression_nr(State,'GAME_MCTS_RUNS',1000,SimRuns),
167 get_animation_expression_nr(State,'GAME_MCTS_TIMEOUT',5000,Timeout),
168 retractall(max_simulation_length(_)),
169 get_animation_expression_nr(State,'GAME_MCTS_MAX_SIMULATION_LENGTH',100,MaxLengthOfSimulation),
170 assert(max_simulation_length(MaxLengthOfSimulation)),
171 (eval_animation_expression_in_state_id(State,'GAME_MCTS_CACHE_LAST_TREE',Res,not_mandatory)
172 -> (is_true(Res) -> Cache=cache ; Cache=no_cache)
173 ; Cache=cache),
174 mcts_auto_play(State,Cache,Timeout,SimRuns,Action,TransID,State2).
175
176 get_animation_expression_nr(State,DEFNAME,Default,Res) :-
177 (eval_animation_expression_in_state_id(State,DEFNAME,R,not_mandatory)
178 -> (get_number(R,Res), Res >= 0 -> true
179 ; add_warning(mcts_game_play,'GAME_MCTS_RUNS/GAME_MCTS_TIMEOUT should return positive integer: ',R),
180 Res=Default
181 )
182 ; Res = Default).
183
184 get_number(int(S),S).
185 get_number(term(floating(R)),R).
186 get_number(S,S) :- number(S).
187
188 :- dynamic saved_mcts_tree/2.
189 save_last_mcts_tree(cache,State2,FinalTree) :- !,
190 retractall(saved_mcts_tree(_,_)),
191 %gen_dot(FinalTree,2).
192 % tree for child can be obtained via: get_mcts_child_for_state(FinalTree,State2,NewTree),
193 assert(saved_mcts_tree(FinalTree,State2)).
194 save_last_mcts_tree(_,_,_).
195
196
197
198 mcts_auto_play(State,Cache,Timeout,SimRuns,Action,TransID,State2) :-
199 (retract(saved_mcts_tree(OldTree,State1)),
200 Cache=cache,
201 get_mcts_child_for_state(OldTree,State1,OldTree1), % TODO: check if this is not already the right tree; check if we do not need another move
202 get_mcts_child_for_state(OldTree1,State,InitialTree)
203 -> true %,print_tree(InitialTree,0,2)
204 ; InitialTree = leaf(State)),
205 mcts_incr_auto_play(State,Timeout,SimRuns,Action,TransID,State2,InitialTree,FinalTree),
206 save_last_mcts_tree(Cache,State2,FinalTree).
207
208 % a version where we can provide the initial MCTS Tree for incremental reuse
209 mcts_incr_auto_play(State,Timeout,SimRuns,Action,TransID,State2,Tree,FinalTree) :-
210 format('Starting MCTS AUTO PLAY, SimRuns=~w, Timeout=~w~n',[SimRuns,Timeout]),
211 statistics(walltime, [Start|_]),
212 (mcts_run(Timeout,SimRuns,Tree,FinalTree,Visits,State2)
213 -> statistics(walltime, [End|_]),
214 Delta is End - Start,
215 game_move(State,Action,TransID,State2),
216 format('Move found by MCTS in ~w ms (~w runs, ~w visits): ~w (~w --(~w)--> ~w)~n',[Delta,SimRuns,Visits,Action,State,TransID,State2])
217 ).
218
219
220 %mcts_run(Target) :- start(Init), mcts_run(5000,10000,leaf(Init),_,_,Target).
221
222 mcts_run(Timeout,Nr,Tree,FinalTree,Visits,Target) :-
223 statistics(walltime, [Cur|_]), EndTime is Cur+Timeout,
224 %set_prolog_flag(profiling,on),
225 mcts_loop(EndTime,Nr,Tree,FinalTree), %set_prolog_flag(profiling,off), print_profile,
226 get_best_mcts_move(FinalTree,Visits,Target),
227 get_node(FinalTree,From),
228 debug_format(19,'Best move from ~w is to ~w (~w visits)~n',[From,Target,Visits]),
229 (debug_mode(on) -> print_tree(FinalTree,0,3) ; true).
230
231
232 % run MCTS for a single initial tree with Nr iterations
233 mcts_loop(EndTime,Nr,Tree,FinalTree) :- Nr>1,!,
234 mcts(Tree,_,NewTree),
235 N1 is Nr-1,
236 (mcts_time_out(EndTime,N1,Tree) -> FinalTree=NewTree
237 ; mcts_loop(EndTime,N1,NewTree,FinalTree)).
238 mcts_loop(_,_,Tree,Tree). %format('Final MCTS Tree : ~w~n',[Tree]),
239
240 mcts_time_out(EndTime,Nr,_Tree) :- 0 is Nr mod 10,
241 statistics(walltime, [Cur|_]),
242 %format(' ~w -> ',[Nr]), print_tree_summary(_Tree),
243 Cur>EndTime,!,
244 format('MCTS TIME-OUT with ~w runs remaining.~n',[Nr]).
245
246
247 :- use_module(library(lists),[maplist/3, max_member/2, reverse/2]).
248
249 % find a direct child for a given (successor) state of the root state
250 % can be used to update the tree after a move was made
251 get_mcts_child_for_state(node(State,V,W,Children),State,Res) :- !,
252 % the node itself; we apply MCTS directly for other player
253 Res=node(State,V,W,Children).
254 get_mcts_child_for_state(node(_,_,_,Children),State,Child) :-
255 member(Child,Children),
256 get_node(Child,State),!.
257 get_mcts_child_for_state(Tree,State,Child) :-
258 print(cannot_get_child_for_state(State,Tree)),nl,
259 Child=leaf(State). % create a new root
260
261 get_best_mcts_move(node(_,_,_,Children),MaxV,Target) :-
262 maplist(get_visits,Children,Visits),
263 max_member(MaxV,Visits),
264 member(N,Children),
265 get_visits(N,MaxV),
266 get_node(N,Target).
267
268 invert_win(0,R) :- !, R=1.
269 invert_win(1,R) :- !, R=0.
270 invert_win(0.5,R) :- !, R=0.5.
271 invert_win(R,R1) :- R1 is 1-R.
272
273 %mcts(X,_,_) :- print(mcts(X)),nl,fail.
274 mcts(node(State,Wins,Visits,Childs),OuterWin,node(State,Wins1,V1,Childs1)) :-
275 V1 is Visits+1,
276 (Childs=[]
277 -> % the node has no children; simulate it, i.e., compute the utility value
278 Childs1=[],
279 simulate_for_parent(State,_,OuterWin,terminal)
280 ; LogNi is log(V1),
281 (select_best_ucb_child(Childs,State,LogNi,Child,Childs1,Child1) -> true
282 ; print(selection_failed),nl,trace),
283 mcts(Child,ChildWin,Child1),
284 invert_win(ChildWin,OuterWin)
285 ),
286 % backpropagate:
287 Wins1 is Wins+OuterWin.
288 %print(update(State,OuterWin,child(Child))),nl.
289 mcts(leaf(State),Wins,node(State,Wins,1,Childs)) :-
290 simulate_for_parent(State,_Val,Wins,Kind),
291 (Kind=terminal
292 -> Childs=[] % do not add any children; game is over already
293 ; winning_move(State,Child) -> Childs=[leaf(Child)] % minimax optimisation: pretend other children do not exist
294 ; findall(leaf(C),game_move(State,_,_,C),Childs)).
295 %print(expanded(State,Wins,Val,Childs)),nl.
296
297 select_best_ucb_child([C],_,_,C,[C1],C1) :- !. % no need to compute when there is a single child
298 select_best_ucb_child([Child1|Children],_Parent,LogNi,Child,NewChildren,NewChild) :- !,
299 ucb(Child1,LogNi,UCB1),
300 get_max_ucb(Children,LogNi,UCB1,Child1,[],Child,Rest),
301 NewChildren = [NewChild|Rest].
302
303 %select_best_ucb_child(Children,_Parent,LogNi,Child,NewChildren,NewChild) :- % old version with sort/maplist
304 % maplist(create_ucb_node(LogNi),Children,UC),
305 % sort(UC,UCS), reverse(UCS,RUCS), % this can be done more efficiently
306 % maplist(project,RUCS,SortedChildren),
307 % SortedChildren = [Child|Rest],
308 % NewChildren = [NewChild|Rest].
309
310 % select BestChild with maximal UCB value
311 get_max_ucb([],_,_,Child,Rest,Child,Rest).
312 get_max_ucb([Node|T],LogNi,CurMax,BestChildSoFar,RestSoFar,BestChild,Rest) :-
313 ucb(Node,LogNi,UCB),
314 UCB>CurMax, % we have a new best node
315 !,
316 get_max_ucb(T,LogNi,UCB,Node,[BestChildSoFar|RestSoFar],BestChild,Rest).
317 get_max_ucb([Node|T],LogNi,CurMax,BestChildSoFar,RestSoFar,BestChild,Rest) :-
318 get_max_ucb(T,LogNi,CurMax,BestChildSoFar,[Node|RestSoFar],BestChild,Rest).
319
320
321 % helper functions for maplist:
322 %project(ucb(_,Node),Node).
323 %create_ucb_node(LogNi,Node,ucb(UCB,Node)) :- ucb(Node,LogNi,UCB).
324 get_visits(node(_,_,V,_),V).
325 get_visits(leaf(_),0).
326 get_wins(node(_,W,_,_),W).
327 get_wins(leaf(_),0).
328 get_node(node(N,_,_,_),N).
329 get_node(leaf(N),N).
330 get_child(node(_,_,_,C),Child) :- nonvar(C),member(Child,C).
331
332 % compute UCB value for a node
333 ucb(leaf(_),_,Res) :- Res = 1000000.
334 ucb(node(_ID,Wins,Visits,C),LogNi,Res) :-
335 (Visits=0 -> Res = 10000
336 ; C=[], % terminal node
337 Wins=Visits % winning node for opponent:
338 % we assume player will always detect the winning move (limited mini-max improvement)
339 % will make a difference for tic-tac-toe, 40 SimRuns, [- - -, 0 x -, x - -] : with detection x blocked
340 -> Res = 10001 %,print(minimax_detection(_ID,Wins,Visits)),nl
341 ; Res is (Wins/Visits) + sqrt(2.0 * LogNi / Visits)
342 ).
343
344 :- dynamic max_simulation_length/1. % GAME_MCTS_MAX_SIMULATION_LENGTH
345
346 % simulate and report win as viewed by the parent of X
347 simulate_for_parent(NodeId,Val,Res,NodeKind) :-
348 max_simulation_length(MaxSimulationLength),
349 simulate_random(MaxSimulationLength,NodeId,Val,NodeKind),
350 (winning_node_for_parent_id(NodeId,Val) -> Res = 1
351 ; Val=0 -> Res = 0.5 % draw
352 ; Res = 0). % loss
353
354 winning_node_for_parent_id(NodeID,Val) :-
355 visited_expression(NodeID,State),
356 winning_node_for_parent(State,Val).
357 % check if value of some child node is winning for ancestor
358 winning_node_for_parent(State,Val) :-
359 get_player(State,MinMax),
360 (MinMax=max -> Val<0 % parent node is minimizing
361 ; Val>0). % parent node is maximimizing
362
363
364 :- use_module(library(random),[random_member/2, random_permutation/2]).
365 simulate_random(_,X,Res,terminal) :-
366 terminal(X,R,_),
367 !,Res=R.
368 simulate_random(Max,X,Res,non_terminal) :- Max>0, M1 is Max-1,
369 my_random_game_move(X,Z),!, % use random_game_move if not provided by game
370 simulate_random(M1,Z,Res,_).
371 simulate_random(_,_,Res,terminal) :-
372 Res = 0. % no moves possible or limit exceeded, we assume a draw
373 % TODO: call heuristic function if simulation was stopped
374
375
376 % -------------
377
378 :- use_module(probsrc(translate),[translate_bstate_limited/3]).
379 :- use_module(dotsrc(dot_graph_generator), [gen_dot_graph/6]).
380 :- use_module(probsrc(tools_strings),[ajoin/2]).
381 % DOT rendering
382
383 tcltk_gen_last_mcts_tree(MaxDepth,File) :-
384 saved_mcts_tree(Tree,_),!,
385 gen_dot_graph(File,[rankdir/'LR',no_page_size],mcts_node(Tree,MaxDepth),mcts_trans(Tree,MaxDepth),
386 dot_no_same_rank,dot_no_subgraph).
387 tcltk_gen_last_mcts_tree(_,_File) :-
388 add_error(mcts_game_play,'Run MCTS Game Play first and set GAME_MCTS_CACHE_LAST_TREE to TRUE',''),
389 fail.
390
391 mcts_tree_available :- saved_mcts_tree(_,_).
392
393 %gen_dot(Tree,MaxDepth) :- gen_dot_graph('mcts_tree.dot',[rankdir/'LR',no_page_size],mcts_node(Tree,MaxDepth),mcts_trans(Tree,MaxDepth),
394 % dot_no_same_rank,dot_no_subgraph).
395
396 :- use_module(probsrc(translate),[translate_event_with_limit/3]).
397 mcts_node(Tree,_,NodeID,none,NodeDesc,Shape,Style,Color) :-
398 get_node(Tree,NodeID),
399 get_wins(Tree,W), get_visits(Tree,N),
400 (visited_expression(NodeID,State),specfile:b_mode,translate_bstate_limited(State,50,TS) -> true ; TS='??'),
401 (get_player_in_state_id(NodeID,Player) -> true ; Player='??'),
402 (terminal(NodeID,UVal,WinParent)
403 -> ajoin(['id (terminal): ',NodeID,', w=',W,',n=',N,'\\nplayer=',Player,
404 ', utility=',UVal,', ',WinParent,'\\n',TS],NodeDesc),
405 Style=bold,
406 (UVal>0 -> Color=green ; UVal<0 -> Color=red ; Color=blue)
407 ; current_state_id(NodeID)
408 -> ajoin(['id (current): ',NodeID,', w=',W,',n=',N,', player=',Player,'\\n',TS],NodeDesc),
409 Style=rounded,
410 Color=black
411 ; ajoin(['id: ',NodeID,', w=',W,',n=',N,', player=',Player,'\\n',TS],NodeDesc),
412 Style=rounded,
413 Color=gray
414 ),
415 Shape=rectangle.
416 mcts_node(Tree,MaxDepth,NodeID,none,NodeDesc,Shape,Style,Color) :-
417 MaxDepth>=1, MD is MaxDepth-1,
418 get_child(Tree,Child),
419 mcts_node(Child,MD,NodeID,none,NodeDesc,Shape,Style,Color).
420 mcts_trans(Tree,_,NodeID,Label,SuccID,Color,Style) :-
421 get_node(Tree,NodeID),
422 (get_best_mcts_move(Tree,_,Best) -> true ; Best=unknown),
423 get_child(Tree,Child),
424 get_node(Child,SuccID),
425 (SuccID=Best -> Style=bold ; Style=solid),
426 (current_state_id(SuccID) -> Color=black ; Color=lightgray),
427 (transition(NodeID,Ev,_,SuccID) -> translate_event_with_limit(Ev,30,Label) ; Label='move').
428 mcts_trans(Tree,MaxDepth,NodeID,Label,SuccID,Color,Style) :-
429 MaxDepth>1, MD is MaxDepth-1,
430 get_child(Tree,Child),
431 mcts_trans(Child,MD,NodeID,Label,SuccID,Color,Style).
432
433
434 % -------------
435
436 pretty_print_node(StateID) :- format('State: ~w',[StateID]).
437
438 print_tree(leaf(Node),Indent,_) :- indent(Indent), pretty_print_node(Node),nl.
439 print_tree(node(Node,Wins,Visits,Children),Indent,Max) :- Max>0, M1 is Max-1,
440 indent(Indent), pretty_print_node(Node),nl,
441 length(Children,Childs),
442 indent(Indent), format(' w=~w, n=~w, childs=~w~n',[Wins,Visits,Childs]),
443 (member(C,Children), print_tree(C,s(Indent),M1), fail
444 ; true).
445 indent(0).
446 indent(s(X)) :- print(' + '), indent(X).
447
448 print_tree_summary(leaf(Node)) :- format('Tree is leaf for ~w~n',[Node]).
449 print_tree_summary(node(Node,Wins,Visits,_Children)) :- format('Tree for ~w: ~w wins, ~w visits~n',[Node,Wins,Visits]).