From 094c26162e6841fc52f1aa7f0a4a03341427564c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Bedi?= Date: Wed, 10 Sep 2025 13:58:26 +0200 Subject: [PATCH] Add error source and use context logger --- pkg/mqtt/client.go | 43 +++++++++++++++++++---------------- pkg/mqtt/framer.go | 16 ++++++------- pkg/mqtt/framer_test.go | 5 ++-- pkg/mqtt/topic.go | 12 +++++----- pkg/plugin/datasource.go | 4 ++-- pkg/plugin/datasource_test.go | 20 ++++++---------- pkg/plugin/query.go | 7 ++---- pkg/plugin/stream.go | 26 ++++++++++----------- pkg/plugin/stream_test.go | 5 ++-- 9 files changed, 67 insertions(+), 71 deletions(-) diff --git a/pkg/mqtt/client.go b/pkg/mqtt/client.go index 7aa9488..ed45b12 100644 --- a/pkg/mqtt/client.go +++ b/pkg/mqtt/client.go @@ -1,6 +1,7 @@ package mqtt import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -10,14 +11,15 @@ import ( "time" paho "github.com/eclipse/paho.mqtt.golang" + "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana-plugin-sdk-go/backend/log" ) type Client interface { GetTopic(string) (*Topic, bool) IsConnected() bool - Subscribe(string) *Topic - Unsubscribe(string) + Subscribe(string, log.Logger) *Topic + Unsubscribe(string, log.Logger) Dispose() } @@ -37,7 +39,8 @@ type client struct { topics TopicMap } -func NewClient(o Options) (Client, error) { +func NewClient(ctx context.Context, o Options) (Client, error) { + logger := log.DefaultLogger.FromContext(ctx) opts := paho.NewClientOptions() opts.AddBroker(o.URI) @@ -63,7 +66,7 @@ func NewClient(o Options) (Client, error) { if o.TLSClientCert != "" || o.TLSClientKey != "" { cert, err := tls.X509KeyPair([]byte(o.TLSClientCert), []byte(o.TLSClientKey)) if err != nil { - return nil, fmt.Errorf("failed to setup TLSClientCert: %w", err) + return nil, backend.DownstreamErrorf("failed to setup TLSClientCert: %w", err) } tlsConfig.Certificates = append(tlsConfig.Certificates, cert) @@ -82,17 +85,17 @@ func NewClient(o Options) (Client, error) { opts.SetCleanSession(false) opts.SetMaxReconnectInterval(10 * time.Second) opts.SetConnectionLostHandler(func(c paho.Client, err error) { - log.DefaultLogger.Error("MQTT Connection lost", "error", err) + logger.Warn("MQTT Connection lost", "error", err) }) opts.SetReconnectingHandler(func(c paho.Client, options *paho.ClientOptions) { - log.DefaultLogger.Debug("MQTT Reconnecting") + logger.Debug("MQTT Reconnecting") }) - log.DefaultLogger.Info("MQTT Connecting", "clientID", clientID) + logger.Info("MQTT Connecting", "clientID", clientID) pahoClient := paho.NewClient(opts) if token := pahoClient.Connect(); token.Wait() && token.Error() != nil { - return nil, fmt.Errorf("error connecting to MQTT broker: %s", token.Error()) + return nil, backend.DownstreamErrorf("error connecting to MQTT broker: %s", token.Error()) } return &client{ @@ -117,7 +120,7 @@ func (c *client) GetTopic(reqPath string) (*Topic, bool) { return c.topics.Load(reqPath) } -func (c *client) Subscribe(reqPath string) *Topic { +func (c *client) Subscribe(reqPath string, logger log.Logger) *Topic { // Check if there's already a topic with this exact key (reqPath) if existingTopic, ok := c.topics.Load(reqPath); ok { return existingTopic @@ -125,12 +128,12 @@ func (c *client) Subscribe(reqPath string) *Topic { chunks := strings.Split(reqPath, "/") if len(chunks) < 2 { - log.DefaultLogger.Error("Invalid path", "path", reqPath) + logger.Error("Invalid path", "path", reqPath) return nil } interval, err := time.ParseDuration(chunks[0]) if err != nil { - log.DefaultLogger.Error("Invalid interval", "path", reqPath, "interval", chunks[0]) + logger.Error("Invalid interval", "path", reqPath, "interval", chunks[0]) return nil } @@ -145,27 +148,27 @@ func (c *client) Subscribe(reqPath string) *Topic { Interval: interval, } - topic, err := decodeTopic(t.Path) + topic, err := decodeTopic(t.Path, logger) if err != nil { - log.DefaultLogger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", err) + logger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", backend.DownstreamError(err)) return nil } - log.DefaultLogger.Debug("Subscribing to MQTT topic", "topic", topic) + logger.Debug("Subscribing to MQTT topic", "topic", topic) if token := c.client.Subscribe(topic, 0, func(_ paho.Client, m paho.Message) { // by wrapping HandleMessage we can directly get the correct topicPath for the incoming topic // and don't need to regex it against + and #. c.HandleMessage(topicPath, []byte(m.Payload())) }); token.Wait() && token.Error() != nil { - log.DefaultLogger.Error("Error subscribing to MQTT topic", "topic", topic, "error", token.Error()) + logger.Error("Error subscribing to MQTT topic", "topic", topic, "error", backend.DownstreamError(token.Error())) } // Store the topic using reqPath as the key (which includes streaming key) c.topics.Map.Store(reqPath, t) return t } -func (c *client) Unsubscribe(reqPath string) { +func (c *client) Unsubscribe(reqPath string, logger log.Logger) { t, ok := c.GetTopic(reqPath) if !ok { return @@ -178,16 +181,16 @@ func (c *client) Unsubscribe(reqPath string) { return } - log.DefaultLogger.Debug("Unsubscribing from MQTT topic", "topic", t.Path) + logger.Debug("Unsubscribing from MQTT topic", "topic", t.Path) - topic, err := decodeTopic(t.Path) + topic, err := decodeTopic(t.Path, logger) if err != nil { - log.DefaultLogger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", err) + logger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", backend.DownstreamError(err)) return } if token := c.client.Unsubscribe(topic); token.Wait() && token.Error() != nil { - log.DefaultLogger.Error("Error unsubscribing from MQTT topic", "topic", t.Path, "error", token.Error()) + logger.Error("Error unsubscribing from MQTT topic", "topic", t.Path, "error", backend.DownstreamError(token.Error())) } } diff --git a/pkg/mqtt/framer.go b/pkg/mqtt/framer.go index 434fa49..952afd8 100644 --- a/pkg/mqtt/framer.go +++ b/pkg/mqtt/framer.go @@ -17,7 +17,7 @@ type framer struct { fieldMap map[string]int } -func (df *framer) next() error { +func (df *framer) next(logger log.Logger) error { switch df.iterator.WhatIsNext() { case jsoniter.StringValue: v := df.iterator.ReadString() @@ -29,7 +29,7 @@ func (df *framer) next() error { v := df.iterator.ReadBool() df.addValue(data.FieldTypeNullableBool, &v) case jsoniter.NilValue: - df.addNil() + df.addNil(logger) df.iterator.ReadNil() case jsoniter.ArrayValue: df.addValue(data.FieldTypeJSON, json.RawMessage(df.iterator.SkipAndReturnBytes())) @@ -42,7 +42,7 @@ func (df *framer) next() error { for fname := df.iterator.ReadObject(); fname != ""; fname = df.iterator.ReadObject() { if size == 0 { df.path = append(df.path, fname) - if err := df.next(); err != nil { + if err := df.next(logger); err != nil { return err } } @@ -61,12 +61,12 @@ func (df *framer) key() string { return strings.Join(df.path, "") } -func (df *framer) addNil() { +func (df *framer) addNil(logger log.Logger) { if idx, ok := df.fieldMap[df.key()]; ok { df.fields[idx].Set(0, nil) return } - log.DefaultLogger.Debug("nil value for unknown field", "key", df.key()) + logger.Debug("nil value for unknown field", "key", df.key()) } func (df *framer) addValue(fieldType data.FieldType, v interface{}) { @@ -96,7 +96,7 @@ func newFramer() *framer { return df } -func (df *framer) toFrame(messages []Message) (*data.Frame, error) { +func (df *framer) toFrame(messages []Message, logger log.Logger) (*data.Frame, error) { // clear the data in the fields for _, field := range df.fields { for i := field.Len() - 1; i >= 0; i-- { @@ -106,10 +106,10 @@ func (df *framer) toFrame(messages []Message) (*data.Frame, error) { for _, message := range messages { df.iterator = jsoniter.ParseBytes(jsoniter.ConfigDefault, message.Value) - err := df.next() + err := df.next(logger) if err != nil { // If JSON parsing fails, treat the raw bytes as a string value - log.DefaultLogger.Debug("JSON parsing failed, treating as raw string", "error", err, "value", string(message.Value)) + logger.Debug("JSON parsing failed, treating as raw string", "error", err, "value", string(message.Value)) rawValue := string(message.Value) df.addValue(data.FieldTypeNullableString, &rawValue) } diff --git a/pkg/mqtt/framer_test.go b/pkg/mqtt/framer_test.go index 7175a13..5895c68 100644 --- a/pkg/mqtt/framer_test.go +++ b/pkg/mqtt/framer_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/grafana-plugin-sdk-go/experimental" "github.com/stretchr/testify/require" ) @@ -86,7 +87,7 @@ func runTest(t *testing.T, name string, values ...any) { for i, v := range values { messages = append(messages, Message{Timestamp: timestamp.Add(time.Duration(i) * time.Minute), Value: toJSON(v)}) } - frame, err := f.toFrame(messages) + frame, err := f.toFrame(messages, log.DefaultLogger) require.NoError(t, err) require.NotNil(t, frame) experimental.CheckGoldenJSONFrame(t, "testdata", name, frame, update) @@ -100,7 +101,7 @@ func runRawTest(t *testing.T, name string, rawValues ...[]byte) { for i, v := range rawValues { messages = append(messages, Message{Timestamp: timestamp.Add(time.Duration(i) * time.Minute), Value: v}) } - frame, err := f.toFrame(messages) + frame, err := f.toFrame(messages, log.DefaultLogger) require.NoError(t, err) require.NotNil(t, frame) experimental.CheckGoldenJSONFrame(t, "testdata", name, frame, update) diff --git a/pkg/mqtt/topic.go b/pkg/mqtt/topic.go index 029c4b1..8f1dbfd 100644 --- a/pkg/mqtt/topic.go +++ b/pkg/mqtt/topic.go @@ -33,11 +33,11 @@ func (t *Topic) Key() string { } // ToDataFrame converts the topic to a data frame. -func (t *Topic) ToDataFrame() (*data.Frame, error) { +func (t *Topic) ToDataFrame(logger log.Logger) (*data.Frame, error) { if t.framer == nil { t.framer = newFramer() } - return t.framer.toFrame(t.Messages) + return t.framer.toFrame(t.Messages, logger) } // TopicMap is a thread-safe map of topics @@ -58,7 +58,7 @@ func (tm *TopicMap) Load(key string) (*Topic, bool) { // AddMessage adds a message to the topic for the given path. func (tm *TopicMap) AddMessage(path string, message Message) { - tm.Map.Range(func(key, t any) bool { + tm.Range(func(key, t any) bool { topic, ok := t.(*Topic) if !ok { return false @@ -75,7 +75,7 @@ func (tm *TopicMap) AddMessage(path string, message Message) { func (tm *TopicMap) HasSubscription(path string) bool { found := false - tm.Map.Range(func(key, t any) bool { + tm.Range(func(key, t any) bool { topic, ok := t.(*Topic) if !ok { return true // this shouldn't happen, but continue iterating @@ -110,10 +110,10 @@ func (tm *TopicMap) Delete(key string) { // // To comply with these restrictions, the topic is encoded using URL-safe base64 // encoding. (RFC 4648; 5. Base 64 Encoding with URL and Filename Safe Alphabet) -func decodeTopic(topicPath string) (string, error) { +func decodeTopic(topicPath string, logger log.Logger) (string, error) { chunks := strings.Split(topicPath, "/") topic := chunks[0] - log.DefaultLogger.Debug("Decoding MQTT topic name", "encodedTopic", topic) + logger.Debug("Decoding MQTT topic name", "encodedTopic", topic) decoded, err := base64.RawURLEncoding.DecodeString(topic) if err != nil { diff --git a/pkg/plugin/datasource.go b/pkg/plugin/datasource.go index 83de2a4..660b52c 100644 --- a/pkg/plugin/datasource.go +++ b/pkg/plugin/datasource.go @@ -21,13 +21,13 @@ var ( ) // NewMQTTDatasource creates a new datasource instance. -func NewMQTTInstance(_ context.Context, s backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { +func NewMQTTInstance(ctx context.Context, s backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) { settings, err := getDatasourceSettings(s) if err != nil { return nil, err } - client, err := mqtt.NewClient(*settings) + client, err := mqtt.NewClient(ctx, *settings) if err != nil { return nil, err } diff --git a/pkg/plugin/datasource_test.go b/pkg/plugin/datasource_test.go index 96f8b65..329c6e7 100644 --- a/pkg/plugin/datasource_test.go +++ b/pkg/plugin/datasource_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/grafana/grafana-plugin-sdk-go/backend" + "github.com/grafana/grafana-plugin-sdk-go/backend/log" "github.com/grafana/mqtt-datasource/pkg/mqtt" "github.com/grafana/mqtt-datasource/pkg/plugin" "github.com/stretchr/testify/require" @@ -13,8 +14,7 @@ import ( func TestCheckHealthHandler(t *testing.T) { t.Run("HealthStatusOK when can connect", func(t *testing.T) { ds := plugin.NewMQTTDatasource(&fakeMQTTClient{ - connected: true, - subscribed: false, + connected: true, }, "xyz") res, _ := ds.CheckHealth( @@ -28,8 +28,7 @@ func TestCheckHealthHandler(t *testing.T) { t.Run("HealthStatusError when disconnected", func(t *testing.T) { ds := plugin.NewMQTTDatasource(&fakeMQTTClient{ - connected: false, - subscribed: false, + connected: false, }, "xyz") res, _ := ds.CheckHealth( @@ -43,8 +42,7 @@ func TestCheckHealthHandler(t *testing.T) { } type fakeMQTTClient struct { - connected bool - subscribed bool + connected bool } func (c *fakeMQTTClient) GetTopic(_ string) (*mqtt.Topic, bool) { @@ -55,10 +53,6 @@ func (c *fakeMQTTClient) IsConnected() bool { return c.connected } -func (c *fakeMQTTClient) IsSubscribed(_ string) bool { - return c.subscribed -} - -func (c *fakeMQTTClient) Subscribe(_ string) *mqtt.Topic { return nil } -func (c *fakeMQTTClient) Unsubscribe(_ string) {} -func (c *fakeMQTTClient) Dispose() {} +func (c *fakeMQTTClient) Subscribe(_ string, _ log.Logger) *mqtt.Topic { return nil } +func (c *fakeMQTTClient) Unsubscribe(_ string, _ log.Logger) {} +func (c *fakeMQTTClient) Dispose() {} diff --git a/pkg/plugin/query.go b/pkg/plugin/query.go index 9243b6a..329b6b3 100644 --- a/pkg/plugin/query.go +++ b/pkg/plugin/query.go @@ -3,7 +3,6 @@ package plugin import ( "context" "encoding/json" - "fmt" "path" "github.com/grafana/grafana-plugin-sdk-go/backend" @@ -30,13 +29,11 @@ func (ds *MQTTDatasource) query(query backend.DataQuery) backend.DataResponse { ) if err := json.Unmarshal(query.JSON, &t); err != nil { - response.Error = err - return response + return backend.ErrorResponseWithErrorSource(backend.DownstreamErrorf("failed to unmarshal query: %w", err)) } if t.Path == "" { - response.Error = fmt.Errorf("topic path is required") - return response + return backend.ErrorResponseWithErrorSource(backend.DownstreamErrorf("topic path is required")) } t.Interval = query.Interval diff --git a/pkg/plugin/stream.go b/pkg/plugin/stream.go index 5bda851..d916aca 100644 --- a/pkg/plugin/stream.go +++ b/pkg/plugin/stream.go @@ -2,7 +2,6 @@ package plugin import ( "context" - "fmt" "strconv" "strings" "time" @@ -19,42 +18,43 @@ func (ds *MQTTDatasource) RunStream(ctx context.Context, req *backend.RunStreamR // Channel path format: "ds/{uid}/{topicKey}" where topicKey includes streaming key // We need to remove the channelPrefix ("ds/{uid}") to get the topic key topicKey := strings.TrimPrefix(req.Path, ds.channelPrefix+"/") + logger := log.DefaultLogger.FromContext(ctx) chunks := strings.Split(topicKey, "/") if len(chunks) < 2 { - return fmt.Errorf("invalid topic key: %s", topicKey) + return backend.DownstreamErrorf("invalid topic key: %s", topicKey) } interval, err := time.ParseDuration(chunks[0]) if err != nil { - return err + return backend.DownstreamErrorf("invalid interval: %s", chunks[0]) } - ds.Client.Subscribe(topicKey) - defer ds.Client.Unsubscribe(topicKey) + ds.Client.Subscribe(topicKey, logger) + defer ds.Client.Unsubscribe(topicKey, logger) ticker := time.NewTicker(interval) for { select { case <-ctx.Done(): - log.DefaultLogger.Debug("stopped streaming (context canceled)", "path", req.Path, "topicKey", topicKey) + logger.Debug("stopped streaming (context canceled)", "path", req.Path, "topicKey", topicKey) ticker.Stop() return nil case <-ticker.C: topic, ok := ds.Client.GetTopic(topicKey) if !ok { - log.DefaultLogger.Debug("topic not found", "path", req.Path, "topicKey", topicKey) + logger.Debug("topic not found", "path", req.Path, "topicKey", topicKey) break } - frame, err := topic.ToDataFrame() + frame, err := topic.ToDataFrame(logger) if err != nil { - log.DefaultLogger.Error("failed to convert topic to data frame", "path", req.Path, "error", err) + logger.Error("failed to convert topic to data frame", "path", req.Path, "error", backend.DownstreamError(err)) break } topic.Messages = []mqtt.Message{} if err := sender.SendFrame(frame, data.IncludeAll); err != nil { - log.DefaultLogger.Error("failed to send data frame", "path", req.Path, "error", err) + logger.Error("failed to send data frame", "path", req.Path, "error", backend.DownstreamError(err)) } } @@ -68,21 +68,21 @@ func (ds *MQTTDatasource) SubscribeStream(ctx context.Context, req *backend.Subs if len(pathParts) < 5 { return &backend.SubscribeStreamResponse{ Status: backend.SubscribeStreamStatusNotFound, - }, fmt.Errorf("invalid channel path format") + }, backend.DownstreamErrorf("invalid channel path format") } orgId, err := strconv.ParseInt(pathParts[len(pathParts)-1], 10, 64) if err != nil { return &backend.SubscribeStreamResponse{ Status: backend.SubscribeStreamStatusNotFound, - }, fmt.Errorf("unable to determine orgId from request") + }, backend.DownstreamErrorf("unable to determine orgId from request") } pluginCfg := backend.PluginConfigFromContext(ctx) if orgId != pluginCfg.OrgID { return &backend.SubscribeStreamResponse{ Status: backend.SubscribeStreamStatusPermissionDenied, - }, fmt.Errorf("invalid orgId supplied in request") + }, backend.DownstreamErrorf("invalid orgId supplied in request") } return &backend.SubscribeStreamResponse{ diff --git a/pkg/plugin/stream_test.go b/pkg/plugin/stream_test.go index f12c0ce..e5147e6 100644 --- a/pkg/plugin/stream_test.go +++ b/pkg/plugin/stream_test.go @@ -114,9 +114,10 @@ func TestMQTTDatasource_SubscribeStream_PathParsing(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Create context with matching org orgID := int64(456) // Default org for testing - if tt.expectedOrg == "789" { + switch tt.expectedOrg { + case "789": orgID = 789 - } else if tt.expectedOrg == "123" { + case "123": orgID = 123 }