Skip to content

Commit 13628d8

Browse files
authored
server: add --media-path for local media files (ggml-org#17697)
* server: add --media-path for local media files * remove unused fn
1 parent a96283a commit 13628d8

File tree

9 files changed

+133
-38
lines changed

9 files changed

+133
-38
lines changed

common/arg.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,12 +2488,29 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24882488
"path to save slot kv cache (default: disabled)",
24892489
[](common_params & params, const std::string & value) {
24902490
params.slot_save_path = value;
2491+
if (!fs_is_directory(params.slot_save_path)) {
2492+
throw std::invalid_argument("not a directory: " + value);
2493+
}
24912494
// if doesn't end with DIRECTORY_SEPARATOR, add it
24922495
if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
24932496
params.slot_save_path += DIRECTORY_SEPARATOR;
24942497
}
24952498
}
24962499
).set_examples({LLAMA_EXAMPLE_SERVER}));
2500+
add_opt(common_arg(
2501+
{"--media-path"}, "PATH",
2502+
"directory for loading local media files; files can be accessed via file:// URLs using relative paths (default: disabled)",
2503+
[](common_params & params, const std::string & value) {
2504+
params.media_path = value;
2505+
if (!fs_is_directory(params.media_path)) {
2506+
throw std::invalid_argument("not a directory: " + value);
2507+
}
2508+
// if doesn't end with DIRECTORY_SEPARATOR, add it
2509+
if (!params.media_path.empty() && params.media_path[params.media_path.size() - 1] != DIRECTORY_SEPARATOR) {
2510+
params.media_path += DIRECTORY_SEPARATOR;
2511+
}
2512+
}
2513+
).set_examples({LLAMA_EXAMPLE_SERVER}));
24972514
add_opt(common_arg(
24982515
{"--models-dir"}, "PATH",
24992516
"directory containing models for the router server (default: disabled)",

common/common.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
694694

695695
// Validate if a filename is safe to use
696696
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
697-
bool fs_validate_filename(const std::string & filename) {
697+
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
698698
if (!filename.length()) {
699699
// Empty filename invalid
700700
return false;
@@ -754,10 +754,14 @@ bool fs_validate_filename(const std::string & filename) {
754754
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
755755
|| c == 0xFFFD // Replacement Character (UTF-8)
756756
|| c == 0xFEFF // Byte Order Mark (BOM)
757-
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
757+
|| c == ':' || c == '*' // Illegal characters
758758
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
759759
return false;
760760
}
761+
if (!allow_subdirs && (c == '/' || c == '\\')) {
762+
// Subdirectories not allowed, reject path separators
763+
return false;
764+
}
761765
}
762766

763767
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -859,6 +863,11 @@ bool fs_create_directory_with_parents(const std::string & path) {
859863
#endif // _WIN32
860864
}
861865

866+
bool fs_is_directory(const std::string & path) {
867+
std::filesystem::path dir(path);
868+
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
869+
}
870+
862871
std::string fs_get_cache_directory() {
863872
std::string cache_directory = "";
864873
auto ensure_trailing_slash = [](std::string p) {

common/common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ struct common_params {
485485
bool log_json = false;
486486

487487
std::string slot_save_path;
488+
std::string media_path; // path to directory for loading media files
488489

489490
float slot_prompt_similarity = 0.1f;
490491

@@ -635,8 +636,9 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
635636
// Filesystem utils
636637
//
637638

638-
bool fs_validate_filename(const std::string & filename);
639+
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
639640
bool fs_create_directory_with_parents(const std::string & path);
641+
bool fs_is_directory(const std::string & path);
640642

641643
std::string fs_get_cache_directory();
642644
std::string fs_get_cache_file(const std::string & filename);

tools/server/server-common.cpp

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <random>
1313
#include <sstream>
14+
#include <fstream>
1415

1516
json format_error_response(const std::string & message, const enum error_type type) {
1617
std::string type_str;
@@ -774,6 +775,65 @@ json oaicompat_completion_params_parse(const json & body) {
774775
return llama_params;
775776
}
776777

778+
// media_path always end with '/', see arg.cpp
779+
static void handle_media(
780+
std::vector<raw_buffer> & out_files,
781+
json & media_obj,
782+
const std::string & media_path) {
783+
std::string url = json_value(media_obj, "url", std::string());
784+
if (string_starts_with(url, "http")) {
785+
// download remote image
786+
// TODO @ngxson : maybe make these params configurable
787+
common_remote_params params;
788+
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
789+
params.max_size = 1024 * 1024 * 10; // 10MB
790+
params.timeout = 10; // seconds
791+
SRV_INF("downloading image from '%s'\n", url.c_str());
792+
auto res = common_remote_get_content(url, params);
793+
if (200 <= res.first && res.first < 300) {
794+
SRV_INF("downloaded %ld bytes\n", res.second.size());
795+
raw_buffer data;
796+
data.insert(data.end(), res.second.begin(), res.second.end());
797+
out_files.push_back(data);
798+
} else {
799+
throw std::runtime_error("Failed to download image");
800+
}
801+
802+
} else if (string_starts_with(url, "file://")) {
803+
if (media_path.empty()) {
804+
throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified");
805+
}
806+
// load local image file
807+
std::string file_path = url.substr(7); // remove "file://"
808+
raw_buffer data;
809+
if (!fs_validate_filename(file_path, true)) {
810+
throw std::invalid_argument("file path is not allowed: " + file_path);
811+
}
812+
SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str());
813+
std::ifstream file(media_path + file_path, std::ios::binary);
814+
if (!file) {
815+
throw std::invalid_argument("file does not exist or cannot be opened: " + file_path);
816+
}
817+
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
818+
out_files.push_back(data);
819+
820+
} else {
821+
// try to decode base64 image
822+
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
823+
if (parts.size() != 2) {
824+
throw std::runtime_error("Invalid url value");
825+
} else if (!string_starts_with(parts[0], "data:image/")) {
826+
throw std::runtime_error("Invalid url format: " + parts[0]);
827+
} else if (!string_ends_with(parts[0], "base64")) {
828+
throw std::runtime_error("url must be base64 encoded");
829+
} else {
830+
auto base64_data = parts[1];
831+
auto decoded_data = base64_decode(base64_data);
832+
out_files.push_back(decoded_data);
833+
}
834+
}
835+
}
836+
777837
// used by /chat/completions endpoint
778838
json oaicompat_chat_params_parse(
779839
json & body, /* openai api json semantics */
@@ -860,41 +920,8 @@ json oaicompat_chat_params_parse(
860920
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
861921
}
862922

863-
json image_url = json_value(p, "image_url", json::object());
864-
std::string url = json_value(image_url, "url", std::string());
865-
if (string_starts_with(url, "http")) {
866-
// download remote image
867-
// TODO @ngxson : maybe make these params configurable
868-
common_remote_params params;
869-
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
870-
params.max_size = 1024 * 1024 * 10; // 10MB
871-
params.timeout = 10; // seconds
872-
SRV_INF("downloading image from '%s'\n", url.c_str());
873-
auto res = common_remote_get_content(url, params);
874-
if (200 <= res.first && res.first < 300) {
875-
SRV_INF("downloaded %ld bytes\n", res.second.size());
876-
raw_buffer data;
877-
data.insert(data.end(), res.second.begin(), res.second.end());
878-
out_files.push_back(data);
879-
} else {
880-
throw std::runtime_error("Failed to download image");
881-
}
882-
883-
} else {
884-
// try to decode base64 image
885-
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
886-
if (parts.size() != 2) {
887-
throw std::invalid_argument("Invalid image_url.url value");
888-
} else if (!string_starts_with(parts[0], "data:image/")) {
889-
throw std::invalid_argument("Invalid image_url.url format: " + parts[0]);
890-
} else if (!string_ends_with(parts[0], "base64")) {
891-
throw std::invalid_argument("image_url.url must be base64 encoded");
892-
} else {
893-
auto base64_data = parts[1];
894-
auto decoded_data = base64_decode(base64_data);
895-
out_files.push_back(decoded_data);
896-
}
897-
}
923+
json image_url = json_value(p, "image_url", json::object());
924+
handle_media(out_files, image_url, opt.media_path);
898925

899926
// replace this chunk with a marker
900927
p["type"] = "text";
@@ -916,6 +943,8 @@ json oaicompat_chat_params_parse(
916943
auto decoded_data = base64_decode(data); // expected to be base64 encoded
917944
out_files.push_back(decoded_data);
918945

946+
// TODO: add audio_url support by reusing handle_media()
947+
919948
// replace this chunk with a marker
920949
p["type"] = "text";
921950
p["text"] = mtmd_default_marker();

tools/server/server-common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ struct oaicompat_parser_options {
284284
bool allow_image;
285285
bool allow_audio;
286286
bool enable_thinking = true;
287+
std::string media_path;
287288
};
288289

289290
// used by /chat/completions endpoint

tools/server/server-context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ struct server_context_impl {
788788
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
789789
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
790790
/* enable_thinking */ enable_thinking,
791+
/* media_path */ params_base.media_path,
791792
};
792793

793794
// print sample chat example to make it clear which template is used

tools/server/server.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
3838
try {
3939
return func(req);
4040
} catch (const std::invalid_argument & e) {
41+
// treat invalid_argument as invalid request (400)
4142
error = ERROR_TYPE_INVALID_REQUEST;
4243
message = e.what();
4344
} catch (const std::exception & e) {
45+
// treat other exceptions as server error (500)
4446
error = ERROR_TYPE_SERVER;
4547
message = e.what();
4648
} catch (...) {

tools/server/tests/unit/test_security.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,34 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
9494
assert res.status_code == 200
9595
assert cors_header in res.headers
9696
assert res.headers[cors_header] == cors_header_value
97+
98+
99+
@pytest.mark.parametrize(
100+
"media_path, image_url, success",
101+
[
102+
(None, "file://mtmd/test-1.jpeg", False), # disabled media path, should fail
103+
("../../../tools", "file://mtmd/test-1.jpeg", True),
104+
("../../../tools", "file:////mtmd//test-1.jpeg", True), # should be the same file as above
105+
("../../../tools", "file://mtmd/notfound.jpeg", False), # non-existent file
106+
("../../../tools", "file://../mtmd/test-1.jpeg", False), # no directory traversal
107+
]
108+
)
109+
def test_local_media_file(media_path, image_url, success,):
110+
server = ServerPreset.tinygemma3()
111+
server.media_path = media_path
112+
server.start()
113+
res = server.make_request("POST", "/chat/completions", data={
114+
"max_tokens": 1,
115+
"messages": [
116+
{"role": "user", "content": [
117+
{"type": "text", "text": "test"},
118+
{"type": "image_url", "image_url": {
119+
"url": image_url,
120+
}},
121+
]},
122+
],
123+
})
124+
if success:
125+
assert res.status_code == 200
126+
else:
127+
assert res.status_code == 400

tools/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ServerProcess:
9595
chat_template_file: str | None = None
9696
server_path: str | None = None
9797
mmproj_url: str | None = None
98+
media_path: str | None = None
9899

99100
# session variables
100101
process: subprocess.Popen | None = None
@@ -217,6 +218,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
217218
server_args.extend(["--chat-template-file", self.chat_template_file])
218219
if self.mmproj_url:
219220
server_args.extend(["--mmproj-url", self.mmproj_url])
221+
if self.media_path:
222+
server_args.extend(["--media-path", self.media_path])
220223

221224
args = [str(arg) for arg in [server_path, *server_args]]
222225
print(f"tests: starting server with: {' '.join(args)}")

0 commit comments

Comments
 (0)