Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.akto.action;

import com.akto.dao.ApiCollectionsDao;
import com.akto.dao.GuardrailPoliciesDao;
import com.akto.dao.context.Context;
import com.akto.dto.ApiCollection;
import com.akto.dto.GuardrailPolicies;
import com.akto.dto.User;
import com.akto.log.LoggerMaker;
Expand All @@ -10,6 +12,7 @@
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.model.Updates;
import org.apache.commons.lang3.StringUtils;

import lombok.Getter;
import lombok.Setter;
Expand All @@ -20,6 +23,7 @@
import java.util.List;
import java.util.Map;


public class GuardrailPoliciesAction extends UserAction {
private static final LoggerMaker loggerMaker = new LoggerMaker(GuardrailPoliciesAction.class, LogDb.DASHBOARD);

Expand All @@ -46,6 +50,11 @@ public String fetchGuardrailPolicies() {
this.guardrailPolicies = GuardrailPoliciesDao.instance.findAllSortedByCreatedTimestamp(0, 20);
this.total = GuardrailPoliciesDao.instance.getTotalCount();

// Populate basePrompt for policies with autoDetect enabled
for (GuardrailPolicies policy : guardrailPolicies) {
populateBasePromptIfNeeded(policy);
}

loggerMaker.info("Fetched " + guardrailPolicies.size() + " guardrail policies out of " + total + " total");

return SUCCESS.toUpperCase();
Expand All @@ -55,6 +64,53 @@ public String fetchGuardrailPolicies() {
}
}

/**
* Populates basePrompt in basePromptRule if:
* 1. basePromptRule exists and is enabled
* 2. autoDetect is true
* 3. basePrompt is not already set (or is empty)
* 4. There are selected agent servers
*
* Fetches detectedBasePrompt from the first selected agent collection
*/
private void populateBasePromptIfNeeded(GuardrailPolicies policy) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Duplicate logic in cyborg as well, can think of keeping it in some common place like libs/dao/src/main/java/com/akto/util
  2. Also, we are not showing the auto detected base prompt in UI. Why do we need this then?

try {
GuardrailPolicies.BasePromptRule basePromptRule = policy.getBasePromptRule();
if (basePromptRule == null || !basePromptRule.isEnabled() || !basePromptRule.isAutoDetect()) {
return;
}

// If basePrompt is already set, use it
if (StringUtils.isNotBlank(basePromptRule.getBasePrompt())) {
return;
}

// Get selected agent servers (prefer V2 format, fallback to old format)
List<GuardrailPolicies.SelectedServer> agentServers = policy.getEffectiveSelectedAgentServers();
if (agentServers == null || agentServers.isEmpty()) {
return;
}

// Try to fetch detected base prompt from the first selected agent collection
try {
int firstAgentCollectionId = Integer.parseInt(agentServers.get(0).getId());
ApiCollection agentCollection = ApiCollectionsDao.instance.getMeta(firstAgentCollectionId);

if (agentCollection != null && StringUtils.isNotBlank(agentCollection.getDetectedBasePrompt())) {
basePromptRule.setBasePrompt(agentCollection.getDetectedBasePrompt());
loggerMaker.debug("Populated basePrompt from collection " + firstAgentCollectionId +
" for policy: " + policy.getName());
}
} catch (NumberFormatException e) {
loggerMaker.debug("Invalid agent collection ID format: " + agentServers.get(0).getId());
} catch (Exception e) {
loggerMaker.debug("Error fetching detected base prompt for policy " + policy.getName() + ": " + e.getMessage());
}
} catch (Exception e) {
loggerMaker.debug("Error populating base prompt: " + e.getMessage());
}
}

public String createGuardrailPolicy() {
try {
User user = getSUser();
Expand Down Expand Up @@ -99,6 +155,7 @@ public String createGuardrailPolicy() {
updates.add(Updates.set("regexPatternsV2", policy.getRegexPatternsV2()));
updates.add(Updates.set("contentFiltering", policy.getContentFiltering()));
updates.add(Updates.set("llmRule", policy.getLlmRule()));
updates.add(Updates.set("basePromptRule", policy.getBasePromptRule()));
updates.add(Updates.set("selectedMcpServers", policy.getSelectedMcpServers()));
updates.add(Updates.set("selectedAgentServers", policy.getSelectedAgentServers()));
updates.add(Updates.set("selectedMcpServersV2", policy.getSelectedMcpServersV2()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,8 @@ function GuardrailPolicies() {
contentFiltering: guardrailData.contentFilters || {},
// Add LLM policy if present
...(guardrailData.llmRule ? { llmRule: guardrailData.llmRule } : {}),
// Add Base Prompt Rule if present
...(guardrailData.basePromptRule ? { basePromptRule: guardrailData.basePromptRule } : {}),
applyOnResponse: guardrailData.applyOnResponse || false,
applyOnRequest: guardrailData.applyOnRequest || false,
url: guardrailData.url || '',
Expand Down
Loading
Loading