Skip to content

Commit 45dc6e5

Browse files
committed
WIP : embedded query into mutate_deep_narrow_path,
added function variable_substitution_deep_narrow_mut_query
1 parent 990a27f commit 45dc6e5

File tree

4 files changed

+132
-23
lines changed

4 files changed

+132
-23
lines changed

config/defaults.py

+4
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,9 @@
8282
MUTPB_DN_MIN_LEN = 2 # minimum length of the deep and narrow paths
8383
MUTPB_DN_MAX_LEN = 10 # absolute max of path length if not stopped by term_pb
8484
MUTPB_DN_TERM_PB = 0.3 # prob to terminate node expansion each step > min_len
85+
MUTPB_DN_FILTER_NODE_COUNT = 10
86+
MUTPB_DN_FILTER_EDGE_COUNT = 1
87+
MUTPB_DN_QUERY_LIMIT = 32
8588
# for import in helpers and __init__
89+
8690
__all__ = [_v for _v in globals().keys() if _v.isupper()]

gp_learner.py

+72-7
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from cluster import expected_precision_loss_by_query_reduction
4242
from cluster import select_best_variant
4343
import config
44-
from gp_query import ask_multi_query
44+
from gp_query import ask_multi_query, \
45+
variable_substitution_deep_narrow_mut_query
4546
from gp_query import calibrate_query_timeout
4647
from gp_query import combined_ask_count_multi_query
4748
from gp_query import predict_query
@@ -420,7 +421,7 @@ def _mutate_expand_node_helper(node, pb_en_out_link=config.MUTPB_EN_OUT_LINK):
420421
new_triple = (node, var_edge, var_node)
421422
else:
422423
new_triple = (var_node, var_edge, node)
423-
return new_triple, var_node
424+
return new_triple, var_node, var_edge
424425

425426

426427
def mutate_expand_node(child, node=None):
@@ -433,11 +434,10 @@ def mutate_expand_node(child, node=None):
433434

434435

