const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); const InheritMultiple = require("./helpers/classes.js"); const UnTooled = require("./helpers/untooled.js"); const { tooledStream, tooledComplete } = require("./helpers/tooled.js"); const { RetryError } = require("../error.js"); const { DockerModelRunnerLLM, parseDockerModelRunnerEndpoint, } = require("../../../AiProviders/dockerModelRunner/index.js"); /** * The agent provider for the Docker Model Runner. */ class DockerModelRunnerProvider extends InheritMultiple([Provider, UnTooled]) { model; /** * * @param {{model?: string}} config */ constructor(config = {}) { super(); const model = config?.model || process.env.DOCKER_MODEL_RUNNER_LLM_MODEL_PREF || null; const client = new OpenAI({ baseURL: parseDockerModelRunnerEndpoint( process.env.DOCKER_MODEL_RUNNER_BASE_PATH ), apiKey: null, maxRetries: 3, }); this._client = client; this.model = model; this.verbose = true; this._supportsToolCalling = null; } get client() { return this._client; } get supportsAgentStreaming() { return true; } /** * Whether this provider supports native OpenAI-compatible tool calling. * Override in subclass and return true to use native tool calling instead of UnTooled. * @returns {boolean|Promise} */ async supportsNativeToolCalling() { if (this._supportsToolCalling !== null) return this._supportsToolCalling; const dmr = new DockerModelRunnerLLM(null, this.model); const capabilities = await dmr.getModelCapabilities(); this._supportsToolCalling = capabilities.tools === true; return this._supportsToolCalling; } async #handleFunctionCallChat({ messages = [] }) { return await this.client.chat.completions .create({ model: this.model, messages, }) .then((result) => { if (!result.hasOwnProperty("choices")) throw new Error("Docker Model Runner chat: No results!"); if (result.choices.length === 0) throw new Error("Docker Model Runner chat: No results length!"); return result.choices[0].message.content; }) .catch((_) => { return null; }); } async #handleFunctionCallStream({ messages = [] }) { return await this.client.chat.completions.create({ model: this.model, stream: true, messages, }); } /** * Stream a chat completion with tool calling support. * Uses native tool calling when supported, otherwise falls back to UnTooled. */ async stream(messages, functions = [], eventHandler = null) { const useNative = functions.length > 0 && (await this.supportsNativeToolCalling()); if (!useNative) { return await UnTooled.prototype.stream.call( this, messages, functions, this.#handleFunctionCallStream.bind(this), eventHandler ); } this.providerLog( "Provider.stream (tooled) - will process this chat completion." ); try { return await tooledStream( this.client, this.model, messages, functions, eventHandler, { provider: this } ); } catch (error) { console.error(error.message, error); if (error instanceof OpenAI.AuthenticationError) throw error; if ( error instanceof OpenAI.RateLimitError || error instanceof OpenAI.InternalServerError || error instanceof OpenAI.APIError ) { throw new RetryError(error.message); } throw error; } } /** * Create a non-streaming completion with tool calling support. * Uses native tool calling when supported, otherwise falls back to UnTooled. */ async complete(messages, functions = []) { const useNative = functions.length > 0 && (await this.supportsNativeToolCalling()); if (!useNative) { return await UnTooled.prototype.complete.call( this, messages, functions, this.#handleFunctionCallChat.bind(this) ); } try { const result = await tooledComplete( this.client, this.model, messages, functions, this.getCost.bind(this), { provider: this } ); if (result.retryWithError) { return this.complete([...messages, result.retryWithError], functions); } return result; } catch (error) { if (error instanceof OpenAI.AuthenticationError) throw error; if ( error instanceof OpenAI.RateLimitError || error instanceof OpenAI.InternalServerError || error instanceof OpenAI.APIError ) { throw new RetryError(error.message); } throw error; } } /** * Get the cost of the completion. * * @param _usage The completion to get the cost for. * @returns The cost of the completion. * Stubbed since Docker Model Runner has no cost basis. */ getCost(_usage) { return 0; } } module.exports = DockerModelRunnerProvider;