Skip to content

Commit f1dca6b

Browse files
committed
Normalize MCQ (set-input query) score to range [0-1]
1 parent cc7c566 commit f1dca6b

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

cohd/cohd_trapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def operate(self):
6969
batch_size_limit = 100 # max length of any IDs list
7070
limit_max_results = 500
7171
json_inf_replacement = 999 # value to replace +/-Infinity with in JSON
72+
mcq_score_scaling = 0.75 # magic number to adjust normalized MCQ score
7273
supported_query_methods = ['relativeFrequency', 'obsExpRatio', 'chiSquare']
7374
supported_operation = 'lookup_and_score'
7475

cohd/cohd_trapi_15.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class CohdTrapi150(CohdTrapi):
4242
edge_types_negative = ['biolink:negatively_correlated_with']
4343
default_negative_predicate = edge_types_negative[0]
4444

45-
tool_version = f'{CohdTrapi._SERVICE_NAME} 6.5.3'
45+
tool_version = f'{CohdTrapi._SERVICE_NAME} 6.5.4'
4646
schema_version = '1.5.0'
4747
biolink_version = bm_version
4848

@@ -600,15 +600,15 @@ def _interpret_query(self):
600600
if self._concept_1_set_interpretation == 'BATCH':
601601
ids = list(set(concept_1_qnode['ids'])) # remove duplicate CURIEs
602602
elif self._concept_1_set_interpretation == 'MANY':
603-
member_ids = concept_1_qnode.get('member_ids')
604-
if not member_ids:
603+
self._mcq_member_ids = concept_1_qnode.get('member_ids')
604+
if not self._mcq_member_ids:
605605
# Missing required member_ids for MCQ
606606
self._valid_query = False
607607
description = 'set_interpretation: MANY but no member_ids'
608608
response = self._trapi_mini_response(TrapiStatusCode.MISSING_MEMBER_IDS, description)
609609
self._invalid_query_response = response, 200
610610
return self._valid_query, self._invalid_query_response
611-
ids = list(set(concept_1_qnode['member_ids'])) # remove duplicate CURIEs
611+
ids = list(self._mcq_member_ids) # remove duplicate CURIEs
612612

613613
# Get the MCQ set ID
614614
self._mcq_set_id = concept_1_qnode['ids'][0]
@@ -999,31 +999,36 @@ def operate_mcq(self):
999999
# categories (domains)
10001000
for domain_id, concept_class_id in self._domain_class_pairs:
10011001
new_results = query_cohd_mysql.query_trapi_mcq(concept_ids=self._concept_1_omop_ids,
1002-
dataset_id=self._dataset_id,
1003-
domain_id=domain_id,
1004-
concept_class_id=concept_class_id,
1005-
ln_ratio_sign=self._association_direction,
1006-
confidence=self._confidence_interval,
1007-
bypass=self._bypass_cache)
1002+
n_member_ids=len(self._mcq_member_ids),
1003+
score_scaling=CohdTrapi.mcq_score_scaling,
1004+
dataset_id=self._dataset_id,
1005+
domain_id=domain_id,
1006+
concept_class_id=concept_class_id,
1007+
ln_ratio_sign=self._association_direction,
1008+
confidence=self._confidence_interval,
1009+
bypass=self._bypass_cache)
10081010
new_set_results, new_single_results = new_results
10091011
if new_set_results:
10101012
set_results.extend(new_set_results)
10111013
single_results.update(new_single_results)
10121014
else:
10131015
# No category (domain) was specified for Node 2. Query the associations between Node 1 and all
10141016
# domains
1015-
new_results = query_cohd_mysql.query_trapi_mcq(concept_id_1=self._concept_1_omop_ids,
1016-
dataset_id=self._dataset_id, domain_id=None,
1017-
ln_ratio_sign=self._association_direction,
1018-
confidence=self._confidence_interval,
1019-
bypass=self._bypass_cache)
1017+
new_results = query_cohd_mysql.query_trapi_mcq(concept_ids=self._concept_1_omop_ids,
1018+
n_member_ids=len(self._mcq_member_ids),
1019+
score_scaling=CohdTrapi.mcq_score_scaling,
1020+
dataset_id=self._dataset_id,
1021+
domain_id=None,
1022+
ln_ratio_sign=self._association_direction,
1023+
confidence=self._confidence_interval,
1024+
bypass=self._bypass_cache)
10201025
new_set_results, new_single_results = new_results
10211026
if new_set_results:
10221027
set_results.extend(new_set_results)
10231028
single_results.update(new_single_results)
10241029

