query.py 8.22 KB
Newer Older
1
from .connection import get_db_connection
2
from conceptnet5.edges import transform_for_linked_data
3
import json
Rob Speer's avatar
Rob Speer committed
4
import itertools
5
from ftfy.fixes import remove_control_chars
6

7
NODE_PREFIX_CRITERIA = {'node', 'other', 'start', 'end'}
8
LIST_QUERIES = {}
9
FEATURE_QUERIES = {}
10

11
12
13
RANDOM_QUERY = "SELECT uri, data FROM edges TABLESAMPLE SYSTEM(0.01) ORDER BY random() LIMIT %(limit)s"
RANDOM_NODES_QUERY = "SELECT * FROM nodes TABLESAMPLE SYSTEM(1) WHERE uri LIKE :prefix ORDER BY random() LIMIT %(limit)s"
DATASET_QUERY = "SELECT uri, data FROM edges TABLESAMPLE SYSTEM(0.01) WHERE data->'dataset' = %(dataset)s ORDER BY weight DESC OFFSET %(offset)s LIMIT %(limit)s"
Rob Speer's avatar
Rob Speer committed
14

Rob Speer's avatar
Rob Speer committed
15

16
17
TOO_BIG_PREFIXES = ['/c/en', '/c/fr', '/c/es', '/c/de', '/c/ja', '/c/zh', '/c/pt', '/c/la', '/c/it', '/c/ru' ,'/c/fi']

18
19
20
NODE_TO_FEATURE_QUERY = """
WITH node_ids AS (
    SELECT p.node_id FROM nodes n, node_prefixes p
21
    WHERE p.prefix_id=n.id AND n.uri=%(node)s
Rob Speer's avatar
Rob Speer committed
22
    LIMIT 10
23
)
Rob Speer's avatar
Rob Speer committed
24
25
26
27
28
SELECT rf.direction, r.uri, e.data
FROM ranked_features rf, edges e, relations r
WHERE rf.node_id IN (SELECT node_id FROM node_ids)
AND rf.edge_id = e.id
AND rf.rel_id = r.id
29
AND rank <= %(limit)s
Rob Speer's avatar
Rob Speer committed
30
ORDER BY direction, uri, rank;
31
"""
Rob Speer's avatar
Rob Speer committed
32
MAX_GROUP_SIZE = 20
33
34


35
def make_list_query(criteria):
36
37
38
    crit_tuple = tuple(sorted(criteria))
    if crit_tuple in LIST_QUERIES:
        return LIST_QUERIES[crit_tuple]
39
    parts = ["WITH matched_edges AS ("]
40
41
    if 'node' in criteria:
        piece_directions = [1, -1]
42
43
    else:
        piece_directions = [1]
44
45
46
    for direction in piece_directions:
        if direction == -1:
            parts.append("UNION ALL")
47
48
49
50
51
        parts.append("SELECT e.uri, e.weight, e.data, np1.uri as starturi, np2.uri as enduri")
        if direction == 1:
            parts.append(", np1.uri as node, np2.uri as other")
        else:
            parts.append(", np2.uri as node, np1.uri as other")
52
        parts.append("""
53
54
            FROM relations r, edges e, nodes n1, nodes n2,
                 node_prefixes p1, node_prefixes p2, nodes np1, nodes np2
55
56
        """)
        if 'source' in criteria:
Rob Speer's avatar
Rob Speer committed
57
            parts.append(", edge_sources es, sources s")
58
        parts.append("""
59
            WHERE e.relation_id=r.id
60
61
            AND e.start_id=n1.id
            AND e.end_id=n2.id
62
63
64
65
            AND p1.prefix_id=np1.id
            AND p1.node_id=n1.id
            AND p2.prefix_id=np2.id
            AND p2.node_id=n2.id
66
67
        """)
        if 'source' in criteria:
68
            parts.append("AND s.uri=%(source)s AND es.source_id=s.id AND es.edge_id=e.id")
69
        if 'node' in criteria and 'filter_node' not in criteria:
70
            if direction == 1:
71
                parts.append("AND np1.uri = %(node)s")
72
            else:
73
                parts.append("AND np2.uri = %(node)s")
74
        if 'other' in criteria and 'filter_other' not in criteria:
75
            if direction == 1:
76
                parts.append("AND np2.uri = %(other)s")
77
            else:
78
                parts.append("AND np1.uri = %(other)s")
79
        if 'rel' in criteria:
80
            parts.append("AND r.uri = %(rel)s")
81
        if 'start' in criteria and 'filter_start' not in criteria:
82
            parts.append("AND np1.uri = %(start)s")
83
        if 'end' in criteria and 'filter_end' not in criteria:
84
            parts.append("AND np2.uri = %(end)s")
85
    parts.append("LIMIT 10000")
86
    parts.append(")")
87
88
89
90
91
92
93
94
95
96
97
98
    parts.append("SELECT DISTINCT ON (weight, uri) uri, data FROM matched_edges")
    more_clauses = []
    if 'filter_node' in criteria:
        more_clauses.append('node LIKE %(filter_node)s')
    if 'filter_other' in criteria:
        more_clauses.append('other LIKE %(filter_other)s')
    if 'filter_start' in criteria:
        more_clauses.append('starturi LIKE %(filter_start)s')
    if 'filter_end' in criteria:
        more_clauses.append('enduri LIKE %(filter_end)s')
    if more_clauses:
        parts.append("WHERE " + " AND ".join(more_clauses))
