Skip to content

Commit 64ebfc0

Browse files
MatteoCalabro-TomTomccojocar
authored andcommitted
feat(autofix): update gemini sdk and add anthropic claude
* upgrade gemini sdk to google.golang.org/genai v1.25.0 * support newer gemini models * add anthropic claude
1 parent 506407e commit 64ebfc0

File tree

7 files changed

+229
-158
lines changed

7 files changed

+229
-158
lines changed

autofix/ai.go

Lines changed: 23 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -4,97 +4,53 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"strings"
78
"time"
89

9-
"github.com/google/generative-ai-go/genai"
10-
"google.golang.org/api/option"
11-
1210
"github.com/securego/gosec/v2/issue"
1311
)
1412

1513
const (
16-
GeminiModel = "gemini-1.5-flash"
17-
AIPrompt = `Provide a brief explanation and a solution to fix this security issue
14+
AIProviderFlagHelp = `AI API provider to generate auto fixes to issues. Valid options are:
15+
- gemini-2.5-pro, gemini-2.5-flash, gemini-2.5-flash-lite, gemini-2.0-flash, gemini-2.0-flash-lite (gemini, default);
16+
- claude-sonnet-4-0 (claude, default), claude-opus-4-0, claude-opus-4-1, claude-sonnet-3-7`
17+
18+
AIPrompt = `Provide a brief explanation and a solution to fix this security issue
1819
in Go programming language: %q.
1920
Answer in markdown format and keep the response limited to 200 words.`
20-
GeminiProvider = "gemini"
2121

2222
timeout = 30 * time.Second
2323
)
2424

25-
// GenAIClient defines the interface for the GenAI client.
2625
type GenAIClient interface {
27-
// Close clean up and close the client.
28-
Close() error
29-
// GenerativeModel build the generative mode.
30-
GenerativeModel(name string) GenAIGenerativeModel
26+
GenerateSolution(ctx context.Context, prompt string) (string, error)
3127
}
3228

33-
// GenAIGenerativeModel defines the interface for the Generative Model.
34-
type GenAIGenerativeModel interface {
35-
// GenerateContent generates an response for given prompt.
36-
GenerateContent(ctx context.Context, prompt string) (string, error)
37-
}
38-
39-
// genAIClientWrapper wraps the genai.Client to implement GenAIClient.
40-
type genAIClientWrapper struct {
41-
client *genai.Client
42-
}
43-
44-
// Close closes the gen AI client.
45-
func (w *genAIClientWrapper) Close() error {
46-
return w.client.Close()
47-
}
48-
49-
// GenerativeModel builds the generative Model.
50-
func (w *genAIClientWrapper) GenerativeModel(name string) GenAIGenerativeModel {
51-
return &genAIGenerativeModelWrapper{model: w.client.GenerativeModel(name)}
52-
}
53-
54-
// genAIGenerativeModelWrapper wraps the genai.GenerativeModel to implement GenAIGenerativeModel
55-
type genAIGenerativeModelWrapper struct {
56-
// model is the underlying generative model
57-
model *genai.GenerativeModel
58-
}
59-
60-
// GenerateContent generates a response for the given prompt using gemini API.
61-
func (w *genAIGenerativeModelWrapper) GenerateContent(ctx context.Context, prompt string) (string, error) {
62-
resp, err := w.model.GenerateContent(ctx, genai.Text(prompt))
63-
if err != nil {
64-
return "", fmt.Errorf("generating autofix: %w", err)
65-
}
66-
if len(resp.Candidates) == 0 {
67-
return "", errors.New("no autofix returned by gemini")
68-
}
69-
70-
if len(resp.Candidates[0].Content.Parts) == 0 {
71-
return "", errors.New("nothing found in the first autofix returned by gemini")
72-
}
73-
74-
// Return the first candidate
75-
return fmt.Sprintf("%+v", resp.Candidates[0].Content.Parts[0]), nil
76-
}
29+
// GenerateSolution generates a solution for the given issues using the specified AI provider
30+
func GenerateSolution(model, aiAPIKey string, issues []*issue.Issue) (err error) {
31+
var client GenAIClient
7732

78-
// NewGenAIClient creates a new gemini API client.
79-
func NewGenAIClient(ctx context.Context, aiAPIKey, endpoint string) (GenAIClient, error) {
80-
clientOptions := []option.ClientOption{option.WithAPIKey(aiAPIKey)}
81-
if endpoint != "" {
82-
clientOptions = append(clientOptions, option.WithEndpoint(endpoint))
33+
switch {
34+
case strings.HasPrefix(model, "claude"):
35+
client, err = NewClaudeClient(model, aiAPIKey)
36+
case strings.HasPrefix(model, "gemini"):
37+
client, err = NewGeminiClient(model, aiAPIKey)
8338
}
8439

85-
client, err := genai.NewClient(ctx, clientOptions...)
86-
if err != nil {
87-
return nil, fmt.Errorf("calling gemini API: %w", err)
40+
switch {
41+
case err != nil:
42+
return fmt.Errorf("initializing AI client: %w", err)
43+
case client == nil:
44+
return fmt.Errorf("unsupported AI backend: %s", model)
8845
}
8946

90-
return &genAIClientWrapper{client: client}, nil
47+
return generateSolution(client, issues)
9148
}
9249

93-
func generateSolutionByGemini(client GenAIClient, issues []*issue.Issue) error {
50+
func generateSolution(client GenAIClient, issues []*issue.Issue) error {
9451
ctx, cancel := context.WithTimeout(context.Background(), timeout)
9552
defer cancel()
9653

97-
model := client.GenerativeModel(GeminiModel)
9854
cachedAutofix := make(map[string]string)
9955
for _, issue := range issues {
10056
if val, ok := cachedAutofix[issue.What]; ok {
@@ -103,7 +59,7 @@ func generateSolutionByGemini(client GenAIClient, issues []*issue.Issue) error {
10359
}
10460

10561
prompt := fmt.Sprintf(AIPrompt, issue.What)
106-
resp, err := model.GenerateContent(ctx, prompt)
62+
resp, err := client.GenerateSolution(ctx, prompt)
10763
if err != nil {
10864
return fmt.Errorf("generating autofix with gemini: %w", err)
10965
}
@@ -117,26 +73,3 @@ func generateSolutionByGemini(client GenAIClient, issues []*issue.Issue) error {
11773
}
11874
return nil
11975
}
120-
121-
// GenerateSolution generates a solution for the given issues using the specified AI provider
122-
func GenerateSolution(aiAPIProvider, aiAPIKey, endpoint string, issues []*issue.Issue) error {
123-
ctx, cancel := context.WithTimeout(context.Background(), timeout)
124-
defer cancel()
125-
126-
var client GenAIClient
127-
128-
switch aiAPIProvider {
129-
case GeminiProvider:
130-
var err error
131-
client, err = NewGenAIClient(ctx, aiAPIKey, endpoint)
132-
if err != nil {
133-
return fmt.Errorf("generating autofix: %w", err)
134-
}
135-
default:
136-
return errors.New("ai provider not supported")
137-
}
138-
139-
defer client.Close()
140-
141-
return generateSolutionByGemini(client, issues)
142-
}

autofix/ai_test.go

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,7 @@ type MockGenAIClient struct {
1717
mock.Mock
1818
}
1919

20-
func (m *MockGenAIClient) Close() error {
21-
args := m.Called()
22-
return args.Error(0)
23-
}
24-
25-
func (m *MockGenAIClient) GenerativeModel(name string) GenAIGenerativeModel {
26-
args := m.Called(name)
27-
return args.Get(0).(GenAIGenerativeModel)
28-
}
29-
30-
// MockGenAIGenerativeModel is a mock of the GenAIGenerativeModel interface
31-
type MockGenAIGenerativeModel struct {
32-
mock.Mock
33-
}
34-
35-
func (m *MockGenAIGenerativeModel) GenerateContent(ctx context.Context, prompt string) (string, error) {
20+
func (m *MockGenAIClient) GenerateSolution(ctx context.Context, prompt string) (string, error) {
3621
args := m.Called(ctx, prompt)
3722
return args.String(0), args.Error(1)
3823
}
@@ -44,17 +29,15 @@ func TestGenerateSolutionByGemini_Success(t *testing.T) {
4429
}
4530

4631
mockClient := new(MockGenAIClient)
47-
mockModel := new(MockGenAIGenerativeModel)
48-
mockClient.On("GenerativeModel", GeminiModel).Return(mockModel).Once()
49-
mockModel.On("GenerateContent", mock.Anything, mock.Anything).Return("Autofix for issue 1", nil).Once()
32+
mockClient.On("GenerateSolution", mock.Anything, mock.Anything).Return("Autofix for issue 1", nil).Once()
5033

5134
// Act
52-
err := generateSolutionByGemini(mockClient, issues)
35+
err := generateSolution(mockClient, issues)
5336

5437
// Assert
5538
require.NoError(t, err)
5639
assert.Equal(t, []*issue.Issue{{What: "Example issue 1", Autofix: "Autofix for issue 1"}}, issues)
57-
mock.AssertExpectationsForObjects(t, mockClient, mockModel)
40+
mock.AssertExpectationsForObjects(t, mockClient)
5841
}
5942

6043
func TestGenerateSolutionByGemini_NoCandidates(t *testing.T) {
@@ -64,16 +47,14 @@ func TestGenerateSolutionByGemini_NoCandidates(t *testing.T) {
6447
}
6548

6649
mockClient := new(MockGenAIClient)
67-
mockModel := new(MockGenAIGenerativeModel)
68-
mockClient.On("GenerativeModel", GeminiModel).Return(mockModel).Once()
69-
mockModel.On("GenerateContent", mock.Anything, mock.Anything).Return("", nil).Once()
50+
mockClient.On("GenerateSolution", mock.Anything, mock.Anything).Return("", nil).Once()
7051

7152
// Act
72-
err := generateSolutionByGemini(mockClient, issues)
53+
err := generateSolution(mockClient, issues)
7354

7455
// Assert
7556
require.EqualError(t, err, "no autofix returned by gemini")
76-
mock.AssertExpectationsForObjects(t, mockClient, mockModel)
57+
mock.AssertExpectationsForObjects(t, mockClient)
7758
}
7859

7960
func TestGenerateSolutionByGemini_APIError(t *testing.T) {
@@ -83,16 +64,14 @@ func TestGenerateSolutionByGemini_APIError(t *testing.T) {
8364
}
8465

8566
mockClient := new(MockGenAIClient)
86-
mockModel := new(MockGenAIGenerativeModel)
87-
mockClient.On("GenerativeModel", GeminiModel).Return(mockModel).Once()
88-
mockModel.On("GenerateContent", mock.Anything, mock.Anything).Return("", errors.New("API error")).Once()
67+
mockClient.On("GenerateSolution", mock.Anything, mock.Anything).Return("", errors.New("API error")).Once()
8968

9069
// Act
91-
err := generateSolutionByGemini(mockClient, issues)
70+
err := generateSolution(mockClient, issues)
9271

9372
// Assert
9473
require.EqualError(t, err, "generating autofix with gemini: API error")
95-
mock.AssertExpectationsForObjects(t, mockClient, mockModel)
74+
mock.AssertExpectationsForObjects(t, mockClient)
9675
}
9776

9877
func TestGenerateSolution_UnsupportedProvider(t *testing.T) {
@@ -102,8 +81,8 @@ func TestGenerateSolution_UnsupportedProvider(t *testing.T) {
10281
}
10382

10483
// Act
105-
err := GenerateSolution("unsupported-provider", "test-api-key", "", issues)
84+
err := GenerateSolution("unsupported-provider", "test-api-key", issues)
10685

10786
// Assert
108-
require.EqualError(t, err, "ai provider not supported")
87+
require.EqualError(t, err, "unsupported AI backend: unsupported-provider")
10988
}

autofix/claude.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package autofix
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
8+
"github.com/anthropics/anthropic-sdk-go"
9+
"github.com/anthropics/anthropic-sdk-go/option"
10+
)
11+
12+
const (
13+
ModelClaudeOpus4_0 = anthropic.ModelClaudeOpus4_0
14+
ModelClaudeOpus4_1 = anthropic.ModelClaudeOpus4_1_20250805
15+
ModelClaudeSonnet4_0 = anthropic.ModelClaudeSonnet4_0
16+
)
17+
18+
var _ GenAIClient = (*claudeWrapper)(nil)
19+
20+
type claudeWrapper struct {
21+
client anthropic.Client
22+
model anthropic.Model
23+
}
24+
25+
func NewClaudeClient(model, apiKey string) (GenAIClient, error) {
26+
var options []option.RequestOption
27+
28+
if apiKey != "" {
29+
options = append(options, option.WithAPIKey(apiKey))
30+
}
31+
32+
anthropicModel := parseAnthropicModel(model)
33+
34+
return &claudeWrapper{
35+
client: anthropic.NewClient(options...),
36+
model: anthropicModel,
37+
}, nil
38+
}
39+
40+
func (c *claudeWrapper) GenerateSolution(ctx context.Context, prompt string) (string, error) {
41+
resp, err := c.client.Messages.New(ctx, anthropic.MessageNewParams{
42+
Model: anthropic.Model(c.model),
43+
MaxTokens: 1024,
44+
Messages: []anthropic.MessageParam{
45+
anthropic.NewUserMessage(anthropic.NewTextBlock(prompt)),
46+
},
47+
})
48+
if err != nil {
49+
return "", fmt.Errorf("generating autofix: %w", err)
50+
}
51+
52+
if resp == nil || len(resp.Content) == 0 {
53+
return "", errors.New("no autofix returned by claude")
54+
}
55+
56+
if len(resp.Content[0].Text) == 0 {
57+
return "", errors.New("nothing found in the first autofix returned by claude")
58+
}
59+
60+
return resp.Content[0].Text, nil
61+
}
62+
63+
func parseAnthropicModel(model string) anthropic.Model {
64+
switch model {
65+
case "claude-sonnet-3-7":
66+
return anthropic.ModelClaude3_7SonnetLatest
67+
case "claude-opus", "claude-opus-4-0":
68+
return anthropic.ModelClaudeOpus4_0
69+
case "claude-opus-4-1":
70+
return anthropic.ModelClaudeOpus4_1_20250805
71+
}
72+
73+
return anthropic.ModelClaudeSonnet4_0
74+
}

0 commit comments

Comments
 (0)