1 :- module(predicate_data_generator, [generate_synthesis_data_from_predicate_raw/3,
2 generate_synthesis_data_from_predicate_raw/5,
3 generate_synthesis_data_from_predicate_untyped/5]).
4
5 :- use_module(library(lists)).
6 :- use_module(library(random)).
7 :- use_module(library(sets), [intersect/2]).
8
9 :- use_module(probsrc(bmachine), [b_load_machine_from_file/1,
10 b_get_machine_variables/1,
11 b_get_all_used_identifiers/1,
12 b_get_main_filename/1,
13 bmachine_is_precompiled/0,
14 b_machine_precompile/0]).
15 :- use_module(probsrc(parsercall), [parse_formula/2]).
16 :- use_module(probsrc(translate), [translate_bexpression/2]).
17 :- use_module(probsrc(solver_interface), [solve_predicate/5,
18 type_check_in_machine_context/2]).
19 :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2,
20 get_texpr_info/2,
21 find_identifier_uses/3,
22 find_typed_identifier_uses/3]).
23 :- use_module(probsrc(preferences), [set_preference/2,
24 get_preference/2]).
25
26 :- use_module(probsrc('por/b_simplifier')).
27 :- use_module(synthesis('deep_learning/ground_truth')).
28 :- use_module(synthesis('deep_learning/b_machine_identifier_normalization')).
29 :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2,
30 create_equality_nodes_from_example/2,
31 b_get_typed_invariant_from_machine/1]).
32
33 min_amount_of_examples(3).
34 max_amount_of_examples(positive, 8).
35 max_amount_of_examples(negative, 8).
36
37 augment_records(5).
38 solver_timeout_ms(10000).
39
40 get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :-
41 min_amount_of_examples(MinAmountOfExamples),
42 max_amount_of_examples(positive, PTempMaxAmountOfExamples),
43 max_amount_of_examples(negative, NTempMaxAmountOfExamples),
44 PMaxAmountOfExamples is PTempMaxAmountOfExamples+1,
45 NMaxAmountOfExamples is NTempMaxAmountOfExamples+1,
46 random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples),
47 random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples).
48
49 random_list_of_numbers(0, Acc, Acc) :-
50 !.
51 random_list_of_numbers(C, Acc, L) :-
52 C1 is C-1,
53 get_random_amount_of_examples(R),
54 \+ member(R, Acc),
55 !,
56 random_list_of_numbers(C1, [R|Acc], L).
57 random_list_of_numbers(C, Acc, L) :-
58 random_list_of_numbers(C, Acc, L).
59
60 exclude_solution([], ExclusionPred, ExclusionPred) :-
61 !.
62 exclude_solution(State, ExclusionPred, NewExclusionPred) :-
63 create_equality_nodes_from_example(State, EqualityNodes),
64 conjunct_predicates(EqualityNodes, EqualityConj),
65 truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred).
66
67 truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :-
68 !.
69 truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])).
70
71 filter_machine_var_states([], _, Acc, Acc).
72 filter_machine_var_states([VarState|T], MachineVars, Acc, States) :-
73 get_texpr_info(VarState, Info),
74 member(synthesis(machinevar, _VarName), Info),
75 % operation parameters are also marked as machinevar but not in MachineVars
76 %member(b(identifier(VarName), _, _), MachineVars),
77 !,
78 filter_machine_var_states(T, MachineVars, [VarState|Acc], States).
79 filter_machine_var_states([_|T], MachineVars, Acc, States) :-
80 filter_machine_var_states(T, MachineVars, Acc, States).
81
82 map_translate_var_state(Env, [], [], Env).
83 map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :-
84 normalize_ids_in_b_ast(Env, Ast, NAst, Env1),
85 NAst = b(_, Type, Info),
86 member(synthesis(machinevar, VarName), Info),
87 translate_bexpression(NAst, PrettyAst),
88 map_translate_var_state(Env1, T, NT, NewEnv).
89
90 get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :-
91 !.
92 get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, PredicateAst, Acc, ListOfExamples, NewEnv) :-
93 AmountOfExamples1 is AmountOfExamples-1,
94 solve_predicate(b(conjunct(ExclusionPred,PredicateAst),pred,[]), _, 1, [force_evaluation], Solution),
95 Solution = solution(Bindings),
96 get_input_nodes_from_bindings(Bindings, TempState),
97 filter_machine_var_states(TempState, MachineVars, [], State),
98 State \== [],
99 exclude_solution(State, ExclusionPred, NewExclusionPred),
100 map_translate_var_state(Env, State, PrettyState, Env1),
101 !,
102 get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, PredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv).
103 % cancel if no solution found
104 get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env).
105
106 get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv) :-
107 random_list_of_numbers(AR, [], RandomNrList),
108 get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, TypingPredicate, AugmentedSetOfData, NewEnv).
109
110 get_augmented_states_for_predicate_rand(Env, [], _, _, _, _, [], Env).
111 get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :-
112 AmountOfExamples = (PAmountOfExamples,NAmountOfExamples),
113 Pos = b(conjunct(TypingPredicate,PredicateAst),pred,[]),
114 Neg = b(conjunct(TypingPredicate,b(negation(PredicateAst),pred,[])),pred,[]),
115 get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, Pos, [], PositiveStates, Env1),
116 get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, Neg, [], NegativeStates, Env2),
117 length(PositiveStates, LPos),
118 length(NegativeStates, LNeg),
119 Amount is LPos+LNeg,
120 min_amount_of_examples(MinAmountOfExamples),
121 Amount >= MinAmountOfExamples,
122 !,
123 get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv).
124 get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv) :-
125 get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, TypingPredicate, NT, NewEnv).
126
127 blacklist_component(int_set).
128 blacklist_component(integer_set).
129 blacklist_component(nat_set).
130 blacklist_component(nat1_set).
131 blacklist_component(natural_set).
132 blacklist_component(natural1_set).
133 blacklist_component(forall).
134 blacklist_component(exists).
135 blacklist_component(comprehension_set).
136 blacklist_component(total_function).
137 blacklist_component(total_surjection).
138 blacklist_component(total_injection).
139 blacklist_component(total_relation).
140 blacklist_component(total_surjection_relation).
141 blacklist_component(surjection_relation).
142 blacklist_component(partial_function).
143 blacklist_component(partial_injection).
144 blacklist_component(partial_surjection).
145 blacklist_component(partial_bijection).
146 blacklist_component(lambda).
147 blacklist_component(quantified_union).
148 blacklist_component(quantified_intersection).
149 %blacklist_component(function).
150 %blacklist_component(general_sum).
151 %blacklist_component(general_product).
152
153 %% filter_predicate(+UsedIds, +TPredicateAst, -PredicateAst).
154 %
155 % Remove constraints that do not refer to any identifier in UsedIds.
156 % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being in UsedIds.
157 % Remove constraints that use components from blacklist_component/1
158 filter_predicate(UsedIds, Ast, NewAst) :-
159 Ast = b(Node,_,_),
160 Node =.. [Functor, Lhs, Rhs],
161 ( Functor == conjunct
162 ; Functor == disjunct
163 ; Functor == equivalence
164 ; Functor == implication),
165 filter_predicate_binary(UsedIds, Ast, Functor, Lhs, Rhs, Clean),
166 !,
167 NewAst = Clean.
168 % remove constraint if contains component in blacklist_component/1
169 filter_predicate(UsedIds, Ast, NewAst) :-
170 Ast = b(Node,_,_),
171 Node =.. [_, Lhs, Rhs],
172 get_used_components(Lhs, [], LhsUsedComponents),
173 get_used_components(Rhs, [], RhsUsedComponents),
174 findall(BlacklistCmpt, (member(BlacklistCmpt, LhsUsedComponents), blacklist_component(BlacklistCmpt)), LhsBlacklistCmpts),
175 findall(BlacklistCmpt, (member(BlacklistCmpt, RhsUsedComponents), blacklist_component(BlacklistCmpt)), RhsBlacklistCmpts),
176 !,
177 ( LhsBlacklistCmpts == [],
178 RhsBlacklistCmpts == []
179 -> ast_uses_id_from_list(UsedIds, Ast),
180 NewAst = Ast
181 ; fail
182 ).
183 filter_predicate(UsedIds, Negation, NegNewAst) :-
184 Negation = b(negation(Ast),pred,_),
185 !,
186 filter_predicate(UsedIds, Ast, NewAst),
187 NegNewAst = b(negation(NewAst),pred,[]).
188 % fail if ast does not use any ids from UsedIds
189 filter_predicate(UsedIds, Ast, NewAst) :-
190 ast_uses_id_from_list(UsedIds, Ast),
191 get_used_components(Ast, [], UsedComponents),
192 findall(BlacklistCmpt, (member(BlacklistCmpt, UsedComponents), blacklist_component(BlacklistCmpt)), BlacklistCmpts),
193 BlacklistCmpts == [],
194 NewAst = Ast.
195
196 filter_predicate_binary(UsedIds, b(_,Type,Info), Functor, Lhs, Rhs, NewNode) :-
197 % if lhs would be empty (or truth) keep rhs only and vice versa
198 ( filter_predicate(UsedIds, Lhs, NLhs)
199 -> ( filter_predicate(UsedIds, Rhs, NRhs)
200 -> NNode =.. [Functor, NLhs, NRhs],
201 NewNode = b(NNode,Type,Info)
202 ; NewNode = NLhs
203 )
204 ; ( filter_predicate(UsedIds, Rhs, NRhs)
205 -> NewNode = NRhs
206 % fail if both sides do not use any ids from UsedIds (predicate would be truth)
207 ; fail
208 )
209 ),!.
210
211 ast_uses_id_from_list(UsedIds, Ast) :-
212 Ast = b(Node,_,_),
213 functor(Node, Functor, _),
214 \+ blacklist_component(Functor),
215 bsyntaxtree:find_identifier_uses(Ast, [], AstUsedIds),
216 !,
217 intersect(UsedIds, AstUsedIds).
218
219 get_used_components(b(Node,_,_), Acc, [Functor|Acc]) :-
220 functor(Node, Functor, 0),!.
221 get_used_components(b(Node,_,_), Acc, UsedComponents) :-
222 Node =.. [Component|Args],
223 map_get_used_components(Args, [Component|Acc], UsedComponents),!.
224 get_used_components(AtomNr, Acc, Acc) :-
225 functor(AtomNr, _, 0).
226
227 map_get_used_components([], Acc, Acc).
228 map_get_used_components([Arg|T], Acc, UsedComponents) :-
229 get_used_components(Arg, Acc, NewAcc),
230 map_get_used_components(T, NewAcc, UsedComponents).
231
232 %% generate_synthesis_data_from_predicate_raw(+MachinePath, +RawPredicate, -GeneratedData).
233 %
234 % Generate positive and negative examples for a pretty-printed predicate and
235 % extract the ground truth B components as used by the program synthesis tool.
236 % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering
237 % data augmentation. Each example is a set of triples (MachineVar, Type, Value).
238 % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from.
239 % Otherwise, fails silently.
240 generate_synthesis_data_from_predicate_raw(MachinePath, RawPredicate, AugmentedSetOfData) :-
241 augment_records(AR),
242 solver_timeout_ms(SolverTimeoutMs),
243 generate_synthesis_data_from_predicate_raw(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData).
244
245 :- dynamic normalized_id_name_mapping/4.
246 :- volatile normalized_id_name_mapping/4.
247
248 get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :-
249 normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames).
250 get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames) :-
251 \+ normalized_id_name_mapping(_, _, _, _),
252 get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames),
253 asserta(normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames)).
254
255 load_b_machine_if_unloaded(MachinePath) :-
256 bmachine_is_precompiled,
257 b_get_main_filename(LoadedMachinePath),
258 atom_concat(_, MachinePath, LoadedMachinePath),
259 !.
260 load_b_machine_if_unloaded(MachinePath) :-
261 retractall(normalized_id_name_mapping(_, _, _, _)),
262 b_load_machine_from_file(MachinePath),
263 b_machine_precompile.
264
265 get_id_name(b(identifier(Name),_,_),Name).
266
267 %% generate_synthesis_data_from_predicate_raw(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +RawPredicate, -AugmentedSetOfData).
268 %
269 % AugmentRecords is the amount of data augmentations.
270 generate_synthesis_data_from_predicate_raw(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :-
271 atom_codes(RawPredicate, RawPredicateCodes),
272 parse_formula(RawPredicateCodes, UntypedPredAst),
273 generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData).
274
275 %% generate_synthesis_data_from_predicate_untyped(+MachinePath, +AugmentRecords, +SolverTimeoutMs, +UntypedPredAst, -AugmentedSetOfData).
276 %
277 generate_synthesis_data_from_predicate_untyped(MachinePath, AugmentRecords, SolverTimeoutMs, UntypedPredAst, AugmentedSetOfData) :-
278 load_b_machine_if_unloaded(MachinePath),
279 type_check_in_machine_context([UntypedPredAst], TypedPredAsts),
280 TypedPredAsts = [TypedPredAst],
281 generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, TypedPredAst, AugmentedSetOfData).
282
283 generate_synthesis_data_from_predicate_ast(MachinePath, AugmentRecords, SolverTimeoutMs, PredicateAst, AugmentedSetOfData) :-
284 preferences:temporary_set_preference(optimize_ast, false),
285 preferences:temporary_set_preference(normalize_ast, true),
286 % should we also set normalize_ast_sort_commutative to true?
287 load_b_machine_if_unloaded(MachinePath),
288 \+ current_machine_uses_records,
289 retractall(tools:id_counter(_)),
290 set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref),
291 b_get_machine_variables(MachineVars),
292 maplist(get_id_name, MachineVars, MachineVarNames),
293 % following includes constraints not using machine variables
294 %b_get_all_used_identifiers(AllUsedIds),
295 b_ast_cleanup:clean_up_pred(PredicateAst, _, CleanPred),
296 b_simplifier:simplify_b_predicate(CleanPred, SimplifiedPred),
297 filter_predicate(MachineVarNames, SimplifiedPred, FilteredPredicateAst),
298 get_library_components_from_pred_or_expr(FilteredPredicateAst, UsedComponents),
299 get_normalized_id_name_mapping_stateful(NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames),
300 Env = [[], NormalizedSets, NormalizedIds, NOperationNames, NRecordFieldNames],
301 %b_get_typed_invariant_from_machine(Invariant),
302 find_typed_identifier_uses(FilteredPredicateAst, [], UsedIds),
303 translate:generate_typing_predicates(UsedIds, TypingPredicates),
304 conjunct_predicates(TypingPredicates, TypingPredicate),
305 %PredWithTypes = b(conjunct(TypingPredicate, FilteredPredicateAst),pred,[]),
306 get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, FilteredPredicateAst, TypingPredicate, AugmentedSetOfData, _),
307 reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref),
308 preferences:reset_temporary_preference(optimize_ast),
309 preferences:reset_temporary_preference(normalize_ast).
310
311 %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref).
312 %
313 reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :-
314 set_preference(try_kodkod_on_load, OldKodkodPref),
315 set_preference(time_out, OldTimeOutPref),
316 set_preference(randomise_enumeration_order, OldRandPref).
317
318 %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref).
319 %
320 % Set desired preferences and return the old ones.
321 set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :-
322 get_preference(try_kodkod_on_load, OldKodkodPref),
323 set_preference(try_kodkod_on_load, false),
324 get_preference(randomise_enumeration_order, OldRandPref),
325 set_preference(randomise_enumeration_order, true),
326 get_preference(time_out, OldTimeOutPref),
327 set_preference(time_out, Timeout).