99
100
    parts.append("""
        ORDER BY weight DESC, uri
101
        OFFSET %(offset)s LIMIT %(limit)s
102
    """)
103
104
105
106
107
108
    query = '\n'.join(parts)
    LIST_QUERIES[crit_tuple] = query
    return query


class AssertionFinder(object):
Rob Speer's avatar
Rob Speer committed
109
    def __init__(self, dbname=None):
Rob Speer's avatar
Rob Speer committed
110
111
        self.connection = None
        self.dbname = dbname
112

113
    def lookup(self, uri, limit=100, offset=0):
Rob Speer's avatar
Rob Speer committed
114
115
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)
116
117
118
119
120
121
        if uri.startswith('/c/') or uri.startswith('http'):
            criteria = {'node': uri}
        elif uri.startswith('/r/'):
            criteria = {'rel': uri}
        elif uri.startswith('/s/'):
            criteria = {'source': uri}
122
123
124
125
        elif uri.startswith('/a/'):
            return self.lookup_assertion(uri)
        elif uri.startswith('/d/'):
            return self.sample_dataset(uri, limit, offset)
126
        else:
127
            raise ValueError("%r isn't a ConceptNet URI that can be looked up")
128
129
        return self.query(criteria, limit, offset)

130
    def lookup_grouped_by_feature(self, uri, limit=20):
131
        uri = remove_control_chars(uri)
Rob Speer's avatar
Rob Speer committed
132
133
134
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)

Rob Speer's avatar
Rob Speer committed
135
136
        def extract_feature(row):
            return tuple(row[:2])
137

Rob Speer's avatar
Rob Speer committed
138
139
        def feature_data(row):
            direction, _, data = row
140

Rob Speer's avatar
Rob Speer committed
141
142
143
144
145
146
147
148
149
150
151
152
            # Hacky way to figure out what the 'other' node is, the one that
            # (in most cases) didn't match the URI. If both start with our
            # given URI, take the longer one, which is either a more specific
            # sense or a different, longer word.
            shorter, longer = sorted([data['start'], data['end']], key=len)
            if shorter.startswith(uri):
                data['other'] = longer
            else:
                data['other'] = shorter
            return data

        cursor = self.connection.cursor()
153
        cursor.execute(NODE_TO_FEATURE_QUERY, {'node': uri, 'limit': limit})
Rob Speer's avatar
Rob Speer committed
154
155
156
157
        results = {}
        for feature, rows in itertools.groupby(cursor.fetchall(), extract_feature):
            results[feature] = [transform_for_linked_data(feature_data(row)) for row in rows]
        return results
158

159
    def lookup_assertion(self, uri):
160
161
162
        # Sanitize URIs to remove control characters such as \x00. The postgres driver would
        # remove \x00 anyway, but this avoids reporting a server error when that happens.
        uri = remove_control_chars(uri)
Rob Speer's avatar
Rob Speer committed
163
164
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)
165
        cursor = self.connection.cursor()
166
        cursor.execute("SELECT data FROM edges WHERE uri=%(uri)s", {'uri': uri})
167
168
169
170
        results = [transform_for_linked_data(data) for (data,) in cursor.fetchall()]
        return results

    def sample_dataset(self, uri, limit=50, offset=0):
171
        uri = remove_control_chars(uri)
Rob Speer's avatar
Rob Speer committed
172
173
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)
174
175
176
177
178
179
        cursor = self.connection.cursor()
        dataset_json = json.dumps(uri)
        cursor.execute(DATASET_QUERY, {'dataset': dataset_json, 'limit': limit, 'offset': offset})
        results = [transform_for_linked_data(data) for uri, data in cursor.fetchall()]
        return results

180
    def random_edges(self, limit=20):
Rob Speer's avatar
Rob Speer committed
181
182
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)
183
184
185
186
        cursor = self.connection.cursor()
        cursor.execute(RANDOM_QUERY, {'limit': limit})
        results = [transform_for_linked_data(data) for uri, data in cursor.fetchall()]
        return results
187
188

    def query(self, criteria, limit=20, offset=0):
189
        criteria = criteria.copy()
Rob Speer's avatar
Rob Speer committed
190
191
        if self.connection is None:
            self.connection = get_db_connection(self.dbname)
192
193
194
195
196
        for criterion in ['node', 'other', 'start', 'end']:
            if criterion in criteria and criteria[criterion] in TOO_BIG_PREFIXES:
                criteria['filter_' + criterion] = criteria[criterion] + '%'

        query_string = make_list_query(criteria)
197
198
199
200
        params = {
            key: remove_control_chars(value)
            for (key, value) in criteria.items()
        }
201
202
        params['limit'] = limit
        params['offset'] = offset
203

204
        cursor = self.connection.cursor()
205
        print(query_string, params)
206
        cursor.execute(query_string, params)
207
208
        results = [
            transform_for_linked_data(data) for uri, data in cursor.fetchall()
209
        ]
210
        return results