package main
import (
"context"
"fmt"
"os"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// Define custom provider name
const ProviderOpenAICustom = schemas.ModelProvider("openai-custom")
type MyAccount struct{}
func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
return []schemas.ModelProvider{
schemas.OpenAI,
ProviderOpenAICustom, // Include your custom provider
}, nil
}
func (a *MyAccount) GetKeysForProvider(ctx context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
switch provider {
case schemas.OpenAI:
return []schemas.Key{{
Value: os.Getenv("OPENAI_API_KEY"),
Models: []string{},
Weight: 1.0,
}}, nil
case ProviderOpenAICustom:
return []schemas.Key{{
Value: os.Getenv("OPENAI_CUSTOM_API_KEY"), // API key for OpenAI-compatible endpoint
Models: []string{},
Weight: 1.0,
}}, nil
}
return nil, fmt.Errorf("provider %s not supported", provider)
}
func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) {
switch provider {
case schemas.OpenAI:
return &schemas.ProviderConfig{
NetworkConfig: schemas.DefaultNetworkConfig,
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
}, nil
case ProviderOpenAICustom:
return &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: "https://your-openai-compatible-endpoint.com", // Custom base URL
DefaultRequestTimeoutInSeconds: 60,
MaxRetries: 1,
RetryBackoffInitial: 100 * time.Millisecond,
RetryBackoffMax: 2 * time.Second,
},
ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{
Concurrency: 3,
BufferSize: 10,
},
CustomProviderConfig: &schemas.CustomProviderConfig{
BaseProviderType: schemas.OpenAI, // Use OpenAI protocol
AllowedRequests: &schemas.AllowedRequests{
TextCompletion: false,
TextCompletionStream: false,
ChatCompletion: true, // Enable chat completion
ChatCompletionStream: true, // Enable streaming
Responses: false,
ResponsesStream: false,
Embedding: false,
Speech: false,
SpeechStream: false,
Transcription: false,
TranscriptionStream: false,
},
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "/v1/chat/completions",
schemas.ChatCompletionStreamRequest: "/v1/chat/completions",
},
},
}, nil
}
return nil, fmt.Errorf("provider %s not supported", provider)
}