Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.akto.action;

import com.akto.dao.ApiCollectionsDao;
import com.akto.dao.GuardrailPoliciesDao;
import com.akto.dao.context.Context;
import com.akto.dto.ApiCollection;
import com.akto.dto.GuardrailPolicies;
import com.akto.dto.User;
import com.akto.log.LoggerMaker;
Expand All @@ -10,6 +12,7 @@
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.model.Updates;
import org.apache.commons.lang3.StringUtils;

import lombok.Getter;
import lombok.Setter;
Expand All @@ -20,6 +23,7 @@
import java.util.List;
import java.util.Map;


public class GuardrailPoliciesAction extends UserAction {
private static final LoggerMaker loggerMaker = new LoggerMaker(GuardrailPoliciesAction.class, LogDb.DASHBOARD);

Expand All @@ -43,6 +47,11 @@ public String fetchGuardrailPolicies() {
this.guardrailPolicies = GuardrailPoliciesDao.instance.findAllSortedByCreatedTimestamp(0, 20);
this.total = GuardrailPoliciesDao.instance.getTotalCount();

// Populate basePrompt for policies with autoDetect enabled
for (GuardrailPolicies policy : guardrailPolicies) {
populateBasePromptIfNeeded(policy);
}

loggerMaker.info("Fetched " + guardrailPolicies.size() + " guardrail policies out of " + total + " total");

return SUCCESS.toUpperCase();
Expand All @@ -52,6 +61,53 @@ public String fetchGuardrailPolicies() {
}
}

/**
* Populates basePrompt in basePromptRule if:
* 1. basePromptRule exists and is enabled
* 2. autoDetect is true
* 3. basePrompt is not already set (or is empty)
* 4. There are selected agent servers
*
* Fetches detectedBasePrompt from the first selected agent collection
*/
private void populateBasePromptIfNeeded(GuardrailPolicies policy) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Duplicate logic in cyborg as well, can think of keeping it in some common place like libs/dao/src/main/java/com/akto/util
  2. Also, we are not showing the auto detected base prompt in UI. Why do we need this then?

try {
GuardrailPolicies.BasePromptRule basePromptRule = policy.getBasePromptRule();
if (basePromptRule == null || !basePromptRule.isEnabled() || !basePromptRule.isAutoDetect()) {
return;
}

// If basePrompt is already set, use it
if (StringUtils.isNotBlank(basePromptRule.getBasePrompt())) {
return;
}

// Get selected agent servers (prefer V2 format, fallback to old format)
List<GuardrailPolicies.SelectedServer> agentServers = policy.getEffectiveSelectedAgentServers();
if (agentServers == null || agentServers.isEmpty()) {
return;
}

// Try to fetch detected base prompt from the first selected agent collection
try {
int firstAgentCollectionId = Integer.parseInt(agentServers.get(0).getId());
ApiCollection agentCollection = ApiCollectionsDao.instance.getMeta(firstAgentCollectionId);

if (agentCollection != null && StringUtils.isNotBlank(agentCollection.getDetectedBasePrompt())) {
basePromptRule.setBasePrompt(agentCollection.getDetectedBasePrompt());
loggerMaker.debug("Populated basePrompt from collection " + firstAgentCollectionId +
" for policy: " + policy.getName());
}
} catch (NumberFormatException e) {
loggerMaker.debug("Invalid agent collection ID format: " + agentServers.get(0).getId());
} catch (Exception e) {
loggerMaker.debug("Error fetching detected base prompt for policy " + policy.getName() + ": " + e.getMessage());
}
} catch (Exception e) {
loggerMaker.debug("Error populating base prompt: " + e.getMessage());
}
}

