Skip to content

Commit ef0030b

Browse files
committed
Add Qwen3Guard support for prompt guard (WIP)
Co-authored-by: Chen Wang <[email protected]> Co-authored-by: Yue Zhu <[email protected]> Signed-off-by: Yue Zhu <[email protected]>
1 parent c5f55f8 commit ef0030b

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

src/semantic-router/pkg/classification/classifier.go

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,22 @@ func (c *ModernBertJailbreakInitializer) Init(modelID string, useCPU bool, numCl
106106
return nil
107107
}
108108

109+
type Qwen3GuardJailbreakInitializer struct{}
110+
111+
func (c *Qwen3GuardJailbreakInitializer) Init(modelID string, useCPU bool, numClasses ...int) error {
112+
err := candle_binding.InitQwen3Guard(modelID)
113+
if err != nil {
114+
return err
115+
}
116+
logging.Infof("Initialized Qwen3Guard jailbreak classifier")
117+
return nil
118+
}
119+
109120
// createJailbreakInitializer creates the appropriate jailbreak initializer based on configuration
110-
func createJailbreakInitializer(useModernBERT bool) JailbreakInitializer {
121+
func createJailbreakInitializer(useModernBERT bool, useQwen3Guard bool) JailbreakInitializer {
122+
if useQwen3Guard {
123+
return &Qwen3GuardJailbreakInitializer{}
124+
}
111125
if useModernBERT {
112126
return &ModernBertJailbreakInitializer{}
113127
}
@@ -130,8 +144,74 @@ func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.Cla
130144
return candle_binding.ClassifyModernBertJailbreakText(text)
131145
}
132146

147+
type Qwen3GuardJailbreakInference struct{}
148+
149+
func (c *Qwen3GuardJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
150+
// Use Qwen3Guard to classify the text
151+
result, err := candle_binding.ClassifyPromptSafety(text)
152+
if err != nil {
153+
return candle_binding.ClassResult{}, fmt.Errorf("qwen3guard classification failed: %w", err)
154+
}
155+
156+
// Convert SafetyClassificationResult to ClassResult
157+
// Class 0 = safe/benign, Class 1 = jailbreak
158+
// Check if "Jailbreak" is in categories or if SafetyLabel is "Unsafe" or "Controversial"
159+
isJailbreak := false
160+
confidence := float32(0.0)
161+
162+
// Check for jailbreak category
163+
for _, cat := range result.Categories {
164+
if cat == "Jailbreak" {
165+
isJailbreak = true
166+
confidence = 0.9 // High confidence if jailbreak category is detected
167+
break
168+
}
169+
}
170+
171+
// If no jailbreak category but unsafe/controversial, still consider it risky
172+
if !isJailbreak && (result.SafetyLabel == "Unsafe" || result.SafetyLabel == "Controversial") {
173+
// Check if any unsafe categories are present
174+
unsafeCategories := []string{"Violent", "Non-violent Illegal Acts", "Sexual Content or Sexual Acts",
175+
"Suicide & Self-Harm", "Unethical Acts", "Politically Sensitive Topics", "Copyright Violation"}
176+
for _, cat := range result.Categories {
177+
for _, unsafeCat := range unsafeCategories {
178+
if cat == unsafeCat {
179+
isJailbreak = true
180+
confidence = 0.7 // Medium confidence for other unsafe content
181+
break
182+
}
183+
}
184+
if isJailbreak {
185+
break
186+
}
187+
}
188+
}
189+
190+
// If safe, set confidence based on safety label
191+
if !isJailbreak {
192+
if result.SafetyLabel == "Safe" {
193+
confidence = 0.95 // High confidence for safe content
194+
} else {
195+
confidence = 0.5 // Low confidence if label is unclear
196+
}
197+
}
198+
199+
class := 0 // safe/benign
200+
if isJailbreak {
201+
class = 1 // jailbreak
202+
}
203+
204+
return candle_binding.ClassResult{
205+
Class: class,
206+
Confidence: confidence,
207+
}, nil
208+
}
209+
133210
// createJailbreakInference creates the appropriate jailbreak inference based on configuration
134-
func createJailbreakInference(useModernBERT bool) JailbreakInference {
211+
func createJailbreakInference(useModernBERT bool, useQwen3Guard bool) JailbreakInference {
212+
if useQwen3Guard {
213+
return &Qwen3GuardJailbreakInference{}
214+
}
135215
if useModernBERT {
136216
return &ModernBertJailbreakInference{}
137217
}
@@ -321,7 +401,7 @@ func newClassifierWithOptions(cfg *config.RouterConfig, options ...option) (*Cla
321401
// allowing flexible deployment scenarios such as gradual migration.
322402
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping) (*Classifier, error) {
323403
options := []option{
324-
withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT), createJailbreakInference(cfg.PromptGuard.UseModernBERT)),
404+
withJailbreak(jailbreakMapping, createJailbreakInitializer(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard), createJailbreakInference(cfg.PromptGuard.UseModernBERT, cfg.PromptGuard.UseQwen3Guard)),
325405
withPII(piiMapping, createPIIInitializer(), createPIIInference()),
326406
}
327407

@@ -393,9 +473,21 @@ func (c *Classifier) initializeJailbreakClassifier() error {
393473
return fmt.Errorf("jailbreak detection is not properly configured")
394474
}
395475

396-
numClasses := c.JailbreakMapping.GetJailbreakTypeCount()
397-
if numClasses < 2 {
398-
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
476+
// Qwen3Guard doesn't require numClasses, but other models do
477+
// For Qwen3Guard, we still need the mapping for type name lookup, but numClasses is optional
478+
var numClasses int
479+
if c.Config.PromptGuard.UseQwen3Guard {
480+
// Qwen3Guard doesn't use numClasses, but we still need at least 2 for the mapping
481+
numClasses = c.JailbreakMapping.GetJailbreakTypeCount()
482+
if numClasses < 2 {
483+
// For Qwen3Guard, we can work with just 2 classes (benign and jailbreak)
484+
numClasses = 2
485+
}
486+
} else {
487+
numClasses = c.JailbreakMapping.GetJailbreakTypeCount()
488+
if numClasses < 2 {
489+
return fmt.Errorf("not enough jailbreak types for classification, need at least 2, got %d", numClasses)
490+
}
399491
}
400492

401493
return c.jailbreakInitializer.Init(c.Config.PromptGuard.ModelID, c.Config.PromptGuard.UseCPU, numClasses)

src/semantic-router/pkg/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ type PromptGuardConfig struct {
336336
// Use ModernBERT for jailbreak detection
337337
UseModernBERT bool `yaml:"use_modernbert"`
338338

339+
// Use Qwen3Guard for jailbreak detection (generative model)
340+
UseQwen3Guard bool `yaml:"use_qwen3guard"`
341+
339342
// Path to the jailbreak type mapping file
340343
JailbreakMappingPath string `yaml:"jailbreak_mapping_path"`
341344
}

0 commit comments

Comments
 (0)