Skip to content

Commit da8b3a5

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 14e10f5 + 507f3a4 commit da8b3a5

23 files changed

+1086
-548
lines changed

common/common.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,30 @@ static std::string parse_device_list(const std::string& value) {
270270
return value;
271271
}
272272

273+
static std::string add_rpc_devices(std::string& servers) {
274+
std::string rpc_devices;
275+
#ifdef GGML_USE_RPC
276+
std::vector<std::string> rpc_servers = string_split(servers, ",");
277+
if (rpc_servers.empty()) {
278+
throw std::invalid_argument("no RPC servers specified");
279+
}
280+
for (auto& server : rpc_servers) {
281+
uint32_t dev_count = ggml_backend_rpc_get_device_count(server.c_str());
282+
uint32_t device = 0;
283+
for (uint32_t i = 0; i < dev_count; ++i) {
284+
const auto buft = ggml_backend_rpc_buffer_type(server.c_str(), device);
285+
if (buft != nullptr) {
286+
rpc_devices = rpc_devices + server + "|" + std::to_string(device) + ",";
287+
++device;
288+
}
289+
}
290+
}
291+
if (!rpc_devices.empty()) {
292+
rpc_devices = rpc_devices.substr(0, rpc_devices.size() - 1); // remove trailing comma
293+
}
294+
#endif
295+
return rpc_devices;
296+
}
273297

