Skip to content

Commit 7733409

Browse files
authored
common: improve verbosity level definitions (ggml-org#17630)
* common: improve verbosity level definitions * string_format * update autogen docs
1 parent cd3c118 commit 7733409

File tree

6 files changed

+46
-19
lines changed

6 files changed

+46
-19
lines changed

common/arg.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2674,7 +2674,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26742674
).set_env("LLAMA_OFFLINE"));
26752675
add_opt(common_arg(
26762676
{"-lv", "--verbosity", "--log-verbosity"}, "N",
2677-
"Set the verbosity threshold. Messages with a higher verbosity will be ignored.",
2677+
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
2678+
" - 0: generic output\n"
2679+
" - 1: error\n"
2680+
" - 2: warning\n"
2681+
" - 3: info\n"
2682+
" - 4: debug\n"
2683+
"(default: %d)\n", params.verbosity),
26782684
[](common_params & params, int value) {
26792685
params.verbosity = value;
26802686
common_log_set_verbosity_thold(value);

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct common_params {
369369

370370
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
371371

372-
int32_t verbosity = 0;
372+
int32_t verbosity = 3; // LOG_LEVEL_INFO
373373
int32_t control_vector_layer_start = -1; // layer range for control vector
374374
int32_t control_vector_layer_end = -1; // layer range for control vector
375375
bool offline = false;

common/download.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
430430
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
431431
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
432432
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
433-
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L);
433+
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 0L);
434434
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
435435
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
436436
auto data_vec = static_cast<std::vector<char> *>(data);

common/log.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,22 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
443443
log->set_timestamps(timestamps);
444444
}
445445

446+
static int common_get_verbosity(enum ggml_log_level level) {
447+
switch (level) {
448+
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
449+
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
450+
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
451+
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
452+
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
453+
case GGML_LOG_LEVEL_NONE:
454+
default:
455+
return LOG_LEVEL_OUTPUT;
456+
}
457+
}
458+
446459
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
447-
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
460+
auto verbosity = common_get_verbosity(level);
461+
if (verbosity <= common_log_verbosity_thold) {
448462
common_log_add(common_log_main(), level, "%s", text);
449463
}
450464
}

common/log.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@
2121
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
2222
#endif
2323

24-
#define LOG_DEFAULT_DEBUG 1
25-
#define LOG_DEFAULT_LLAMA 0
24+
#define LOG_LEVEL_DEBUG 4
25+
#define LOG_LEVEL_INFO 3
26+
#define LOG_LEVEL_WARN 2
27+
#define LOG_LEVEL_ERROR 1
28+
#define LOG_LEVEL_OUTPUT 0 // output data from tools
29+
30+
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
31+
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
2632

2733
enum log_colors {
2834
LOG_COLORS_AUTO = -1,
@@ -67,10 +73,11 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
6773
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
6874
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
6975
//
70-
// I - info (stdout, V = 0)
71-
// W - warning (stderr, V = 0)
72-
// E - error (stderr, V = 0)
7376
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
77+
// I - info (stdout, V = LOG_DEFAULT_INFO)
78+
// W - warning (stderr, V = LOG_DEFAULT_WARN)
79+
// E - error (stderr, V = LOG_DEFAULT_ERROR)
80+
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
7481
//
7582

7683
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
@@ -95,14 +102,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
95102
} \
96103
} while (0)
97104

98-
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
99-
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
105+
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
106+
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
100107

101-
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
102-
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
103-
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
104-
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
105-
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
108+
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
109+
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
110+
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
111+
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
112+
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
106113

107114
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
108115
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)

tools/server/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ The project is under active development, and we are [looking for feedback and co
5252
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
5353
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
5454
| `--swa-full` | use full-size SWA cache (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)<br/>(env: LLAMA_ARG_SWA_FULL) |
55-
| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)<br/>(env: LLAMA_ARG_KV_SPLIT) |
55+
| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
5656
| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')<br/>(env: LLAMA_ARG_FLASH_ATTN) |
5757
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
5858
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
@@ -103,11 +103,11 @@ The project is under active development, and we are [looking for feedback and co
103103
| `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)<br/>(env: LLAMA_ARG_HF_FILE_V) |
104104
| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)<br/>(env: HF_TOKEN) |
105105
| `--log-disable` | Log disable |
106-
| `--log-file FNAME` | Log to file |
106+
| `--log-file FNAME` | Log to file<br/>(env: LLAMA_LOG_FILE) |
107107
| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')<br/>'auto' enables colors when output is to a terminal<br/>(env: LLAMA_LOG_COLORS) |
108108
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
109109
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
110-
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.<br/>(env: LLAMA_LOG_VERBOSITY) |
110+
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:<br/> - 0: generic output<br/> - 1: error<br/> - 2: warning<br/> - 3: info<br/> - 4: debug<br/>(default: 3)<br/><br/>(env: LLAMA_LOG_VERBOSITY) |
111111
| `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
112112
| `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
113113
| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |

0 commit comments

Comments
 (0)