Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elastic.ccm;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.support.ActionTestUtils;
import org.elasticsearch.action.support.TestPlainActionFuture;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;

public class CCMCacheTests extends ESSingleNodeTestCase {

private static final TimeValue TIMEOUT = TimeValue.THIRTY_SECONDS;

private CCMCache ccmCache;
private CCMPersistentStorageService ccmPersistentStorageService;

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
}

@Before
public void createComponents() {
ccmCache = node().injector().getInstance(CCMCache.class);
ccmPersistentStorageService = node().injector().getInstance(CCMPersistentStorageService.class);
}

@Override
protected boolean resetNodeAfterTest() {
return true;
}

@After
public void clearCacheAndIndex() {
try {
indicesAdmin().prepareDelete(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
} catch (ResourceNotFoundException e) {
// mission complete!
}
}

public void testCacheHit() throws IOException {
var expectedCcmModel = storeCcm();
var actualCcmModel = getFromCache();
assertThat(actualCcmModel, equalTo(expectedCcmModel));
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(getFromCache(), sameInstance(actualCcmModel));
assertThat(ccmCache.stats().getHits(), equalTo(1L));
}

private CCMModel storeCcm() throws IOException {
var ccmModel = CCMModel.fromXContentBytes(new BytesArray("""
{
"api_key": "test_key"
}
"""));
var listener = new TestPlainActionFuture<Void>();
ccmPersistentStorageService.store(ccmModel, listener);
listener.actionGet(TIMEOUT);
return ccmModel;
}

private CCMModel getFromCache() {
var listener = new TestPlainActionFuture<CCMModel>();
ccmCache.get(listener);
return listener.actionGet(TIMEOUT);
}

public void testCacheInvalidate() throws Exception {
var expectedCcmModel = storeCcm();
var actualCcmModel = getFromCache();
assertThat(actualCcmModel, equalTo(expectedCcmModel));
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
assertThat(ccmCache.cacheCount(), equalTo(1));

var listener = new TestPlainActionFuture<Void>();
ccmCache.invalidate(listener);
listener.actionGet(TIMEOUT);

assertThat(getFromCache(), not(sameInstance(actualCcmModel)));
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(ccmCache.stats().getMisses(), equalTo(2L));
assertThat(ccmCache.stats().getEvictions(), equalTo(1L));
assertThat(ccmCache.cacheCount(), equalTo(1));
}

public void testEmptyInvalidate() throws InterruptedException {
var latch = new CountDownLatch(1);
ccmCache.invalidate(ActionTestUtils.assertNoFailureListener(success -> latch.countDown()));
assertTrue(latch.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS));

assertThat(ccmCache.stats().getEvictions(), equalTo(0L));
assertThat(ccmCache.cacheCount(), equalTo(0));
}

private boolean isPresent() {
var listener = new TestPlainActionFuture<Boolean>();
ccmCache.isEnabled(listener);
return listener.actionGet(TIMEOUT);
}

public void testIsEnabled() throws IOException {
storeCcm();

getFromCache();
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));

assertTrue(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(1L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
}

public void testIsDisabledWithMissingIndex() {
assertFalse(isPresent());
}

public void testIsDisabledWithPresentIndex() {
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);
assertFalse(isPresent());
}

public void testIsDisabledWithCacheHit() {
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);

assertFalse(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));

assertFalse(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(1L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
}
Copy link
Contributor

@jonathan-buttner jonathan-buttner Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed it but do we have a test for the else case in this block:

CCMCache::get

if (cachedEntry != null && cachedEntry.enabled()) {
            listener.onResponse(cachedEntry.ccmModel());
        } else {

Could we add a test that when the internal entry is in the disabled state (but present and not null) that we get a cache miss aka hit the else case?


public void testIsDisabledRefreshedWithGet() throws IOException {
indicesAdmin().prepareCreate(CCMIndex.INDEX_NAME).execute().actionGet(TIMEOUT);

assertFalse(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(0L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));

var expectedCcmModel = storeCcm();

assertFalse(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(1L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));

var actualCcmModel = getFromCache();
assertThat(actualCcmModel, equalTo(expectedCcmModel));
assertThat(ccmCache.stats().getHits(), equalTo(2L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));

assertTrue(isPresent());
assertThat(ccmCache.stats().getHits(), equalTo(3L));
assertThat(ccmCache.stats().getMisses(), equalTo(1L));
}
}
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
exports org.elasticsearch.xpack.inference.registry;
exports org.elasticsearch.xpack.inference.rest;
exports org.elasticsearch.xpack.inference.services;
exports org.elasticsearch.xpack.inference.services.elastic.ccm;
exports org.elasticsearch.xpack.inference;
exports org.elasticsearch.xpack.inference.action.task;
exports org.elasticsearch.xpack.inference.telemetry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ public class InferenceFeatures implements FeatureSpecification {
private static final NodeFeature SEMANTIC_TEXT_FIELDS_CHUNKS_FORMAT = new NodeFeature("semantic_text.fields_chunks_format");

public static final NodeFeature INFERENCE_ENDPOINT_CACHE = new NodeFeature("inference.endpoint.cache");
public static final NodeFeature INFERENCE_CCM_CACHE = new NodeFeature("inference.ccm.cache");
public static final NodeFeature SEARCH_USAGE_EXTENDED_DATA = new NodeFeature("search.usage.extended_data");

@Override
public Set<NodeFeature> getFeatures() {
return Set.of(INFERENCE_ENDPOINT_CACHE);
return Set.of(INFERENCE_ENDPOINT_CACHE, INFERENCE_CCM_CACHE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor;
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMCache;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMPersistentStorageService;
Expand Down Expand Up @@ -276,7 +277,8 @@ public List<ActionHandler> getActions() {
new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class),
new ActionHandler(GetCCMConfigurationAction.INSTANCE, TransportGetCCMConfigurationAction.class),
new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class),
new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class)
new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class),
new ActionHandler(CCMCache.ClearCCMCacheAction.INSTANCE, CCMCache.ClearCCMCacheAction.class)
);
}

