Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/cli/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct whisper_params {
bool log_score = false;
bool use_gpu = true;
bool flash_attn = true;
int32_t gpu_device = 0;
bool suppress_nst = false;
bool carry_initial_prompt = false;

Expand Down Expand Up @@ -129,6 +130,10 @@ static char * requires_value_error(const std::string & arg) {
}

static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
if (const char * env_gpu_device = std::getenv("GPU_DEVICE")) {
params.gpu_device = std::stoi(env_gpu_device);
}

for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

Expand Down Expand Up @@ -195,6 +200,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-g" || arg == "--gpu-device") { params.gpu_device = std::stoi(ARGV_NEXT); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should use --device here to be consistent with llama.cpp.

else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
Expand Down Expand Up @@ -276,6 +282,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -g N, --gpu-device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device);
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
Expand Down Expand Up @@ -1003,6 +1010,7 @@ int main(int argc, char ** argv) {
struct whisper_context_params cparams = whisper_context_default_params();

cparams.use_gpu = params.use_gpu;
cparams.gpu_device = params.gpu_device;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
Expand Down
8 changes: 8 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct whisper_params {
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = true;
int32_t gpu_device = 0;
bool suppress_nst = false;
bool no_context = true;
bool no_language_probabilities = false;
Expand Down Expand Up @@ -177,6 +178,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold);
fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -g N, --gpu-device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device);
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true");
fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false");
Expand All @@ -196,6 +198,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
}

bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
if (const char * env_gpu_device = std::getenv("GPU_DEVICE")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And perhaps this should be WHISPER_ARG_DEVICE to be consistent with llama.cpp.

params.gpu_device = std::stoi(env_gpu_device);
}

for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

Expand Down Expand Up @@ -235,6 +241,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-g" || arg == "--gpu-device") { params.gpu_device = std::stoi(argv[++i]); }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; }
else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
Expand Down Expand Up @@ -638,6 +645,7 @@ int main(int argc, char ** argv) {
struct whisper_context_params cparams = whisper_context_default_params();

cparams.use_gpu = params.use_gpu;
cparams.gpu_device = params.gpu_device;
cparams.flash_attn = params.flash_attn;

if (!params.dtw.empty()) {
Expand Down
Loading