public String createGuardrailPolicy() {
try {
User user = getSUser();
Expand Down Expand Up @@ -96,6 +152,7 @@ public String createGuardrailPolicy() {
updates.add(Updates.set("regexPatternsV2", policy.getRegexPatternsV2()));
updates.add(Updates.set("contentFiltering", policy.getContentFiltering()));
updates.add(Updates.set("llmRule", policy.getLlmRule()));
updates.add(Updates.set("basePromptRule", policy.getBasePromptRule()));
updates.add(Updates.set("selectedMcpServers", policy.getSelectedMcpServers()));
updates.add(Updates.set("selectedAgentServers", policy.getSelectedAgentServers()));
updates.add(Updates.set("selectedMcpServersV2", policy.getSelectedMcpServersV2()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ function GuardrailPolicies() {
contentFiltering: guardrailData.contentFilters || {},
// Add LLM policy if present
...(guardrailData.llmRule ? { llmRule: guardrailData.llmRule } : {}),
// Add Base Prompt Rule if present
...(guardrailData.basePromptRule ? { basePromptRule: guardrailData.basePromptRule } : {}),
applyOnResponse: guardrailData.applyOnResponse || false,
applyOnRequest: guardrailData.applyOnRequest || false,
url: guardrailData.url || '',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
const [enableLlmRule, setEnableLlmRule] = useState(false);
const [llmConfidenceScore, setLlmConfidenceScore] = useState(0.5);

// Step 6.5: Base Prompt Rule
const [enableBasePromptRule, setEnableBasePromptRule] = useState(false);
const [basePrompt, setBasePrompt] = useState("");
const [basePromptAutoDetect, setBasePromptAutoDetect] = useState(true);
const [basePromptConfidenceScore, setBasePromptConfidenceScore] = useState(0.5);

// Step 7: URL and Confidence Score
const [url, setUrl] = useState("");
const [confidenceScore, setConfidenceScore] = useState(25); // Start with 25 (first checkpoint)
Expand Down Expand Up @@ -162,12 +168,18 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
},
{
number: 7,
title: "Base Prompt Rule",
optional: true,
summary: enableBasePromptRule ? `Enabled${basePromptAutoDetect ? ' (Auto-detect)' : ''}${basePrompt ? ` - ${basePrompt.substring(0, 30)}${basePrompt.length > 30 ? '...' : ''}` : ''}, Confidence: ${basePromptConfidenceScore.toFixed(2)}` : null
},
{
number: 8,
title: "URL and Confidence Score",
optional: true,
summary: url ? `URL: ${url.substring(0, 30)}${url.length > 30 ? '...' : ''}, Confidence: ${confidenceScore}` : null
},
{
number: 8,
number: 9,
title: "Server and application settings",
optional: false,
summary: (selectedMcpServers.length > 0 || selectedAgentServers.length > 0)
Expand Down Expand Up @@ -289,6 +301,10 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
setLlmRule("");
setEnableLlmRule(false);
setLlmConfidenceScore(0.5);
setEnableBasePromptRule(false);
setBasePrompt("");
setBasePromptAutoDetect(true);
setBasePromptConfidenceScore(0.5);
setUrl("");
setConfidenceScore(25);
setUrlError("");
Expand Down Expand Up @@ -359,6 +375,19 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
setLlmConfidenceScore(0.5);
}

// Base Prompt Rule
if (policy.basePromptRule) {
setEnableBasePromptRule(policy.basePromptRule.enabled || false);
setBasePrompt(policy.basePromptRule.basePrompt || "");
setBasePromptAutoDetect(policy.basePromptRule.autoDetect !== undefined ? policy.basePromptRule.autoDetect : true);
setBasePromptConfidenceScore(policy.basePromptRule.confidenceScore !== undefined ? policy.basePromptRule.confidenceScore : 0.5);
} else {
setEnableBasePromptRule(false);
setBasePrompt("");
setBasePromptAutoDetect(true);
setBasePromptConfidenceScore(0.5);
}

// URL and Confidence Score
setUrl(policy.url || "");
// Map existing confidence score to nearest checkpoint
Expand Down Expand Up @@ -406,7 +435,7 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
};

const handleSkipToServers = () => {
setCurrentStep(8);
setCurrentStep(9);
};

const handleSave = async () => {
Expand Down Expand Up @@ -465,6 +494,14 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
confidenceScore: llmConfidenceScore
}
} : {}),
...(enableBasePromptRule ? {
basePromptRule: {
enabled: true,
basePrompt: basePromptAutoDetect ? null : basePrompt.trim(), // Only save manual prompt, auto-detect will fetch on read
autoDetect: basePromptAutoDetect,
confidenceScore: basePromptConfidenceScore
}
} : {}),
url: url || null,
confidenceScore: confidenceScore,
selectedMcpServers: selectedMcpServers, // Old format (just IDs)
Expand Down Expand Up @@ -958,6 +995,62 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
);

const renderStep7 = () => (
<LegacyCard sectioned>
<VerticalStack gap="4">
<Text variant="headingMd">Base Prompt Rule</Text>
<Text variant="bodyMd" tone="subdued">
Configure a base prompt rule to check the intent of user input in agent prompts with placeholders like {`{var}`} or {`{}`}.
</Text>

<Checkbox
label="Enable base prompt rule"
checked={enableBasePromptRule}
onChange={setEnableBasePromptRule}
helpText="When enabled, the guardrail will analyze user inputs that fill placeholders in the base prompt."
/>

{enableBasePromptRule && (
<>
<Checkbox
label="Auto-detect base prompt from traffic"
checked={basePromptAutoDetect}
onChange={setBasePromptAutoDetect}
helpText="Automatically detect the base prompt pattern from agent traffic. If disabled, you must provide the base prompt manually."
/>

{!basePromptAutoDetect && (
<TextField
label="Base Prompt Template"
value={basePrompt}
onChange={setBasePrompt}
multiline={5}
placeholder="You are a helpful assistant. Answer the following question: {}"
helpText="Provide the base prompt template with placeholders using {} or {var_name} syntax."
/>
)}

<Box>
<Text variant="bodyMd" fontWeight="medium">Confidence Score: {basePromptConfidenceScore.toFixed(2)}</Text>
<Box paddingBlockStart="2">
<RangeSlider
label=""
value={basePromptConfidenceScore}
min={0}
max={1}
step={0.01}
output
onChange={setBasePromptConfidenceScore}
helpText="Set the confidence threshold (0-1). Higher values require more confidence to block content."
/>
</Box>
</Box>
</>
)}
</VerticalStack>
</LegacyCard>
);

const renderStep8 = () => (
<LegacyCard sectioned>
<VerticalStack gap="4">
<Text variant="headingMd">URL and Confidence Score</Text>
Expand Down Expand Up @@ -1001,7 +1094,7 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
</LegacyCard>
);

const renderStep8 = () => (
const renderStep9 = () => (
<LegacyCard sectioned>
<VerticalStack gap="4">
<Text variant="headingMd">Server and application settings</Text>
Expand Down Expand Up @@ -1069,6 +1162,7 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
case 6: return renderStep6();
case 7: return renderStep7();
case 8: return renderStep8();
case 9: return renderStep9();
default: return renderStep1();
}
};
Expand All @@ -1083,7 +1177,7 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
});
}