10251030
# Results within each query call should be sorted, but still need to be sorted across query calls
1026-
new_set_results = sort_cohd_results(new_set_results, sort_field='ln_ratio_score')
1031+
new_set_results = sort_cohd_results(new_set_results, sort_field='mcq_score')
10271032

10281033
# Convert results from COHD format to Translator Reasoner standard
10291034
self._add_mcq_results_to_trapi(set_results, single_results)
@@ -1169,8 +1174,8 @@ def _add_mcq_result(self, set_result, single_results, criteria):
11691174
kg_node_2, kg_set_edge, kg_set_edge_id = self._add_kg_set_edge(node_2, is_subject, set_result)
11701175

11711176
# Add to results
1172-
score = set_result['ln_ratio_score']
1173-
self._add_result(self._mcq_set_id, concept_2_curie, kg_set_edge_id, score)
1177+
score = set_result['mcq_score']
1178+
self._add_result(self._mcq_set_id, concept_2_curie, kg_set_edge_id, score, mcq=True)
11741179

11751180
# Add single result edges and auxiliary graphs
11761181
support_graphs = list()
@@ -1196,7 +1201,7 @@ def _add_mcq_result(self, set_result, single_results, criteria):
11961201
"value": support_graphs
11971202
})
11981203

1199-
def _add_result(self, kg_node_1_id, kg_node_2_id, kg_edge_id, score):
1204+
def _add_result(self, kg_node_1_id, kg_node_2_id, kg_edge_id, score, mcq=False):
12001205
""" Adds a knowledge graph edge to the results list
12011206
12021207
Parameters
@@ -1205,6 +1210,7 @@ def _add_result(self, kg_node_1_id, kg_node_2_id, kg_edge_id, score):
12051210
kg_node_2_id: Object node ID
12061211
kg_edge_id: edge ID
12071212
score: result score
1213+
mcq: True/False if MCQ analysis
12081214
12091215
Returns
12101216
-------
@@ -1231,7 +1237,7 @@ def _add_result(self, kg_node_1_id, kg_node_2_id, kg_edge_id, score):
12311237
}]
12321238
},
12331239
'score': score,
1234-
'scoring_method': 'Lower bound of biolink:ln_ratio_confidence_interval',
1240+
'scoring_method': 'COHD set-input query scoring, range: [0,1]' if mcq else 'Lower bound of biolink:ln_ratio_confidence_interval',
12351241
}
12361242
]
12371243
}
@@ -1913,7 +1919,7 @@ def _add_kg_set_edge(self, node_2, is_subject, set_result):
19131919
'value_type_id': 'EDAM:data_1772', # Score
19141920
'attribute_source': CohdTrapi._INFORES_ID,
19151921
'description': 'Observed-expected frequency ratio.'
1916-
},
1922+
},
19171923
{
19181924
'attribute_type_id': 'biolink:supporting_data_set', # Database ID
19191925
'original_attribute_name': 'dataset_id',

cohd/query_cohd_mysql.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pymysql
22
from flask import jsonify
33
from scipy.stats import chisquare
4+
import numpy as np
45
from numpy import argsort
56
import logging
67
import pandas as pd
@@ -15,6 +16,7 @@
1516
DATASET_ID_DEFAULT = 1
1617
DATASET_ID_DEFAULT_HIER = 3
1718
DEFAULT_CONFIDENCE = 0.99
19+
DEFAULT_MCQ_SCORE_SCALING = 0.75
1820

1921
# OXO API configuration
2022
URL_OXO_SEARCH = 'https://www.ebi.ac.uk/spot/oxo/api/search'
@@ -1132,7 +1134,11 @@ def query_db(service, method, args):
11321134
elif type(concept_ids) is not list:
11331135
concept_ids = [concept_ids]
11341136

1135-
set_results, single_results = query_trapi_mcq(concept_ids, dataset_id, domain_id, bypass=True)
1137+
set_results, single_results = query_trapi_mcq(concept_ids=concept_ids,
1138+
n_member_ids=len(concept_ids),
1139+
dataset_id=dataset_id,
1140+
domain_id=domain_id,
1141+
bypass=True)
11361142
json_return = {
11371143
'set_results': set_results,
11381144
'single_results': single_results
@@ -1837,7 +1843,7 @@ def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_i
18371843
concept_list_1_w_df= pd.DataFrame({'concept_id_1':concept_id_1})
18381844
concept_list_1_w_df['w'] = 1
18391845

1840-
# Calculate the weights based on Jaccard index between input concep
1846+
# Calculate the weights based on Jaccard index between input concepts
18411847
pair_count_q1 = pd.DataFrame(get_pair_concept_count(cur=cur,dataset_id=dataset_id,domain_id=domain_id, concept_id_list_1=concept_id_1,concept_id_list_2=concept_id_1))
18421848
if pair_count_q1.shape[0] > 0:
18431849
# Sum of Jaccard index
@@ -1849,6 +1855,7 @@ def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_i
18491855
# Weight = 1/(1 + sum(Jaccards))
18501856
concept_list_1_w_df['w'] = 1/concept_list_1_w_df['w']
18511857
concept_list_1_w_df = concept_list_1_w_df[['concept_id_1','w']]
1858+
total_weights = concept_list_1_w_df.w.sum()
18521859

18531860
# Multiply the scores by the weights
18541861
pair_count_df = pair_count_df.merge(concept_list_1_w_df)
@@ -1858,7 +1865,7 @@ def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_i
18581865
# Group by concept_id_2. Sum the scores and combine concept_id_1 into a list
18591866
gb = pair_count_df.groupby('concept_id_2')
18601867
weighted_stats = gb[json_key].agg('sum')
1861-
return weighted_stats.reset_index()
1868+
return weighted_stats.reset_index(), total_weights
18621869

18631870

18641871
def _get_ci_scores(r, score_col):
@@ -1871,14 +1878,17 @@ def _get_ci_scores(r, score_col):
18711878

18721879

18731880
@cache.memoize(timeout=86400, unless=_bypass_cache)
1874-
def query_trapi_mcq(concept_ids, dataset_id=None, domain_id=None, concept_class_id=None,
1881+
def query_trapi_mcq(concept_ids, n_member_ids, score_scaling=DEFAULT_MCQ_SCORE_SCALING,
1882+
dataset_id=None, domain_id=None, concept_class_id=None,
18751883
ln_ratio_sign=0, confidence=DEFAULT_CONFIDENCE, bypass=False):
18761884
""" Query for TRAPI Multicurie Query. Calculates weighted scores using methods similar to linkage disequilibrium to
18771885
downweight contributions from input concepts that are similar to each other
18781886
18791887
Parameters
18801888
----------
18811889
concept_ids: list of OMOP concept IDs
1890+
n_member_ids: number of input IDs in set node
1891+
score_scaling: linear scaling of ln_ratio_score prior to logistic normalization
18821892
dataset_id: (optional) String - COHD dataset ID
18831893
domain_id: (optional) String - OMOP domain ID
18841894
concept_class_id: (optional) String - OMOP concept class ID
@@ -1912,13 +1922,19 @@ def query_trapi_mcq(concept_ids, dataset_id=None, domain_id=None, concept_class_
19121922

19131923
# Adjust the scores by weights
19141924
concept_list_1 = list(set(associations['concept_id_1'].tolist()))
1915-
weighted_ln_ratio = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id,
1925+
weighted_ln_ratio, total_weights = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id,
19161926
concept_id_1=concept_list_1, pair_count_df=associations,
19171927
json_key = 'ln_ratio_score')
19181928
# weighted_log_odds = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id,
19191929
# concept_id_1=concept_list_1, pair_count_df=associations,
19201930
# json_key = 'log_odds_score')
19211931

1932+
# For TRAPI result score, normalize the score relative to the number of input CURIEs and
1933+
# scale the score range to [0-1] using a scaled logistic function
1934+
n_mapped_ids = len(concept_list_1)
1935+
weighted_ln_ratio['mcq_score'] = weighted_ln_ratio['ln_ratio_score'] / total_weights * n_mapped_ids / n_member_ids
1936+
weighted_ln_ratio['mcq_score'] = (1/(1+np.exp(-np.abs(weighted_ln_ratio['mcq_score']*score_scaling)))-0.5) * 2
1937+
19221938
# Add list of single associations
19231939
single_associations = dict()
19241940
for i, row in weighted_ln_ratio.iterrows():

0 commit comments

Comments
 (0)