Skip to content

Commit 8e025b6

Browse files
agentic form fixes (#123)
1 parent a6a1186 commit 8e025b6

File tree

2 files changed

+57
-38
lines changed

2 files changed

+57
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "solana-agent"
3-
version = "31.1.1"
3+
version = "31.1.2"
44
description = "AI Agents for Solana"
55
authors = ["Bevan Hunt <bevan@bevanhunt.com>"]
66
license = "MIT"

solana_agent/services/query.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,9 @@ async def process(
200200

201201
# 7) Captured data context + incremental save using previous assistant message
202202
capture_context = ""
203-
form_complete = False
203+
# Two completion flags:
204+
required_complete = False
205+
form_complete = False # required + optional
204206

205207
# Helpers
206208
def _non_empty(v: Any) -> bool:
@@ -215,7 +217,6 @@ def _non_empty(v: Any) -> bool:
215217

216218
def _parse_numbers_list(s: str) -> List[str]:
217219
nums = re.findall(r"\b(\d+)\b", s)
218-
# dedupe keep order
219220
seen, out = set(), []
220221
for n in nums:
221222
if n not in seen:
@@ -225,23 +226,17 @@ def _parse_numbers_list(s: str) -> List[str]:
225226

226227
def _extract_numbered_options(text: str) -> Dict[str, str]:
227228
"""Parse previous assistant message for lines like:
228-
'1) Foo', '2. Bar', '- 3) Baz', '* 4. Buzz'
229-
Returns mapping '1' -> 'Foo', etc.
230-
"""
229+
'1) Foo', '1. Foo', '- 1) Foo', '* 1. Foo' -> {'1': 'Foo'}"""
231230
options: Dict[str, str] = {}
232231
if not text:
233232
return options
234233
for raw in text.splitlines():
235234
line = raw.strip()
236235
if not line:
237236
continue
238-
# Common Markdown patterns: "1. Label", "1) Label", "- 1) Label", "* 1. Label"
239237
m = re.match(r"^(?:[-*]\s*)?(\d+)[\.)]?\s+(.*)$", line)
240238
if m:
241-
idx, label = m.group(1), m.group(2).strip()
242-
# Strip trailing markdown soft-break spaces
243-
label = label.rstrip()
244-
# Ignore labels that are too short or look like continuations
239+
idx, label = m.group(1), m.group(2).strip().rstrip()
245240
if len(label) >= 1:
246241
options[idx] = label
247242
return options
@@ -252,7 +247,6 @@ def _detect_field_from_prev_question(
252247
if not prev_text or not isinstance(schema, dict):
253248
return None
254249
t = prev_text.lower()
255-
# Heuristic synonyms for your onboarding schema
256250
patterns = [
257251
("ideas", ["which ideas attract you", "ideas"]),
258252
("description", ["please describe yourself", "describe yourself"]),
@@ -269,7 +263,6 @@ def _detect_field_from_prev_question(
269263
for field, keys in patterns:
270264
if field in candidates and any(key in t for key in keys):
271265
return field
272-
# Fallback: property name appears directly
273266
for field in candidates:
274267
if field in t:
275268
return field
@@ -322,33 +315,41 @@ def _detect_field_from_prev_question(
322315
required_fields = list(
323316
(active_capture_schema or {}).get("required", []) or []
324317
)
325-
# Prefer a field detected from prev assistant; else if exactly one required missing, use it
326-
target_field: Optional[str] = _detect_field_from_prev_question(
327-
prev_assistant, active_capture_schema
328-
)
318+
all_fields = list(props.keys())
319+
optional_fields = [
320+
f for f in all_fields if f not in set(required_fields)
321+
]
322+
329323
active_data_existing = (
330324
latest_by_name.get(active_capture_name, {}) or {}
331325
).get("data", {}) or {}
332326

333-
def _missing_required() -> List[str]:
327+
def _missing(fields: List[str]) -> List[str]:
334328
return [
335329
f
336-
for f in required_fields
330+
for f in fields
337331
if not _non_empty(active_data_existing.get(f))
338332
]
339333

334+
missing_required = _missing(required_fields)
335+
missing_optional = _missing(optional_fields)
336+
337+
target_field: Optional[str] = _detect_field_from_prev_question(
338+
prev_assistant, active_capture_schema
339+
)
340340
if not target_field:
341-
missing = _missing_required()
342-
if len(missing) == 1:
343-
target_field = missing[0]
341+
# If exactly one required missing, target it; else if none required missing and exactly one optional missing, target it.
342+
if len(missing_required) == 1:
343+
target_field = missing_required[0]
344+
elif len(missing_required) == 0 and len(missing_optional) == 1:
345+
target_field = missing_optional[0]
344346

345-
if target_field:
347+
if target_field and target_field in props:
346348
f_schema = props.get(target_field, {}) or {}
347349
f_type = f_schema.get("type")
348350
number_to_label = _extract_numbered_options(prev_assistant)
349351

350352
if number_to_label:
351-
# Map any numbers in user's reply to their labels
352353
nums = _parse_numbers_list(user_text)
353354
labels = [
354355
number_to_label[n] for n in nums if n in number_to_label
@@ -359,7 +360,6 @@ def _missing_required() -> List[str]:
359360
else:
360361
incremental[target_field] = labels[0]
361362

362-
# If we didn't map via options, fallback to type-based parse
363363
if target_field not in incremental:
364364
if f_type == "number":
365365
m = re.search(r"\b([0-9]+(?:\.[0-9]+)?)\b", user_text)
@@ -369,19 +369,17 @@ def _missing_required() -> List[str]:
369369
except Exception:
370370
pass
371371
elif f_type == "array":
372-
# Accept CSV-style input as array of strings
373372
parts = [
374373
p.strip()
375374
for p in re.split(r"[,\n;]+", user_text)
376375
if p.strip()
377376
]
378377
if parts:
379378
incremental[target_field] = parts
380-
else: # string/default
379+
else:
381380
if user_text.strip():
382381
incremental[target_field] = user_text.strip()
383382

384-
# Filter out empty junk and save
385383
if incremental:
386384
cleaned = {
387385
k: v for k, v in incremental.items() if _non_empty(v)
@@ -397,6 +395,7 @@ def _missing_required() -> List[str]:
397395
)
398396
except Exception as se:
399397
logger.error(f"Error saving incremental capture: {se}")
398+
400399
except Exception as e:
401400
logger.debug(f"Incremental extraction skipped: {e}")
402401

@@ -411,19 +410,33 @@ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
411410

412411
lines: List[str] = []
413412
if active_capture_name and isinstance(active_capture_schema, dict):
414-
active_data = _get_active_data(active_capture_name)
413+
props = (active_capture_schema or {}).get("properties", {})
415414
required_fields = list(
416415
(active_capture_schema or {}).get("required", []) or []
417416
)
418-
missing = [
419-
f for f in required_fields if not _non_empty(active_data.get(f))
417+
all_fields = list(props.keys())
418+
optional_fields = [
419+
f for f in all_fields if f not in set(required_fields)
420420
]
421-
form_complete = len(missing) == 0 and len(required_fields) > 0
421+
422+
active_data = _get_active_data(active_capture_name)
423+
424+
def _missing_from(data: Dict[str, Any], fields: List[str]) -> List[str]:
425+
return [f for f in fields if not _non_empty(data.get(f))]
426+
427+
missing_required = _missing_from(active_data, required_fields)
428+
missing_optional = _missing_from(active_data, optional_fields)
429+
430+
required_complete = (
431+
len(missing_required) == 0 and len(required_fields) > 0
432+
)
433+
form_complete = required_complete and len(missing_optional) == 0
422434

423435
lines.append(
424436
"CAPTURED FORM STATE (Authoritative; do not re-ask filled values):"
425437
)
426438
lines.append(f"- form_name: {active_capture_name}")
439+
427440
if active_data:
428441
pairs = [
429442
f"{k}: {v}" for k, v in active_data.items() if _non_empty(v)
@@ -433,8 +446,12 @@ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
433446
)
434447
else:
435448
lines.append("- filled_fields: (none)")
449+
450+
lines.append(
451+
f"- missing_required_fields: {', '.join(missing_required) if missing_required else '(none)'}"
452+
)
436453
lines.append(
437-
f"- missing_required_fields: {', '.join(missing) if missing else '(none)'}"
454+
f"- missing_optional_fields: {', '.join(missing_optional) if missing_optional else '(none)'}"
438455
)
439456
lines.append("")
440457

@@ -455,7 +472,7 @@ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
455472
if lines:
456473
capture_context = "\n".join(lines) + "\n\n"
457474

458-
# Merge contexts
475+
# Merge contexts + flow rules
459476
combined_context = ""
460477
if capture_context:
461478
combined_context += capture_context
@@ -470,9 +487,11 @@ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
470487
"- Prefer KB/tools for facts.\n"
471488
"- History is for tone and continuity.\n\n"
472489
"FORM FLOW RULES:\n"
473-
"- Ask exactly one missing required field per turn.\n"
490+
"- Ask exactly one field per turn.\n"
491+
"- If any required fields are missing, ask the next missing required field.\n"
492+
"- If all required fields are filled but optional fields are missing, ask the next missing optional field.\n"
474493
"- Do NOT re-ask or verify values present in Captured User Data (auto-saved, authoritative).\n"
475-
"- If no required fields are missing, proceed without further capture questions.\n\n"
494+
"- Do NOT provide summaries until no required or optional fields are missing.\n\n"
476495
)
477496

478497
# 8) Generate response
@@ -510,7 +529,7 @@ def _get_active_data(name: Optional[str]) -> Dict[str, Any]:
510529
except Exception:
511530
pass
512531

513-
# If form is complete, ask for structured output JSON
532+
# Only run final structured output when no required or optional fields are missing
514533
if capture_schema and capture_name and form_complete:
515534
try:
516535
DynamicModel = self._build_model_from_json_schema(
@@ -739,5 +758,5 @@ def py_type(js: Dict[str, Any]):
739758
else:
740759
fields[field_name] = (typ, default)
741760

742-
Model = create_model(name, **fields) # type: ignore
761+
Model = create_model(name, **fields)
743762
return Model

0 commit comments

Comments
 (0)