if (currentStep > 1 && currentStep < 8) {
if (currentStep > 1 && currentStep < 9) {
actions.push({
content: "Skip to Server settings",
onAction: handleSkipToServers
Expand All @@ -1094,14 +1188,14 @@ const CreateGuardrailModal = ({ isOpen, onClose, onSave, editingPolicy = null, i
};

const getPrimaryAction = () => {
if (currentStep === 8) {
if (currentStep === 9) {
return {
content: isEditMode ? "Update Guardrail" : "Create Guardrail",
onAction: handleSave,
loading: loading,
disabled: !name.trim() || !blockedMessage.trim() || urlError
};
} else if (currentStep < 8) {
} else if (currentStep < 9) {
return {
content: "Next",
onAction: handleNext,
Expand Down
11 changes: 11 additions & 0 deletions libs/dao/src/main/java/com/akto/dto/ApiCollection.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public class ApiCollection {
String registryStatus;
public static final String REGISTRY_STATUS = "registryStatus";

String detectedBasePrompt;
public static final String DETECTED_BASE_PROMPT = "detectedBasePrompt";

private static final List<String> ENV_KEYWORDS_WITH_DOT = Arrays.asList(
"staging", "preprod", "qa", "demo", "dev", "test", "svc",
"localhost", "local", "intranet", "lan", "example", "invalid",
Expand Down Expand Up @@ -467,4 +470,12 @@ public String getRegistryStatus() {
public void setRegistryStatus(String registryStatus) {
this.registryStatus = registryStatus;
}

public String getDetectedBasePrompt() {
return detectedBasePrompt;
}

public void setDetectedBasePrompt(String detectedBasePrompt) {
this.detectedBasePrompt = detectedBasePrompt;
}
}
23 changes: 22 additions & 1 deletion libs/dao/src/main/java/com/akto/dto/GuardrailPolicies.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class GuardrailPolicies {

private LLMRule llmRule;

// Step 6.5: Base Prompt Rule - for checking intent of user input in agent base prompts with placeholders
private BasePromptRule basePromptRule;

// Step 7: Server and application settings (old format - backward compatibility)
private List<String> selectedMcpServers;
private List<String> selectedAgentServers;
Expand Down Expand Up @@ -118,7 +121,7 @@ public GuardrailPolicies(String name, String description, String blockedMessage,
int updatedTimestamp, String createdBy, String updatedBy, String selectedCollection,
String selectedModel, List<DeniedTopic> deniedTopics, List<PiiType> piiTypes,
List<String> regexPatterns, List<RegexPattern> regexPatternsV2, Map<String, Object> contentFiltering,
LLMRule llmRule,
LLMRule llmRule, BasePromptRule basePromptRule,
List<String> selectedMcpServers, List<String> selectedAgentServers,
List<SelectedServer> selectedMcpServersV2, List<SelectedServer> selectedAgentServersV2,
boolean applyOnResponse, boolean applyOnRequest, String url, double confidenceScore, boolean active) {
Expand All @@ -138,6 +141,7 @@ public GuardrailPolicies(String name, String description, String blockedMessage,
this.regexPatternsV2 = regexPatternsV2;
this.contentFiltering = contentFiltering;
this.llmRule = llmRule;
this.basePromptRule = basePromptRule;
this.selectedMcpServers = selectedMcpServers;
this.selectedAgentServers = selectedAgentServers;
this.selectedMcpServersV2 = selectedMcpServersV2;
Expand Down Expand Up @@ -217,4 +221,21 @@ public LLMRule(boolean enabled, String userPrompt, double confidenceScore) {
this.confidenceScore = confidenceScore;
}
}

@Getter
@Setter
@NoArgsConstructor
public static class BasePromptRule {
private boolean enabled;
private String basePrompt; // Base prompt with placeholders like {var} or {}
private boolean autoDetect; // Whether to auto-detect base_prompt from traffic
private double confidenceScore;

public BasePromptRule(boolean enabled, String basePrompt, boolean autoDetect, double confidenceScore) {
this.enabled = enabled;
this.basePrompt = basePrompt;
this.autoDetect = autoDetect;
this.confidenceScore = confidenceScore;
}
}
}
Loading