const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); const { tooledStream, tooledComplete } = require("./helpers/tooled.js"); const { RetryError } = require("../error.js"); /** * The agent provider for the GroqAI provider. * Supports true OpenAI-compatible tool calling when enabled via ENV, * falling back to the UnTooled prompt-based approach otherwise. */ class GroqProvider extends InheritMultiple([Provider, UnTooled]) { model; constructor(config = {}) { const { model = "llama3-8b-8192" } = config; super(); const client = new OpenAI({ baseURL: "https://api.groq.com/openai/v1", apiKey: process.env.GROQ_API_KEY, maxRetries: 3, }); this._client = client; this.model = model; this.verbose = true; this._supportsToolCalling = null; } get client() { return this._client; } get supportsAgentStreaming() { return true; } /** * Whether this provider supports native OpenAI-compatible tool calling. * - Since Groq models vary in tool calling support, we check the ENV. * - If the ENV is not set, we default to false. * @returns {boolean} */ supportsNativeToolCalling() { if (this._supportsToolCalling !== null) return this._supportsToolCalling; const supportsToolCalling = process.env.PROVIDER_SUPPORTS_NATIVE_TOOL_CALLING?.includes("groq"); if (supportsToolCalling) this.providerLog("Groq supports native tool calling is ENABLED via ENV."); else this.providerLog( "Groq supports native tool calling is DISABLED via ENV. Will use UnTooled instead." ); this._supportsToolCalling = supportsToolCalling; return supportsToolCalling; } async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, messages, }) .then((result) => { if (!result.hasOwnProperty("choices")) throw new Error("GroqAI chat: No results!"); if (result.choices.length === 0) throw new Error("GroqAI chat: No results length!"); return result.choices[0].message.content; }) .catch((_) => { return null; }); } async #handleFunctionCallStream({ messages = [] }) { return await this.client.chat.completions.create({ model: this.model, stream: true, messages, }); } /** * Stream a chat completion with tool calling support. * Uses native tool calling when enabled via ENV, otherwise falls back to UnTooled. */ async stream(messages, functions = [], eventHandler = null) { const useNative = functions.length > 0 && (await this.supportsNativeToolCalling()); if (!useNative) { return await UnTooled.prototype.stream.call( this, messages, functions, this.#handleFunctionCallStream.bind(this), eventHandler ); } this.providerLog( "Provider.stream (tooled) - will process this chat completion." ); try { return await tooledStream( this.client, this.model, messages, functions, eventHandler ); } catch (error) { console.error(error.message, error); if (error instanceof OpenAI.AuthenticationError) throw error; if ( error instanceof OpenAI.RateLimitError || error instanceof OpenAI.InternalServerError || error instanceof OpenAI.APIError ) { throw new RetryError(error.message); } throw error; } } /** * Create a non-streaming completion with tool calling support. * Uses native tool calling when enabled via ENV, otherwise falls back to UnTooled. */ async complete(messages, functions = []) { const useNative = functions.length > 0 && (await this.supportsNativeToolCalling()); if (!useNative) { return await UnTooled.prototype.complete.call( this, messages, functions, this.#handleFunctionCallChat.bind(this) ); } try { const result = await tooledComplete( this.client, this.model, messages, functions, this.getCost.bind(this) ); if (result.retryWithError) { return this.complete([...messages, result.retryWithError], functions); } return result; } catch (error) { if (error instanceof OpenAI.AuthenticationError) throw error; if ( error instanceof OpenAI.RateLimitError || error instanceof OpenAI.InternalServerError || error instanceof OpenAI.APIError ) { throw new RetryError(error.message); } throw error; } } /** * Get the cost of the completion. * @param _usage The completion to get the cost for. * @returns The cost of the completion. */ getCost(_usage) { return 0; } } module.exports = GroqProvider;