Skip to content

Commit 89d10bc

Browse files
committed
enhance contamination
1 parent 2e0911f commit 89d10bc

File tree

1 file changed

+48
-22
lines changed

1 file changed

+48
-22
lines changed

app.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -234,52 +234,78 @@ def impurity(text, min_len=10):
234234
with metrics:
235235

236236
with st.spinner('Calculating contamination ratio...'):
237-
238237
train_dataset = datasets['train']
239238
test_dataset = datasets['test']
239+
240240
from nltk import ngrams
241-
def generate_ngrams(text, n=8):
242-
return set(ngrams(text.split(), n))
241+
from datasketch import MinHash, MinHashLSH
242+
243+
def process_data(df):
244+
minhashes = {}
245+
for idx, r in df.iterrows():
246+
minhash = MinHash(num_perm=128)
247+
for d in ngrams(r['text'], 13):
248+
s = "".join(d).encode('utf-8')
249+
minhash.update(s)
250+
minhashes[idx] = minhash
251+
return minhashes
243252

244-
train_dataset['ngrams'] = train_dataset['text'].apply(generate_ngrams)
245-
test_dataset['ngrams'] = test_dataset['text'].apply(generate_ngrams)
253+
train_minhashes = process_data(train_dataset)
254+
test_minhashes = process_data(test_dataset)
246255

247-
# Creating a set of n-grams in the train set
248-
train_ngrams = set.union(*train_dataset['ngrams'])
256+
lsh = MinHashLSH(threshold=0.8, num_perm=128)
249257

250-
# Creating a boolean mask marking documents in the test set that have appeared in the train set
251-
common_docs = test_dataset['ngrams'].apply(lambda x: not x.isdisjoint(train_ngrams))
252-
common_docs_count = common_docs.sum()
258+
for idx, minhash in train_minhashes.items():
259+
lsh.insert(idx, minhash)
260+
261+
duplicates_count = 0
262+
for idx, minhash in test_minhashes.items():
263+
result = lsh.query(minhash)
264+
if len(result) > 0:
265+
duplicates_count += 1
253266

254267
train_dataset_count = len(train_dataset)
255268
test_dataset_count = len(test_dataset)
256-
contaminate_ratio = common_docs_count / test_dataset_count
269+
contaminate_ratio = duplicates_count / test_dataset_count
257270

258271
col1, col2, col3, col4 = st.columns(4)
259272
col1.metric(label="Train Set Size", value="%d" % train_dataset_count)
260273
col2.metric(label="Test Set Size", value="%d" % test_dataset_count)
261-
col3.metric(label="Overlapped Docs", value="%d" % common_docs_count)
274+
col3.metric(label="Overlapped Docs", value="%d" % duplicates_count)
262275
col4.metric(label="Contaminated Ratio", value="%.2f%%" % (contaminate_ratio * 100))
263276
with code:
264277
st.code(
265278
'''
266279
from nltk import ngrams
267-
def generate_ngrams(text, n=8):
268-
return set(ngrams(text.split(), n))
280+
from datasketch import MinHash, MinHashLSH
281+
282+
def process_data(df):
283+
minhashes = {}
284+
for idx, r in df.iterrows():
285+
minhash = MinHash(num_perm=128)
286+
for d in ngrams(r['text'], 13):
287+
s = "".join(d).encode('utf-8')
288+
minhash.update(s)
289+
minhashes[idx] = minhash
290+
return minhashes
291+
292+
train_minhashes = process_data(train_dataset)
293+
test_minhashes = process_data(test_dataset)
269294
270-
train_dataset['ngrams'] = train_dataset['text'].apply(generate_ngrams)
271-
test_dataset['ngrams'] = test_dataset['text'].apply(generate_ngrams)
295+
lsh = MinHashLSH(threshold=0.8, num_perm=128)
272296
273-
# Creating a set of n-grams in the train set
274-
train_ngrams = set.union(*train_dataset['ngrams'])
297+
for idx, minhash in train_minhashes.items():
298+
lsh.insert(idx, minhash)
275299
276-
# Creating a boolean mask marking documents in the test set that have appeared in the train set
277-
common_docs = test_dataset['ngrams'].apply(lambda x: not x.isdisjoint(train_ngrams))
278-
common_docs_count = common_docs.sum()
300+
duplicates_count = 0
301+
for idx, minhash in test_minhashes.items():
302+
result = lsh.query(minhash)
303+
if len(result) > 0:
304+
duplicates_count += 1
279305
280306
train_dataset_count = len(train_dataset)
281307
test_dataset_count = len(test_dataset)
282-
contaminate_ratio = common_docs / test_dataset_count
308+
contaminate_ratio = duplicates_count / test_dataset_count
283309
'''
284310
)
285311

0 commit comments

Comments
 (0)