Skip to content

Commit a915000

Browse files
committed
Adding possiblilty to output directly from network
1 parent 7cd01b1 commit a915000

File tree

3 files changed

+45
-18
lines changed

3 files changed

+45
-18
lines changed

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,18 @@ The flags require a value to be passed as the following argument.
5252
./net datafile -lr 0.03
5353

5454
The following flags are available:
55-
-r : read a previously trained network, the name of which is currently configured to be 'lstm_net.net'.
56-
-lr: learning rate that is to be used during training, see the example above.
57-
-it: the number of iterations used for training (not to be confused with epochs).
58-
-mb: mini batch size.
59-
-dl: decrease the learning rate over time, according to lr(n+1) <- lr(n) / (1 + n/value).
60-
-st: number of iterations between how the network is continously stored during training (.json and .net).
55+
-r : read a previously trained network, the name of which is currently configured to be 'lstm_net.net'.
56+
-lr : learning rate that is to be used during training, see the example above.
57+
-it : the number of iterations used for training (not to be confused with epochs).
58+
-mb : mini batch size.
59+
-dl : decrease the learning rate over time, according to lr(n+1) <- lr(n) / (1 + n/value).
60+
-st : number of iterations between how the network is continously stored during training (.json and .net).
61+
-out: number of characters to output directly, note: a network and a datafile must be provided.
6162

6263
Check std_conf.h to see what default values are used, these are set during compilation.
6364

64-
./net compiled Feb 14 2019 13:41:44
65+
./net compiled Feb 14 2019 14:41:42
66+
6567
</pre>
6668

6769
The -st flags is great. Per default the network is stored upon interrupting the program with Ctrl-C. But using this argument, you can let the program train and have it store the network continously during the training process.

lstm.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ void lstm_store_net_layers(lstm_model_t** model, const char * filename, int laye
804804
fp = fopen(filename, "w");
805805

806806
if ( fp == NULL ) {
807-
printf("Failed to open file: %s for writing.\n", filename);
807+
fprintf(stderr, "Failed to open file: %s for writing.\n", filename);
808808
return;
809809
}
810810

@@ -898,7 +898,7 @@ void lstm_read_net_layers(lstm_model_t** model, const char * filename, int layer
898898

899899
if ( fp == NULL ) {
900900
printf("Failed to open file: %s for reading.\n", filename);
901-
return;
901+
exit(1);
902902
}
903903

904904
while ( p < layers ) {
@@ -918,7 +918,6 @@ void lstm_read_net_layers(lstm_model_t** model, const char * filename, int layer
918918
++p;
919919
}
920920

921-
printf("Loaded the net: %s\n", filename);
922921
fclose(fp);
923922
}
924923

main.c

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ lstm_model_t *model = NULL, *layer1 = NULL, *layer2 = NULL;
1717
lstm_model_t **model_layers;
1818
set_T set;
1919

20+
static int write_output_directly_bytes = 0;
21+
static char *read_network = NULL;
22+
2023
void store_the_net_layers(int signo)
2124
{
2225
if ( model_layers != NULL ){
@@ -44,12 +47,13 @@ void usage(char *argv[]) {
4447
printf(" %s datafile -lr 0.03\n", argv[0]);
4548
printf("\n");
4649
printf("The following flags are available:\n");
47-
printf(" -r : read a previously trained network, the name of which is currently configured to be '%s'.\n", STD_LOADABLE_NET_NAME);
48-
printf(" -lr: learning rate that is to be used during training, see the example above.\n");
49-
printf(" -it: the number of iterations used for training (not to be confused with epochs).\n");
50-
printf(" -mb: mini batch size.\n");
51-
printf(" -dl: decrease the learning rate over time, according to lr(n+1) <- lr(n) / (1 + n/value).\n");
52-
printf(" -st: number of iterations between how the network is continously stored during training (.json and .net).\n");
50+
printf(" -r : read a previously trained network, the name of which is currently configured to be '%s'.\n", STD_LOADABLE_NET_NAME);
51+
printf(" -lr : learning rate that is to be used during training, see the example above.\n");
52+
printf(" -it : the number of iterations used for training (not to be confused with epochs).\n");
53+
printf(" -mb : mini batch size.\n");
54+
printf(" -dl : decrease the learning rate over time, according to lr(n+1) <- lr(n) / (1 + n/value).\n");
55+
printf(" -st : number of iterations between how the network is continously stored during training (.json and .net).\n");
56+
printf(" -out: number of characters to output directly, note: a network and a datafile must be provided.\n");
5357
printf("\n");
5458
printf("Check std_conf.h to see what default values are used, these are set during compilation.\n");
5559
printf("\n");
@@ -67,7 +71,7 @@ void parse_input_args(int argc, char** argv, lstm_model_parameters_t* params)
6771
break; // All flags have values attributed to them
6872

6973
if ( !strcmp(argv[a], "-r") ) {
70-
lstm_read_net_layers(model_layers, argv[a + 1], LAYERS);
74+
read_network = argv[a + 1];
7175
} else if ( !strcmp(argv[a], "-lr") ) {
7276
params->learning_rate = atof(argv[a + 1]);
7377
if ( params->learning_rate == 0.0 ) {
@@ -94,6 +98,11 @@ void parse_input_args(int argc, char** argv, lstm_model_parameters_t* params)
9498
if ( params->store_network_every == 0 ) {
9599
usage(argv);
96100
}
101+
} else if ( !strcmp(argv[a], "-out") ) {
102+
write_output_directly_bytes = atoi(argv[a+1]);
103+
if ( write_output_directly_bytes <= 0 ) {
104+
usage(argv);
105+
}
97106
}
98107

99108
a += 2;
@@ -164,7 +173,7 @@ int main(int argc, char *argv[])
164173
set_insert_symbol(&set, (char)c );
165174
++file_size;
166175
}
167-
set_insert_symbol(&set, '.');
176+
168177
fclose(fp);
169178

170179
X_train = calloc(file_size+1, sizeof(int));
@@ -197,6 +206,23 @@ int main(int argc, char *argv[])
197206

198207
parse_input_args(argc, argv, &params);
199208

209+
if ( write_output_directly_bytes && read_network != NULL ) {
210+
211+
if ( read_network != NULL )
212+
lstm_read_net_layers(model_layers, read_network, LAYERS);
213+
214+
lstm_output_string_layers(model_layers, &set, set_indx_to_char(&set, 0), write_output_directly_bytes, LAYERS);
215+
216+
free(model_layers);
217+
free(X_train);
218+
return 0;
219+
} else if ( write_output_directly_bytes && read_network == NULL ) {
220+
usage(argv);
221+
} else if ( read_network != NULL ) {
222+
lstm_read_net_layers(model_layers, read_network, LAYERS);
223+
printf("Loaded the net: %s\n", read_network);
224+
}
225+
200226
if ( argc >= 6 && !strcmp(argv[4], "-c") ) {
201227
do {
202228
clean = strchr(argv[5], '_');

0 commit comments

Comments
 (0)