274298
std::pair<long, std::vector<char>> common_remote_get_content(const std::string& url, const common_remote_params&) {
275299
if (!url.empty()) {
@@ -1296,15 +1320,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
12961320
if (arg == "--rpc") {
12971321
CHECK_ARG
12981322
#ifdef GGML_USE_RPC
1299-
params.rpc_servers = argv[i];
1300-
std::string servers(params.rpc_servers);
1301-
size_t pos = 0;
1302-
while ((pos = servers.find(",")) != std::string::npos) {
1303-
std::string server = servers.substr(0, pos);
1304-
ggml_backend_rpc_buffer_type(server.c_str());
1305-
servers.erase(0, pos + 1);
1323+
std::string servers(argv[i]);
1324+
servers = add_rpc_devices(servers);
1325+
if (servers.empty()) {
1326+
return false;
13061327
}
1307-
ggml_backend_rpc_buffer_type(servers.c_str());
1328+
params.rpc_servers = servers;
13081329
#endif
13091330
return true;
13101331
}
@@ -1319,10 +1340,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
13191340
}
13201341
if (arg == "--override-tensor" || arg == "-ot") {
13211342
CHECK_ARG
1322-
/*for (auto endpoint : params.rpc_servers.split)
1323-
{
1324-
1325-
}*/
13261343
if (!parse_buft_overrides(std::string{ argv[i] }, params.tensor_buft_overrides)) {
13271344
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
13281345
invalid_param = true;

common/grammar-parser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ namespace grammar_parser {
369369
}
370370
// Validate the state to ensure that all rules are defined
371371
for (const auto & rule : state.rules) {
372+
if (rule.empty()) {
373+
throw std::runtime_error("Undefined rule");
374+
}
372375
for (const auto & elem : rule) {
373376
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
374377
// Ensure that the rule at that location exists

common/json-schema-to-grammar.cpp

Lines changed: 60 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "json-schema-to-grammar.h"
2+
#include "common.h"
23
#include <algorithm>
34
#include <fstream>
45
#include <map>
@@ -19,6 +20,9 @@ static std::string repeat(const std::string & str, size_t n);
1920
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
2021
auto has_max = max_items != std::numeric_limits<int>::max();
2122

23+
if (max_items == 0) {
24+
return "";
25+
}
2226
if (min_items == 0 && max_items == 1) {
2327
return item_rule + "?";
2428
}
@@ -40,52 +44,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
4044
return result;
4145
}
4246

43-
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
44-
class string_view {
45-
const std::string & _str;
46-
const size_t _start;
47-
const size_t _end;
48-
public:
49-
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
50-
51-
size_t size() const {
52-
return _end - _start;
53-
}
54-
55-
size_t length() const {
56-
return size();
57-
}
58-
59-
operator std::string() const {
60-
return str();
61-
}
62-
63-
std::string str() const {
64-
return _str.substr(_start, _end - _start);
65-
}
66-
67-
string_view substr(size_t pos, size_t len = std::string::npos) const {
68-
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
69-
}
70-
71-
char operator[](size_t pos) const {
72-
auto index = _start + pos;
73-
if (index >= _end) {
74-
throw std::out_of_range("string_view index out of range");
75-
}
76-
return _str[_start + pos];
77-
}
78-
79-
bool operator==(const string_view & other) const {
80-
std::string this_str = *this;
81-
std::string other_str = other;
82-
return this_str == other_str;
83-
}
84-
};
85-
86-
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
87-
auto has_min = min_value != std::numeric_limits<int>::min();
88-
auto has_max = max_value != std::numeric_limits<int>::max();
47+
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
48+
auto has_min = min_value != std::numeric_limits<int64_t>::min();
49+
auto has_max = max_value != std::numeric_limits<int64_t>::max();
8950

9051
auto digit_range = [&](char from, char to) {
9152
out << "[";
@@ -111,14 +72,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
11172
}
11273
out << "}";
11374
};
114-
std::function<void(const string_view &, const string_view &)> uniform_range =
115-
[&](const string_view & from, const string_view & to) {
75+
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
76+
[&](const std::string_view & from, const std::string_view & to) {
11677
size_t i = 0;
11778
while (i < from.length() && i < to.length() && from[i] == to[i]) {
11879
i++;
11980
}
12081
if (i > 0) {
121-
out << "\"" << from.substr(0, i).str() << "\"";
82+
out << "\"" << from.substr(0, i) << "\"";
12283
}
12384
if (i < from.length() && i < to.length()) {
12485
if (i > 0) {
@@ -201,7 +162,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
201162
if (has_min) {
202163
if (min_value < 0) {
203164
out << "\"-\" (";
204-
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
165+
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
205166
out << ") | [0] | [1-9] ";
206167
more_digits(0, decimals_left - 1);
207168
} else if (min_value == 0) {
@@ -236,7 +197,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
236197
}
237198
digit_range(c, c);
238199
out << " (";
239-
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
200+
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
240201
out << ")";
241202
if (c < '9') {
242203
out << " | ";
@@ -258,7 +219,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
258219
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
259220
} else {
260221
out << "\"-\" (";
261-
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
222+
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
262223
out << ")";
263224
}
264225
return;
@@ -615,7 +576,7 @@ class SchemaConverter {
615576
}
616577
return join_seq();
617578
};
618-
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
579+
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
619580
}
620581

621582
/*
@@ -688,7 +649,10 @@ class SchemaConverter {
688649
}
689650

690651
std::string _resolve_ref(const std::string & ref) {
691-
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
652+
auto it = ref.find('#');
653+
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
654+
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
655+
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
692656
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
693657
_refs_being_resolved.insert(ref);
694658
json resolved = _refs[ref];
@@ -861,11 +825,24 @@ class SchemaConverter {
861825
std::vector<std::string> tokens = split(pointer, "/");
862826
for (size_t i = 1; i < tokens.size(); ++i) {
863827
std::string sel = tokens[i];
864-
if (target.is_null() || !target.contains(sel)) {
828+
if (target.is_object() && target.contains(sel)) {
829+
target = target[sel];
830+
} else if (target.is_array()) {
831+
size_t sel_index;
832+
try {
833+
sel_index = std::stoul(sel);
834+
} catch (const std::invalid_argument & e) {
835+
sel_index = target.size();
836+
}
837+
if (sel_index >= target.size()) {
838+
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
839+
return;
840+
}
841+
target = target[sel_index];
842+
} else {
865843
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
866844
return;
867845
}
868-
target = target[sel];
869846
}
870847
_refs[ref] = target;
871848
}
@@ -931,9 +908,10 @@ class SchemaConverter {
931908
_build_object_rule(
932909
properties, required, name,
933910
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
934-
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
911+
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
935912
std::unordered_set<std::string> required;
936913
std::vector<std::pair<std::string, json>> properties;
914+
std::map<std::string, size_t> enum_values;
937915
std::string hybrid_name = name;
938916
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
939917
if (comp_schema.contains("$ref")) {
@@ -945,6 +923,14 @@ class SchemaConverter {
945923
required.insert(prop.key());
946924
}
947925
}
926+
} else if (comp_schema.contains("enum")) {
927+
for (const auto & v : comp_schema["enum"]) {
928+
const auto rule = _generate_constant_rule(v);
929+
if (enum_values.find(rule) == enum_values.end()) {
930+
enum_values[rule] = 0;
931+
}
932+
enum_values[rule] += 1;
933+
}
948934
} else {
949935
// todo warning
950936
}
@@ -958,6 +944,17 @@ class SchemaConverter {
958944
add_component(t, true);
959945
}
960946
}
947+
if (!enum_values.empty()) {
948+
std::vector<std::string> enum_intersection;
949+
for (const auto & p : enum_values) {
950+
if (p.second == schema["allOf"].size()) {
951+
enum_intersection.push_back(p.first);
952+
}
953+
}
954+
if (!enum_intersection.empty()) {
955+
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
956+
}
957+
}
961958
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
962959
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
963960
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
@@ -992,17 +989,17 @@ class SchemaConverter {
992989
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
993990
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
994991
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
995-
int min_value = std::numeric_limits<int>::min();
996-
int max_value = std::numeric_limits<int>::max();
992+
int64_t min_value = std::numeric_limits<int64_t>::min();
993+
int64_t max_value = std::numeric_limits<int64_t>::max();
997994
if (schema.contains("minimum")) {
998-
min_value = schema["minimum"].get<int>();
995+
min_value = schema["minimum"].get<int64_t>();
999996
} else if (schema.contains("exclusiveMinimum")) {
1000-
min_value = schema["exclusiveMinimum"].get<int>() + 1;
997+
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
1001998
}
1002999
if (schema.contains("maximum")) {
1003-
max_value = schema["maximum"].get<int>();
1000+
max_value = schema["maximum"].get<int64_t>();
10041001
} else if (schema.contains("exclusiveMaximum")) {
1005-
max_value = schema["exclusiveMaximum"].get<int>() - 1;
1002+
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
10061003
}
10071004
std::stringstream out;
10081005
out << "(";

common/sampling.cpp

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
2222
#endif // LLAMA_USE_LLGUIDANCE
2323
}
2424
else {
25-
2625
std::vector<std::string> trigger_patterns;
2726
std::vector<std::string> patterns_anywhere;
2827
std::vector<llama_token> trigger_tokens;
@@ -70,30 +69,34 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vo
7069
trigger_tokens.data(), trigger_tokens.size())
7170
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
7271

73-
// if there is a grammar, parse it
74-
if (!params.grammar.empty()) {
75-
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
72+
//if (!grmr) {
73+
// return nullptr;
74+
//}
75+
76+
// if there is a grammar, parse it
77+
if (!params.grammar.empty()) {
78+
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
7679
if (result->parsed_grammar.success) {
77-
// will be empty (default) if there are parse errors
78-
if (result->parsed_grammar.rules.empty()) {
79-
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
80-
delete result;
81-
return nullptr;
82-
}
80+
// will be empty (default) if there are parse errors
81+
if (result->parsed_grammar.rules.empty()) {
82+
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
83+
delete result;
84+
return nullptr;
85+
}
8386

84-
// Ensure that there is a "root" node.
85-
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
86-
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
87-
delete result;
88-
return nullptr;
89-
}
87+
// Ensure that there is a "root" node.
88+
if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
89+
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
90+
delete result;
91+
return nullptr;
92+
}
9093
if (grmr == nullptr) {
91-
throw std::runtime_error("Failed to initialize llama_grammar");
92-
}
93-
}
94+
throw std::runtime_error("Failed to initialize llama_grammar");
95+
}
96+
}
9497
}
95-
result->prev.resize(params.n_prev);
96-
result->n_valid = 0;
98+
result->prev.resize(params.n_prev);
99+
result->n_valid = 0;
97100
}
98101
result->grammar = grmr;
99102
// init DRY

examples/gbnf-validator/gbnf-validator.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,14 @@
1313
#include <vector>
1414

1515
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
16-
auto decoded = decode_utf8(input_str, {});
17-
const auto & code_points = decoded.first;
18-
19-
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
20-
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
21-
16+
const auto cpts = unicode_cpts_from_utf8(input_str);
17+
auto& cur_stacks = llama_grammar_get_stacks(grammar);
2218
size_t pos = 0;
23-
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
24-
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
25-
26-
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
27-
19+
for (const auto& cpt : cpts) {
20+
llama_grammar_accept(grammar, cpt);
2821
if (cur_stacks.empty()) {
2922
error_pos = pos;
30-
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
31-
cur_stacks = prev_stacks;
23+
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
3224
return false;
3325
}
3426
++pos;

0 commit comments

Comments
 (0)