Skip to content

Commit 5cd74b5

Browse files
authored
Add methods to convert words to id
2 parents 933bb5d + 9483cb4 commit 5cd74b5

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:
44
# command to install dependencies
55
install:
66
- pip install -r requirements.txt
7-
- sudo mv data/nltk_data /usr/share/nltk_data
7+
- sudo mv nltk_data /usr/share/nltk_data
88
# command to run tests
99
script:
1010
- cd streampredictor

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
nltk==3.2.2
2-
numpy==1.12.0
3-
protobuf==3.2.0
4-
matplotlib==2.0.0
5-
progressbar2==3.12.0
1+
nltk
2+
numpy
3+
protobuf
4+
matplotlib
5+
progressbar

streampredictor/DataObtainer.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,18 @@ def get_clean_words_from_file(file, max_input_length):
4444
text = opened_file.read()
4545
return nltk.word_tokenize(clean_text(text))[:max_input_length]
4646

47+
def get_words_from_ptb(file, max_input_length):
48+
with open(file) as opened_file:
49+
text = opened_file.read().replace('\n', '')
50+
return text.split(' ')[:max_input_length]
4751

4852
def clean_text(text, max_input_length=10**10000):
4953
text = text.replace('\n', ' ')
5054
max_length = min(max_input_length, len(text))
5155
rotation = random.randint(0,max_length)
5256
text = text[rotation:max_length] + text[:rotation]
5357
# make sure to remove # for category separation
54-
text = ''.join(e for e in text if e.isalnum() or e in '.?", ')
58+
text = ''.join(e for e in text if e.isalnum() or e in '.?", <>')
5559
return text
5660

5761

@@ -60,7 +64,18 @@ def get_online_words(max_input_length):
6064
words = nltk.word_tokenize(clean_text(text, max_input_length))
6165
return words
6266

67+
def convert_words_to_id(words):
68+
"""
69+
Converts words list to id list and returns id sequence, word2id and id2word dictionary.
70+
"""
71+
unique_words = list(set(words))
72+
id2word = dict((id,word) for id,word in enumerate(unique_words))
73+
word2id = dict((i,j) for j,i in id2word.iteritems())
74+
id_sequence = [word2id[word] for word in words]
75+
return id_sequence, word2id, id2word
6376

6477
if __name__ == '__main__':
65-
text = get_random_book_local('../data')
66-
print(text)
78+
words = get_words_from_ptb('../data/ptb.test.txt', max_input_length=100)
79+
print(words)
80+
seq, word2id, id2word = convert_words_to_id(words)
81+
print(seq)

streampredictor/test_word_predictor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest import TestCase
22

33
from . import WordPredictor
4+
from . import DataObtainer
45

56
training_text = 'cat hat mat bat sat in the barn'
67
words = training_text.split(' ')
@@ -18,3 +19,12 @@ def test_generates_sample(self):
1819
wp.train(words)
1920
generated_text = wp.generate(5)
2021
self.assertGreater(len(generated_text.split(' ')), 5)
22+
23+
24+
class TestDataObtainer(TestCase):
25+
def test_convert_word_2_id(self):
26+
test_words = ['aaa', 'bbb', 'ccc', 'aaa']
27+
seq, word2id, id2word = DataObtainer.convert_words_to_id(test_words)
28+
self.assertEqual([0,1,2,0], seq)
29+
for word in test_words:
30+
self.assertEqual(word, id2word[word2id[word]])

0 commit comments

Comments
 (0)