1 :- module(predicate_data_generator, [generate_synthesis_data_from_predicate/3,
2 generate_synthesis_data_from_predicate/5]).
3
4 :- use_module(library(random)).
5 :- use_module(library(sets), [intersect/2]).
6
7 :- use_module(probsrc(bmachine), [b_load_machine_from_file/1,
8 b_get_machine_variables/1,
9 b_get_all_used_identifiers/1,
10 b_get_main_filename/1,
11 bmachine_is_precompiled/0,
12 b_machine_precompile/0]).
13 :- use_module(probsrc(parsercall), [parse_formula/2]).
14 :- use_module(probsrc(translate), [translate_bexpression/2]).
15 :- use_module(probsrc(solver_interface), [solve_predicate/3, type_check_in_machine_context/2]).
16 :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2, get_texpr_info/2, find_identifier_uses/3]).
17 :- use_module(probsrc('synthesis/deep_learning/ground_truth')).
18 :- use_module(probsrc('synthesis/deep_learning/b_machine_identifier_normalization')).
19 :- use_module(probsrc(preferences), [set_preference/2,
20 get_preference/2]).
21
22 :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2,
23 create_equality_nodes_from_example/2,
24 b_get_typed_invariant_from_machine/1]).
25
26 min_amount_of_examples(3).
27 max_amount_of_examples(positive, 8).
28 max_amount_of_examples(negative, 8).
29
30 augment_records(5).
31 solver_timeout_ms(10000).
32
33 get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :-
34 min_amount_of_examples(MinAmountOfExamples),
35 max_amount_of_examples(positive, PTempMaxAmountOfExamples),
36 max_amount_of_examples(negative, NTempMaxAmountOfExamples),
37 PMaxAmountOfExamples is PTempMaxAmountOfExamples+1,
38 NMaxAmountOfExamples is NTempMaxAmountOfExamples+1,
39 random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples),
40 random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples).
41
42 random_list_of_numbers(0, Acc, Acc) :-
43 !.
44 random_list_of_numbers(C, Acc, L) :-
45 C1 is C-1,
46 get_random_amount_of_examples(R),
47 \+ member(R, Acc),
48 !,
49 random_list_of_numbers(C1, [R|Acc], L).
50 random_list_of_numbers(C, Acc, L) :-
51 random_list_of_numbers(C, Acc, L).
52
53 exclude_solution([], ExclusionPred, ExclusionPred) :-
54 !.
55 exclude_solution(State, ExclusionPred, NewExclusionPred) :-
56 create_equality_nodes_from_example(State, EqualityNodes),
57 conjunct_predicates(EqualityNodes, EqualityConj),
58 truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred).
59
60 truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :-
61 !.
62 truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])).
63
64 filter_machine_var_states([], _, Acc, Acc).
65 filter_machine_var_states([VarState|T], MachineVars, Acc, States) :-
66 get_texpr_info(VarState, Info),
67 member(synthesis(machinevar, VarName), Info),
68 member(b(identifier(VarName), _, _), MachineVars),
69 !,
70 filter_machine_var_states(T, MachineVars, [VarState|Acc], States).
71 filter_machine_var_states([_|T], MachineVars, Acc, States) :-
72 filter_machine_var_states(T, MachineVars, Acc, States).
73
74 map_translate_var_state(Env, [], [], Env).
75 map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :-
76 normalize_ids_in_b_ast(Env, Ast, NAst, Env1),
77 NAst = b(_, Type, Info),
78 member(synthesis(machinevar, VarName), Info),
79 translate_bexpression(NAst, PrettyAst),
80 map_translate_var_state(Env1, T, NT, NewEnv).
81
82 get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :-
83 !.
84 get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, VPredicateAst, Acc, ListOfExamples, NewEnv) :-
85 AmountOfExamples1 is AmountOfExamples-1,
86 solve_predicate(b(conjunct(ExclusionPred,VPredicateAst),pred,[]), _, Solution),
87 Solution = solution(Bindings),
88 get_input_nodes_from_bindings(Bindings, TempState),
89 filter_machine_var_states(TempState, MachineVars, [], State),
90 State \== [],
91 exclude_solution(State, ExclusionPred, NewExclusionPred),
92 map_translate_var_state(Env, State, PrettyState, Env1),
93 !,
94 get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, VPredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv).
95 % cancel if no solution found
96 get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env).
97
98 get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv) :-
99 random_list_of_numbers(AR, [], RandomNrList),
100 get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv).
101
102 get_augmented_states_for_predicate_rand(Env, [], _, _, _, [], Env).
103 get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :-
104 AmountOfExamples = (PAmountOfExamples,NAmountOfExamples),
105 get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, PredicateAst, [], PositiveStates, Env1),
106 get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, b(negation(PredicateAst),pred,[]), [], NegativeStates, Env2),
107 length(PositiveStates, LPos),
108 length(NegativeStates, LNeg),
109 Amount is LPos+LNeg,
110 min_amount_of_examples(MinAmountOfExamples),
111 Amount >= MinAmountOfExamples,
112 !,
113 get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv).
114 get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, NT, NewEnv) :-
115 get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv).
116
117 %% remove_foolish_sub_predicates(+AllUsedIds, +TPredicateAst, -PredicateAst).
118 %
119 % Remove predicates that do not refer to any identifier within a machine (AllUsedIds).
120 % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being a machine variable.
121 % TODO: remove typing predicates ???
122 remove_foolish_sub_predicates(AllUsedIds, Ast, NewAst) :-
123 Ast = b(Node,_,_),
124 Node =.. [Functor, Lhs, Rhs],
125 (Node == conjunct ; Node == disjunct),
126 !,
127 remove_foolish_sub_predicates_binary(AllUsedIds, Ast, Functor, Lhs, Rhs, NewAst).
128 % fail if sub-predicate does not use any ids from AllUsedIds
129 remove_foolish_sub_predicates(AllUsedIds, Ast, Ast) :-
130 find_identifier_uses(Ast, [], UsedIds),
131 intersect(AllUsedIds, UsedIds),
132 !.
133
134 remove_foolish_sub_predicates_binary(AllUsedIds, b(_,Type,Info), Functor, Lhs, Rhs, b(NNode,Type,Info)) :-
135 remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs),
136 remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!,
137 NNode =.. [Functor, NLhs, NRhs].
138 remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NRhs) :-
139 \+ remove_foolish_sub_predicates(AllUsedIds, Lhs, _),
140 remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!.
141 remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NLhs) :-
142 remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs),
143 \+ remove_foolish_sub_predicates(AllUsedIds, Rhs, _),!.
144 % fail if both sides do not use any ids from AllUsedIds
145
146 %% generate_synthesis_data_from_predicate(+RawPredicate, -GeneratedData).
147 %
148 % Generate positive and negative examples for a pretty-printed predicate and
149 % extract the ground truth B components as used by the program synthesis tool.
150 % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering
151 % data augmentation. Each example is a set of triples (MachineVar, Type, Value).
152 % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from.
153 % Otherwise, fails silently.
154 generate_synthesis_data_from_predicate(MachinePath, RawPredicate, AugmentedSetOfData) :-
155 augment_records(AR),
156 solver_timeout_ms(SolverTimeoutMs),
157 generate_synthesis_data_from_predicate(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData).
158
159 load_b_machine_if_unloaded(MachinePath) :-
160 bmachine_is_precompiled,
161 b_get_main_filename(LoadedMachinePath),
162 atom_concat(_, MachinePath, LoadedMachinePath),
163 !.
164 load_b_machine_if_unloaded(MachinePath) :-
165 b_load_machine_from_file(MachinePath),
166 b_machine_precompile.
167
168 %% generate_synthesis_data_from_predicate(+AugmentRecords, +SolverTimeoutMs, +RawPredicate, -GeneratedData).
169 %
170 % AugmentRecords is the amount of data augmentations.
171 generate_synthesis_data_from_predicate(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :-
172 load_b_machine_if_unloaded(MachinePath),
173 \+ current_machine_uses_records,
174 set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref),
175 atom_codes(RawPredicate, RawPredicateCodes),
176 parse_formula(RawPredicateCodes, ParsedPredicate),
177 type_check_in_machine_context([ParsedPredicate], TParsedPredicates),
178 TParsedPredicates = [TPredicateAst],
179 b_get_machine_variables(MachineVars),
180 b_get_all_used_identifiers(AllUsedIds),
181 remove_foolish_sub_predicates(AllUsedIds, TPredicateAst, PredicateAst),
182 get_library_components_from_pred_or_expr(PredicateAst, UsedComponents),
183 get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames),
184 Env = [[], NormalizedSets, NormalizedIds, NOperationNames],
185 b_get_typed_invariant_from_machine(Invariant),
186 PredWithInv = b(conjunct(Invariant, PredicateAst),pred,[]),
187 get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, PredWithInv, AugmentedSetOfData, _),
188 reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref).
189
190 %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref).
191 %
192 reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :-
193 set_preference(try_kodkod_on_load, OldKodkodPref),
194 set_preference(time_out, OldTimeOutPref),
195 set_preference(randomise_enumeration_order, OldRandPref).
196
197 %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref).
198 %
199 % Set desired preferences and return the old ones.
200 set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :-
201 get_preference(try_kodkod_on_load, OldKodkodPref),
202 set_preference(try_kodkod_on_load, false),
203 get_preference(randomise_enumeration_order, OldRandPref),
204 set_preference(randomise_enumeration_order, true),
205 get_preference(time_out, OldTimeOutPref),
206 set_preference(time_out, Timeout).