1 | :- module(predicate_data_generator, [generate_synthesis_data_from_predicate/3, | |
2 | generate_synthesis_data_from_predicate/5]). | |
3 | ||
4 | :- use_module(library(random)). | |
5 | :- use_module(library(sets), [intersect/2]). | |
6 | ||
7 | :- use_module(probsrc(bmachine), [b_load_machine_from_file/1, | |
8 | b_get_machine_variables/1, | |
9 | b_get_all_used_identifiers/1, | |
10 | b_get_main_filename/1, | |
11 | bmachine_is_precompiled/0, | |
12 | b_machine_precompile/0]). | |
13 | :- use_module(probsrc(parsercall), [parse_formula/2]). | |
14 | :- use_module(probsrc(translate), [translate_bexpression/2]). | |
15 | :- use_module(probsrc(solver_interface), [solve_predicate/3, type_check_in_machine_context/2]). | |
16 | :- use_module(probsrc(bsyntaxtree), [conjunct_predicates/2, get_texpr_info/2, find_identifier_uses/3]). | |
17 | :- use_module(probsrc('synthesis/deep_learning/ground_truth')). | |
18 | :- use_module(probsrc('synthesis/deep_learning/b_machine_identifier_normalization')). | |
19 | :- use_module(probsrc(preferences), [set_preference/2, | |
20 | get_preference/2]). | |
21 | ||
22 | :- use_module(synthesis(synthesis_util), [get_input_nodes_from_bindings/2, | |
23 | create_equality_nodes_from_example/2, | |
24 | b_get_typed_invariant_from_machine/1]). | |
25 | ||
26 | min_amount_of_examples(3). | |
27 | max_amount_of_examples(positive, 8). | |
28 | max_amount_of_examples(negative, 8). | |
29 | ||
30 | augment_records(5). | |
31 | solver_timeout_ms(10000). | |
32 | ||
33 | get_random_amount_of_examples((PAmountOfExamples,NAmountOfExamples)) :- | |
34 | min_amount_of_examples(MinAmountOfExamples), | |
35 | max_amount_of_examples(positive, PTempMaxAmountOfExamples), | |
36 | max_amount_of_examples(negative, NTempMaxAmountOfExamples), | |
37 | PMaxAmountOfExamples is PTempMaxAmountOfExamples+1, | |
38 | NMaxAmountOfExamples is NTempMaxAmountOfExamples+1, | |
39 | random(MinAmountOfExamples, PMaxAmountOfExamples, PAmountOfExamples), | |
40 | random(MinAmountOfExamples, NMaxAmountOfExamples, NAmountOfExamples). | |
41 | ||
42 | random_list_of_numbers(0, Acc, Acc) :- | |
43 | !. | |
44 | random_list_of_numbers(C, Acc, L) :- | |
45 | C1 is C-1, | |
46 | get_random_amount_of_examples(R), | |
47 | \+ member(R, Acc), | |
48 | !, | |
49 | random_list_of_numbers(C1, [R|Acc], L). | |
50 | random_list_of_numbers(C, Acc, L) :- | |
51 | random_list_of_numbers(C, Acc, L). | |
52 | ||
53 | exclude_solution([], ExclusionPred, ExclusionPred) :- | |
54 | !. | |
55 | exclude_solution(State, ExclusionPred, NewExclusionPred) :- | |
56 | create_equality_nodes_from_example(State, EqualityNodes), | |
57 | conjunct_predicates(EqualityNodes, EqualityConj), | |
58 | truth_or_exclude(EqualityConj,ExclusionPred, NewExclusionPred). | |
59 | ||
60 | truth_or_exclude(b(truth,pred,[]), ExclusionPred, ExclusionPred) :- | |
61 | !. | |
62 | truth_or_exclude(EqualityConj, ExclusionPred, b(conjunct(b(negation(EqualityConj),pred,[]),ExclusionPred),pred,[])). | |
63 | ||
64 | filter_machine_var_states([], _, Acc, Acc). | |
65 | filter_machine_var_states([VarState|T], MachineVars, Acc, States) :- | |
66 | get_texpr_info(VarState, Info), | |
67 | member(synthesis(machinevar, VarName), Info), | |
68 | member(b(identifier(VarName), _, _), MachineVars), | |
69 | !, | |
70 | filter_machine_var_states(T, MachineVars, [VarState|Acc], States). | |
71 | filter_machine_var_states([_|T], MachineVars, Acc, States) :- | |
72 | filter_machine_var_states(T, MachineVars, Acc, States). | |
73 | ||
74 | map_translate_var_state(Env, [], [], Env). | |
75 | map_translate_var_state(Env, [Ast|T], [(VarName,Type,PrettyAst)|NT], NewEnv) :- | |
76 | normalize_ids_in_b_ast(Env, Ast, NAst, Env1), | |
77 | NAst = b(_, Type, Info), | |
78 | member(synthesis(machinevar, VarName), Info), | |
79 | translate_bexpression(NAst, PrettyAst), | |
80 | map_translate_var_state(Env1, T, NT, NewEnv). | |
81 | ||
82 | get_amount_of_states_for_predicate(Env, 0, _, _, _, Acc, Acc, Env) :- | |
83 | !. | |
84 | get_amount_of_states_for_predicate(Env, AmountOfExamples, ExclusionPred, MachineVars, VPredicateAst, Acc, ListOfExamples, NewEnv) :- | |
85 | AmountOfExamples1 is AmountOfExamples-1, | |
86 | solve_predicate(b(conjunct(ExclusionPred,VPredicateAst),pred,[]), _, Solution), | |
87 | Solution = solution(Bindings), | |
88 | get_input_nodes_from_bindings(Bindings, TempState), | |
89 | filter_machine_var_states(TempState, MachineVars, [], State), | |
90 | State \== [], | |
91 | exclude_solution(State, ExclusionPred, NewExclusionPred), | |
92 | map_translate_var_state(Env, State, PrettyState, Env1), | |
93 | !, | |
94 | get_amount_of_states_for_predicate(Env1, AmountOfExamples1, NewExclusionPred, MachineVars, VPredicateAst, [PrettyState|Acc], ListOfExamples, NewEnv). | |
95 | % cancel if no solution found | |
96 | get_amount_of_states_for_predicate(Env, _, _, _, _, Acc, Acc, Env). | |
97 | ||
98 | get_augmented_states_for_predicate(Env, AR, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv) :- | |
99 | random_list_of_numbers(AR, [], RandomNrList), | |
100 | get_augmented_states_for_predicate_rand(Env, RandomNrList, MachineVars, UsedComponents, PredicateAst, AugmentedSetOfData, NewEnv). | |
101 | ||
102 | get_augmented_states_for_predicate_rand(Env, [], _, _, _, [], Env). | |
103 | get_augmented_states_for_predicate_rand(Env, [AmountOfExamples|T], MachineVars, UsedComponents, PredicateAst, [(PositiveStates,NegativeStates,UsedComponents)|NT], NewEnv) :- | |
104 | AmountOfExamples = (PAmountOfExamples,NAmountOfExamples), | |
105 | get_amount_of_states_for_predicate(Env, PAmountOfExamples, b(truth,pred,[]), MachineVars, PredicateAst, [], PositiveStates, Env1), | |
106 | get_amount_of_states_for_predicate(Env1, NAmountOfExamples, b(truth,pred,[]), MachineVars, b(negation(PredicateAst),pred,[]), [], NegativeStates, Env2), | |
107 | length(PositiveStates, LPos), | |
108 | length(NegativeStates, LNeg), | |
109 | Amount is LPos+LNeg, | |
110 | min_amount_of_examples(MinAmountOfExamples), | |
111 | Amount >= MinAmountOfExamples, | |
112 | !, | |
113 | get_augmented_states_for_predicate_rand(Env2, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv). | |
114 | get_augmented_states_for_predicate_rand(Env, [_|T], MachineVars, UsedComponents, PredicateAst, NT, NewEnv) :- | |
115 | get_augmented_states_for_predicate_rand(Env, T, MachineVars, UsedComponents, PredicateAst, NT, NewEnv). | |
116 | ||
117 | %% remove_foolish_sub_predicates(+AllUsedIds, +TPredicateAst, -PredicateAst). | |
118 | % | |
119 | % Remove predicates that do not refer to any identifier within a machine (AllUsedIds). | |
120 | % For instance, 'x: NAT & x < 10 & NAT \/ NAT1 = NAT' will be reduced to 'x: NAT & x < 10' with x being a machine variable. | |
121 | % TODO: remove typing predicates ??? | |
122 | remove_foolish_sub_predicates(AllUsedIds, Ast, NewAst) :- | |
123 | Ast = b(Node,_,_), | |
124 | Node =.. [Functor, Lhs, Rhs], | |
125 | (Node == conjunct ; Node == disjunct), | |
126 | !, | |
127 | remove_foolish_sub_predicates_binary(AllUsedIds, Ast, Functor, Lhs, Rhs, NewAst). | |
128 | % fail if sub-predicate does not use any ids from AllUsedIds | |
129 | remove_foolish_sub_predicates(AllUsedIds, Ast, Ast) :- | |
130 | find_identifier_uses(Ast, [], UsedIds), | |
131 | intersect(AllUsedIds, UsedIds), | |
132 | !. | |
133 | ||
134 | remove_foolish_sub_predicates_binary(AllUsedIds, b(_,Type,Info), Functor, Lhs, Rhs, b(NNode,Type,Info)) :- | |
135 | remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs), | |
136 | remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!, | |
137 | NNode =.. [Functor, NLhs, NRhs]. | |
138 | remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NRhs) :- | |
139 | \+ remove_foolish_sub_predicates(AllUsedIds, Lhs, _), | |
140 | remove_foolish_sub_predicates(AllUsedIds, Rhs, NRhs),!. | |
141 | remove_foolish_sub_predicates_binary(AllUsedIds, _, _, Lhs, Rhs, NLhs) :- | |
142 | remove_foolish_sub_predicates(AllUsedIds, Lhs, NLhs), | |
143 | \+ remove_foolish_sub_predicates(AllUsedIds, Rhs, _),!. | |
144 | % fail if both sides do not use any ids from AllUsedIds | |
145 | ||
146 | %% generate_synthesis_data_from_predicate(+RawPredicate, -GeneratedData). | |
147 | % | |
148 | % Generate positive and negative examples for a pretty-printed predicate and | |
149 | % extract the ground truth B components as used by the program synthesis tool. | |
150 | % Data is a list of triples (PositiveExamples, NegativeExamples, GroundTruth) considering | |
151 | % data augmentation. Each example is a set of triples (MachineVar, Type, Value). | |
152 | % Note: Predicate assumes that the B or Event-B machine is loaded that RawPredicate originates from. | |
153 | % Otherwise, fails silently. | |
154 | generate_synthesis_data_from_predicate(MachinePath, RawPredicate, AugmentedSetOfData) :- | |
155 | augment_records(AR), | |
156 | solver_timeout_ms(SolverTimeoutMs), | |
157 | generate_synthesis_data_from_predicate(MachinePath, AR, SolverTimeoutMs, RawPredicate, AugmentedSetOfData). | |
158 | ||
159 | load_b_machine_if_unloaded(MachinePath) :- | |
160 | bmachine_is_precompiled, | |
161 | b_get_main_filename(LoadedMachinePath), | |
162 | atom_concat(_, MachinePath, LoadedMachinePath), | |
163 | !. | |
164 | load_b_machine_if_unloaded(MachinePath) :- | |
165 | b_load_machine_from_file(MachinePath), | |
166 | b_machine_precompile. | |
167 | ||
168 | %% generate_synthesis_data_from_predicate(+AugmentRecords, +SolverTimeoutMs, +RawPredicate, -GeneratedData). | |
169 | % | |
170 | % AugmentRecords is the amount of data augmentations. | |
171 | generate_synthesis_data_from_predicate(MachinePath, AugmentRecords, SolverTimeoutMs, RawPredicate, AugmentedSetOfData) :- | |
172 | load_b_machine_if_unloaded(MachinePath), | |
173 | \+ current_machine_uses_records, | |
174 | set_desired_preferences(SolverTimeoutMs, OldKodkodPref, OldTimeOutPref, OldRandPref), | |
175 | atom_codes(RawPredicate, RawPredicateCodes), | |
176 | parse_formula(RawPredicateCodes, ParsedPredicate), | |
177 | type_check_in_machine_context([ParsedPredicate], TParsedPredicates), | |
178 | TParsedPredicates = [TPredicateAst], | |
179 | b_get_machine_variables(MachineVars), | |
180 | b_get_all_used_identifiers(AllUsedIds), | |
181 | remove_foolish_sub_predicates(AllUsedIds, TPredicateAst, PredicateAst), | |
182 | get_library_components_from_pred_or_expr(PredicateAst, UsedComponents), | |
183 | get_normalized_id_name_mapping(NormalizedSets, NormalizedIds, NOperationNames), | |
184 | Env = [[], NormalizedSets, NormalizedIds, NOperationNames], | |
185 | b_get_typed_invariant_from_machine(Invariant), | |
186 | PredWithInv = b(conjunct(Invariant, PredicateAst),pred,[]), | |
187 | get_augmented_states_for_predicate(Env, AugmentRecords, MachineVars, UsedComponents, PredWithInv, AugmentedSetOfData, _), | |
188 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref). | |
189 | ||
190 | %% reset_old_preferences(+OldKodkodPref, +OldTimeOutPref, +OldRandPref). | |
191 | % | |
192 | reset_old_preferences(OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
193 | set_preference(try_kodkod_on_load, OldKodkodPref), | |
194 | set_preference(time_out, OldTimeOutPref), | |
195 | set_preference(randomise_enumeration_order, OldRandPref). | |
196 | ||
197 | %% set_desired_preferences(-OldKodkodPref, -OldTimeOutPref, -OldRandPref). | |
198 | % | |
199 | % Set desired preferences and return the old ones. | |
200 | set_desired_preferences(Timeout, OldKodkodPref, OldTimeOutPref, OldRandPref) :- | |
201 | get_preference(try_kodkod_on_load, OldKodkodPref), | |
202 | set_preference(try_kodkod_on_load, false), | |
203 | get_preference(randomise_enumeration_order, OldRandPref), | |
204 | set_preference(randomise_enumeration_order, true), | |
205 | get_preference(time_out, OldTimeOutPref), | |
206 | set_preference(time_out, Timeout). |