diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 9a54742fe1d..5f510ba50d8 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -77,6 +77,7 @@ struct whisper_params { bool log_score = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool carry_initial_prompt = false; @@ -129,6 +130,10 @@ static char * requires_value_error(const std::string & arg) { } static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + if (const char * env_gpu_device = std::getenv("GPU_DEVICE")) { + params.gpu_device = std::stoi(env_gpu_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -195,6 +200,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-g" || arg == "--gpu-device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -276,6 +282,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -g N, --gpu-device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); @@ -1003,6 +1010,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1d49aa3be52..2199ab9bead 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -102,6 +102,7 @@ struct whisper_params { bool no_timestamps = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool no_context = true; bool no_language_probabilities = false; @@ -177,6 +178,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -g N, --gpu-device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); @@ -196,6 +198,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para } bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) { + if (const char * env_gpu_device = std::getenv("GPU_DEVICE")) { + params.gpu_device = std::stoi(env_gpu_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -235,6 +241,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-g" || arg == "--gpu-device") { params.gpu_device = std::stoi(argv[++i]); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -638,6 +645,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) {