11package mistral
22
33import (
4+ "encoding/json"
5+
46 "github.com/mutablelogic/go-llm"
7+ "github.com/mutablelogic/go-llm/pkg/tool"
58)
69
710///////////////////////////////////////////////////////////////////////////////
@@ -12,6 +15,21 @@ type Completions []Completion
1215
1316var _ llm.Completion = Completions {}
1417
18+ // Message with text or object content
19+ type Message struct {
20+ RoleContent
21+ ToolCallArray `json:"tool_calls,omitempty"`
22+ }
23+
24+ type RoleContent struct {
25+ Role string `json:"role,omitempty"` // assistant, user, tool, system
26+ Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool
27+ Name string `json:"name,omitempty"` // function name - when role is tool
28+ Content any `json:"content,omitempty"` // string or array of text, reference, image_url
29+ }
30+
31+ var _ llm.Completion = (* Message )(nil )
32+
1533// Completion Variation
1634type Completion struct {
1735 Index uint64 `json:"index"`
@@ -20,23 +38,15 @@ type Completion struct {
2038 Reason string `json:"finish_reason,omitempty"`
2139}
2240
23- // Message with text or object content
24- type Message struct {
25- Role string `json:"role,omitempty"` // assistant, user, tool, system
26- Prefix bool `json:"prefix,omitempty"`
27- Content any `json:"content,omitempty"`
28- ToolCalls `json:"tool_calls,omitempty"`
29- }
30-
3141type Content struct {
32- Type string `json:"type"` // text, reference, image_url
42+ Type string `json:"type,omitempty "` // text, reference, image_url
3343 * Text `json:"text,omitempty"` // text content
3444 * Prediction `json:"content,omitempty"` // prediction
3545 * Image `json:"image_url,omitempty"` // image_url
3646}
3747
3848// A set of tool calls
39- type ToolCalls []ToolCall
49+ type ToolCallArray []ToolCall
4050
4151// text content
4252type Text string
@@ -78,62 +88,87 @@ func NewImageAttachment(a *llm.Attachment) *Content {
7888}
7989
8090///////////////////////////////////////////////////////////////////////////////
81- // PUBLIC METHODS
91+ // PUBLIC METHODS - MESSAGE
8292
83- // Return the number of completions
84- func (c Completions ) Num () int {
85- return len (c )
93+ func (m Message ) Num () int {
94+ return 1
8695}
8796
88- // Return the role of the completion
89- func (c Completions ) Role () string {
90- // The role should be the same for all completions, let's use the first one
91- if len (c ) == 0 {
92- return ""
93- }
94- return c [0 ].Message .Role
97+ func (m Message ) Role () string {
98+ return m .RoleContent .Role
9599}
96100
97- // Return the text content for a specific completion
98- func (c Completions ) Text (index int ) string {
99- if index < 0 || index >= len (c ) {
101+ func (m Message ) Text (index int ) string {
102+ if index != 0 {
100103 return ""
101104 }
102- completion := c [ index ]. Message
103- if text , ok := completion .Content .(string ); ok {
105+ // If content is text, return it
106+ if text , ok := m .Content .(string ); ok {
104107 return text
105108 }
106- // TODO: Will the text be in other forms?
109+ // For other kinds, return empty string for the moment
107110 return ""
108111}
109112
110- // Return the current session tool calls given the completion index.
111- // Will return nil if no tool calls were returned.
112- func (c Completions ) ToolCalls (index int ) []llm.ToolCall {
113- if index < 0 || index >= len (c ) {
114- return nil
115- }
116-
117- // Get the completion
118- completion := c [index ].Message
119- if completion == nil {
113+ func (m Message ) ToolCalls (index int ) []llm.ToolCall {
114+ if index != 0 {
120115 return nil
121116 }
122117
123118 // Make the tool calls
124- calls := make ([]llm.ToolCall , 0 , len (completion .ToolCalls ))
125- for _ , call := range completion .ToolCalls {
126- calls = append (calls , & toolcall {call })
119+ calls := make ([]llm.ToolCall , 0 , len (m .ToolCallArray ))
120+ for _ , call := range m .ToolCallArray {
121+ var args map [string ]any
122+ if call .Function .Arguments != "" {
123+ if err := json .Unmarshal ([]byte (call .Function .Arguments ), & args ); err != nil {
124+ return nil
125+ }
126+ }
127+ calls = append (calls , tool .NewCall (call .Id , call .Function .Name , args ))
127128 }
128129
129130 // Return success
130131 return calls
131132}
132133
134+ ///////////////////////////////////////////////////////////////////////////////
135+ // PUBLIC METHODS - COMPLETIONS
136+
137+ // Return the number of completions
138+ func (c Completions ) Num () int {
139+ return len (c )
140+ }
141+
133142// Return message for a specific completion
134143func (c Completions ) Message (index int ) * Message {
135144 if index < 0 || index >= len (c ) {
136145 return nil
137146 }
138147 return c [index ].Message
139148}
149+
150+ // Return the role of the completion
151+ func (c Completions ) Role () string {
152+ // The role should be the same for all completions, let's use the first one
153+ if len (c ) == 0 {
154+ return ""
155+ }
156+ return c [0 ].Message .Role ()
157+ }
158+
159+ // Return the text content for a specific completion
160+ func (c Completions ) Text (index int ) string {
161+ if index < 0 || index >= len (c ) {
162+ return ""
163+ }
164+ return c [index ].Message .Text (0 )
165+ }
166+
167+ // Return the current session tool calls given the completion index.
168+ // Will return nil if no tool calls were returned.
169+ func (c Completions ) ToolCalls (index int ) []llm.ToolCall {
170+ if index < 0 || index >= len (c ) {
171+ return nil
172+ }
173+ return c [index ].Message .ToolCalls (0 )
174+ }
0 commit comments