435436
def mutate_deep_narrow_path(
436-
child,
437+
child, sparql, timeout, gtp_scores,
437438
min_len=config.MUTPB_DN_MIN_LEN,
438439
max_len=config.MUTPB_DN_MAX_LEN,
439440
term_pb=config.MUTPB_DN_TERM_PB,
440-
pb_en_out_link=config.MUTPB_EN_OUT_LINK,
441441
):
442442
assert isinstance(child, GraphPattern)
443443
nodes = list(child.nodes)
@@ -451,15 +451,76 @@ def mutate_deep_narrow_path(
451451
if hop >= max_len:
452452
break
453453
hop += 1
454-
new_triple, var_node = _mutate_expand_node_helper(start_node)
455-
gp += [new_triple]
456-
start_node = var_node
454+
new_triple, var_node, var_edge = _mutate_expand_node_helper(start_node)
455+
test_gp = gp + [new_triple]
456+
test_gp, fixed = _mutate_deep_narrow_path_helper(
457+
sparql, timeout, gtp_scores, test_gp, var_edge, var_node)
458+
if fixed == 'Y':
459+
start_node = var_node
460+
gp = test_gp
457461

458462
# TODO: insert connection to a target node
459463
# TODO: fix edge or node ( to_count_var_over_values_query)
460464
return gp
461465

462466

467+
def _mutate_deep_narrow_path_helper(
468+
sparql,
469+
timeout,
470+
gtp_scores,
471+
child,
472+
edge_var,
473+
node_var,
474+
gtp_sample_n=config.MUTPB_FV_RGTP_SAMPLE_N,
475+
limit_res=config.MUTPB_DN_QUERY_LIMIT,
476+
sample_n=config.MUTPB_FV_SAMPLE_MAXN,
477+
):
478+
assert isinstance(child, GraphPattern)
479+
assert isinstance(gtp_scores, GTPScores)
480+
481+
# The further we get, the less gtps are remaining. Sampling too many (all)
482+
# of them might hurt as common substitutions (> limit ones) which are dead
483+
# ends could cover less common ones that could actually help
484+
gtp_sample_n = min(gtp_sample_n, int(gtp_scores.remaining_gain))
485+
gtp_sample_n = random.randint(1, gtp_sample_n)
486+
487+
ground_truth_pairs = gtp_scores.remaining_gain_sample_gtps(
488+
n=gtp_sample_n)
489+
t, substitution_counts = variable_substitution_deep_narrow_mut_query(
490+
sparql, timeout, child, edge_var, node_var, ground_truth_pairs,
491+
limit_res)
492+
if not substitution_counts:
493+
# the current pattern is unfit, as we can't find anything fulfilling it
494+
logger.debug("tried to fix a var %s without result:\n%s"
495+
"seems as if the pattern can't be fulfilled!",
496+
edge_var, child.to_sparql_select_query())
497+
fixed = 'N'
498+
return [child], fixed
499+
mutate_fix_var_filter(substitution_counts)
500+
if not substitution_counts:
501+
# could have happened that we removed the only possible substitution
502+
fixed = 'N'
503+
return [child], fixed
504+
# randomly pick n of the substitutions with a prob ~ to their counts
505+
items, counts = zip(*substitution_counts.most_common())
506+
substs = sample_from_list(items, counts, sample_n)
507+
logger.info(
508+
'fixed variable %s in %sto:\n %s\n<%d out of:\n%s\n',
509+
edge_var.n3(),
510+
child,
511+
'\n '.join([subst.n3() for subst in substs]),
512+
sample_n,
513+
'\n'.join([' %d: %s' % (c, v.n3())
514+
for v, c in substitution_counts.most_common()]),
515+
)
516+
fixed = 'Y'
517+
res = [
518+
GraphPattern(child, mapping={edge_var: subst})
519+
for subst in substs
520+
]
521+
return res, fixed
522+
523+
463524
def mutate_add_edge(child):
464525
# TODO: can maybe be improved by sparqling
465526
nodes = list(child.nodes)
@@ -682,6 +743,7 @@ def mutate(
682743
pb_dt=config.MUTPB_DT,
683744
pb_en=config.MUTPB_EN,
684745
pb_fv=config.MUTPB_FV,
746+
pb_dn=config.MUTPB_DN,
685747
pb_id=config.MUTPB_ID,
686748
pb_iv=config.MUTPB_IV,
687749
pb_mv=config.MUTPB_MV,
@@ -721,6 +783,9 @@ def mutate(
721783
if random.random() < pb_sp:
722784
child = mutate_simplify_pattern(child)
723785

786+
if random.random() < pb_dn:
787+
child = mutate_deep_narrow_path(child, sparql, timeout, gtp_scores)
788+
724789
if random.random() < pb_fv:
725790
child = canonicalize(child)
726791
children = mutate_fix_var(sparql, timeout, gtp_scores, child)

gp_query.py

+41
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,21 @@ def variable_substitution_query(
426426
)
427427

428428

429+
def variable_substitution_deep_narrow_mut_query(
430+
sparql, timeout, graph_pattern, edge_var, node_var,
431+
source_target_pairs, limit_res, batch_size=config.BATCH_SIZE):
432+
_vars, _values, _ret_val_mapping = _get_vars_values_mapping(
433+
graph_pattern, source_target_pairs)
434+
_edge_var_node_var_and_vars = (edge_var, node_var, _vars)
435+
return _multi_query(
436+
sparql, timeout, graph_pattern, source_target_pairs, batch_size,
437+
_edge_var_node_var_and_vars, _values, _ret_val_mapping,
438+
_var_subst_res_init, _var_subst_dnp_chunk_q,
439+
_var_subst_dnp_chunk_result_ext, limit=limit_res,
440+
# non standard, passed via **kwds, see handling below
441+
)
442+
443+
429444
# noinspection PyUnusedLocal
430445
def _var_subst_res_init(_, **kwds):
431446
return Counter()
@@ -440,6 +455,17 @@ def _var_subst_chunk_q(gp, _sel_var_and_vars, values_chunk, limit):
440455
limit=limit)
441456

442457

458+
def _var_subst_dnp_chunk_q(gp, _edge_var_node_var_and_vars,
459+
values_chunk, limit):
460+
edge_var, node_var, _vars = _edge_var_node_var_and_vars
461+
return gp.to_find_edge_var_for_narrow_path_query(
462+
edge_var=edge_var,
463+
node_var=node_var,
464+
vars_=_vars,
465+
values={_vars: values_chunk},
466+
limit_res=limit)
467+
468+
443469
# noinspection PyUnusedLocal
444470
def _var_subst_chunk_result_ext(q_res, _sel_var_and_vars, _, **kwds):
445471
var, _vars = _sel_var_and_vars
@@ -456,6 +482,21 @@ def _var_subst_chunk_result_ext(q_res, _sel_var_and_vars, _, **kwds):
456482
return chunk_res
457483

458484

485+
def _var_subst_dnp_chunk_result_ext(q_res, _edge_var_node_var_and_vars, _, **kwds):
486+
edge_var, node_var, _vars = _edge_var_node_var_and_vars
487+
chunk_res = Counter()
488+
res_rows_path = ['results', 'bindings']
489+
bindings = sparql_json_result_bindings_to_rdflib(
490+
get_path(q_res, res_rows_path, default=[])
491+
)
492+
493+
for row in bindings:
494+
row_res = get_path(row, [edge_var])
495+
count_res = int(get_path(row, [COUNT_VAR], '0'))
496+
chunk_res[row_res] += count_res
497+
return chunk_res
498+
499+
459500
def _var_subst_res_update(res, update, **_):
460501
res += update
461502

graph_pattern.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import six
3232

3333
from utils import URIShortener
34+
import config
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -41,10 +42,12 @@
4142
TARGET_VAR = Variable('target')
4243
ASK_VAR = Variable('ask')
4344
COUNT_VAR = Variable('count')
44-
EDGE_VAR_COUNT = Variable('edge_count_var')
45-
NODE_VAR_COUNT = Variable('node_count_var')
46-
MAX_NODE_COUNT = Variable('maximum node count')
47-
PRIO_VAR = Variable('priority')
45+
EDGE_VAR_COUNT = Variable('edge_var_count')
46+
NODE_VAR_COUNT = Variable('node_var_count')
47+
MAX_NODE_COUNT = Variable('max_node_count')
48+
PRIO_VAR = Variable('prio_var')
49+
50+
4851
def gen_random_var():
4952
return Variable(RANDOM_VAR_PREFIX + ''.join(
5053
random.choice(string.ascii_letters + string.digits)
@@ -711,9 +714,10 @@ def to_count_var_over_values_query(self, var, vars_, values, limit):
711714
res += 'LIMIT %d\n' % limit
712715
return self._sparql_prefix(res)
713716

714-
def to_find_edge_var_for_narrow_path_query(self, edge_var, node_var,
715-
vars_, filter_node_count,
716-
filter_edge_count, limit_res):
717+
def to_find_edge_var_for_narrow_path_query\
718+
(self, edge_var, node_var, vars_, values, limit_res,
719+
filter_node_count=config.MUTPB_DN_FILTER_NODE_COUNT,
720+
filter_edge_count=config.MUTPB_DN_FILTER_EDGE_COUNT):
717721
"""Counts possible substitutions for edge_var to get a narrow path
718722
719723
Meant to perform a query like this:
@@ -759,21 +763,16 @@ def to_find_edge_var_for_narrow_path_query(self, edge_var, node_var,
759763

760764
res = 'SELECT * WHERE {\n'
761765
res += ' {\n'\
762-
' SELECT %s (COUNT(*) as %s) (Max(%s) AS %s) ' \
763-
' (COUNT(*)/AVG(%s) AS %s) WHERE {\n' % (
766+
' SELECT %s (COUNT(*) AS %s) (MAX(%s) AS %s) ' \
767+
'(COUNT(*)/AVG(%s) AS %s) WHERE {\n' % (
764768
edge_var.n3(), EDGE_VAR_COUNT.n3(),
765769
NODE_VAR_COUNT.n3(), MAX_NODE_COUNT.n3(),
766770
NODE_VAR_COUNT.n3(), PRIO_VAR.n3())
767771
res += ' SELECT DISTINCT %s %s (COUNT(%s) AS %s) WHERE {\n' % (
768772
' '.join([v.n3() for v in vars_]),
769773
edge_var.n3(), node_var.n3(), NODE_VAR_COUNT.n3())
770-
# res += self._sparql_values_part(values)
771-
res += 'VALUES(%s) {\n' \
772-
'(dbr: Adolescence dbr: Youth)' \
773-
'(dbr:Adult dbr:Child)' \
774-
'(dbr:Angel dbr:Heaven)' \
775-
'(dbr:Arithmetic dbr:Mathematics)' \
776-
'}\n' % (' '.join([v.n3() for v in vars_]))
774+
res += self._sparql_values_part(values)
775+
777776
# triples part
778777
tres = []
779778
for s, p, o in self:

0 commit comments

Comments
 (0)