Skip to content

Commit 8c0c13f

Browse files
fix(ace): Ensure only input fields are passed to curator (#456)
The AxACE optimizer was incorrectly passing the entire example, including output fields, to the curator's `question_context`. This caused the curator to generate a playbook that confused inputs and outputs. This commit fixes the issue by filtering the example object based on the program's input signature before serializing it to `question_context`. Additionally, this commit: - Adds a unit test to verify that the `question_context` is correctly formed. - Removes several `console.log` statements from the test suite. - Refactors the test suite to use the `flow()` factory function instead of the deprecated `new AxFlow()` constructor. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent 2f3e6ac commit 8c0c13f

File tree

8 files changed

+282
-232
lines changed

8 files changed

+282
-232
lines changed

src/ax/dsp/optimizer.test.ts

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,66 @@
1-
import { describe, expect, it } from 'vitest';
1+
import { describe, expect, it, vi } from 'vitest';
22

33
import type { AxAIService } from '../ai/types.js';
44

55
import type { AxOptimizer } from './optimizer.js';
66
import { AxBootstrapFewShot } from './optimizers/bootstrapFewshot.js';
7+
import { AxACE } from './optimizers/ace.js';
78
import { AxMiPRO } from './optimizers/miproV2.js';
9+
import { f } from './sig.js';
10+
import { ax } from './template.js';
811

912
// Mock dependencies
1013
const mockAI = {
1114
name: 'mock',
12-
chat: async () => ({ results: [{ index: 0, content: 'mock response' }] }),
15+
chat: async () => ({
16+
results: [
17+
{
18+
index: 0,
19+
content: JSON.stringify({
20+
answer: 'mock student response',
21+
}),
22+
},
23+
],
24+
}),
25+
getOptions: () => ({ tracer: undefined }),
26+
getLogger: () => undefined,
1327
} as unknown as AxAIService;
1428

15-
// Removed unused mockProgram
16-
1729
const mockExamples = [
1830
{ input: 'test input', output: 'test output' },
1931
{ input: 'test input 2', output: 'test output 2' },
2032
];
2133

34+
describe('AxACE', () => {
35+
it('should call runCurator with the correct context', async () => {
36+
const ace = new AxACE({
37+
studentAI: mockAI,
38+
});
39+
40+
const runCuratorSpy = vi
41+
.spyOn(ace as any, 'runCurator')
42+
.mockResolvedValue(undefined);
43+
44+
const program = ax(
45+
f().input('question', f.string()).output('answer', f.string()).build()
46+
);
47+
const examples = [
48+
{ question: 'q1', answer: 'a1' },
49+
{ question: 'q2', answer: 'a2' },
50+
];
51+
const metricFn = () => 1;
52+
53+
await ace.compile(program, examples, metricFn, {
54+
aceOptions: { maxEpochs: 1, maxReflectorRounds: 1 },
55+
});
56+
57+
expect(runCuratorSpy).toHaveBeenCalled();
58+
const curatorCall = runCuratorSpy.mock.calls[0][0];
59+
expect(curatorCall.example).toHaveProperty('question');
60+
expect(curatorCall.example).toHaveProperty('answer');
61+
});
62+
});
63+
2264
describe('Optimizer Interface', () => {
2365
it('AxBootstrapFewShot implements AxOptimizer interface', () => {
2466
const optimizer = new AxBootstrapFewShot({

src/ax/dsp/optimizers/ace.test.ts

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
import { describe, expect, it } from 'vitest';
1+
import { describe, expect, it, vi } from 'vitest';
2+
3+
import { ax } from '../template.js';
4+
import { f } from '../sig.js';
5+
import type { AxAIService } from '../../ai/types';
26

37
import { applyCuratorOperations, createEmptyPlaybook } from './acePlaybook.js';
48
import { AxACE } from './ace.js';
@@ -128,3 +132,56 @@ describe('AxACE helpers', () => {
128132
expect(newBullet?.content).toBe('Third tactic');
129133
});
130134
});
135+
136+
describe('AxACE', () => {
137+
it('runCurator should only receive input fields in question_context', async () => {
138+
const mockCuratorAI = {
139+
name: 'mockCurator',
140+
chat: vi.fn().mockResolvedValue({
141+
results: [
142+
{
143+
index: 0,
144+
content: '{"reasoning": "mock", "operations":[]}',
145+
},
146+
],
147+
}),
148+
getOptions: () => ({ tracer: undefined }),
149+
getLogger: () => undefined,
150+
} as unknown as AxAIService;
151+
152+
const program = ax(
153+
f().input('question', f.string()).output('answer', f.string()).build()
154+
);
155+
156+
const example = {
157+
question: 'This is the input',
158+
answer: 'This is the output',
159+
};
160+
161+
const ace = new AxACE({
162+
studentAI: {} as any,
163+
teacherAI: mockCuratorAI,
164+
});
165+
166+
const curatorProgram = (ace as any).getOrCreateCuratorProgram();
167+
const forwardSpy = vi.spyOn(curatorProgram, 'forward');
168+
169+
// Directly call the internal runCurator method for a focused unit test
170+
await (ace as any).runCurator({
171+
program,
172+
example,
173+
reflection: { keyInsight: 'test' }, // Minimal reflection to trigger curator
174+
playbook: { sections: {}, stats: { bulletCount: 0 } },
175+
});
176+
177+
expect(forwardSpy).toHaveBeenCalled();
178+
179+
const forwardArgs = forwardSpy.mock.calls[0][1] as any;
180+
const receivedContext = JSON.parse(forwardArgs.question_context);
181+
182+
expect(receivedContext).toBeDefined();
183+
expect(receivedContext).toHaveProperty('question');
184+
expect(receivedContext.question).toBe('This is the input');
185+
expect(receivedContext).not.toHaveProperty('answer');
186+
});
187+
});

src/ax/dsp/optimizers/ace.ts

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ export class AxACE extends AxBaseOptimizer {
148148

149149
private curatorProgram?: AxGen<any, any>;
150150

151+
private program?: Readonly<AxGen<any, any>>;
152+
151153
constructor(
152154
args: Readonly<AxOptimizerArgs>,
153155
options?: Readonly<AxACEOptions>
@@ -220,6 +222,7 @@ export class AxACE extends AxBaseOptimizer {
220222

221223
const startTime = Date.now();
222224
this.validateExamples(examples);
225+
this.program = program;
223226

224227
const baseInstruction = await this.extractProgramInstruction(program);
225228
const originalDescription = program.getSignature().getDescription() ?? '';
@@ -287,6 +290,7 @@ export class AxACE extends AxBaseOptimizer {
287290
});
288291

289292
const rawCurator = await this.runCurator({
293+
program,
290294
example,
291295
reflection,
292296
playbook: this.playbook,
@@ -458,6 +462,12 @@ export class AxACE extends AxBaseOptimizer {
458462
feedback?: string;
459463
}>
460464
): Promise<AxACECuratorOutput | undefined> {
465+
if (!this.program) {
466+
throw new Error(
467+
'AxACE: `compile` must be run before `applyOnlineUpdate`'
468+
);
469+
}
470+
461471
const generatorOutput = this.createGeneratorOutput(
462472
args.prediction,
463473
args.example
@@ -480,6 +490,7 @@ export class AxACE extends AxBaseOptimizer {
480490
});
481491

482492
const rawCurator = await this.runCurator({
493+
program: this.program,
483494
example: args.example,
484495
reflection,
485496
playbook: this.playbook,
@@ -992,11 +1003,13 @@ export class AxACE extends AxBaseOptimizer {
9921003
}
9931004
}
9941005

995-
private async runCurator({
1006+
private async runCurator<IN, OUT extends AxGenOut>({
1007+
program,
9961008
example,
9971009
reflection,
9981010
playbook,
9991011
}: Readonly<{
1012+
program: Readonly<AxGen<IN, OUT>>;
10001013
example: AxExample;
10011014
reflection?: AxACEReflectionOutput;
10021015
playbook: AxACEPlaybook;
@@ -1008,14 +1021,26 @@ export class AxACE extends AxBaseOptimizer {
10081021
const curator = this.getOrCreateCuratorProgram();
10091022
const curatorAI = this.teacherAI ?? this.studentAI;
10101023

1024+
const signature = program.getSignature();
1025+
const inputFields = signature.getInputFields();
1026+
const questionContext = inputFields.reduce(
1027+
(acc, field) => {
1028+
if (field.name in example) {
1029+
acc[field.name] = example[field.name as keyof typeof example];
1030+
}
1031+
return acc;
1032+
},
1033+
{} as Record<string, unknown>
1034+
);
1035+
10111036
try {
10121037
const outputRaw = await curator.forward(curatorAI, {
10131038
playbook: JSON.stringify({
10141039
markdown: renderPlaybook(playbook),
10151040
structured: playbook,
10161041
}),
10171042
reflection: JSON.stringify(reflection),
1018-
question_context: JSON.stringify(example),
1043+
question_context: JSON.stringify(questionContext),
10191044
token_budget: 1024,
10201045
});
10211046

src/ax/dsp/sigtypes-runtime.test.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ describe('TypeScript Parser Parity with JS Parser', () => {
6969
expect(parsed.outputs.length).toBeGreaterThan(0);
7070
}).not.toThrow(`JS Parser failed for: "${signature}"`);
7171
}
72-
73-
console.log(
74-
`✅ JS Parser successfully handled ${testCases.length} test cases`
75-
);
7672
});
7773

7874
test('should demonstrate TypeScript type inference capabilities', () => {
@@ -97,7 +93,5 @@ describe('TypeScript Parser Parity with JS Parser', () => {
9793
};
9894

9995
expect(simpleClassExample.outputs.category).toBe('positive');
100-
101-
console.log('✅ TypeScript type inference working for supported cases');
10296
});
10397
});

src/ax/flow/flow-extensions.test.ts

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import { describe, expect, it } from 'vitest';
22

3-
import { AxFlow } from './flow.js';
3+
import { flow } from './flow.js';
44
import { f } from '../dsp/sig.js';
55

66
describe('AxFlow nodeExtended method', () => {
77
it('should create chain-of-thought node with internal reasoning field', () => {
8-
const flow = new AxFlow();
9-
const cotFlow = flow.nodeExtended(
8+
const cotFlow = flow().nodeExtended(
109
'reasoner',
1110
'userInput:string -> answer:string',
1211
{
@@ -32,8 +31,7 @@ describe('AxFlow nodeExtended method', () => {
3231
});
3332

3433
it('should create confidence-scoring node', () => {
35-
const flow = new AxFlow();
36-
const confFlow = flow.nodeExtended(
34+
const confFlow = flow().nodeExtended(
3735
'scorer',
3836
'userInput:string -> analysis:string',
3937
{
@@ -56,8 +54,7 @@ describe('AxFlow nodeExtended method', () => {
5654
});
5755

5856
it('should create contextual node with additional input fields', () => {
59-
const flow = new AxFlow();
60-
const contextFlow = flow.nodeExtended(
57+
const contextFlow = flow().nodeExtended(
6158
'contextual',
6259
'question:string -> answer:string',
6360
{
@@ -82,8 +79,7 @@ describe('AxFlow nodeExtended method', () => {
8279
});
8380

8481
it('should create extended node with all extension types', () => {
85-
const flow = new AxFlow();
86-
const extendedFlow = flow.nodeExtended(
82+
const extendedFlow = flow().nodeExtended(
8783
'analyzer',
8884
'userInput:string -> analysis:string',
8985
{
@@ -130,10 +126,8 @@ describe('AxFlow nodeExtended method', () => {
130126
});
131127

132128
it('should maintain type safety and prevent duplicate field names', () => {
133-
const flow = new AxFlow();
134-
135129
expect(() =>
136-
flow.nodeExtended('test', 'userInput:string -> analysis:string', {
130+
flow().nodeExtended('test', 'userInput:string -> analysis:string', {
137131
appendInputs: [
138132
{
139133
name: 'userInput',
@@ -145,10 +139,10 @@ describe('AxFlow nodeExtended method', () => {
145139
});
146140

147141
it('should work with AxSignature instances as base', () => {
148-
const flow = new AxFlow();
149-
const baseSig = flow.getSignature(); // Get default signature
142+
const myFlow = flow();
143+
const baseSig = myFlow.getSignature(); // Get default signature
150144

151-
const extendedFlow = flow.nodeExtended('thinker', baseSig, {
145+
const extendedFlow = myFlow.nodeExtended('thinker', baseSig, {
152146
prependOutputs: [
153147
{
154148
name: 'reasoning',
@@ -165,9 +159,7 @@ describe('AxFlow nodeExtended method', () => {
165159
});
166160

167161
it('should support method chaining', () => {
168-
const flow = new AxFlow();
169-
170-
const chainedFlow = flow
162+
const chainedFlow = flow()
171163
.nodeExtended('reasoner', 'question:string -> analysis:string', {
172164
prependOutputs: [
173165
{
@@ -188,11 +180,9 @@ describe('AxFlow nodeExtended method', () => {
188180
});
189181

190182
it('should validate field types according to input/output rules', () => {
191-
const flow = new AxFlow();
192-
193183
// Class types not allowed in input
194184
expect(() =>
195-
flow.nodeExtended('test', 'userInput:string -> analysis:string', {
185+
flow().nodeExtended('test', 'userInput:string -> analysis:string', {
196186
appendInputs: [
197187
{
198188
name: 'category',
@@ -204,7 +194,7 @@ describe('AxFlow nodeExtended method', () => {
204194

205195
// Image types not allowed in output
206196
expect(() =>
207-
flow.nodeExtended('test', 'userInput:string -> analysis:string', {
197+
flow().nodeExtended('test', 'userInput:string -> analysis:string', {
208198
appendOutputs: [
209199
{
210200
name: 'outputImage',
@@ -216,10 +206,8 @@ describe('AxFlow nodeExtended method', () => {
216206
});
217207

218208
it('should have nx alias that works identically to nodeExtended', () => {
219-
const flow = new AxFlow();
220-
221209
// Test nx alias with same functionality as nodeExtended
222-
const nxFlow = flow.nx('reasoner', 'userInput:string -> answer:string', {
210+
const nxFlow = flow().nx('reasoner', 'userInput:string -> answer:string', {
223211
prependOutputs: [
224212
{
225213
name: 'reasoning',

0 commit comments

Comments
 (0)