Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions pkg/mqtt/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mqtt

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand All @@ -10,14 +11,15 @@ import (
"time"

paho "github.com/eclipse/paho.mqtt.golang"
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
)

type Client interface {
GetTopic(string) (*Topic, bool)
IsConnected() bool
Subscribe(string) *Topic
Unsubscribe(string)
Subscribe(string, log.Logger) *Topic
Unsubscribe(string, log.Logger)
Dispose()
}

Expand All @@ -37,7 +39,8 @@ type client struct {
topics TopicMap
}

func NewClient(o Options) (Client, error) {
func NewClient(ctx context.Context, o Options) (Client, error) {
logger := log.DefaultLogger.FromContext(ctx)
opts := paho.NewClientOptions()

opts.AddBroker(o.URI)
Expand All @@ -63,7 +66,7 @@ func NewClient(o Options) (Client, error) {
if o.TLSClientCert != "" || o.TLSClientKey != "" {
cert, err := tls.X509KeyPair([]byte(o.TLSClientCert), []byte(o.TLSClientKey))
if err != nil {
return nil, fmt.Errorf("failed to setup TLSClientCert: %w", err)
return nil, backend.DownstreamErrorf("failed to setup TLSClientCert: %w", err)
}

tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
Expand All @@ -82,17 +85,17 @@ func NewClient(o Options) (Client, error) {
opts.SetCleanSession(false)
opts.SetMaxReconnectInterval(10 * time.Second)
opts.SetConnectionLostHandler(func(c paho.Client, err error) {
log.DefaultLogger.Error("MQTT Connection lost", "error", err)
logger.Warn("MQTT Connection lost", "error", err)
})
opts.SetReconnectingHandler(func(c paho.Client, options *paho.ClientOptions) {
log.DefaultLogger.Debug("MQTT Reconnecting")
logger.Debug("MQTT Reconnecting")
})

log.DefaultLogger.Info("MQTT Connecting", "clientID", clientID)
logger.Info("MQTT Connecting", "clientID", clientID)

pahoClient := paho.NewClient(opts)
if token := pahoClient.Connect(); token.Wait() && token.Error() != nil {
return nil, fmt.Errorf("error connecting to MQTT broker: %s", token.Error())
return nil, backend.DownstreamErrorf("error connecting to MQTT broker: %s", token.Error())
}

return &client{
Expand All @@ -117,20 +120,20 @@ func (c *client) GetTopic(reqPath string) (*Topic, bool) {
return c.topics.Load(reqPath)
}

func (c *client) Subscribe(reqPath string) *Topic {
func (c *client) Subscribe(reqPath string, logger log.Logger) *Topic {
// Check if there's already a topic with this exact key (reqPath)
if existingTopic, ok := c.topics.Load(reqPath); ok {
return existingTopic
}

chunks := strings.Split(reqPath, "/")
if len(chunks) < 2 {
log.DefaultLogger.Error("Invalid path", "path", reqPath)
logger.Error("Invalid path", "path", reqPath)
return nil
}
interval, err := time.ParseDuration(chunks[0])
if err != nil {
log.DefaultLogger.Error("Invalid interval", "path", reqPath, "interval", chunks[0])
logger.Error("Invalid interval", "path", reqPath, "interval", chunks[0])
return nil
}

Expand All @@ -145,27 +148,27 @@ func (c *client) Subscribe(reqPath string) *Topic {
Interval: interval,
}

topic, err := decodeTopic(t.Path)
topic, err := decodeTopic(t.Path, logger)
if err != nil {
log.DefaultLogger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", err)
logger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", backend.DownstreamError(err))
return nil
}

log.DefaultLogger.Debug("Subscribing to MQTT topic", "topic", topic)
logger.Debug("Subscribing to MQTT topic", "topic", topic)

if token := c.client.Subscribe(topic, 0, func(_ paho.Client, m paho.Message) {
// by wrapping HandleMessage we can directly get the correct topicPath for the incoming topic
// and don't need to regex it against + and #.
c.HandleMessage(topicPath, []byte(m.Payload()))
}); token.Wait() && token.Error() != nil {
log.DefaultLogger.Error("Error subscribing to MQTT topic", "topic", topic, "error", token.Error())
logger.Error("Error subscribing to MQTT topic", "topic", topic, "error", backend.DownstreamError(token.Error()))
}
// Store the topic using reqPath as the key (which includes streaming key)
c.topics.Map.Store(reqPath, t)
return t
}

func (c *client) Unsubscribe(reqPath string) {
func (c *client) Unsubscribe(reqPath string, logger log.Logger) {
t, ok := c.GetTopic(reqPath)
if !ok {
return
Expand All @@ -178,16 +181,16 @@ func (c *client) Unsubscribe(reqPath string) {
return
}

log.DefaultLogger.Debug("Unsubscribing from MQTT topic", "topic", t.Path)
logger.Debug("Unsubscribing from MQTT topic", "topic", t.Path)

topic, err := decodeTopic(t.Path)
topic, err := decodeTopic(t.Path, logger)
if err != nil {
log.DefaultLogger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", err)
logger.Error("Error decoding MQTT topic name", "encodedTopic", t.Path, "error", backend.DownstreamError(err))
return
}

if token := c.client.Unsubscribe(topic); token.Wait() && token.Error() != nil {
log.DefaultLogger.Error("Error unsubscribing from MQTT topic", "topic", t.Path, "error", token.Error())
logger.Error("Error unsubscribing from MQTT topic", "topic", t.Path, "error", backend.DownstreamError(token.Error()))
}
}

Expand Down
16 changes: 8 additions & 8 deletions pkg/mqtt/framer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type framer struct {
fieldMap map[string]int
}

func (df *framer) next() error {
func (df *framer) next(logger log.Logger) error {
switch df.iterator.WhatIsNext() {
case jsoniter.StringValue:
v := df.iterator.ReadString()
Expand All @@ -29,7 +29,7 @@ func (df *framer) next() error {
v := df.iterator.ReadBool()
df.addValue(data.FieldTypeNullableBool, &v)
case jsoniter.NilValue:
df.addNil()
df.addNil(logger)
df.iterator.ReadNil()
case jsoniter.ArrayValue:
df.addValue(data.FieldTypeJSON, json.RawMessage(df.iterator.SkipAndReturnBytes()))
Expand All @@ -42,7 +42,7 @@ func (df *framer) next() error {
for fname := df.iterator.ReadObject(); fname != ""; fname = df.iterator.ReadObject() {
if size == 0 {
df.path = append(df.path, fname)
if err := df.next(); err != nil {
if err := df.next(logger); err != nil {
return err
}
}
Expand All @@ -61,12 +61,12 @@ func (df *framer) key() string {
return strings.Join(df.path, "")
}

func (df *framer) addNil() {
func (df *framer) addNil(logger log.Logger) {
if idx, ok := df.fieldMap[df.key()]; ok {
df.fields[idx].Set(0, nil)
return
}
log.DefaultLogger.Debug("nil value for unknown field", "key", df.key())
logger.Debug("nil value for unknown field", "key", df.key())
}

func (df *framer) addValue(fieldType data.FieldType, v interface{}) {
Expand Down Expand Up @@ -96,7 +96,7 @@ func newFramer() *framer {
return df
}

func (df *framer) toFrame(messages []Message) (*data.Frame, error) {
func (df *framer) toFrame(messages []Message, logger log.Logger) (*data.Frame, error) {
// clear the data in the fields
for _, field := range df.fields {
for i := field.Len() - 1; i >= 0; i-- {
Expand All @@ -106,10 +106,10 @@ func (df *framer) toFrame(messages []Message) (*data.Frame, error) {

for _, message := range messages {
df.iterator = jsoniter.ParseBytes(jsoniter.ConfigDefault, message.Value)
err := df.next()
err := df.next(logger)
if err != nil {
// If JSON parsing fails, treat the raw bytes as a string value
log.DefaultLogger.Debug("JSON parsing failed, treating as raw string", "error", err, "value", string(message.Value))
logger.Debug("JSON parsing failed, treating as raw string", "error", err, "value", string(message.Value))
rawValue := string(message.Value)
df.addValue(data.FieldTypeNullableString, &rawValue)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/mqtt/framer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"
"time"

"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/grafana-plugin-sdk-go/experimental"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -86,7 +87,7 @@ func runTest(t *testing.T, name string, values ...any) {
for i, v := range values {
messages = append(messages, Message{Timestamp: timestamp.Add(time.Duration(i) * time.Minute), Value: toJSON(v)})
}
frame, err := f.toFrame(messages)
frame, err := f.toFrame(messages, log.DefaultLogger)
require.NoError(t, err)
require.NotNil(t, frame)
experimental.CheckGoldenJSONFrame(t, "testdata", name, frame, update)
Expand All @@ -100,7 +101,7 @@ func runRawTest(t *testing.T, name string, rawValues ...[]byte) {
for i, v := range rawValues {
messages = append(messages, Message{Timestamp: timestamp.Add(time.Duration(i) * time.Minute), Value: v})
}
frame, err := f.toFrame(messages)
frame, err := f.toFrame(messages, log.DefaultLogger)
require.NoError(t, err)
require.NotNil(t, frame)
experimental.CheckGoldenJSONFrame(t, "testdata", name, frame, update)
Expand Down
12 changes: 6 additions & 6 deletions pkg/mqtt/topic.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ func (t *Topic) Key() string {
}

// ToDataFrame converts the topic to a data frame.
func (t *Topic) ToDataFrame() (*data.Frame, error) {
func (t *Topic) ToDataFrame(logger log.Logger) (*data.Frame, error) {
if t.framer == nil {
t.framer = newFramer()
}
return t.framer.toFrame(t.Messages)
return t.framer.toFrame(t.Messages, logger)
}

// TopicMap is a thread-safe map of topics
Expand All @@ -58,7 +58,7 @@ func (tm *TopicMap) Load(key string) (*Topic, bool) {

// AddMessage adds a message to the topic for the given path.
func (tm *TopicMap) AddMessage(path string, message Message) {
tm.Map.Range(func(key, t any) bool {
tm.Range(func(key, t any) bool {
topic, ok := t.(*Topic)
if !ok {
return false
Expand All @@ -75,7 +75,7 @@ func (tm *TopicMap) AddMessage(path string, message Message) {
func (tm *TopicMap) HasSubscription(path string) bool {
found := false

tm.Map.Range(func(key, t any) bool {
tm.Range(func(key, t any) bool {
topic, ok := t.(*Topic)
if !ok {
return true // this shouldn't happen, but continue iterating
Expand Down Expand Up @@ -110,10 +110,10 @@ func (tm *TopicMap) Delete(key string) {
//
// To comply with these restrictions, the topic is encoded using URL-safe base64
// encoding. (RFC 4648; 5. Base 64 Encoding with URL and Filename Safe Alphabet)
func decodeTopic(topicPath string) (string, error) {
func decodeTopic(topicPath string, logger log.Logger) (string, error) {
chunks := strings.Split(topicPath, "/")
topic := chunks[0]
log.DefaultLogger.Debug("Decoding MQTT topic name", "encodedTopic", topic)
logger.Debug("Decoding MQTT topic name", "encodedTopic", topic)
decoded, err := base64.RawURLEncoding.DecodeString(topic)

if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/plugin/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ var (
)

// NewMQTTDatasource creates a new datasource instance.
func NewMQTTInstance(_ context.Context, s backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
func NewMQTTInstance(ctx context.Context, s backend.DataSourceInstanceSettings) (instancemgmt.Instance, error) {
settings, err := getDatasourceSettings(s)
if err != nil {
return nil, err
}

client, err := mqtt.NewClient(*settings)
client, err := mqtt.NewClient(ctx, *settings)
if err != nil {
return nil, err
}
Expand Down
20 changes: 7 additions & 13 deletions pkg/plugin/datasource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana-plugin-sdk-go/backend/log"
"github.com/grafana/mqtt-datasource/pkg/mqtt"
"github.com/grafana/mqtt-datasource/pkg/plugin"
"github.com/stretchr/testify/require"
Expand All @@ -13,8 +14,7 @@ import (
func TestCheckHealthHandler(t *testing.T) {
t.Run("HealthStatusOK when can connect", func(t *testing.T) {
ds := plugin.NewMQTTDatasource(&fakeMQTTClient{
connected: true,
subscribed: false,
connected: true,
}, "xyz")

res, _ := ds.CheckHealth(
Expand All @@ -28,8 +28,7 @@ func TestCheckHealthHandler(t *testing.T) {

t.Run("HealthStatusError when disconnected", func(t *testing.T) {
ds := plugin.NewMQTTDatasource(&fakeMQTTClient{
connected: false,
subscribed: false,
connected: false,
}, "xyz")

res, _ := ds.CheckHealth(
Expand All @@ -43,8 +42,7 @@ func TestCheckHealthHandler(t *testing.T) {
}

type fakeMQTTClient struct {
connected bool
subscribed bool
connected bool
}

func (c *fakeMQTTClient) GetTopic(_ string) (*mqtt.Topic, bool) {
Expand All @@ -55,10 +53,6 @@ func (c *fakeMQTTClient) IsConnected() bool {
return c.connected
}

func (c *fakeMQTTClient) IsSubscribed(_ string) bool {
return c.subscribed
}

func (c *fakeMQTTClient) Subscribe(_ string) *mqtt.Topic { return nil }
func (c *fakeMQTTClient) Unsubscribe(_ string) {}
func (c *fakeMQTTClient) Dispose() {}
func (c *fakeMQTTClient) Subscribe(_ string, _ log.Logger) *mqtt.Topic { return nil }
func (c *fakeMQTTClient) Unsubscribe(_ string, _ log.Logger) {}
func (c *fakeMQTTClient) Dispose() {}
7 changes: 2 additions & 5 deletions pkg/plugin/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package plugin
import (
"context"
"encoding/json"
"fmt"
"path"

"github.com/grafana/grafana-plugin-sdk-go/backend"
Expand All @@ -30,13 +29,11 @@ func (ds *MQTTDatasource) query(query backend.DataQuery) backend.DataResponse {
)

if err := json.Unmarshal(query.JSON, &t); err != nil {
response.Error = err
return response
return backend.ErrorResponseWithErrorSource(backend.DownstreamErrorf("failed to unmarshal query: %w", err))
}

if t.Path == "" {
response.Error = fmt.Errorf("topic path is required")
return response
return backend.ErrorResponseWithErrorSource(backend.DownstreamErrorf("topic path is required"))
}

t.Interval = query.Interval
Expand Down
Loading
Loading