Expand Down Expand Up @@ -453,7 +455,19 @@ public Collection<?> createComponents(PluginServices services) {
private Collection<?> createCCMComponents(PluginServices services) {
ccmFeature.set(new CCMFeature(settings));
var ccmPersistentStorageService = new CCMPersistentStorageService(services.client());
return List.of(new CCMService(ccmPersistentStorageService), ccmFeature.get(), ccmPersistentStorageService);
return List.of(
new CCMService(ccmPersistentStorageService),
ccmFeature.get(),
ccmPersistentStorageService,
new CCMCache(
ccmPersistentStorageService,
services.clusterService(),
settings,
services.featureService(),
services.projectResolver(),
services.client()
)
);
}

@Override
Expand Down Expand Up @@ -653,6 +667,7 @@ public static Set<Setting<?>> getInferenceSettings() {
settings.addAll(InferenceEndpointRegistry.getSettingsDefinitions());
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
settings.addAll(CCMSettings.getSettingsDefinitions());
settings.addAll(CCMCache.getSettingsDefinitions());
return Collections.unmodifiableSet(settings);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.common;

import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
import org.elasticsearch.action.support.nodes.BaseNodesRequest;
import org.elasticsearch.action.support.nodes.BaseNodesResponse;
import org.elasticsearch.action.support.nodes.TransportNodesAction;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractTransportRequest;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.List;
import java.util.Map;

/**
* Broadcasts a {@link Writeable} to all nodes and responds with an empty object.
* This is intended to be used as a fire-and-forget style, where responses and failures are logged and swallowed.
*/
public abstract class BroadcastMessageAction<Message extends Writeable> extends TransportNodesAction<
BroadcastMessageAction.Request<Message>,
BroadcastMessageAction.Response,
BroadcastMessageAction.NodeRequest<Message>,
BroadcastMessageAction.NodeResponse,
Void> {

protected BroadcastMessageAction(
String actionName,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Writeable.Reader<Message> messageReader
) {
super(
actionName,
clusterService,
transportService,
actionFilters,
in -> new NodeRequest<>(messageReader.read(in)),
clusterService.threadPool().executor(ThreadPool.Names.MANAGEMENT)
);
}

@Override
protected Response newResponse(Request<Message> request, List<NodeResponse> nodeResponses, List<FailedNodeException> failures) {
return new Response(clusterService.getClusterName(), nodeResponses, failures);
}

@Override
protected NodeRequest<Message> newNodeRequest(Request<Message> request) {
return new NodeRequest<>(request.message);
}

@Override
protected NodeResponse newNodeResponse(StreamInput in, DiscoveryNode node) throws IOException {
return new NodeResponse(in, node);
}

@Override
protected NodeResponse nodeOperation(NodeRequest<Message> request, Task task) {
receiveMessage(request.message);
return new NodeResponse(transportService.getLocalNode());
}

/**
* This method is run on each node in the cluster.
*/
protected abstract void receiveMessage(Message message);

public static <T extends Writeable> Request<T> request(T message, TimeValue timeout) {
return new Request<>(message, timeout);
}

public static class Request<Message extends Writeable> extends BaseNodesRequest {
private final Message message;

protected Request(Message message, TimeValue timeout) {
super(Strings.EMPTY_ARRAY);
this.message = message;
setTimeout(timeout);
}
}

public static class Response extends BaseNodesResponse<NodeResponse> {

protected Response(ClusterName clusterName, List<NodeResponse> nodes, List<FailedNodeException> failures) {
super(clusterName, nodes, failures);
}

@Override
protected List<NodeResponse> readNodesFrom(StreamInput in) throws IOException {
return in.readCollectionAsList(NodeResponse::new);
}

@Override
protected void writeNodesTo(StreamOutput out, List<NodeResponse> nodes) {
TransportAction.localOnly();
}
}

public static class NodeRequest<Message extends Writeable> extends AbstractTransportRequest {
private final Message message;

private NodeRequest(Message message) {
this.message = message;
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "broadcasted message to an individual node", parentTaskId, headers);
}
}

public static class NodeResponse extends BaseNodeResponse {
protected NodeResponse(StreamInput in) throws IOException {
super(in);
}

protected NodeResponse(StreamInput in, DiscoveryNode node) throws IOException {
super(in, node);
}

protected NodeResponse(DiscoveryNode node) {
super(node);
}
}
}
Loading