refactor: rename pi-* packages to forge-native names (Phase 1)
Rename all four packages/pi-* directories to forge-native names, stripping the 'pi' identity and establishing forge's own: - packages/pi-coding-agent → packages/coding-agent - packages/pi-ai → packages/ai - packages/pi-agent-core → packages/agent-core - packages/pi-tui → packages/tui Package names updated: - @singularity-forge/pi-coding-agent → @singularity-forge/coding-agent - @singularity-forge/pi-ai → @singularity-forge/ai - @singularity-forge/pi-agent-core → @singularity-forge/agent-core - @singularity-forge/pi-tui → @singularity-forge/tui All import references, bare string references, path references, internal variable names (_bundledPi*), and dist files updated. @mariozechner/pi-* third-party compat aliases preserved. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
parent
6725a55591
commit
02a4339a51
576 changed files with 17234 additions and 150217 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
.sf/metrics.db
BIN
.sf/metrics.db
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -55,14 +55,14 @@
|
|||
},
|
||||
"discuss-milestone": {
|
||||
"minimax/MiniMax-M2.7-highspeed": {
|
||||
"successes": 2,
|
||||
"successes": 3,
|
||||
"failures": 0,
|
||||
"timeouts": 0,
|
||||
"totalTokens": 8639600,
|
||||
"totalCost": 2.0647307100000005,
|
||||
"lastUsed": "2026-05-10T01:43:48.671Z",
|
||||
"totalTokens": 10591636,
|
||||
"totalCost": 2.6534383800000003,
|
||||
"lastUsed": "2026-05-10T08:13:47.678Z",
|
||||
"successRate": 1,
|
||||
"total": 2
|
||||
"total": 3
|
||||
}
|
||||
},
|
||||
"run-uat": {
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
"!!**/dist-test",
|
||||
"!!**/rust-engine/npm",
|
||||
"!!**/*.min.js",
|
||||
"!!packages/pi-coding-agent/src/core/export-html/template.css",
|
||||
"!!packages/coding-agent/src/core/export-html/template.css",
|
||||
"!!src/resources/skills/create-sf-extension/templates"
|
||||
]
|
||||
},
|
||||
|
|
|
|||
33150
package-lock.json
generated
33150
package-lock.json
generated
File diff suppressed because it is too large
Load diff
10
package.json
10
package.json
|
|
@ -42,10 +42,10 @@
|
|||
},
|
||||
"packageManager": "npm@11.13.0",
|
||||
"scripts": {
|
||||
"build:pi-tui": "npm --workspace @singularity-forge/pi-tui run build",
|
||||
"build:pi-ai": "npm --workspace @singularity-forge/pi-ai run build",
|
||||
"build:pi-agent-core": "npm --workspace @singularity-forge/pi-agent-core run build",
|
||||
"build:pi-coding-agent": "npm --workspace @singularity-forge/pi-coding-agent run build",
|
||||
"build:pi-tui": "npm --workspace @singularity-forge/tui run build",
|
||||
"build:pi-ai": "npm --workspace @singularity-forge/ai run build",
|
||||
"build:pi-agent-core": "npm --workspace @singularity-forge/agent-core run build",
|
||||
"build:pi-coding-agent": "npm --workspace @singularity-forge/coding-agent run build",
|
||||
"build:native-pkg": "npm --workspace @singularity-forge/native run build",
|
||||
"build:rpc-client": "npm --workspace @singularity-forge/rpc-client run build",
|
||||
"build:google-gemini-cli-provider": "npm --workspace @singularity-forge/google-gemini-cli-provider run build",
|
||||
|
|
@ -60,7 +60,7 @@
|
|||
"copy-themes": "node scripts/copy-themes.cjs",
|
||||
"copy-export-html": "node scripts/copy-export-html.cjs",
|
||||
"test:unit": "npx vitest run --config vitest.config.ts",
|
||||
"test:packages": "node --test packages/pi-coding-agent/dist/core/*.test.js packages/pi-coding-agent/dist/core/tools/spawn-shell-windows.test.js",
|
||||
"test:packages": "node --test packages/coding-agent/dist/core/*.test.js packages/coding-agent/dist/core/tools/spawn-shell-windows.test.js",
|
||||
"test:marketplace": "npx vitest run src/resources/extensions/sf/tests/claude-import-tui.test.ts src/tests/marketplace-discovery.test.ts --config vitest.config.ts",
|
||||
"test:sf-light": "npx vitest run src/resources/extensions/sf/tests --config vitest.config.ts",
|
||||
"test:coverage": "npx vitest run --config vitest.config.ts --coverage",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
* dedicated workspace package so provider code can depend on one small helper
|
||||
* instead of embedding the upstream integration inline.
|
||||
*
|
||||
* Consumer: `@singularity-forge/pi-ai` Google Gemini provider.
|
||||
* Consumer: `@singularity-forge/ai` Google Gemini provider.
|
||||
*/
|
||||
import {
|
||||
AuthType,
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
{
|
||||
"name": "@singularity-forge/pi-agent-core",
|
||||
"version": "2.75.3",
|
||||
"description": "General-purpose agent core (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json"
|
||||
},
|
||||
"dependencies": {},
|
||||
"engines": {
|
||||
"node": ">=26.1.0"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,975 +0,0 @@
|
|||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolCall,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Maximum number of consecutive turns where ALL tool calls in the turn fail
|
||||
* schema validation before the loop terminates. This prevents unbounded retry
|
||||
* loops when the LLM repeatedly emits tool calls with arguments that cannot
|
||||
* pass validation (e.g., schema overload, truncated JSON, missing required
|
||||
* fields). See: https://github.com/singularity-forge/sf-run/issues/2783
|
||||
*/
|
||||
export const MAX_CONSECUTIVE_VALIDATION_FAILURES = 3;
|
||||
|
||||
export const ZERO_USAGE = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Build an AssistantMessage for an unhandled error caught outside runLoop.
|
||||
* Uses the model from config so the message satisfies the full interface.
|
||||
*/
|
||||
function createErrorMessage(
|
||||
error: unknown,
|
||||
config: AgentLoopConfig,
|
||||
): AssistantMessage {
|
||||
const msg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: msg }],
|
||||
api: config.model.api,
|
||||
provider: config.model.provider,
|
||||
model: config.model.id,
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: "error",
|
||||
errorMessage: msg,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit a message_start + message_end pair for a single message.
|
||||
*/
|
||||
function emitMessagePair(
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
message: AgentMessage,
|
||||
): void {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit the standard error sequence when the outer agent loop catches an error.
|
||||
* Pushes message_start/end, turn_end, agent_end, then closes the stream.
|
||||
*/
|
||||
function emitErrorSequence(
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
errMsg: AssistantMessage,
|
||||
newMessages: AgentMessage[],
|
||||
): void {
|
||||
emitMessagePair(stream, errMsg);
|
||||
stream.push({ type: "turn_end", message: errMsg, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: [...newMessages, errMsg] });
|
||||
stream.end([...newMessages, errMsg]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [...prompts];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, ...prompts],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
for (const prompt of prompts) {
|
||||
emitMessagePair(stream, prompt);
|
||||
}
|
||||
|
||||
try {
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
} catch (error) {
|
||||
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
|
||||
try {
|
||||
await runLoop(
|
||||
currentContext,
|
||||
newMessages,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
} catch (error) {
|
||||
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let firstTurn = true;
|
||||
// Check for steering messages at start (user may have typed while waiting)
|
||||
let pendingMessages: AgentMessage[] =
|
||||
(await config.getSteeringMessages?.()) || [];
|
||||
|
||||
// Track consecutive turns where ALL tool calls fail validation.
|
||||
// When the LLM repeatedly emits tool calls with schema-overloaded or malformed
|
||||
// arguments, each turn produces only error tool results. Without a cap, this
|
||||
// creates an unbounded retry loop that burns budget. (#2783)
|
||||
let consecutiveAllToolErrorTurns = 0;
|
||||
|
||||
// Outer loop: continues when queued follow-up messages arrive after agent would stop
|
||||
while (true) {
|
||||
let hasMoreToolCalls = true;
|
||||
let steeringAfterTools: AgentMessage[] | null = null;
|
||||
|
||||
// Inner loop: process tool calls and steering messages
|
||||
while (hasMoreToolCalls || pendingMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
stream.push({ type: "turn_start" });
|
||||
} else {
|
||||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process pending messages (inject before next assistant response)
|
||||
if (pendingMessages.length > 0) {
|
||||
for (const message of pendingMessages) {
|
||||
emitMessagePair(stream, message);
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
pendingMessages = [];
|
||||
}
|
||||
|
||||
// Stream assistant response
|
||||
let message: AssistantMessage;
|
||||
try {
|
||||
message = await streamAssistantResponse(
|
||||
currentContext,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
streamFn,
|
||||
);
|
||||
} catch (error) {
|
||||
// Critical failure before stream started (e.g. getApiKey threw, credentials in
|
||||
// backoff, network unavailable). Convert to a graceful error message so the
|
||||
// agent loop can end cleanly instead of crashing with an unhandled rejection.
|
||||
const errorText =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
message = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: config.model.api,
|
||||
provider: config.model.provider,
|
||||
model: config.model.id,
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: signal?.aborted ? "aborted" : "error",
|
||||
errorMessage: errorText,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
stream.push({ type: "message_start", message: { ...message } });
|
||||
stream.push({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
}
|
||||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for tool calls or paused server turn
|
||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||
hasMoreToolCalls =
|
||||
toolCalls.length > 0 || message.stopReason === "pauseTurn";
|
||||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls && config.externalToolExecution) {
|
||||
// External execution mode: tools were handled by the provider
|
||||
// (e.g., Claude Code SDK). Emit tool_execution events for each
|
||||
// tool call. Prefer any provider-supplied externalResult attached
|
||||
// to the tool call so the UI can show the real stdout/stderr
|
||||
// instead of a generic placeholder.
|
||||
for (const tc of toolCalls as AgentToolCall[]) {
|
||||
const externalResult = (
|
||||
tc as AgentToolCall & {
|
||||
externalResult?: {
|
||||
content?: Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>;
|
||||
details?: Record<string, unknown>;
|
||||
isError?: boolean;
|
||||
};
|
||||
}
|
||||
).externalResult;
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
args: tc.arguments,
|
||||
});
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
result: externalResult
|
||||
? {
|
||||
content: externalResult.content ?? [
|
||||
{ type: "text", text: "" },
|
||||
],
|
||||
details: externalResult.details ?? {},
|
||||
}
|
||||
: {
|
||||
content: [{ type: "text", text: "(executed by provider)" }],
|
||||
details: {},
|
||||
},
|
||||
isError: externalResult?.isError ?? false,
|
||||
});
|
||||
}
|
||||
// Don't add tool results to context or loop back — the streamSimple
|
||||
// call already ran the full multi-turn agentic loop.
|
||||
hasMoreToolCalls = false;
|
||||
} else if (hasMoreToolCalls) {
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext,
|
||||
message,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
steeringAfterTools = toolExecution.steeringMessages ?? null;
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
|
||||
// Schema overload detection (#2783): count only preparation-phase
|
||||
// errors (schema validation, tool-not-found, tool-blocked) toward the
|
||||
// consecutive failure cap. Tool execution errors — such as bash
|
||||
// commands returning non-zero exit codes (e.g. grep/rg exit 1 for
|
||||
// "no matches") — are valid tool usage and must NOT trigger the cap.
|
||||
// See: #3618
|
||||
const hasPreparationErrors = toolExecution.preparationErrorCount > 0;
|
||||
const allToolsFailedPreparation =
|
||||
toolResults.length > 0 &&
|
||||
toolExecution.preparationErrorCount === toolResults.length;
|
||||
if (allToolsFailedPreparation) {
|
||||
consecutiveAllToolErrorTurns++;
|
||||
} else if (!hasPreparationErrors) {
|
||||
// Reset only when there are zero preparation errors this turn.
|
||||
// Mixed turns (some prep errors, some successes) don't reset,
|
||||
// but they also don't increment — this avoids masking a
|
||||
// pattern of alternating schema failures with one working call.
|
||||
consecutiveAllToolErrorTurns = 0;
|
||||
}
|
||||
|
||||
if (
|
||||
consecutiveAllToolErrorTurns >= MAX_CONSECUTIVE_VALIDATION_FAILURES
|
||||
) {
|
||||
// Force-stop: the LLM is stuck retrying broken tool calls.
|
||||
// Emit the turn_end and terminate the agent loop cleanly.
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
const stopMessage: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Agent stopped: ${consecutiveAllToolErrorTurns} consecutive turns with all tool calls failing. This usually means the model is repeatedly sending arguments that do not match the tool schema.`,
|
||||
},
|
||||
],
|
||||
api: config.model.api,
|
||||
provider: config.model.provider,
|
||||
model: config.model.id,
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: "error",
|
||||
errorMessage:
|
||||
"Schema overload: consecutive tool validation failures exceeded cap",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
emitMessagePair(stream, stopMessage);
|
||||
newMessages.push(stopMessage);
|
||||
stream.push({
|
||||
type: "turn_end",
|
||||
message: stopMessage,
|
||||
toolResults: [],
|
||||
});
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
|
||||
// Get steering messages after turn completes
|
||||
if (steeringAfterTools && steeringAfterTools.length > 0) {
|
||||
pendingMessages = steeringAfterTools;
|
||||
steeringAfterTools = null;
|
||||
} else {
|
||||
pendingMessages = (await config.getSteeringMessages?.()) || [];
|
||||
}
|
||||
}
|
||||
|
||||
// Agent would stop here. Check for follow-up messages.
|
||||
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
|
||||
if (followUpMessages.length > 0) {
|
||||
// Set as pending so inner loop processes them
|
||||
pendingMessages = followUpMessages;
|
||||
continue;
|
||||
}
|
||||
|
||||
// No more messages, exit
|
||||
break;
|
||||
}
|
||||
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey
|
||||
? await config.getApiKey(config.model.provider)
|
||||
: undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
||||
for await (const event of response) {
|
||||
switch (event.type) {
|
||||
case "start":
|
||||
partialMessage = event.partial;
|
||||
context.messages.push(partialMessage);
|
||||
addedPartial = true;
|
||||
stream.push({ type: "message_start", message: { ...partialMessage } });
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
case "text_delta":
|
||||
case "text_end":
|
||||
case "thinking_start":
|
||||
case "thinking_delta":
|
||||
case "thinking_end":
|
||||
case "toolcall_start":
|
||||
case "toolcall_delta":
|
||||
case "toolcall_end":
|
||||
case "server_tool_use":
|
||||
case "web_search_result":
|
||||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
stream.push({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
|
||||
// Predictive Execution: stream hook for pre-fetching
|
||||
if (
|
||||
config.onStreamChunk &&
|
||||
(event.type === "text_delta" || event.type === "thinking_delta")
|
||||
) {
|
||||
try {
|
||||
config.onStreamChunk(event.delta, context);
|
||||
} catch {
|
||||
// Predictive hooks are advisory; never let prefetch/critic
|
||||
// failures interrupt provider streaming.
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case "done":
|
||||
case "error": {
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
}
|
||||
if (!addedPartial) {
|
||||
stream.push({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
stream.push({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return await response.result();
|
||||
}
|
||||
|
||||
/**
|
||||
* Result from executing tool calls in a turn. Includes metadata about
|
||||
* error provenance so the schema overload detector can distinguish
|
||||
* preparation failures (schema validation, tool-not-found, tool-blocked)
|
||||
* from execution failures (the tool ran but threw, e.g. bash exit code 1).
|
||||
*/
|
||||
interface ToolExecutionResult {
|
||||
toolResults: ToolResultMessage[];
|
||||
steeringMessages?: AgentMessage[];
|
||||
/** Number of tool results that failed during preparation (validation/schema). */
|
||||
preparationErrorCount: number;
|
||||
}
|
||||
|
||||
function hasUserSteeringMessage(messages: readonly AgentMessage[]): boolean {
|
||||
return messages.some((message) => message.role === "user");
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const toolCalls = assistantMessage.content.filter(
|
||||
(c) => c.type === "toolCall",
|
||||
) as AgentToolCall[];
|
||||
if (config.toolExecution === "sequential") {
|
||||
return executeToolCallsSequential(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
toolCalls,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
return executeToolCallsParallel(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
toolCalls,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
);
|
||||
}
|
||||
|
||||
async function executeToolCallsSequential(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const results: ToolResultMessage[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
let preparationErrorCount = 0;
|
||||
const interruptOnSteering = config.interruptToolExecutionOnSteering === true;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
const preparation = await prepareToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
toolCall,
|
||||
config,
|
||||
signal,
|
||||
);
|
||||
if (preparation.kind === "immediate") {
|
||||
if (preparation.isError) {
|
||||
preparationErrorCount++;
|
||||
}
|
||||
results.push(
|
||||
emitToolCallOutcome(
|
||||
toolCall,
|
||||
preparation.result,
|
||||
preparation.isError,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
const executed = await executePreparedToolCall(
|
||||
preparation,
|
||||
signal,
|
||||
stream,
|
||||
);
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
preparation,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = [...(steeringMessages ?? []), ...steering];
|
||||
if (interruptOnSteering && hasUserSteeringMessage(steering)) {
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, steeringMessages, preparationErrorCount };
|
||||
}
|
||||
|
||||
async function executeToolCallsParallel(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const results: ToolResultMessage[] = [];
|
||||
const runnableCalls: PreparedToolCall[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
let preparationErrorCount = 0;
|
||||
const interruptOnSteering = config.interruptToolExecutionOnSteering === true;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
const preparation = await prepareToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
toolCall,
|
||||
config,
|
||||
signal,
|
||||
);
|
||||
if (preparation.kind === "immediate") {
|
||||
if (preparation.isError) {
|
||||
preparationErrorCount++;
|
||||
}
|
||||
results.push(
|
||||
emitToolCallOutcome(
|
||||
toolCall,
|
||||
preparation.result,
|
||||
preparation.isError,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
runnableCalls.push(preparation);
|
||||
}
|
||||
|
||||
if (config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = [...(steeringMessages ?? []), ...steering];
|
||||
if (interruptOnSteering && hasUserSteeringMessage(steering)) {
|
||||
for (const runnable of runnableCalls) {
|
||||
results.push(
|
||||
skipToolCall(runnable.toolCall, stream, { emitStart: false }),
|
||||
);
|
||||
}
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
return {
|
||||
toolResults: results,
|
||||
steeringMessages,
|
||||
preparationErrorCount,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const runningCalls = runnableCalls.map((prepared) => ({
|
||||
prepared,
|
||||
execution: executePreparedToolCall(prepared, signal, stream),
|
||||
}));
|
||||
|
||||
for (const running of runningCalls) {
|
||||
const executed = await running.execution;
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
running.prepared,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (!steeringMessages && config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, steeringMessages, preparationErrorCount };
|
||||
}
|
||||
|
||||
type PreparedToolCall = {
|
||||
kind: "prepared";
|
||||
toolCall: AgentToolCall;
|
||||
tool: AgentTool<any>;
|
||||
args: unknown;
|
||||
};
|
||||
|
||||
type ImmediateToolCallOutcome = {
|
||||
kind: "immediate";
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
type ExecutedToolCallOutcome = {
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
async function prepareToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCall: AgentToolCall,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
): Promise<PreparedToolCall | ImmediateToolCallOutcome> {
|
||||
const tool = currentContext.tools?.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(`Tool ${toolCall.name} not found`),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
if (config.beforeToolCall) {
|
||||
const beforeResult = await config.beforeToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall,
|
||||
args: validatedArgs,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (beforeResult?.block) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(
|
||||
beforeResult.reason || "Tool execution was blocked",
|
||||
),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
kind: "prepared",
|
||||
toolCall,
|
||||
tool,
|
||||
args: validatedArgs,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(
|
||||
error instanceof Error ? error.message : String(error),
|
||||
),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function executePreparedToolCall(
|
||||
prepared: PreparedToolCall,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ExecutedToolCallOutcome> {
|
||||
try {
|
||||
const result = await prepared.tool.execute(
|
||||
prepared.toolCall.id,
|
||||
prepared.args as never,
|
||||
signal,
|
||||
(partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: prepared.toolCall.id,
|
||||
toolName: prepared.toolCall.name,
|
||||
args: prepared.toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
},
|
||||
);
|
||||
return { result, isError: false };
|
||||
} catch (error) {
|
||||
return {
|
||||
result: createErrorToolResult(
|
||||
error instanceof Error ? error.message : String(error),
|
||||
),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function finalizeExecutedToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
prepared: PreparedToolCall,
|
||||
executed: ExecutedToolCallOutcome,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ToolResultMessage> {
|
||||
let result = executed.result;
|
||||
let isError = executed.isError;
|
||||
|
||||
if (config.afterToolCall) {
|
||||
try {
|
||||
const afterResult = await config.afterToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall: prepared.toolCall,
|
||||
args: prepared.args,
|
||||
result,
|
||||
isError,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (afterResult) {
|
||||
result = {
|
||||
content:
|
||||
afterResult.content !== undefined
|
||||
? afterResult.content
|
||||
: result.content,
|
||||
details:
|
||||
afterResult.details !== undefined
|
||||
? afterResult.details
|
||||
: result.details,
|
||||
};
|
||||
isError =
|
||||
afterResult.isError !== undefined ? afterResult.isError : isError;
|
||||
}
|
||||
} catch (error) {
|
||||
result = createErrorToolResult(
|
||||
error instanceof Error ? error.message : String(error),
|
||||
);
|
||||
isError = true;
|
||||
}
|
||||
}
|
||||
|
||||
return emitToolCallOutcome(prepared.toolCall, result, isError, stream);
|
||||
}
|
||||
|
||||
function createErrorToolResult(message: string): AgentToolResult<any> {
|
||||
return {
|
||||
content: [{ type: "text", text: message }],
|
||||
details: {},
|
||||
};
|
||||
}
|
||||
|
||||
function emitToolCallOutcome(
|
||||
toolCall: AgentToolCall,
|
||||
result: AgentToolResult<any>,
|
||||
isError: boolean,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
emitMessagePair(stream, toolResultMessage);
|
||||
return toolResultMessage;
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: AgentToolCall,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
options?: { emitStart?: boolean },
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {},
|
||||
};
|
||||
|
||||
if (options?.emitStart !== false) {
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
}
|
||||
|
||||
return emitToolCallOutcome(toolCall, result, true, stream);
|
||||
}
|
||||
|
|
@ -1,190 +0,0 @@
|
|||
// Agent activeInferenceModel regression tests
|
||||
// Verifies that activeInferenceModel is set/cleared correctly in _runLoop,
|
||||
// and that the footer reads activeInferenceModel instead of state.model.
|
||||
// Regression test for https://github.com/singularity-forge/sf-run/issues/1844 Bug 2
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { dirname, join } from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import {
|
||||
type AssistantMessageEventStream,
|
||||
getModel,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { describe, it } from "vitest";
|
||||
import { Agent } from "./agent.ts";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
describe("Agent — activeInferenceModel (#1844 Bug 2)", () => {
|
||||
it("activeInferenceModel is declared in AgentState interface", () => {
|
||||
const typesSource = readFileSync(join(__dirname, "types.ts"), "utf-8");
|
||||
assert.match(
|
||||
typesSource,
|
||||
/activeInferenceModel\??:\s*Model/,
|
||||
"AgentState must declare activeInferenceModel field",
|
||||
);
|
||||
});
|
||||
|
||||
it("_runLoop sets activeInferenceModel before streaming and clears in finally", () => {
|
||||
const agentSource = readFileSync(join(__dirname, "agent.ts"), "utf-8");
|
||||
|
||||
// Must set activeInferenceModel = model before streaming starts
|
||||
const setLine = agentSource.indexOf(
|
||||
"this._state.activeInferenceModel = model",
|
||||
);
|
||||
assert.ok(
|
||||
setLine > -1,
|
||||
"agent.ts must set activeInferenceModel = model in _runLoop",
|
||||
);
|
||||
|
||||
// Must clear activeInferenceModel = undefined after streaming completes
|
||||
const clearLine = agentSource.indexOf(
|
||||
"this._state.activeInferenceModel = undefined",
|
||||
);
|
||||
assert.ok(
|
||||
clearLine > -1,
|
||||
"agent.ts must clear activeInferenceModel in finally block",
|
||||
);
|
||||
|
||||
// The set must come before the clear
|
||||
assert.ok(
|
||||
setLine < clearLine,
|
||||
"activeInferenceModel must be set before cleared",
|
||||
);
|
||||
});
|
||||
|
||||
it("footer displays activeInferenceModel instead of state.model", () => {
|
||||
const footerPath = join(
|
||||
__dirname,
|
||||
"..",
|
||||
"..",
|
||||
"pi-coding-agent",
|
||||
"src",
|
||||
"modes",
|
||||
"interactive",
|
||||
"components",
|
||||
"footer.ts",
|
||||
);
|
||||
const footerSource = readFileSync(footerPath, "utf-8");
|
||||
assert.match(
|
||||
footerSource,
|
||||
/activeInferenceModel/,
|
||||
"footer.ts must reference activeInferenceModel for display",
|
||||
);
|
||||
});
|
||||
|
||||
it("activeInferenceModel is set before AbortController creation", () => {
|
||||
const agentSource = readFileSync(join(__dirname, "agent.ts"), "utf-8");
|
||||
|
||||
const setLine = agentSource.indexOf(
|
||||
"this._state.activeInferenceModel = model",
|
||||
);
|
||||
const abortLine = agentSource.indexOf(
|
||||
"this.abortController = new AbortController",
|
||||
);
|
||||
assert.ok(setLine > -1 && abortLine > -1);
|
||||
assert.ok(
|
||||
setLine < abortLine,
|
||||
"activeInferenceModel must be set before streaming infrastructure is created",
|
||||
);
|
||||
});
|
||||
|
||||
it("getProviderOptions are forwarded into the provider stream call", async () => {
|
||||
let capturedOptions: Record<string, unknown> | undefined;
|
||||
const agent = new Agent({
|
||||
initialState: {
|
||||
model: getModel("anthropic", "claude-3-5-sonnet-20241022"),
|
||||
systemPrompt: "test",
|
||||
tools: [],
|
||||
},
|
||||
getProviderOptions: async () => ({ customRuntimeOption: "present" }),
|
||||
streamFn: (_model, _context, options): AssistantMessageEventStream => {
|
||||
capturedOptions = options as Record<string, unknown> | undefined;
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
yield {
|
||||
type: "start",
|
||||
partial: {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
};
|
||||
yield {
|
||||
type: "done",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "ok" }],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
};
|
||||
},
|
||||
result: async () => ({
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "ok" }],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-3-5-sonnet-20241022",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
}),
|
||||
[Symbol.asyncDispose]: async () => {},
|
||||
} as AssistantMessageEventStream;
|
||||
},
|
||||
});
|
||||
|
||||
await agent.prompt("hello");
|
||||
assert.equal(capturedOptions?.customRuntimeOption, "present");
|
||||
});
|
||||
});
|
||||
|
|
@ -1,688 +0,0 @@
|
|||
/**
|
||||
* Agent class that uses the agent-loop directly.
|
||||
* No transport abstraction - calls streamSimple via the loop.
|
||||
*/
|
||||
|
||||
import {
|
||||
getModel,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
streamSimple,
|
||||
type TextContent,
|
||||
type ThinkingBudgets,
|
||||
type Transport,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { agentLoop, agentLoopContinue, ZERO_USAGE } from "./agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentState,
|
||||
AgentTool,
|
||||
StreamFn,
|
||||
ThinkingLevel,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
||||
*/
|
||||
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter(
|
||||
(m) =>
|
||||
m.role === "user" || m.role === "assistant" || m.role === "toolResult",
|
||||
);
|
||||
}
|
||||
|
||||
export interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
* Default filters to user/assistant/toolResult and converts attachments.
|
||||
*/
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to context before convertToLlm.
|
||||
* Use for context pruning, injecting external context, etc.
|
||||
*/
|
||||
transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
steeringMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Whether steering messages interrupt the current assistant tool batch.
|
||||
* Defaults to false so user comments are absorbed at the next safe boundary.
|
||||
*/
|
||||
interruptToolExecutionOnSteering?: boolean;
|
||||
|
||||
/**
|
||||
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
followUpMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||
*/
|
||||
streamFn?: StreamFn;
|
||||
|
||||
/**
|
||||
* Optional session identifier forwarded to LLM providers.
|
||||
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
||||
*/
|
||||
sessionId?: string;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
||||
*/
|
||||
getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Inspect or replace provider payloads before they are sent.
|
||||
*/
|
||||
onPayload?: SimpleStreamOptions["onPayload"];
|
||||
|
||||
/**
|
||||
* Custom token budgets for thinking levels (token-based providers only).
|
||||
*/
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
*/
|
||||
transport?: Transport;
|
||||
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately,
|
||||
* allowing higher-level retry logic to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
|
||||
/**
|
||||
* Determines whether a model uses external tool execution (tools handled
|
||||
* by the provider, not dispatched locally). Evaluated per-loop so model
|
||||
* switches mid-session are handled correctly.
|
||||
*/
|
||||
externalToolExecution?: (model: Model<any>) => boolean;
|
||||
|
||||
/**
|
||||
* Optional provider-specific options to merge into the next stream call.
|
||||
*
|
||||
* Use this for runtime-only callbacks or handles that should not live in
|
||||
* shared agent state, such as UI bridges for external CLI providers.
|
||||
*/
|
||||
getProviderOptions?: (
|
||||
model: Model<any>,
|
||||
) =>
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
| Promise<Record<string, unknown> | undefined>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal wrapper that tracks message origin for origin-aware queue clearing.
|
||||
* "user" = typed by human in TUI; "system" = generated by extensions/background jobs.
|
||||
*/
|
||||
interface QueueEntry {
|
||||
message: AgentMessage;
|
||||
origin: "user" | "system";
|
||||
}
|
||||
|
||||
export class Agent {
|
||||
private _state: AgentState = {
|
||||
systemPrompt: "",
|
||||
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
messages: [],
|
||||
isStreaming: false,
|
||||
streamMessage: null,
|
||||
pendingToolCalls: new Set<string>(),
|
||||
error: undefined,
|
||||
};
|
||||
|
||||
private listeners = new Set<(e: AgentEvent) => void>();
|
||||
private abortController?: AbortController;
|
||||
private convertToLlm: (
|
||||
messages: AgentMessage[],
|
||||
) => Message[] | Promise<Message[]>;
|
||||
private transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
private steeringQueue: QueueEntry[] = [];
|
||||
private followUpQueue: QueueEntry[] = [];
|
||||
private steeringMode: "all" | "one-at-a-time";
|
||||
private followUpMode: "all" | "one-at-a-time";
|
||||
public streamFn: StreamFn;
|
||||
private _sessionId?: string;
|
||||
public getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
private _onPayload?: SimpleStreamOptions["onPayload"];
|
||||
private runningPrompt?: Promise<void>;
|
||||
private resolveRunningPrompt?: () => void;
|
||||
private _thinkingBudgets?: ThinkingBudgets;
|
||||
private _transport: Transport;
|
||||
private _maxRetryDelayMs?: number;
|
||||
private _beforeToolCall?: AgentLoopConfig["beforeToolCall"];
|
||||
private _afterToolCall?: AgentLoopConfig["afterToolCall"];
|
||||
private _externalToolExecution?: (model: Model<any>) => boolean;
|
||||
private _getProviderOptions?: AgentOptions["getProviderOptions"];
|
||||
private _interruptToolExecutionOnSteering: boolean;
|
||||
|
||||
constructor(opts: AgentOptions = {}) {
|
||||
this._state = { ...this._state, ...opts.initialState };
|
||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||
this.transformContext = opts.transformContext;
|
||||
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
||||
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
||||
this.streamFn = opts.streamFn || streamSimple;
|
||||
this._sessionId = opts.sessionId;
|
||||
this.getApiKey = opts.getApiKey;
|
||||
this._onPayload = opts.onPayload;
|
||||
this._thinkingBudgets = opts.thinkingBudgets;
|
||||
this._transport = opts.transport ?? "sse";
|
||||
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
||||
this._externalToolExecution = opts.externalToolExecution;
|
||||
this._getProviderOptions = opts.getProviderOptions;
|
||||
this._interruptToolExecutionOnSteering =
|
||||
opts.interruptToolExecutionOnSteering ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session ID used for provider caching.
|
||||
*/
|
||||
get sessionId(): string | undefined {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the session ID for provider caching.
|
||||
* Call this when switching sessions (new session, branch, resume).
|
||||
*/
|
||||
set sessionId(value: string | undefined) {
|
||||
this._sessionId = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thinking budgets.
|
||||
*/
|
||||
get thinkingBudgets(): ThinkingBudgets | undefined {
|
||||
return this._thinkingBudgets;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set custom thinking budgets for token-based providers.
|
||||
*/
|
||||
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
||||
this._thinkingBudgets = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current preferred transport.
|
||||
*/
|
||||
get transport(): Transport {
|
||||
return this._transport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the preferred transport.
|
||||
*/
|
||||
setTransport(value: Transport) {
|
||||
this._transport = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current max retry delay in milliseconds.
|
||||
*/
|
||||
get maxRetryDelayMs(): number | undefined {
|
||||
return this._maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the maximum delay to wait for server-requested retries.
|
||||
* Set to 0 to disable the cap.
|
||||
*/
|
||||
set maxRetryDelayMs(value: number | undefined) {
|
||||
this._maxRetryDelayMs = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Install a hook called before each tool executes, after argument validation.
|
||||
* Return `{ block: true }` to prevent execution.
|
||||
*/
|
||||
setBeforeToolCall(fn: AgentLoopConfig["beforeToolCall"]): void {
|
||||
this._beforeToolCall = fn;
|
||||
}
|
||||
|
||||
/**
|
||||
* Install a hook called after each tool executes, before results are emitted.
|
||||
* Return field overrides for content/details/isError.
|
||||
*/
|
||||
setAfterToolCall(fn: AgentLoopConfig["afterToolCall"]): void {
|
||||
this._afterToolCall = fn;
|
||||
}
|
||||
|
||||
get state(): AgentState {
|
||||
return this._state;
|
||||
}
|
||||
|
||||
subscribe(fn: (e: AgentEvent) => void): () => void {
|
||||
this.listeners.add(fn);
|
||||
return () => this.listeners.delete(fn);
|
||||
}
|
||||
|
||||
// State mutators
|
||||
setSystemPrompt(v: string) {
|
||||
this._state.systemPrompt = v;
|
||||
}
|
||||
|
||||
setModel(m: Model<any>) {
|
||||
this._state.model = m;
|
||||
}
|
||||
|
||||
setThinkingLevel(l: ThinkingLevel) {
|
||||
this._state.thinkingLevel = l;
|
||||
}
|
||||
|
||||
setSteeringMode(mode: "all" | "one-at-a-time") {
|
||||
this.steeringMode = mode;
|
||||
}
|
||||
|
||||
getSteeringMode(): "all" | "one-at-a-time" {
|
||||
return this.steeringMode;
|
||||
}
|
||||
|
||||
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
||||
this.followUpMode = mode;
|
||||
}
|
||||
|
||||
getFollowUpMode(): "all" | "one-at-a-time" {
|
||||
return this.followUpMode;
|
||||
}
|
||||
|
||||
setTools(t: AgentTool<any>[]) {
|
||||
this._state.tools = t;
|
||||
}
|
||||
|
||||
replaceMessages(ms: AgentMessage[]) {
|
||||
this._state.messages = ms.slice();
|
||||
}
|
||||
|
||||
appendMessage(m: AgentMessage) {
|
||||
this._state.messages = [...this._state.messages, m];
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a steering message for the agent mid-run.
|
||||
* Delivered after the current tool batch unless interrupt behavior is explicitly enabled.
|
||||
*/
|
||||
steer(m: AgentMessage, origin: "user" | "system" = "system") {
|
||||
this.steeringQueue.push({ message: m, origin });
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a follow-up message to be processed after the agent finishes.
|
||||
* Delivered only when agent has no more tool calls or steering messages.
|
||||
*/
|
||||
followUp(m: AgentMessage, origin: "user" | "system" = "system") {
|
||||
this.followUpQueue.push({ message: m, origin });
|
||||
}
|
||||
|
||||
clearSteeringQueue() {
|
||||
this.steeringQueue = [];
|
||||
}
|
||||
|
||||
clearFollowUpQueue() {
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
clearAllQueues() {
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Drain user-origin messages from queues, leaving system messages in place.
|
||||
* Used during abort to preserve messages the user explicitly typed.
|
||||
*/
|
||||
drainUserMessages(): { steering: AgentMessage[]; followUp: AgentMessage[] } {
|
||||
const userSteering = this.steeringQueue
|
||||
.filter((e) => e.origin === "user")
|
||||
.map((e) => e.message);
|
||||
const userFollowUp = this.followUpQueue
|
||||
.filter((e) => e.origin === "user")
|
||||
.map((e) => e.message);
|
||||
this.steeringQueue = this.steeringQueue.filter((e) => e.origin !== "user");
|
||||
this.followUpQueue = this.followUpQueue.filter((e) => e.origin !== "user");
|
||||
return { steering: userSteering, followUp: userFollowUp };
|
||||
}
|
||||
|
||||
hasQueuedMessages(): boolean {
|
||||
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
||||
}
|
||||
|
||||
private dequeueSteeringMessages(): AgentMessage[] {
|
||||
if (this.steeringMode === "one-at-a-time") {
|
||||
if (this.steeringQueue.length > 0) {
|
||||
const first = this.steeringQueue[0];
|
||||
this.steeringQueue = this.steeringQueue.slice(1);
|
||||
return [first.message];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const steering = this.steeringQueue.map((e) => e.message);
|
||||
this.steeringQueue = [];
|
||||
return steering;
|
||||
}
|
||||
|
||||
private dequeueFollowUpMessages(): AgentMessage[] {
|
||||
if (this.followUpMode === "one-at-a-time") {
|
||||
if (this.followUpQueue.length > 0) {
|
||||
const first = this.followUpQueue[0];
|
||||
this.followUpQueue = this.followUpQueue.slice(1);
|
||||
return [first.message];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const followUp = this.followUpQueue.map((e) => e.message);
|
||||
this.followUpQueue = [];
|
||||
return followUp;
|
||||
}
|
||||
|
||||
clearMessages() {
|
||||
this._state.messages = [];
|
||||
}
|
||||
|
||||
abort() {
|
||||
this.abortController?.abort();
|
||||
}
|
||||
|
||||
waitForIdle(): Promise<void> {
|
||||
return this.runningPrompt ?? Promise.resolve();
|
||||
}
|
||||
|
||||
reset() {
|
||||
this._state.messages = [];
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this._state.error = undefined;
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
/** Send a prompt with an AgentMessage */
|
||||
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||
async prompt(
|
||||
input: string | AgentMessage | AgentMessage[],
|
||||
images?: ImageContent[],
|
||||
) {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing a prompt. Please wait for it to finish before sending another message.",
|
||||
);
|
||||
}
|
||||
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
let msgs: AgentMessage[];
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
msgs = input;
|
||||
} else if (typeof input === "string") {
|
||||
const content: Array<TextContent | ImageContent> = [
|
||||
{ type: "text", text: input },
|
||||
];
|
||||
if (images && images.length > 0) {
|
||||
content.push(...images);
|
||||
}
|
||||
msgs = [
|
||||
{
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
} else {
|
||||
msgs = [input];
|
||||
}
|
||||
|
||||
await this._runLoop(msgs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from current context (used for retries and resuming queued messages).
|
||||
*/
|
||||
async continue() {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing. Wait for completion before continuing.",
|
||||
);
|
||||
}
|
||||
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
if (messages[messages.length - 1].role === "assistant") {
|
||||
const queuedSteering = this.dequeueSteeringMessages();
|
||||
if (queuedSteering.length > 0) {
|
||||
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
||||
return;
|
||||
}
|
||||
|
||||
const queuedFollowUp = this.dequeueFollowUpMessages();
|
||||
if (queuedFollowUp.length > 0) {
|
||||
await this._runLoop(queuedFollowUp);
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
await this._runLoop(undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the agent loop.
|
||||
* If messages are provided, starts a new conversation turn with those messages.
|
||||
* Otherwise, continues from existing context.
|
||||
*/
|
||||
private async _runLoop(
|
||||
messages?: AgentMessage[],
|
||||
options?: { skipInitialSteeringPoll?: boolean },
|
||||
) {
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
this._state.activeInferenceModel = model;
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
this.abortController = new AbortController();
|
||||
this._state.isStreaming = true;
|
||||
this._state.streamMessage = null;
|
||||
this._state.error = undefined;
|
||||
|
||||
const reasoning =
|
||||
this._state.thinkingLevel === "off"
|
||||
? undefined
|
||||
: this._state.thinkingLevel;
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: this._state.systemPrompt,
|
||||
messages: this._state.messages.slice(),
|
||||
tools: this._state.tools,
|
||||
};
|
||||
|
||||
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
||||
const providerOptions = await this._getProviderOptions?.(model);
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
...(providerOptions ?? {}),
|
||||
model,
|
||||
reasoning,
|
||||
sessionId: this._sessionId,
|
||||
onPayload: this._onPayload,
|
||||
transport: this._transport,
|
||||
thinkingBudgets: this._thinkingBudgets,
|
||||
maxRetryDelayMs: this._maxRetryDelayMs,
|
||||
convertToLlm: this.convertToLlm,
|
||||
transformContext: this.transformContext,
|
||||
getApiKey: this.getApiKey,
|
||||
getSteeringMessages: async () => {
|
||||
if (skipInitialSteeringPoll) {
|
||||
skipInitialSteeringPoll = false;
|
||||
return [];
|
||||
}
|
||||
return this.dequeueSteeringMessages();
|
||||
},
|
||||
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
||||
beforeToolCall: this._beforeToolCall,
|
||||
afterToolCall: this._afterToolCall,
|
||||
interruptToolExecutionOnSteering: this._interruptToolExecutionOnSteering,
|
||||
externalToolExecution: this._externalToolExecution?.(model) ?? false,
|
||||
};
|
||||
|
||||
let partial: AgentMessage | null = null;
|
||||
|
||||
try {
|
||||
const stream = messages
|
||||
? agentLoop(
|
||||
messages,
|
||||
context,
|
||||
config,
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
)
|
||||
: agentLoopContinue(
|
||||
context,
|
||||
config,
|
||||
this.abortController.signal,
|
||||
this.streamFn,
|
||||
);
|
||||
|
||||
for await (const event of stream) {
|
||||
// Update internal state based on events
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
case "message_update":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
partial = null;
|
||||
this._state.streamMessage = null;
|
||||
this.appendMessage(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_start":
|
||||
this._updatePendingToolCalls("add", event.toolCallId);
|
||||
break;
|
||||
|
||||
case "tool_execution_end":
|
||||
this._updatePendingToolCalls("delete", event.toolCallId);
|
||||
break;
|
||||
|
||||
case "turn_end":
|
||||
if (
|
||||
event.message.role === "assistant" &&
|
||||
(event.message as any).errorMessage
|
||||
) {
|
||||
this._state.error = (event.message as any).errorMessage;
|
||||
}
|
||||
break;
|
||||
|
||||
case "agent_end":
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
break;
|
||||
}
|
||||
|
||||
// Emit to listeners
|
||||
this.emit(event);
|
||||
}
|
||||
|
||||
// Handle any remaining partial message
|
||||
if (
|
||||
partial &&
|
||||
partial.role === "assistant" &&
|
||||
partial.content.length > 0
|
||||
) {
|
||||
const onlyEmpty = !partial.content.some(
|
||||
(c) =>
|
||||
(c.type === "thinking" && c.thinking.trim().length > 0) ||
|
||||
(c.type === "text" && c.text.trim().length > 0) ||
|
||||
(c.type === "toolCall" && c.name.trim().length > 0),
|
||||
);
|
||||
if (!onlyEmpty) {
|
||||
this.appendMessage(partial);
|
||||
} else {
|
||||
if (this.abortController?.signal.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
const errorMsg: AgentMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
} as AgentMessage;
|
||||
|
||||
this.appendMessage(errorMsg);
|
||||
this._state.error = err?.message || String(err);
|
||||
this.emit({ type: "agent_end", messages: [errorMsg] });
|
||||
} finally {
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this._state.activeInferenceModel = undefined;
|
||||
this.abortController = undefined;
|
||||
this.resolveRunningPrompt?.();
|
||||
this.runningPrompt = undefined;
|
||||
this.resolveRunningPrompt = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private _updatePendingToolCalls(action: "add" | "delete", id: string): void {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s[action](id);
|
||||
this._state.pendingToolCalls = s;
|
||||
}
|
||||
|
||||
private emit(e: AgentEvent) {
|
||||
for (const listener of this.listeners) {
|
||||
listener(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
// Core Agent
|
||||
export * from "./agent.js";
|
||||
// Loop functions
|
||||
export * from "./agent-loop.js";
|
||||
// Interactive question contract
|
||||
export * from "./interactive-questions.js";
|
||||
// Proxy utilities
|
||||
export * from "./proxy.js";
|
||||
// SF project graph
|
||||
export * from "./sf-graph.js";
|
||||
// Types
|
||||
export * from "./types.js";
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { test } from "vitest";
|
||||
import {
|
||||
formatRoundResultForTool,
|
||||
type Question,
|
||||
roundResultFromElicitationContent,
|
||||
roundResultFromRemoteAnswer,
|
||||
} from "./interactive-questions.js";
|
||||
|
||||
const questions: Question[] = [
|
||||
{
|
||||
id: "choice",
|
||||
header: "Choice",
|
||||
question: "Pick one",
|
||||
options: [
|
||||
{ label: "Alpha", description: "A" },
|
||||
{ label: "None of the above", description: "Other" },
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "multi",
|
||||
header: "Multi",
|
||||
question: "Pick many",
|
||||
allowMultiple: true,
|
||||
options: [
|
||||
{ label: "Frontend", description: "UI" },
|
||||
{ label: "Backend", description: "API" },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
test("roundResultFromElicitationContent preserves notes and multi-select arrays", () => {
|
||||
const result = roundResultFromElicitationContent(questions, {
|
||||
action: "accept",
|
||||
content: {
|
||||
choice: "None of the above",
|
||||
choice__note: "Hybrid",
|
||||
multi: ["Frontend"],
|
||||
},
|
||||
});
|
||||
|
||||
assert.deepEqual(result, {
|
||||
endInterview: false,
|
||||
answers: {
|
||||
choice: { selected: "None of the above", notes: "Hybrid" },
|
||||
multi: { selected: ["Frontend"], notes: "" },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test("roundResultFromRemoteAnswer uses question metadata to keep one multi-select as array", () => {
|
||||
const result = roundResultFromRemoteAnswer(
|
||||
{
|
||||
answers: {
|
||||
choice: { answers: ["Alpha"] },
|
||||
multi: { answers: ["Backend"] },
|
||||
},
|
||||
},
|
||||
questions,
|
||||
);
|
||||
|
||||
assert.deepEqual(result.answers.choice.selected, "Alpha");
|
||||
assert.deepEqual(result.answers.multi.selected, ["Backend"]);
|
||||
assert.equal(
|
||||
formatRoundResultForTool(result),
|
||||
JSON.stringify({
|
||||
answers: {
|
||||
choice: { answers: ["Alpha"] },
|
||||
multi: { answers: ["Backend"] },
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
|
@ -1,171 +0,0 @@
|
|||
/**
|
||||
* Shared structured-question contract for local UI, remote channels, and MCP.
|
||||
*
|
||||
* Purpose: keep every ask_user_questions transport on the same answer shape so
|
||||
* gate hooks and LLM-facing JSON do not drift between local TUI, remote
|
||||
* Slack/Discord/Telegram, and MCP elicitation paths.
|
||||
*
|
||||
* Consumer: SF ask_user_questions extension, remote question manager, and
|
||||
* structured-question transports.
|
||||
*/
|
||||
|
||||
export interface QuestionOption {
|
||||
label: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface Question {
|
||||
id: string;
|
||||
header: string;
|
||||
question: string;
|
||||
options: QuestionOption[];
|
||||
allowMultiple?: boolean;
|
||||
}
|
||||
|
||||
export interface RoundAnswer {
|
||||
selected: string | string[];
|
||||
notes: string;
|
||||
}
|
||||
|
||||
export interface RoundResult {
|
||||
/** Always false; wrap-up/exit is handled outside a single question round. */
|
||||
endInterview: false;
|
||||
answers: Record<string, RoundAnswer>;
|
||||
}
|
||||
|
||||
export interface RemoteAnswerLike {
|
||||
answers: Record<string, { answers?: string[]; user_note?: string }>;
|
||||
}
|
||||
|
||||
export type ElicitationContentValue = string | number | boolean | string[];
|
||||
|
||||
export interface ElicitationResultLike {
|
||||
action?: "accept" | "decline" | "cancel" | string;
|
||||
content?: Record<string, ElicitationContentValue>;
|
||||
}
|
||||
|
||||
export const DEFAULT_OTHER_OPTION_LABEL = "None of the above";
|
||||
|
||||
function normalizeNote(value: ElicitationContentValue | undefined): string {
|
||||
return typeof value === "string" ? value.trim() : "";
|
||||
}
|
||||
|
||||
function normalizeSelectedList(
|
||||
value: ElicitationContentValue | undefined,
|
||||
allowMultiple: boolean,
|
||||
): string[] {
|
||||
if (allowMultiple) {
|
||||
return Array.isArray(value)
|
||||
? value.filter((item): item is string => typeof item === "string")
|
||||
: [];
|
||||
}
|
||||
return typeof value === "string" && value.length > 0 ? [value] : [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert local/MCP elicitation form content into the canonical RoundResult.
|
||||
*
|
||||
* Purpose: preserve the multi-select array contract and "None of the above"
|
||||
* notes consistently across transports.
|
||||
*
|
||||
* Consumer: MCP ask_user_questions handler and any form-based local bridge.
|
||||
*/
|
||||
export function roundResultFromElicitationContent(
|
||||
questions: readonly Question[],
|
||||
result: ElicitationResultLike,
|
||||
otherOptionLabel = DEFAULT_OTHER_OPTION_LABEL,
|
||||
): RoundResult {
|
||||
const content = result.content ?? {};
|
||||
const answers: Record<string, RoundAnswer> = {};
|
||||
|
||||
for (const question of questions) {
|
||||
if (question.allowMultiple) {
|
||||
answers[question.id] = {
|
||||
selected: normalizeSelectedList(content[question.id], true),
|
||||
notes: "",
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
const list = normalizeSelectedList(content[question.id], false);
|
||||
const selected = list[0] ?? "";
|
||||
const notes =
|
||||
selected === otherOptionLabel
|
||||
? normalizeNote(content[`${question.id}__note`])
|
||||
: "";
|
||||
answers[question.id] = { selected, notes };
|
||||
}
|
||||
|
||||
return { endInterview: false, answers };
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a remote-channel answer into the canonical RoundResult.
|
||||
*
|
||||
* Purpose: remote adapters store answers as `{ answers: string[] }`; consumers
|
||||
* need the same `selected` shape as local TUI, especially array preservation for
|
||||
* multi-select questions with a single selected item.
|
||||
*
|
||||
* Consumer: SF remote question manager.
|
||||
*/
|
||||
export function roundResultFromRemoteAnswer(
|
||||
answer: RemoteAnswerLike,
|
||||
questions: readonly Question[],
|
||||
): RoundResult {
|
||||
const allowMultipleById = new Map<string, boolean>();
|
||||
for (const question of questions) {
|
||||
allowMultipleById.set(question.id, question.allowMultiple ?? false);
|
||||
}
|
||||
|
||||
const answers: Record<string, RoundAnswer> = {};
|
||||
for (const [id, data] of Object.entries(answer.answers)) {
|
||||
const list = data.answers ?? [];
|
||||
const allowMultiple = allowMultipleById.get(id) ?? false;
|
||||
answers[id] = {
|
||||
selected: allowMultiple ? [...list] : (list[0] ?? ""),
|
||||
notes: data.user_note ?? "",
|
||||
};
|
||||
}
|
||||
|
||||
return { endInterview: false, answers };
|
||||
}
|
||||
|
||||
/**
|
||||
* Render the canonical RoundResult as the historical LLM/tool JSON payload.
|
||||
*
|
||||
* Purpose: keep the text response backward-compatible while structured callers
|
||||
* consume RoundResult directly.
|
||||
*
|
||||
* Consumer: ask_user_questions local/remote/MCP handlers.
|
||||
*/
|
||||
export function formatRoundResultForTool(result: RoundResult): string {
|
||||
const answers: Record<string, { answers: string[] }> = {};
|
||||
for (const [id, answer] of Object.entries(result.answers)) {
|
||||
const list = Array.isArray(answer.selected)
|
||||
? [...answer.selected]
|
||||
: [answer.selected];
|
||||
if (answer.notes) list.push(`user_note: ${answer.notes}`);
|
||||
answers[id] = { answers: list };
|
||||
}
|
||||
return JSON.stringify({ answers });
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the structured content payload shared by MCP and extension details.
|
||||
*
|
||||
* Purpose: provide the same cancellation and response contract to gate hooks
|
||||
* regardless of transport.
|
||||
*
|
||||
* Consumer: MCP ask_user_questions handler.
|
||||
*/
|
||||
export function buildQuestionStructuredContent(
|
||||
questions: readonly Question[],
|
||||
response: RoundResult | null,
|
||||
cancelled: boolean,
|
||||
): {
|
||||
questions: readonly Question[];
|
||||
response: RoundResult | null;
|
||||
cancelled: boolean;
|
||||
} {
|
||||
return { questions, response, cancelled };
|
||||
}
|
||||
|
|
@ -1,363 +0,0 @@
|
|||
/**
|
||||
* Proxy stream function for apps that route LLM calls through a server.
|
||||
* The server manages auth and proxies requests to LLM providers.
|
||||
*/
|
||||
|
||||
// Internal import for JSON parsing utility
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type Context,
|
||||
EventStream,
|
||||
type Model,
|
||||
parseStreamingJson,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type ToolCall,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { ZERO_USAGE } from "./agent-loop.js";
|
||||
|
||||
// Create stream class matching ProxyMessageEventStream
|
||||
class ProxyMessageEventStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| {
|
||||
type: "toolcall_start";
|
||||
contentIndex: number;
|
||||
id: string;
|
||||
toolName: string;
|
||||
}
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse" | "pauseTurn">;
|
||||
usage: AssistantMessage["usage"];
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
errorMessage?: string;
|
||||
usage: AssistantMessage["usage"];
|
||||
};
|
||||
|
||||
export interface ProxyStreamOptions extends SimpleStreamOptions {
|
||||
/** Auth token for the proxy server */
|
||||
authToken: string;
|
||||
/** Proxy server URL (e.g., "https://genai.example.com") */
|
||||
proxyUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling LLM providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*
|
||||
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const agent = new Agent({
|
||||
* streamFn: (model, context, options) =>
|
||||
* streamProxy(model, context, {
|
||||
* ...options,
|
||||
* authToken: await getAuthToken(),
|
||||
* proxyUrl: "https://genai.example.com",
|
||||
* }),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
function _streamProxy(
|
||||
model: Model<any>,
|
||||
context: Context,
|
||||
options: ProxyStreamOptions,
|
||||
): ProxyMessageEventStream {
|
||||
const stream = new ProxyMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: { ...ZERO_USAGE, cost: { ...ZERO_USAGE.cost } },
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${options.proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
const event = processProxyEvent(proxyEvent, partial);
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
const reason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.stopReason = reason;
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason,
|
||||
error: partial,
|
||||
});
|
||||
stream.end();
|
||||
} finally {
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a proxy event and update the partial message.
|
||||
*/
|
||||
function processProxyEvent(
|
||||
proxyEvent: ProxyAssistantMessageEvent,
|
||||
partial: AssistantMessage,
|
||||
): AssistantMessageEvent | undefined {
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
return { type: "start", partial };
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
|
||||
return {
|
||||
type: "text_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
};
|
||||
return {
|
||||
type: "thinking_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
return {
|
||||
type: "toolcall_start",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
partial,
|
||||
};
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments =
|
||||
parseStreamingJson((content as any).partialJson) || {};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
|
||||
default: {
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,396 +0,0 @@
|
|||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
streamSimple,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
|
||||
/** Stream function - can return sync or Promise for async config lookup */
|
||||
export type StreamFn = (
|
||||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Configuration for how tool calls from a single assistant message are executed.
|
||||
*
|
||||
* - "sequential": each tool call is prepared, executed, and finalized before the next one starts.
|
||||
* - "parallel": tool calls are prepared sequentially, then allowed tools execute concurrently.
|
||||
* Final tool results are still emitted in assistant source order.
|
||||
*/
|
||||
export type ToolExecutionMode = "sequential" | "parallel";
|
||||
|
||||
/** A single tool call content block emitted by an assistant message. */
|
||||
export type AgentToolCall = Extract<
|
||||
AssistantMessage["content"][number],
|
||||
{ type: "toolCall" }
|
||||
>;
|
||||
|
||||
/**
|
||||
* Result returned from `beforeToolCall`.
|
||||
*
|
||||
* Returning `{ block: true }` prevents the tool from executing. The loop emits an error tool result instead.
|
||||
* `reason` becomes the text shown in that error result. If omitted, a default blocked message is used.
|
||||
*/
|
||||
export interface BeforeToolCallResult {
|
||||
block?: boolean;
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Partial override returned from `afterToolCall`.
|
||||
*
|
||||
* Merge semantics are field-by-field:
|
||||
* - `content`: if provided, replaces the tool result content array in full
|
||||
* - `details`: if provided, replaces the tool result details value in full
|
||||
* - `isError`: if provided, replaces the tool result error flag
|
||||
*
|
||||
* Omitted fields keep the original executed tool result values.
|
||||
*/
|
||||
export interface AfterToolCallResult {
|
||||
content?: (TextContent | ImageContent)[];
|
||||
details?: unknown;
|
||||
isError?: boolean;
|
||||
}
|
||||
|
||||
/** Context passed to `beforeToolCall`. */
|
||||
export interface BeforeToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** Current agent context at the time the tool call is prepared. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
/** Context passed to `afterToolCall`. */
|
||||
export interface AfterToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** The executed tool result before any `afterToolCall` overrides are applied. */
|
||||
result: AgentToolResult<any>;
|
||||
/** Whether the executed tool result is currently treated as an error. */
|
||||
isError: boolean;
|
||||
/** Current agent context at the time the tool call is finalized. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the agent loop.
|
||||
*/
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
*
|
||||
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
|
||||
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
|
||||
* status messages) should be filtered out.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* convertToLlm: (messages) => messages.flatMap(m => {
|
||||
* if (m.role === "custom") {
|
||||
* // Convert custom message to user message
|
||||
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
|
||||
* }
|
||||
* if (m.role === "notification") {
|
||||
* // Filter out UI-only messages
|
||||
* return [];
|
||||
* }
|
||||
* // Pass through standard LLM messages
|
||||
* return [m];
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to the context before `convertToLlm`.
|
||||
*
|
||||
* Use this for operations that work at the AgentMessage level:
|
||||
* - Context window management (pruning old messages)
|
||||
* - Injecting context from external sources
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* transformContext: async (messages) => {
|
||||
* if (estimateTokens(messages) > MAX_TOKENS) {
|
||||
* return pruneOldMessages(messages);
|
||||
* }
|
||||
* return messages;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
transformContext?: (
|
||||
messages: AgentMessage[],
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
*
|
||||
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
|
||||
* during long-running tool execution phases.
|
||||
*/
|
||||
getApiKey?: (
|
||||
provider: string,
|
||||
) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Streaming hook for Predictive Execution.
|
||||
* Called whenever a chunk of text or thinking is streamed from the LLM.
|
||||
* Allows the system to parse intent early (e.g., "I should check") and pre-fetch context
|
||||
* or run background jobs before the LLM finishes and requests a tool.
|
||||
*/
|
||||
onStreamChunk?: (chunk: string, context: AgentContext) => void;
|
||||
|
||||
/**
|
||||
* Returns steering messages to inject into the conversation mid-run.
|
||||
*
|
||||
* Called after tool execution boundaries to check for user steering.
|
||||
* By default, returned messages are added to the context after the current
|
||||
* tool batch finishes and before the next LLM call.
|
||||
*
|
||||
* Use this for "steering" the agent while it's working.
|
||||
*/
|
||||
getSteeringMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Whether steering messages interrupt the current assistant tool batch.
|
||||
*
|
||||
* Default false preserves active tool execution: a user message typed while
|
||||
* tools are running is absorbed after the current tool batch finishes. Set
|
||||
* this true only for explicit stop-now workflows that should skip remaining
|
||||
* tool calls.
|
||||
*/
|
||||
interruptToolExecutionOnSteering?: boolean;
|
||||
|
||||
/**
|
||||
* Returns follow-up messages to process after the agent would otherwise stop.
|
||||
*
|
||||
* Called when the agent has no more tool calls and no steering messages.
|
||||
* If messages are returned, they're added to the context and the agent
|
||||
* continues with another turn.
|
||||
*
|
||||
* Use this for follow-up messages that should wait until the agent finishes.
|
||||
*/
|
||||
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Tool execution mode.
|
||||
* - "sequential": execute tool calls one by one
|
||||
* - "parallel": preflight tool calls sequentially, then execute allowed tools concurrently
|
||||
*
|
||||
* Default: "parallel"
|
||||
*/
|
||||
toolExecution?: ToolExecutionMode;
|
||||
|
||||
/**
|
||||
* Called before a tool is executed, after arguments have been validated.
|
||||
*
|
||||
* Return `{ block: true }` to prevent execution. The loop emits an error tool result instead.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
beforeToolCall?: (
|
||||
context: BeforeToolCallContext,
|
||||
signal?: AbortSignal,
|
||||
) => Promise<BeforeToolCallResult | undefined>;
|
||||
|
||||
/**
|
||||
* Called after a tool finishes executing, before final tool events are emitted.
|
||||
*
|
||||
* Return an `AfterToolCallResult` to override parts of the executed tool result:
|
||||
* - `content` replaces the full content array
|
||||
* - `details` replaces the full details payload
|
||||
* - `isError` replaces the error flag
|
||||
*
|
||||
* Any omitted fields keep their original values. No deep merge is performed.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
afterToolCall?: (
|
||||
context: AfterToolCallContext,
|
||||
signal?: AbortSignal,
|
||||
) => Promise<AfterToolCallResult | undefined>;
|
||||
|
||||
/**
|
||||
* When true, tool calls in assistant messages are rendered in the TUI
|
||||
* but NOT executed locally. Used for providers that handle tool execution
|
||||
* internally (e.g., Claude Code CLI via Agent SDK).
|
||||
*
|
||||
* The agent loop emits tool_execution_start/end events for TUI rendering
|
||||
* but skips tool.execute() and does not add tool results to context.
|
||||
*/
|
||||
externalToolExecution?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Thinking/reasoning level for models that support it.
|
||||
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
|
||||
*/
|
||||
export type ThinkingLevel =
|
||||
| "off"
|
||||
| "minimal"
|
||||
| "low"
|
||||
| "medium"
|
||||
| "high"
|
||||
| "xhigh";
|
||||
|
||||
/**
|
||||
* Extensible interface for custom app messages.
|
||||
* Apps can extend via declaration merging:
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* declare module "@mariozechner/agent" {
|
||||
* interface CustomAgentMessages {
|
||||
* artifact: ArtifactMessage;
|
||||
* notification: NotificationMessage;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
// biome-ignore lint/suspicious/noEmptyInterface: extension point for downstream declaration merging
|
||||
export interface CustomAgentMessages {}
|
||||
|
||||
/**
|
||||
* AgentMessage: Union of LLM messages + custom messages.
|
||||
* This abstraction allows apps to add custom message types while maintaining
|
||||
* type safety and compatibility with the base LLM messages.
|
||||
*/
|
||||
export type AgentMessage =
|
||||
| Message
|
||||
| CustomAgentMessages[keyof CustomAgentMessages];
|
||||
|
||||
/**
|
||||
* Agent state containing all configuration and conversation data.
|
||||
*/
|
||||
export interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[]; // Can include attachments + custom message types
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null;
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
/**
|
||||
* The model currently being used for inference. Set at _runLoop() start,
|
||||
* cleared when the loop ends. When present, UI should display this instead
|
||||
* of `model` to avoid showing a stale value after a mid-turn model switch.
|
||||
*/
|
||||
activeInferenceModel?: Model<any>;
|
||||
}
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (
|
||||
partialResult: AgentToolResult<T>,
|
||||
) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<
|
||||
TParameters extends TSchema = TSchema,
|
||||
TDetails = any,
|
||||
> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: AgentMessage[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Events emitted by the Agent for UI updates.
|
||||
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
|
||||
*/
|
||||
export type AgentEvent =
|
||||
// Agent lifecycle
|
||||
| { type: "agent_start" }
|
||||
| { type: "agent_end"; messages: AgentMessage[] }
|
||||
// Turn lifecycle - a turn is one assistant response + any tool calls/results
|
||||
| { type: "turn_start" }
|
||||
| {
|
||||
type: "turn_end";
|
||||
message: AgentMessage;
|
||||
toolResults: ToolResultMessage[];
|
||||
}
|
||||
// Message lifecycle - emitted for user, assistant, and toolResult messages
|
||||
| { type: "message_start"; message: AgentMessage }
|
||||
// Only emitted for assistant messages during streaming
|
||||
| {
|
||||
type: "message_update";
|
||||
message: AgentMessage;
|
||||
assistantMessageEvent: AssistantMessageEvent;
|
||||
}
|
||||
| { type: "message_end"; message: AgentMessage }
|
||||
// Tool execution lifecycle
|
||||
| {
|
||||
type: "tool_execution_start";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: any;
|
||||
}
|
||||
| {
|
||||
type: "tool_execution_update";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
args: any;
|
||||
partialResult: any;
|
||||
}
|
||||
| {
|
||||
type: "tool_execution_end";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
result: any;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
export interface MemoryRecord {
|
||||
id?: string;
|
||||
text?: string;
|
||||
summary?: string;
|
||||
tags?: string[];
|
||||
metadata?: Record<string, unknown>;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface MemoryProvider {
|
||||
/** Search for specific anti-patterns or facts across federated nodes or locally. */
|
||||
search(
|
||||
query: string,
|
||||
options?: { limit?: number; threshold?: number },
|
||||
): Promise<MemoryRecord[]>;
|
||||
/** Store a new learning or anti-pattern to the federated graph. */
|
||||
store(memory: MemoryRecord): Promise<void>;
|
||||
}
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2024",
|
||||
"module": "Node16",
|
||||
"lib": ["ES2024"],
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"incremental": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"inlineSources": true,
|
||||
"inlineSourceMap": false,
|
||||
"moduleResolution": "Node16",
|
||||
"resolveJsonModule": true,
|
||||
"allowImportingTsExtensions": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true,
|
||||
"useDefineForClassFields": false,
|
||||
"types": ["node"],
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": [
|
||||
"node_modules",
|
||||
"dist",
|
||||
"**/*.d.ts",
|
||||
"src/**/*.d.ts",
|
||||
"src/**/*.test.ts"
|
||||
]
|
||||
}
|
||||
1
packages/pi-ai/bedrock-provider.d.ts
vendored
1
packages/pi-ai/bedrock-provider.d.ts
vendored
|
|
@ -1 +0,0 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
|
|
@ -1 +0,0 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
1
packages/pi-ai/oauth.d.ts
vendored
1
packages/pi-ai/oauth.d.ts
vendored
|
|
@ -1 +0,0 @@
|
|||
export * from "./dist/oauth.js";
|
||||
|
|
@ -1 +0,0 @@
|
|||
export * from "./dist/oauth.js";
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
{
|
||||
"name": "@singularity-forge/pi-ai",
|
||||
"version": "2.75.3",
|
||||
"description": "Unified LLM API (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./oauth": {
|
||||
"types": "./dist/oauth.d.ts",
|
||||
"import": "./dist/oauth.js"
|
||||
},
|
||||
"./bedrock-provider": {
|
||||
"types": "./bedrock-provider.d.ts",
|
||||
"import": "./bedrock-provider.js"
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.95.1",
|
||||
"@anthropic-ai/vertex-sdk": "^0.16.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.1045.0",
|
||||
"@google/gemini-cli-core": "^0.41.2",
|
||||
"@google/genai": "^2.0.1",
|
||||
"@mistralai/mistralai": "^2.2.1",
|
||||
"@singularity-forge/google-gemini-cli-provider": "^2.75.3",
|
||||
"@sinclair/typebox": "^0.34.49",
|
||||
"ajv": "^8.20.0",
|
||||
"ajv-formats": "^3.0.1",
|
||||
"chalk": "^5.6.2",
|
||||
"jsonrepair": "^3.14.0",
|
||||
"openai": "^6.37.0",
|
||||
"proxy-agent": "^8.0.1",
|
||||
"undici": "^8.2.0",
|
||||
"yaml": "^2.8.3",
|
||||
"zod-to-json-schema": "^3.24.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@smithy/node-http-handler": "^4.5.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=26.1.0"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,89 +0,0 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export type ApiStreamFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export type ApiStreamSimpleFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface ApiProvider<
|
||||
TApi extends Api = Api,
|
||||
TOptions extends StreamOptions = StreamOptions,
|
||||
> {
|
||||
api: TApi;
|
||||
stream: StreamFunction<TApi, TOptions>;
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface ApiProviderInternal {
|
||||
api: Api;
|
||||
stream: ApiStreamFunction;
|
||||
streamSimple: ApiStreamSimpleFunction;
|
||||
}
|
||||
|
||||
type RegisteredApiProvider = {
|
||||
provider: ApiProviderInternal;
|
||||
sourceId?: string;
|
||||
};
|
||||
|
||||
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
|
||||
|
||||
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
|
||||
api: TApi,
|
||||
stream: StreamFunction<TApi, TOptions>,
|
||||
): ApiStreamFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return stream(model as Model<TApi>, context, options as TOptions);
|
||||
};
|
||||
}
|
||||
|
||||
function wrapStreamSimple<TApi extends Api>(
|
||||
api: TApi,
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
|
||||
): ApiStreamSimpleFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return streamSimple(model as Model<TApi>, context, options);
|
||||
};
|
||||
}
|
||||
|
||||
export function registerApiProvider<
|
||||
TApi extends Api,
|
||||
TOptions extends StreamOptions,
|
||||
>(provider: ApiProvider<TApi, TOptions>, sourceId?: string): void {
|
||||
apiProviderRegistry.set(provider.api, {
|
||||
provider: {
|
||||
api: provider.api,
|
||||
stream: wrapStream(provider.api, provider.stream),
|
||||
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
|
||||
},
|
||||
sourceId,
|
||||
});
|
||||
}
|
||||
|
||||
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
|
||||
return apiProviderRegistry.get(api)?.provider;
|
||||
}
|
||||
|
||||
export function clearApiProviders(): void {
|
||||
apiProviderRegistry.clear();
|
||||
}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
import {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
} from "./providers/amazon-bedrock.js";
|
||||
|
||||
export const bedrockProviderModule = {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
};
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
import { existsSync, readFileSync, writeFileSync } from "node:fs";
|
||||
import { createInterface } from "node:readline";
|
||||
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
|
||||
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
|
||||
|
||||
const AUTH_FILE = "auth.json";
|
||||
const PROVIDERS = getOAuthProviders();
|
||||
|
||||
function prompt(
|
||||
rl: ReturnType<typeof createInterface>,
|
||||
question: string,
|
||||
): Promise<string> {
|
||||
return new Promise((resolve) => rl.question(question, resolve));
|
||||
}
|
||||
|
||||
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
|
||||
if (!existsSync(AUTH_FILE)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveAuth(
|
||||
auth: Record<string, { type: "oauth" } & OAuthCredentials>,
|
||||
): void {
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
async function login(providerId: OAuthProviderId): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
console.error(`Unknown provider: ${providerId}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
||||
|
||||
try {
|
||||
const credentials = await provider.login({
|
||||
onAuth: (info) => {
|
||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
||||
if (info.instructions) console.log(info.instructions);
|
||||
console.log();
|
||||
},
|
||||
onPrompt: async (p) => {
|
||||
return await promptFn(
|
||||
`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`,
|
||||
);
|
||||
},
|
||||
onProgress: (msg) => console.log(msg),
|
||||
});
|
||||
|
||||
const auth = loadAuth();
|
||||
auth[providerId] = { type: "oauth", ...credentials };
|
||||
saveAuth(auth);
|
||||
|
||||
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
||||
} finally {
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
const command = args[0];
|
||||
|
||||
if (
|
||||
!command ||
|
||||
command === "help" ||
|
||||
command === "--help" ||
|
||||
command === "-h"
|
||||
) {
|
||||
const providerList = PROVIDERS.map(
|
||||
(p) => ` ${p.id.padEnd(20)} ${p.name}`,
|
||||
).join("\n");
|
||||
console.log(`Usage: npx @singularity-forge/pi-ai <command> [provider]
|
||||
|
||||
Commands:
|
||||
login [provider] Login to an OAuth provider
|
||||
list List available providers
|
||||
|
||||
Providers:
|
||||
${providerList}
|
||||
|
||||
Examples:
|
||||
npx @singularity-forge/pi-ai login # interactive provider selection
|
||||
npx @singularity-forge/pi-ai login anthropic # login to specific provider
|
||||
npx @singularity-forge/pi-ai list # list providers
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "list") {
|
||||
console.log("Available OAuth providers:\n");
|
||||
for (const p of PROVIDERS) {
|
||||
console.log(` ${p.id.padEnd(20)} ${p.name}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "login") {
|
||||
let provider = args[1] as OAuthProviderId | undefined;
|
||||
|
||||
if (!provider) {
|
||||
const rl = createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
});
|
||||
console.log("Select a provider:\n");
|
||||
for (let i = 0; i < PROVIDERS.length; i++) {
|
||||
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
|
||||
rl.close();
|
||||
|
||||
const index = parseInt(choice, 10) - 1;
|
||||
if (index < 0 || index >= PROVIDERS.length) {
|
||||
console.error("Invalid selection");
|
||||
process.exit(1);
|
||||
}
|
||||
provider = PROVIDERS[index].id;
|
||||
}
|
||||
|
||||
if (!PROVIDERS.some((p) => p.id === provider)) {
|
||||
console.error(`Unknown provider: ${provider}`);
|
||||
console.error(
|
||||
`Use 'npx @singularity-forge/pi-ai list' to see available providers`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(`Logging in to ${provider}...`);
|
||||
await login(provider);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error(`Unknown command: ${command}`);
|
||||
console.error(`Use 'npx @singularity-forge/pi-ai --help' for usage`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Error:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { getEnvApiKey } from "./env-api-keys.js";
|
||||
|
||||
describe("getEnvApiKey", () => {
|
||||
it("uses GEMINI_API_KEY for google when present", () => {
|
||||
const savedGemini = process.env.GEMINI_API_KEY;
|
||||
const savedGoogleGenerative = process.env.GOOGLE_GENERATIVE_AI_API_KEY;
|
||||
|
||||
process.env.GEMINI_API_KEY = "gemini-key";
|
||||
process.env.GOOGLE_GENERATIVE_AI_API_KEY = "google-generative-key";
|
||||
|
||||
try {
|
||||
assert.equal(getEnvApiKey("google"), "gemini-key");
|
||||
} finally {
|
||||
if (savedGemini === undefined) delete process.env.GEMINI_API_KEY;
|
||||
else process.env.GEMINI_API_KEY = savedGemini;
|
||||
if (savedGoogleGenerative === undefined)
|
||||
delete process.env.GOOGLE_GENERATIVE_AI_API_KEY;
|
||||
else process.env.GOOGLE_GENERATIVE_AI_API_KEY = savedGoogleGenerative;
|
||||
}
|
||||
});
|
||||
|
||||
it("accepts GOOGLE_GENERATIVE_AI_API_KEY for google", () => {
|
||||
const savedGemini = process.env.GEMINI_API_KEY;
|
||||
const savedGoogleGenerative = process.env.GOOGLE_GENERATIVE_AI_API_KEY;
|
||||
|
||||
delete process.env.GEMINI_API_KEY;
|
||||
process.env.GOOGLE_GENERATIVE_AI_API_KEY = "google-generative-key";
|
||||
|
||||
try {
|
||||
assert.equal(getEnvApiKey("google"), "google-generative-key");
|
||||
} finally {
|
||||
if (savedGemini === undefined) delete process.env.GEMINI_API_KEY;
|
||||
else process.env.GEMINI_API_KEY = savedGemini;
|
||||
if (savedGoogleGenerative === undefined)
|
||||
delete process.env.GOOGLE_GENERATIVE_AI_API_KEY;
|
||||
else process.env.GOOGLE_GENERATIVE_AI_API_KEY = savedGoogleGenerative;
|
||||
}
|
||||
});
|
||||
|
||||
it("uses the OpenCode Go subscription key before the Zen key", () => {
|
||||
const savedZen = process.env.OPENCODE_API_KEY;
|
||||
const savedGo = process.env.OPENCODE_GO_API_KEY;
|
||||
|
||||
process.env.OPENCODE_API_KEY = "zen-key";
|
||||
process.env.OPENCODE_GO_API_KEY = "go-key";
|
||||
|
||||
try {
|
||||
assert.equal(getEnvApiKey("opencode"), "zen-key");
|
||||
assert.equal(getEnvApiKey("opencode-go"), "go-key");
|
||||
} finally {
|
||||
if (savedZen === undefined) delete process.env.OPENCODE_API_KEY;
|
||||
else process.env.OPENCODE_API_KEY = savedZen;
|
||||
if (savedGo === undefined) delete process.env.OPENCODE_GO_API_KEY;
|
||||
else process.env.OPENCODE_GO_API_KEY = savedGo;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
|
@ -1,192 +0,0 @@
|
|||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _existsSync: typeof import("node:fs").existsSync | null = null;
|
||||
let _homedir: typeof import("node:os").homedir | null = null;
|
||||
let _join: typeof import("node:path").join | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_FS_SPECIFIER = "node:" + "fs";
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
const NODE_PATH_SPECIFIER = "node:" + "path";
|
||||
|
||||
// Eagerly load in Node.js/Bun environment only
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
|
||||
_existsSync = (m as typeof import("node:fs")).existsSync;
|
||||
});
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_homedir = (m as typeof import("node:os")).homedir;
|
||||
});
|
||||
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
|
||||
_join = (m as typeof import("node:path")).join;
|
||||
});
|
||||
}
|
||||
|
||||
import type { KnownProvider } from "./types.js";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
if (cachedVertexAdcCredentialsExists === null) {
|
||||
// If node modules haven't loaded yet (async import race at startup),
|
||||
// return false WITHOUT caching so the next call retries once they're ready.
|
||||
// Only cache false permanently in a browser environment where fs is never available.
|
||||
if (!_existsSync || !_homedir || !_join) {
|
||||
const isNode =
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun);
|
||||
if (!isNode) {
|
||||
// Definitively in a browser — safe to cache false permanently
|
||||
cachedVertexAdcCredentialsExists = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
if (gacPath) {
|
||||
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
|
||||
} else {
|
||||
// Fall back to default ADC path (lazy evaluation)
|
||||
cachedVertexAdcCredentialsExists = _existsSync(
|
||||
_join(
|
||||
_homedir(),
|
||||
".config",
|
||||
"gcloud",
|
||||
"application_default_credentials.json",
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: any): string | undefined {
|
||||
// Fall back to environment variables
|
||||
if (provider === "github-copilot") {
|
||||
return (
|
||||
process.env.COPILOT_GITHUB_TOKEN ||
|
||||
process.env.GH_TOKEN ||
|
||||
process.env.GITHUB_TOKEN
|
||||
);
|
||||
}
|
||||
|
||||
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// Anthropic on Vertex AI uses Application Default Credentials.
|
||||
// Detected via ANTHROPIC_VERTEX_PROJECT_ID (same env var as Claude Code).
|
||||
if (provider === "anthropic-vertex") {
|
||||
const hasProject = !!process.env.ANTHROPIC_VERTEX_PROJECT_ID;
|
||||
if (hasProject) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
// Fall back to Google Cloud project env vars
|
||||
const hasGoogleProject = !!(
|
||||
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
|
||||
);
|
||||
if (hasGoogleProject && hasVertexAdcCredentials()) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
// Vertex AI uses Application Default Credentials, not API keys.
|
||||
// Auth is configured via `gcloud auth application-default login`.
|
||||
if (provider === "google-vertex") {
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasProject = !!(
|
||||
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
|
||||
);
|
||||
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
|
||||
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
// Xiaomi MiMo token-plan providers share a single key; allow legacy fallbacks.
|
||||
if (
|
||||
provider === "xiaomi" ||
|
||||
provider === "xiaomi-token-plan-ams" ||
|
||||
provider === "xiaomi-token-plan-sgp" ||
|
||||
provider === "xiaomi-token-plan-cn"
|
||||
) {
|
||||
return (
|
||||
process.env.XIAOMI_API_KEY ||
|
||||
process.env.XIAOMI_TOKEN_PLAN_API_KEY ||
|
||||
process.env.MIMO_API_KEY
|
||||
);
|
||||
}
|
||||
|
||||
if (provider === "amazon-bedrock") {
|
||||
// Amazon Bedrock supports multiple credential sources:
|
||||
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
|
||||
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
|
||||
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
|
||||
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
|
||||
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
|
||||
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
|
||||
if (
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
const envMap: Record<string, string | string[]> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
|
||||
google: ["GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"],
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
mistral: "MISTRAL_API_KEY",
|
||||
minimax: "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
huggingface: "HF_TOKEN",
|
||||
opencode: "OPENCODE_API_KEY",
|
||||
"opencode-go": ["OPENCODE_GO_API_KEY", "OPENCODE_API_KEY"],
|
||||
"kimi-coding": "KIMI_API_KEY",
|
||||
xiaomi: "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-ams": "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-sgp": "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-cn": "XIAOMI_API_KEY",
|
||||
"alibaba-coding-plan": "ALIBABA_API_KEY",
|
||||
"alibaba-dashscope": "DASHSCOPE_API_KEY",
|
||||
ollama: "OLLAMA_API_KEY",
|
||||
"ollama-cloud": "OLLAMA_API_KEY",
|
||||
"custom-openai": "CUSTOM_OPENAI_API_KEY",
|
||||
longcat: "LONGCAT_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
if (Array.isArray(envVar)) {
|
||||
for (const name of envVar) {
|
||||
const value = process.env[name];
|
||||
if (value) return value;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
export type { Static, TSchema } from "@sinclair/typebox";
|
||||
export { Type } from "@sinclair/typebox";
|
||||
|
||||
export * from "./api-registry.js";
|
||||
export * from "./env-api-keys.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export {
|
||||
mapThinkingLevelToEffort,
|
||||
supportsAdaptiveThinking,
|
||||
} from "./providers/anthropic-shared.js";
|
||||
export * from "./providers/azure-openai-responses.js";
|
||||
export * from "./providers/google.js";
|
||||
export * from "./providers/google-gemini-cli.js";
|
||||
export * from "./providers/google-vertex.js";
|
||||
export * from "./providers/mistral.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./providers/provider-capabilities.js";
|
||||
export * from "./providers/register-builtins.js";
|
||||
export type { ProviderSwitchReport } from "./providers/transform-messages.js";
|
||||
export {
|
||||
createEmptyReport,
|
||||
hasTransformations,
|
||||
transformMessagesWithReport,
|
||||
} from "./providers/transform-messages.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
export * from "./utils/json-parse.js";
|
||||
export type {
|
||||
OAuthAuthInfo,
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthPrompt,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInterface,
|
||||
} from "./utils/oauth/types.js";
|
||||
export * from "./utils/overflow.js";
|
||||
export * from "./utils/repair-tool-json.js";
|
||||
export * from "./utils/typebox-helpers.js";
|
||||
export * from "./utils/validation.js";
|
||||
|
|
@ -1,369 +0,0 @@
|
|||
// Manually-maintained model definitions for providers NOT tracked by models.dev.
|
||||
//
|
||||
// The auto-generated file (models.generated.ts) is rebuilt from the models.dev
|
||||
// third-party catalog. Providers that use proprietary endpoints and are not
|
||||
// listed on models.dev must be defined here so they survive regeneration.
|
||||
//
|
||||
// See: https://github.com/singularity-forge/sf-run/issues/2339
|
||||
//
|
||||
// To add a custom provider:
|
||||
// 1. Add its model definitions below following the existing pattern.
|
||||
// 2. Add its API key mapping to env-api-keys.ts.
|
||||
// 3. Add its provider name to KnownProvider in types.ts (if not already there).
|
||||
|
||||
import type { Model } from "./types.js";
|
||||
|
||||
export const CUSTOM_MODELS = {
|
||||
// ─── Alibaba Coding Plan ─────────────────────────────────────────────
|
||||
// Direct Alibaba DashScope Coding Plan endpoint (OpenAI-compatible).
|
||||
// NOT the same as alibaba/* models on OpenRouter — different endpoint & auth.
|
||||
// Original PR: #295 | Fixes: #1003, #1055, #1057
|
||||
"alibaba-coding-plan": {
|
||||
"qwen3.5-plus": {
|
||||
id: "qwen3.5-plus",
|
||||
name: "Qwen3.5 Plus",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 983616,
|
||||
maxTokens: 65536,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3-max-2026-01-23": {
|
||||
id: "qwen3-max-2026-01-23",
|
||||
name: "Qwen3 Max 2026-01-23",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 258048,
|
||||
maxTokens: 32768,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3-coder-next": {
|
||||
id: "qwen3-coder-next",
|
||||
name: "Qwen3 Coder Next",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 204800,
|
||||
maxTokens: 65536,
|
||||
compat: { supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3-coder-plus": {
|
||||
id: "qwen3-coder-plus",
|
||||
name: "Qwen3 Coder Plus",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 997952,
|
||||
maxTokens: 65536,
|
||||
compat: { supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"MiniMax-M2.5": {
|
||||
id: "MiniMax-M2.5",
|
||||
name: "MiniMax M2.5",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 196608,
|
||||
maxTokens: 65536,
|
||||
compat: {
|
||||
supportsStore: false,
|
||||
supportsDeveloperRole: false,
|
||||
supportsReasoningEffort: true,
|
||||
maxTokensField: "max_tokens",
|
||||
},
|
||||
} satisfies Model<"openai-completions">,
|
||||
"glm-5": {
|
||||
id: "glm-5",
|
||||
name: "GLM-5",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 202752,
|
||||
maxTokens: 16384,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"glm-4.7": {
|
||||
id: "glm-4.7",
|
||||
name: "GLM-4.7",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 169984,
|
||||
maxTokens: 16384,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"kimi-k2.5": {
|
||||
id: "kimi-k2.5",
|
||||
name: "Kimi K2.5",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-coding-plan",
|
||||
baseUrl: "https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 258048,
|
||||
maxTokens: 32768,
|
||||
compat: { thinkingFormat: "zai", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
},
|
||||
|
||||
// ─── Alibaba DashScope ───────────────────────────────────────────────
|
||||
// Regular DashScope API for users without the Coding Plan.
|
||||
// Uses the international OpenAI-compatible endpoint.
|
||||
// Requires DASHSCOPE_API_KEY from: dashscope.console.aliyun.com
|
||||
// Pricing: https://www.alibabacloud.com/help/en/model-studio/model-pricing
|
||||
"alibaba-dashscope": {
|
||||
"qwen3-max": {
|
||||
id: "qwen3-max",
|
||||
name: "Qwen3 Max",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-dashscope",
|
||||
baseUrl: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1.2,
|
||||
output: 6,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
maxTokens: 32768,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3.5-plus": {
|
||||
id: "qwen3.5-plus",
|
||||
name: "Qwen3.5 Plus",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-dashscope",
|
||||
baseUrl: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.4,
|
||||
output: 1.2,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
maxTokens: 65536,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3.5-flash": {
|
||||
id: "qwen3.5-flash",
|
||||
name: "Qwen3.5 Flash",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-dashscope",
|
||||
baseUrl: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.1,
|
||||
output: 0.4,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
maxTokens: 32768,
|
||||
compat: { supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3-coder-plus": {
|
||||
id: "qwen3-coder-plus",
|
||||
name: "Qwen3 Coder Plus",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-dashscope",
|
||||
baseUrl: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1.0,
|
||||
output: 5.0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
maxTokens: 65536,
|
||||
compat: { supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
"qwen3.6-plus": {
|
||||
id: "qwen3.6-plus",
|
||||
name: "Qwen3.6 Plus",
|
||||
api: "openai-completions",
|
||||
provider: "alibaba-dashscope",
|
||||
baseUrl: "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 0.5,
|
||||
output: 3.0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1000000,
|
||||
maxTokens: 65536,
|
||||
compat: { thinkingFormat: "qwen", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
},
|
||||
|
||||
// ─── Z.AI (GLM-5.1) ────────────────────────────────────────────────
|
||||
// GLM-5.1 is the latest GLM model from Zhipu AI, not yet in models.dev.
|
||||
// Uses the Z.AI Coding Plan endpoint (OpenAI-compatible).
|
||||
// Ref: https://docs.z.ai/devpack/using5.1
|
||||
zai: {
|
||||
"glm-5.1": {
|
||||
id: "glm-5.1",
|
||||
name: "GLM-5.1",
|
||||
api: "openai-completions",
|
||||
provider: "zai",
|
||||
baseUrl: "https://api.z.ai/api/coding/paas/v4",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 3.2,
|
||||
cacheRead: 0.2,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 204800,
|
||||
maxTokens: 131072,
|
||||
compat: { thinkingFormat: "zai", supportsDeveloperRole: false },
|
||||
} satisfies Model<"openai-completions">,
|
||||
},
|
||||
|
||||
// ─── Xiaomi MiMo ─────────────────────────────────────────────────────
|
||||
// Direct Xiaomi Token Plan AMS endpoint (Anthropic-compatible).
|
||||
// Uses Bearer auth with XIAOMI_API_KEY against /anthropic.
|
||||
xiaomi: {
|
||||
"mimo-v2-omni": {
|
||||
id: "mimo-v2-omni",
|
||||
name: "MiMo V2 Omni",
|
||||
api: "anthropic-messages",
|
||||
provider: "xiaomi",
|
||||
baseUrl: "https://token-plan-ams.xiaomimimo.com/anthropic",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: {
|
||||
input: 0.4,
|
||||
output: 2,
|
||||
cacheRead: 0.08,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 262144,
|
||||
maxTokens: 65536,
|
||||
} satisfies Model<"anthropic-messages">,
|
||||
"mimo-v2-pro": {
|
||||
id: "mimo-v2-pro",
|
||||
name: "MiMo V2 Pro",
|
||||
api: "anthropic-messages",
|
||||
provider: "xiaomi",
|
||||
baseUrl: "https://token-plan-ams.xiaomimimo.com/anthropic",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 3,
|
||||
cacheRead: 0.2,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
maxTokens: 131072,
|
||||
} satisfies Model<"anthropic-messages">,
|
||||
"mimo-v2.5": {
|
||||
id: "mimo-v2.5",
|
||||
name: "MiMo V2.5",
|
||||
api: "anthropic-messages",
|
||||
provider: "xiaomi",
|
||||
baseUrl: "https://token-plan-ams.xiaomimimo.com/anthropic",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 3,
|
||||
cacheRead: 0.2,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
maxTokens: 131072,
|
||||
} satisfies Model<"anthropic-messages">,
|
||||
"mimo-v2.5-pro": {
|
||||
id: "mimo-v2.5-pro",
|
||||
name: "MiMo V2.5 Pro",
|
||||
api: "anthropic-messages",
|
||||
provider: "xiaomi",
|
||||
baseUrl: "https://token-plan-ams.xiaomimimo.com/anthropic",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: {
|
||||
input: 1,
|
||||
output: 3,
|
||||
cacheRead: 0.2,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
contextWindow: 1048576,
|
||||
maxTokens: 131072,
|
||||
} satisfies Model<"anthropic-messages">,
|
||||
},
|
||||
} as const;
|
||||
|
|
@ -1,509 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { MODELS } from "./models.generated.js";
|
||||
import { getModel, getModels, getProviders } from "./models.js";
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Regression: qwen/qwen3.6-plus missing from OpenRouter (issue #3582)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("regression #3582 — qwen/qwen3.6-plus available via openrouter", () => {
|
||||
it("qwen/qwen3.6-plus exists in MODELS['openrouter']", () => {
|
||||
const model =
|
||||
MODELS["openrouter"][
|
||||
"qwen/qwen3.6-plus" as keyof (typeof MODELS)["openrouter"]
|
||||
];
|
||||
assert.ok(model, "qwen/qwen3.6-plus must be present in MODELS.openrouter");
|
||||
});
|
||||
|
||||
it("qwen/qwen3.6-plus is accessible via getModel()", () => {
|
||||
const model = getModel("openrouter", "qwen/qwen3.6-plus" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"getModel('openrouter', 'qwen/qwen3.6-plus') must return a model",
|
||||
);
|
||||
});
|
||||
|
||||
it("qwen/qwen3.6-plus has id matching its registry key", () => {
|
||||
const model = getModel("openrouter", "qwen/qwen3.6-plus" as any);
|
||||
assert.equal(model.id, "qwen/qwen3.6-plus");
|
||||
});
|
||||
|
||||
it("qwen/qwen3.6-plus has provider set to openrouter", () => {
|
||||
const model = getModel("openrouter", "qwen/qwen3.6-plus" as any);
|
||||
assert.equal(model.provider, "openrouter");
|
||||
});
|
||||
|
||||
it("qwen/qwen3.6-plus has reasoning enabled", () => {
|
||||
const model = getModel("openrouter", "qwen/qwen3.6-plus" as any);
|
||||
assert.equal(model.reasoning, true, "Qwen3.6 Plus is a reasoning model");
|
||||
});
|
||||
|
||||
it("qwen/qwen3.6-plus has 1M context window", () => {
|
||||
const model = getModel("openrouter", "qwen/qwen3.6-plus" as any);
|
||||
assert.equal(model.contextWindow, 1_000_000);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Regression: z-ai/glm-5.1 missing from OpenRouter (issue #4069)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("regression #4069 — z-ai/glm-5.1 available via openrouter", () => {
|
||||
it("z-ai/glm-5.1 exists in MODELS['openrouter']", () => {
|
||||
const model =
|
||||
MODELS["openrouter"][
|
||||
"z-ai/glm-5.1" as keyof (typeof MODELS)["openrouter"]
|
||||
];
|
||||
assert.ok(model, "z-ai/glm-5.1 must be present in MODELS.openrouter");
|
||||
});
|
||||
|
||||
it("z-ai/glm-5.1 is accessible via getModel()", () => {
|
||||
const model = getModel("openrouter", "z-ai/glm-5.1" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"getModel('openrouter', 'z-ai/glm-5.1') must return a model",
|
||||
);
|
||||
});
|
||||
|
||||
it("z-ai/glm-5.1 has id matching its registry key", () => {
|
||||
const model = getModel("openrouter", "z-ai/glm-5.1" as any);
|
||||
assert.equal(model.id, "z-ai/glm-5.1");
|
||||
});
|
||||
|
||||
it("z-ai/glm-5.1 has provider set to openrouter", () => {
|
||||
const model = getModel("openrouter", "z-ai/glm-5.1" as any);
|
||||
assert.equal(model.provider, "openrouter");
|
||||
});
|
||||
|
||||
it("z-ai/glm-5.1 has a positive context window", () => {
|
||||
const model = getModel("openrouter", "z-ai/glm-5.1" as any);
|
||||
assert.ok(model.contextWindow > 0);
|
||||
});
|
||||
|
||||
it("z-ai/glm-5.1 uses the OpenRouter base URL", () => {
|
||||
const model = getModel("openrouter", "z-ai/glm-5.1" as any);
|
||||
assert.equal(model.baseUrl, "https://openrouter.ai/api/v1");
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Structural invariants — every model in MODELS must be well-formed
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("MODELS structural invariants", () => {
|
||||
type ModelEntry = {
|
||||
providerKey: string;
|
||||
modelKey: string;
|
||||
model: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function allModels(): ModelEntry[] {
|
||||
const entries: ModelEntry[] = [];
|
||||
for (const [providerKey, providerModels] of Object.entries(MODELS)) {
|
||||
for (const [modelKey, model] of Object.entries(providerModels)) {
|
||||
entries.push({
|
||||
providerKey,
|
||||
modelKey,
|
||||
model: model as Record<string, unknown>,
|
||||
});
|
||||
}
|
||||
}
|
||||
return entries;
|
||||
}
|
||||
|
||||
it("every model's id field matches its key in MODELS", () => {
|
||||
const mismatches: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (model["id"] !== modelKey) {
|
||||
mismatches.push(`${providerKey}/${modelKey}: id="${model["id"]}"`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
mismatches,
|
||||
[],
|
||||
`Models where 'id' doesn't match registry key:\n ${mismatches.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model's provider field matches its parent provider key", () => {
|
||||
const mismatches: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (model["provider"] !== providerKey) {
|
||||
mismatches.push(
|
||||
`${providerKey}/${modelKey}: provider="${model["provider"]}"`,
|
||||
);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
mismatches,
|
||||
[],
|
||||
`Models where 'provider' doesn't match parent key:\n ${mismatches.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a non-empty string name", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (typeof model["name"] !== "string" || model["name"].trim() === "") {
|
||||
invalid.push(`${providerKey}/${modelKey}`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with missing or empty name:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a non-empty string api", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (typeof model["api"] !== "string" || model["api"].trim() === "") {
|
||||
invalid.push(`${providerKey}/${modelKey}`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with missing or empty api:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model's baseUrl starts with https:// (or is empty for azure-openai-responses)", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (providerKey === "azure-openai-responses") continue;
|
||||
const url = model["baseUrl"];
|
||||
if (typeof url !== "string" || !url.startsWith("https://")) {
|
||||
invalid.push(`${providerKey}/${modelKey}: baseUrl="${url}"`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with missing or non-HTTPS baseUrl:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("azure-openai-responses models have an empty baseUrl (runtime-configured)", () => {
|
||||
const models = getModels("azure-openai-responses");
|
||||
assert.ok(
|
||||
models.length > 0,
|
||||
"azure-openai-responses must have at least one model",
|
||||
);
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.baseUrl,
|
||||
"",
|
||||
`azure-openai-responses/${model.id} should have empty baseUrl`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("every model has a boolean reasoning field", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (typeof model["reasoning"] !== "boolean") {
|
||||
invalid.push(
|
||||
`${providerKey}/${modelKey}: reasoning=${model["reasoning"]}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with non-boolean reasoning:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a non-empty input array", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
const input = model["input"];
|
||||
if (!Array.isArray(input) || input.length === 0) {
|
||||
invalid.push(`${providerKey}/${modelKey}`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with missing or empty input array:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a positive contextWindow", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
const cw = model["contextWindow"];
|
||||
if (typeof cw !== "number" || cw <= 0 || !Number.isFinite(cw)) {
|
||||
invalid.push(`${providerKey}/${modelKey}: contextWindow=${cw}`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with invalid contextWindow:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a positive maxTokens", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
const mt = model["maxTokens"];
|
||||
if (typeof mt !== "number" || mt <= 0 || !Number.isFinite(mt)) {
|
||||
invalid.push(`${providerKey}/${modelKey}: maxTokens=${mt}`);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with invalid maxTokens:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model's maxTokens does not exceed contextWindow", () => {
|
||||
const knownExceptions = new Set([
|
||||
"openrouter/meta-llama/llama-3-8b-instruct",
|
||||
"openrouter/nex-agi/deepseek-v3.1-nex-n1",
|
||||
"openrouter/openai/gpt-3.5-turbo-0613",
|
||||
"openrouter/z-ai/glm-5",
|
||||
]);
|
||||
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
if (knownExceptions.has(`${providerKey}/${modelKey}`)) continue;
|
||||
const cw = model["contextWindow"] as number;
|
||||
const mt = model["maxTokens"] as number;
|
||||
if (typeof cw === "number" && typeof mt === "number" && mt > cw) {
|
||||
invalid.push(
|
||||
`${providerKey}/${modelKey}: maxTokens(${mt}) > contextWindow(${cw})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models where maxTokens exceeds contextWindow:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("every model has a cost object with non-negative numeric fields", () => {
|
||||
const invalid: string[] = [];
|
||||
for (const { providerKey, modelKey, model } of allModels()) {
|
||||
const cost = model["cost"] as Record<string, unknown> | undefined;
|
||||
if (!cost || typeof cost !== "object") {
|
||||
invalid.push(`${providerKey}/${modelKey}: missing cost object`);
|
||||
continue;
|
||||
}
|
||||
for (const field of [
|
||||
"input",
|
||||
"output",
|
||||
"cacheRead",
|
||||
"cacheWrite",
|
||||
] as const) {
|
||||
const val = cost[field];
|
||||
if (typeof val !== "number" || val < 0 || !Number.isFinite(val)) {
|
||||
invalid.push(`${providerKey}/${modelKey}: cost.${field}=${val}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
invalid,
|
||||
[],
|
||||
`Models with invalid cost fields:\n ${invalid.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("does not expose OpenRouter meta-router aliases as selectable models", () => {
|
||||
const openrouterModels = MODELS["openrouter"] as Record<string, unknown>;
|
||||
for (const id of ["auto", "openrouter/auto", "openrouter/free"]) {
|
||||
assert.equal(
|
||||
openrouterModels[id],
|
||||
undefined,
|
||||
`openrouter/${id} must be blocked`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("replaces retired OpenRouter Elephant alias with current Ling routes", () => {
|
||||
const openrouterModels = MODELS["openrouter"] as Record<string, unknown>;
|
||||
assert.equal(openrouterModels["openrouter/elephant-alpha"], undefined);
|
||||
assert.ok(openrouterModels["inclusionai/ling-2.6-1t:free"]);
|
||||
assert.ok(openrouterModels["inclusionai/ling-2.6-flash"]);
|
||||
});
|
||||
|
||||
it("no provider has duplicate model IDs", () => {
|
||||
const duplicates: string[] = [];
|
||||
for (const [providerKey, providerModels] of Object.entries(MODELS)) {
|
||||
const ids = Object.values(providerModels).map(
|
||||
(m) => (m as Record<string, unknown>)["id"] as string,
|
||||
);
|
||||
const seen = new Set<string>();
|
||||
for (const id of ids) {
|
||||
if (seen.has(id)) duplicates.push(`${providerKey}/${id}`);
|
||||
seen.add(id);
|
||||
}
|
||||
}
|
||||
assert.deepEqual(
|
||||
duplicates,
|
||||
[],
|
||||
`Duplicate model IDs within a provider:\n ${duplicates.join("\n ")}`,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Registry shape
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("MODELS registry shape", () => {
|
||||
it("has exactly 26 providers", () => {
|
||||
const count = Object.keys(MODELS).length;
|
||||
assert.equal(
|
||||
count,
|
||||
26,
|
||||
`Expected 26 providers, got ${count}: ${Object.keys(MODELS).join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("has at least 200 models in total (sanity check)", () => {
|
||||
let total = 0;
|
||||
for (const providerModels of Object.values(MODELS)) {
|
||||
total += Object.keys(providerModels).length;
|
||||
}
|
||||
assert.ok(
|
||||
total >= 200,
|
||||
`Registry has only ${total} models — unexpectedly small`,
|
||||
);
|
||||
});
|
||||
|
||||
it("all 26 expected providers are present", () => {
|
||||
const expected = [
|
||||
"amazon-bedrock",
|
||||
"anthropic",
|
||||
"azure-openai-responses",
|
||||
"cerebras",
|
||||
"github-copilot",
|
||||
"google",
|
||||
"google-gemini-cli",
|
||||
"google-vertex",
|
||||
"groq",
|
||||
"huggingface",
|
||||
"kimi-coding",
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"mistral",
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
"opencode-go",
|
||||
"openrouter",
|
||||
"vercel-ai-gateway",
|
||||
"xai",
|
||||
"xiaomi",
|
||||
"xiaomi-token-plan-ams",
|
||||
"xiaomi-token-plan-cn",
|
||||
"xiaomi-token-plan-sgp",
|
||||
"zai",
|
||||
];
|
||||
const actual = Object.keys(MODELS).sort();
|
||||
assert.deepEqual(actual, expected.sort());
|
||||
});
|
||||
|
||||
it("getProviders() returns all generated providers", () => {
|
||||
const providers = getProviders();
|
||||
for (const p of Object.keys(MODELS)) {
|
||||
assert.ok(
|
||||
providers.includes(p as any),
|
||||
`getProviders() missing generated provider: ${p}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Removed models must not exist
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("removed models are absent from the registry", () => {
|
||||
const removedModels: Array<{ provider: string; id: string }> = [
|
||||
{ provider: "openrouter", id: "anthropic/claude-3.5-sonnet" },
|
||||
{ provider: "openrouter", id: "anthropic/claude-3.5-sonnet-20240620" },
|
||||
{ provider: "openrouter", id: "mistralai/mistral-small-24b-instruct-2501" },
|
||||
{
|
||||
provider: "openrouter",
|
||||
id: "mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
},
|
||||
{ provider: "openrouter", id: "qwen/qwen3-4b:free" },
|
||||
{ provider: "openrouter", id: "stepfun/step-3.5-flash:free" },
|
||||
{ provider: "openrouter", id: "x-ai/grok-4.20-beta" },
|
||||
{ provider: "openrouter", id: "arcee-ai/trinity-mini:free" },
|
||||
{ provider: "openrouter", id: "google/gemini-3-pro-preview" },
|
||||
{ provider: "openrouter", id: "kwaipilot/kat-coder-pro" },
|
||||
{ provider: "openrouter", id: "meituan/longcat-flash-thinking" },
|
||||
{ provider: "vercel-ai-gateway", id: "xai/grok-2-vision" },
|
||||
{ provider: "anthropic", id: "claude-3-7-sonnet-latest" },
|
||||
];
|
||||
|
||||
for (const { provider, id } of removedModels) {
|
||||
it(`${provider}/${id} has been removed`, () => {
|
||||
const model = getModel(provider as any, id as any);
|
||||
assert.equal(
|
||||
model,
|
||||
undefined,
|
||||
`${provider}/${id} should be removed but is still present`,
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Spot-checks for notable models added in this regeneration
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("spot-checks for models added in this regeneration", () => {
|
||||
const newModels: Array<{
|
||||
provider: string;
|
||||
id: string;
|
||||
reasoning?: boolean;
|
||||
}> = [
|
||||
{ provider: "openrouter", id: "z-ai/glm-5.1" },
|
||||
{ provider: "openrouter", id: "z-ai/glm-5v-turbo" },
|
||||
{ provider: "openrouter", id: "google/gemma-4-31b-it" },
|
||||
{ provider: "openrouter", id: "google/gemma-4-26b-a4b-it" },
|
||||
{
|
||||
provider: "openrouter",
|
||||
id: "arcee-ai/trinity-large-thinking",
|
||||
reasoning: true,
|
||||
},
|
||||
{ provider: "openrouter", id: "openai/gpt-audio" },
|
||||
{ provider: "openrouter", id: "anthropic/claude-opus-4.6-fast" },
|
||||
{ provider: "openrouter", id: "qwen/qwen3.6-plus" },
|
||||
{ provider: "groq", id: "groq/compound" },
|
||||
{ provider: "groq", id: "groq/compound-mini" },
|
||||
{ provider: "huggingface", id: "zai-org/GLM-5.1" },
|
||||
{ provider: "openai", id: "gpt-5.3-chat-latest" },
|
||||
{ provider: "mistral", id: "mistral-small-2603" },
|
||||
{ provider: "zai", id: "glm-5.1" },
|
||||
];
|
||||
|
||||
for (const { provider, id, reasoning } of newModels) {
|
||||
it(`${provider}/${id} is present in the registry`, () => {
|
||||
const model = getModel(provider as any, id as any);
|
||||
assert.ok(
|
||||
model,
|
||||
`Expected ${provider}/${id} to be present after regeneration`,
|
||||
);
|
||||
assert.equal(model.id, id);
|
||||
assert.equal(model.provider, provider);
|
||||
if (reasoning !== undefined) {
|
||||
assert.equal(
|
||||
model.reasoning,
|
||||
reasoning,
|
||||
`${id} reasoning should be ${reasoning}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,516 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import {
|
||||
applyCapabilityPatches,
|
||||
getModel,
|
||||
getModels,
|
||||
getProviders,
|
||||
supportsXhigh,
|
||||
} from "./models.js";
|
||||
import type { Api, Model } from "./types.js";
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Custom provider preservation (regression: #2339)
|
||||
//
|
||||
// Custom providers (like alibaba-coding-plan) are manually maintained and
|
||||
// NOT sourced from models.dev. They must survive models.generated.ts
|
||||
// regeneration by living in models.custom.ts.
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("model registry — custom providers", () => {
|
||||
it("hides Gemini customtools variants from the runtime registry", () => {
|
||||
const googleModels = getModels("google").map((model) => model.id);
|
||||
const geminiCliModels = getModels("google-gemini-cli").map(
|
||||
(model) => model.id,
|
||||
);
|
||||
|
||||
assert.equal(
|
||||
googleModels.some((id) => id.endsWith("-customtools")),
|
||||
false,
|
||||
);
|
||||
assert.equal(
|
||||
geminiCliModels.some((id) => id.endsWith("-customtools")),
|
||||
false,
|
||||
);
|
||||
assert.equal(
|
||||
getModel("google" as any, "gemini-3.1-pro-preview-customtools" as any),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("alibaba-coding-plan is a registered provider", () => {
|
||||
const providers = getProviders();
|
||||
assert.ok(
|
||||
providers.includes("alibaba-coding-plan"),
|
||||
`Expected "alibaba-coding-plan" in providers, got: ${providers.join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("alibaba-coding-plan has all expected models", () => {
|
||||
const models = getModels("alibaba-coding-plan");
|
||||
const ids = models.map((m) => m.id).sort();
|
||||
const expected = [
|
||||
"MiniMax-M2.5",
|
||||
"glm-4.7",
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"qwen3-coder-next",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-max-2026-01-23",
|
||||
"qwen3.5-plus",
|
||||
];
|
||||
assert.deepEqual(ids, expected);
|
||||
});
|
||||
|
||||
it("alibaba-coding-plan models use the correct base URL", () => {
|
||||
const models = getModels("alibaba-coding-plan");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.baseUrl,
|
||||
"https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
`Model ${model.id} has wrong baseUrl: ${model.baseUrl}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-coding-plan models use openai-completions API", () => {
|
||||
const models = getModels("alibaba-coding-plan");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.api,
|
||||
"openai-completions",
|
||||
`Model ${model.id} has wrong api: ${model.api}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-coding-plan models have provider set correctly", () => {
|
||||
const models = getModels("alibaba-coding-plan");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.provider,
|
||||
"alibaba-coding-plan",
|
||||
`Model ${model.id} has wrong provider: ${model.provider}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("getModel retrieves alibaba-coding-plan models by provider+id", () => {
|
||||
// Use type assertion to test runtime behavior — alibaba-coding-plan may come
|
||||
// from custom models rather than the generated file, so the narrow
|
||||
// GeneratedProvider type doesn't include it until models.custom.ts is merged.
|
||||
const model = getModel("alibaba-coding-plan" as any, "qwen3.5-plus" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return a model for alibaba-coding-plan/qwen3.5-plus",
|
||||
);
|
||||
assert.equal(model.id, "qwen3.5-plus");
|
||||
assert.equal(model.provider, "alibaba-coding-plan");
|
||||
});
|
||||
});
|
||||
|
||||
describe("model registry — custom zai provider (GLM-5.1)", () => {
|
||||
it("zai provider includes glm-5.1 from custom models", () => {
|
||||
const models = getModels("zai" as any);
|
||||
const ids = models.map((m) => m.id);
|
||||
assert.ok(
|
||||
ids.includes("glm-5.1"),
|
||||
`Expected "glm-5.1" in zai models, got: ${ids.join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("glm-5.1 has correct provider and base URL", () => {
|
||||
const model = getModel("zai" as any, "glm-5.1" as any);
|
||||
assert.ok(model, "Expected getModel to return a model for zai/glm-5.1");
|
||||
assert.equal(model.id, "glm-5.1");
|
||||
assert.equal(model.provider, "zai");
|
||||
assert.equal(model.baseUrl, "https://api.z.ai/api/coding/paas/v4");
|
||||
assert.equal(model.api, "openai-completions");
|
||||
});
|
||||
|
||||
it("glm-5.1 has reasoning enabled and correct context window", () => {
|
||||
const model = getModel("zai" as any, "glm-5.1" as any);
|
||||
assert.ok(model);
|
||||
assert.equal(model.reasoning, true);
|
||||
assert.equal(model.contextWindow, 200000);
|
||||
assert.equal(model.maxTokens, 131072);
|
||||
});
|
||||
|
||||
it("custom glm-5.1 does not overwrite generated zai models", () => {
|
||||
const models = getModels("zai" as any);
|
||||
const ids = models.map((m) => m.id);
|
||||
// Generated models must still exist alongside custom glm-5.1
|
||||
assert.ok(ids.includes("glm-5"), "Generated glm-5 should still exist");
|
||||
assert.ok(
|
||||
ids.includes("glm-5-turbo"),
|
||||
"Generated glm-5-turbo should still exist",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("model registry — xiaomi provider", () => {
|
||||
it("xiaomi is a registered provider", () => {
|
||||
const providers = getProviders();
|
||||
assert.ok(
|
||||
providers.includes("xiaomi"),
|
||||
`Expected "xiaomi" in providers, got: ${providers.join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("xiaomi includes the expected chat models from the direct Anthropic-compatible endpoint", () => {
|
||||
const models = getModels("xiaomi" as any);
|
||||
const ids = models.map((m) => m.id).sort();
|
||||
assert.deepEqual(ids, [
|
||||
"mimo-v2-flash",
|
||||
"mimo-v2-omni",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2.5",
|
||||
"mimo-v2.5-pro",
|
||||
]);
|
||||
});
|
||||
|
||||
it("xiaomi models use the Anthropic-compatible endpoint and provider identity", () => {
|
||||
const models = getModels("xiaomi" as any);
|
||||
for (const model of models) {
|
||||
assert.equal(model.provider, "xiaomi");
|
||||
assert.equal(model.api, "anthropic-messages");
|
||||
assert.equal(
|
||||
model.baseUrl,
|
||||
"https://token-plan-ams.xiaomimimo.com/anthropic",
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("getModel retrieves xiaomi MiMo models by provider+id", () => {
|
||||
const model = getModel("xiaomi" as any, "mimo-v2-pro" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return a model for xiaomi/mimo-v2-pro",
|
||||
);
|
||||
assert.equal(model.id, "mimo-v2-pro");
|
||||
assert.equal(model.provider, "xiaomi");
|
||||
assert.equal(model.api, "anthropic-messages");
|
||||
});
|
||||
});
|
||||
|
||||
describe("model registry — kimi-coding provider", () => {
|
||||
it("kimi-coding is a registered provider", () => {
|
||||
const providers = getProviders();
|
||||
assert.ok(
|
||||
providers.includes("kimi-coding"),
|
||||
`Expected "kimi-coding" in providers, got: ${providers.join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("kimi-coding exposes the canonical live model id", () => {
|
||||
const model = getModel("kimi-coding" as any, "kimi-for-coding" as any);
|
||||
assert.ok(model, "Expected getModel to return kimi-coding/kimi-for-coding");
|
||||
assert.equal(model.id, "kimi-for-coding");
|
||||
assert.equal(model.provider, "kimi-coding");
|
||||
assert.equal(model.api, "anthropic-messages");
|
||||
assert.equal(model.baseUrl, "https://api.kimi.com/coding");
|
||||
assert.equal(model.contextWindow, 262144);
|
||||
});
|
||||
|
||||
it("kimi-coding uses market comparison pricing for Kimi K2.6", () => {
|
||||
const model = getModel("kimi-coding" as any, "kimi-for-coding" as any);
|
||||
assert.ok(model, "Expected getModel to return kimi-coding/kimi-for-coding");
|
||||
assert.equal(model.name, "Kimi K2.6");
|
||||
assert.equal(model.cost.input, 0.7448);
|
||||
assert.equal(model.cost.output, 4.655);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// New provider: alibaba-dashscope (feat: #3891)
|
||||
//
|
||||
// Regular DashScope API for users without the Coding Plan.
|
||||
// Separate from alibaba-coding-plan — different endpoint, auth, and pricing.
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("model registry — alibaba-dashscope provider", () => {
|
||||
it("alibaba-dashscope is a registered provider", () => {
|
||||
const providers = getProviders();
|
||||
assert.ok(
|
||||
providers.includes("alibaba-dashscope"),
|
||||
`Expected "alibaba-dashscope" in providers, got: ${providers.join(", ")}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("alibaba-dashscope has all expected models", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
const ids = models.map((m) => m.id).sort();
|
||||
const expected = [
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-max",
|
||||
"qwen3.5-flash",
|
||||
"qwen3.5-plus",
|
||||
"qwen3.6-plus",
|
||||
];
|
||||
assert.deepEqual(ids, expected);
|
||||
});
|
||||
|
||||
it("alibaba-dashscope models use the international DashScope base URL", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.baseUrl,
|
||||
"https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
`Model ${model.id} has wrong baseUrl: ${model.baseUrl}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-dashscope models use openai-completions API", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.api,
|
||||
"openai-completions",
|
||||
`Model ${model.id} has wrong api: ${model.api}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-dashscope models have provider set correctly", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.provider,
|
||||
"alibaba-dashscope",
|
||||
`Model ${model.id} has wrong provider: ${model.provider}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-dashscope models all have 1M context window", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
for (const model of models) {
|
||||
assert.equal(
|
||||
model.contextWindow,
|
||||
1_000_000,
|
||||
`Model ${model.id} has wrong contextWindow: ${model.contextWindow}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("alibaba-dashscope models have positive paid costs (not free-tier)", () => {
|
||||
const models = getModels("alibaba-dashscope");
|
||||
for (const model of models) {
|
||||
assert.ok(
|
||||
model.cost.input > 0,
|
||||
`${model.id}: input cost should be > 0 (paid tier)`,
|
||||
);
|
||||
assert.ok(
|
||||
model.cost.output > 0,
|
||||
`${model.id}: output cost should be > 0 (paid tier)`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it("qwen3-max is a reasoning model with correct pricing", () => {
|
||||
const model = getModel("alibaba-dashscope" as any, "qwen3-max" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return qwen3-max for alibaba-dashscope",
|
||||
);
|
||||
assert.equal(model.reasoning, true);
|
||||
assert.equal(model.cost.input, 1.2);
|
||||
assert.equal(model.cost.output, 6);
|
||||
assert.equal(model.maxTokens, 32768);
|
||||
});
|
||||
|
||||
it("qwen3.5-plus is a reasoning model with correct pricing", () => {
|
||||
const model = getModel("alibaba-dashscope" as any, "qwen3.5-plus" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return qwen3.5-plus for alibaba-dashscope",
|
||||
);
|
||||
assert.equal(model.reasoning, true);
|
||||
assert.equal(model.cost.input, 0.4);
|
||||
assert.equal(model.cost.output, 1.2);
|
||||
assert.equal(model.maxTokens, 65536);
|
||||
});
|
||||
|
||||
it("qwen3.5-flash is not a reasoning model", () => {
|
||||
const model = getModel("alibaba-dashscope" as any, "qwen3.5-flash" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return qwen3.5-flash for alibaba-dashscope",
|
||||
);
|
||||
assert.equal(model.reasoning, false);
|
||||
assert.equal(model.cost.input, 0.1);
|
||||
assert.equal(model.cost.output, 0.4);
|
||||
});
|
||||
|
||||
it("qwen3-coder-plus is not a reasoning model", () => {
|
||||
const model = getModel(
|
||||
"alibaba-dashscope" as any,
|
||||
"qwen3-coder-plus" as any,
|
||||
);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return qwen3-coder-plus for alibaba-dashscope",
|
||||
);
|
||||
assert.equal(model.reasoning, false);
|
||||
assert.equal(model.cost.input, 1.0);
|
||||
assert.equal(model.cost.output, 5.0);
|
||||
});
|
||||
|
||||
it("qwen3.6-plus is a reasoning model", () => {
|
||||
const model = getModel("alibaba-dashscope" as any, "qwen3.6-plus" as any);
|
||||
assert.ok(
|
||||
model,
|
||||
"Expected getModel to return qwen3.6-plus for alibaba-dashscope",
|
||||
);
|
||||
assert.equal(model.reasoning, true);
|
||||
assert.equal(model.cost.input, 0.5);
|
||||
assert.equal(model.cost.output, 3.0);
|
||||
});
|
||||
|
||||
it("alibaba-dashscope is independent of alibaba-coding-plan (different endpoint)", () => {
|
||||
const dashscope = getModels("alibaba-dashscope");
|
||||
const codingPlan = getModels("alibaba-coding-plan");
|
||||
for (const m of dashscope) {
|
||||
assert.notEqual(
|
||||
m.baseUrl,
|
||||
"https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
`${m.id} must not use the Coding Plan endpoint`,
|
||||
);
|
||||
}
|
||||
// Both providers must coexist — coding-plan must not have been overwritten
|
||||
assert.ok(
|
||||
codingPlan.length > 0,
|
||||
"alibaba-coding-plan must still have models",
|
||||
);
|
||||
});
|
||||
|
||||
it("getModel returns undefined for unknown model in alibaba-dashscope (failure path)", () => {
|
||||
const model = getModel("alibaba-dashscope" as any, "does-not-exist" as any);
|
||||
assert.equal(model, undefined);
|
||||
});
|
||||
});
|
||||
|
||||
describe("model registry — custom models do not collide with generated models", () => {
|
||||
it("generated providers still exist alongside custom providers", () => {
|
||||
const providers = getProviders();
|
||||
// Spot-check a few generated providers
|
||||
assert.ok(providers.includes("openai"), "openai should be in providers");
|
||||
assert.ok(
|
||||
providers.includes("anthropic"),
|
||||
"anthropic should be in providers",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Capability patches (regression: #2546)
|
||||
//
|
||||
// CAPABILITY_PATCHES must apply capabilities to models in the static
|
||||
// registry AND to models constructed outside of it (custom, extension,
|
||||
// discovered). supportsXhigh() reads model.capabilities — not model IDs.
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/** Helper: build a minimal synthetic model for testing */
|
||||
function syntheticModel(overrides: Partial<Model<Api>>): Model<Api> {
|
||||
return {
|
||||
id: "test-model",
|
||||
name: "Test Model",
|
||||
api: "openai-completions" as Api,
|
||||
provider: "test-provider",
|
||||
baseUrl: "https://example.com",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16384,
|
||||
...overrides,
|
||||
} as Model<Api>;
|
||||
}
|
||||
|
||||
describe("supportsXhigh — registry models", () => {
|
||||
it("returns true for GPT-5.4 from the registry", () => {
|
||||
const model = getModel("openai", "gpt-5.4" as any);
|
||||
if (!model) return; // skip if model not in generated catalog
|
||||
assert.equal(supportsXhigh(model), true);
|
||||
});
|
||||
|
||||
it("returns false for a non-reasoning model", () => {
|
||||
const models = getModels("openai");
|
||||
const nonXhigh = models.find((m) => !m.id.includes("gpt-5."));
|
||||
if (!nonXhigh) return;
|
||||
assert.equal(supportsXhigh(nonXhigh), false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("supportsXhigh — synthetic models (regression: custom/extension models)", () => {
|
||||
it("returns false for a model without capabilities", () => {
|
||||
const model = syntheticModel({ id: "my-custom-model" });
|
||||
assert.equal(supportsXhigh(model), false);
|
||||
});
|
||||
|
||||
it("returns true when capabilities.supportsXhigh is explicitly set", () => {
|
||||
const model = syntheticModel({
|
||||
id: "my-custom-model",
|
||||
capabilities: { supportsXhigh: true },
|
||||
});
|
||||
assert.equal(supportsXhigh(model), true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("applyCapabilityPatches", () => {
|
||||
it("patches a GPT-5.4 model that has no capabilities", () => {
|
||||
const model = syntheticModel({ id: "gpt-5.4-custom" });
|
||||
assert.equal(model.capabilities, undefined);
|
||||
|
||||
const [patched] = applyCapabilityPatches([model]);
|
||||
assert.equal(patched.capabilities?.supportsXhigh, true);
|
||||
assert.equal(patched.capabilities?.supportsServiceTier, true);
|
||||
});
|
||||
|
||||
it("patches a GPT-5.2 model", () => {
|
||||
const model = syntheticModel({ id: "gpt-5.2" });
|
||||
const [patched] = applyCapabilityPatches([model]);
|
||||
assert.equal(patched.capabilities?.supportsXhigh, true);
|
||||
});
|
||||
|
||||
it("patches an Anthropic Opus 4.6 model", () => {
|
||||
const model = syntheticModel({
|
||||
id: "claude-opus-4-6-20260301",
|
||||
api: "anthropic-messages" as Api,
|
||||
});
|
||||
const [patched] = applyCapabilityPatches([model]);
|
||||
assert.equal(patched.capabilities?.supportsXhigh, true);
|
||||
// Opus should not get supportsServiceTier
|
||||
assert.equal(patched.capabilities?.supportsServiceTier, undefined);
|
||||
});
|
||||
|
||||
it("preserves explicit capabilities over patches", () => {
|
||||
const model = syntheticModel({
|
||||
id: "gpt-5.4-custom",
|
||||
capabilities: { supportsXhigh: false, charsPerToken: 3 },
|
||||
});
|
||||
const [patched] = applyCapabilityPatches([model]);
|
||||
// Explicit supportsXhigh: false wins over patch's true
|
||||
assert.equal(patched.capabilities?.supportsXhigh, false);
|
||||
// Patch fills in supportsServiceTier since it wasn't explicitly set
|
||||
assert.equal(patched.capabilities?.supportsServiceTier, true);
|
||||
// Explicit charsPerToken is preserved
|
||||
assert.equal(patched.capabilities?.charsPerToken, 3);
|
||||
});
|
||||
|
||||
it("does not modify models that match no patches", () => {
|
||||
const model = syntheticModel({ id: "gemini-2.5-pro" });
|
||||
const [patched] = applyCapabilityPatches([model]);
|
||||
assert.equal(patched.capabilities, undefined);
|
||||
// Should return the same reference when unpatched
|
||||
assert.equal(patched, model);
|
||||
});
|
||||
|
||||
it("is idempotent — re-applying patches produces the same result", () => {
|
||||
const model = syntheticModel({ id: "gpt-5.3" });
|
||||
const first = applyCapabilityPatches([model]);
|
||||
const second = applyCapabilityPatches(first);
|
||||
assert.deepEqual(first[0].capabilities, second[0].capabilities);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
import { CUSTOM_MODELS } from "./models.custom.js";
|
||||
import { MODELS } from "./models.generated.js";
|
||||
import type {
|
||||
Api,
|
||||
KnownProvider,
|
||||
Model,
|
||||
ModelCapabilities,
|
||||
Usage,
|
||||
} from "./types.js";
|
||||
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
function isHiddenBuiltInModelId(id: string): boolean {
|
||||
return id.endsWith("-customtools");
|
||||
}
|
||||
|
||||
// Initialize registry from auto-generated MODELS (models.dev catalog)
|
||||
for (const [provider, models] of Object.entries(MODELS)) {
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
if (isHiddenBuiltInModelId(id)) continue;
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
// Merge manually-maintained custom providers that are NOT in models.dev.
|
||||
// Custom models are additive — they never overwrite generated entries.
|
||||
// See: https://github.com/singularity-forge/sf-run/issues/2339
|
||||
for (const [provider, models] of Object.entries(CUSTOM_MODELS)) {
|
||||
if (!modelRegistry.has(provider)) {
|
||||
modelRegistry.set(provider, new Map<string, Model<Api>>());
|
||||
}
|
||||
const providerModels = modelRegistry.get(provider)!;
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
if (!providerModels.has(id)) {
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const kimiCodingModels = modelRegistry.get("kimi-coding");
|
||||
const kimiK26 = kimiCodingModels?.get("kimi-k2.6");
|
||||
if (kimiCodingModels && kimiK26 && !kimiCodingModels.has("kimi-for-coding")) {
|
||||
kimiCodingModels.set("kimi-for-coding", {
|
||||
...kimiK26,
|
||||
id: "kimi-for-coding",
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Capability Patches ───────────────────────────────────────────────────────
|
||||
//
|
||||
// Declare capabilities for models that pre-date the `capabilities` field or
|
||||
// that live in the auto-generated catalog (models.generated.ts) which we
|
||||
// cannot edit directly. Pattern-matching on model IDs is acceptable HERE
|
||||
// because this is the single source of truth — call sites must never repeat it.
|
||||
//
|
||||
// Add new entries as additional capabilities emerge. Existing models that
|
||||
// define `capabilities` in their model definition take precedence (the patch
|
||||
// only fills in fields that are not already set).
|
||||
|
||||
type CapabilityPatch = {
|
||||
match: (m: Model<Api>) => boolean;
|
||||
caps: ModelCapabilities;
|
||||
};
|
||||
|
||||
const CAPABILITY_PATCHES: CapabilityPatch[] = [
|
||||
// GPT-5.x supports xhigh thinking and OpenAI service tiers
|
||||
{
|
||||
match: (m) =>
|
||||
m.id.includes("gpt-5.2") ||
|
||||
m.id.includes("gpt-5.3") ||
|
||||
m.id.includes("gpt-5.4"),
|
||||
caps: { supportsXhigh: true, supportsServiceTier: true },
|
||||
},
|
||||
// Anthropic Opus 4.6 supports xhigh thinking
|
||||
{
|
||||
match: (m) =>
|
||||
m.api === "anthropic-messages" &&
|
||||
(m.id.includes("opus-4-6") || m.id.includes("opus-4.6")),
|
||||
caps: { supportsXhigh: true },
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Apply capability patches to a list of models.
|
||||
*
|
||||
* Models constructed outside the static pi-ai registry (custom models from
|
||||
* models.json, extension-registered models, discovered models) do not pass
|
||||
* through the module-init patch loop. Call this function after assembling
|
||||
* any model list to ensure capabilities are set correctly.
|
||||
*
|
||||
* Explicit `capabilities` already set on a model take precedence over patches.
|
||||
*/
|
||||
export function applyCapabilityPatches(models: Model<Api>[]): Model<Api>[] {
|
||||
return models.map((model) => {
|
||||
for (const patch of CAPABILITY_PATCHES) {
|
||||
if (patch.match(model)) {
|
||||
return {
|
||||
...model,
|
||||
capabilities: { ...patch.caps, ...model.capabilities },
|
||||
};
|
||||
}
|
||||
}
|
||||
return model;
|
||||
});
|
||||
}
|
||||
|
||||
// Apply patches to the static registry at module load
|
||||
for (const [, providerModels] of modelRegistry) {
|
||||
for (const [id, model] of providerModels) {
|
||||
for (const patch of CAPABILITY_PATCHES) {
|
||||
if (patch.match(model)) {
|
||||
providerModels.set(id, {
|
||||
...model,
|
||||
capabilities: { ...patch.caps, ...model.capabilities },
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Providers that have entries in the generated MODELS constant */
|
||||
type GeneratedProvider = keyof typeof MODELS & KnownProvider;
|
||||
|
||||
type ModelApi<
|
||||
TProvider extends GeneratedProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi }
|
||||
? TApi extends Api
|
||||
? TApi
|
||||
: never
|
||||
: never;
|
||||
|
||||
export function getModel<
|
||||
TProvider extends GeneratedProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>> {
|
||||
const providerModels = modelRegistry.get(provider);
|
||||
return providerModels?.get(modelId as string) as Model<
|
||||
ModelApi<TProvider, TModelId>
|
||||
>;
|
||||
}
|
||||
|
||||
export function getProviders(): KnownProvider[] {
|
||||
return Array.from(modelRegistry.keys()) as KnownProvider[];
|
||||
}
|
||||
|
||||
export function getModels<TProvider extends KnownProvider>(
|
||||
provider: TProvider,
|
||||
): Model<Api>[] {
|
||||
const models = modelRegistry.get(provider);
|
||||
return models ? (Array.from(models.values()) as Model<Api>[]) : [];
|
||||
}
|
||||
|
||||
export function calculateCost<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
usage: Usage,
|
||||
): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
|
||||
usage.cost.total =
|
||||
usage.cost.input +
|
||||
usage.cost.output +
|
||||
usage.cost.cacheRead +
|
||||
usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports xhigh thinking level.
|
||||
*
|
||||
* Reads from `model.capabilities.supportsXhigh` — set via CAPABILITY_PATCHES
|
||||
* for generated models or declared directly in custom model definitions.
|
||||
* Do not add model-ID or provider-name checks here; update CAPABILITY_PATCHES instead.
|
||||
*/
|
||||
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
|
||||
return model.capabilities?.supportsXhigh ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two models are equal by comparing both their id and provider.
|
||||
* Returns false if either model is null or undefined.
|
||||
*/
|
||||
export function modelsAreEqual<TApi extends Api>(
|
||||
a: Model<TApi> | null | undefined,
|
||||
b: Model<TApi> | null | undefined,
|
||||
): boolean {
|
||||
if (!a || !b) return false;
|
||||
return a.id === b.id && a.provider === b.provider;
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
export * from "./utils/oauth/index.js";
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
/**
|
||||
* TDD Red Phase — Bug #4392 / Pre-existing Bug #4352
|
||||
*
|
||||
* `supportsAdaptiveThinking()` in amazon-bedrock.ts is missing opus-4-7,
|
||||
* sonnet-4-7, and haiku-4-5. These tests FAIL until the bug is fixed.
|
||||
*
|
||||
* Related: #4392 (opus-4-7 adaptive thinking not recognised on Bedrock)
|
||||
* #4352 (pre-existing: only opus-4-6 / sonnet-4-6 whitelisted)
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import type { Model } from "../types.js";
|
||||
import {
|
||||
type BedrockOptions,
|
||||
buildAdditionalModelRequestFields,
|
||||
mapThinkingLevelToEffort,
|
||||
supportsAdaptiveThinking,
|
||||
} from "./amazon-bedrock.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function makeModel(id: string): Model<"bedrock-converse-stream"> {
|
||||
return {
|
||||
id,
|
||||
name: id,
|
||||
api: "bedrock-converse-stream",
|
||||
provider: "amazon-bedrock" as any,
|
||||
baseUrl: "",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 32000,
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// supportsAdaptiveThinking — RED tests (#4392 / #4352)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("supportsAdaptiveThinking — Bug #4392 / #4352: missing models", () => {
|
||||
// These two already pass (regression guard):
|
||||
it("returns true for opus-4-6 (hyphen, Bedrock ARN style)", () => {
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-opus-4-6-20250514-v1:0"),
|
||||
);
|
||||
});
|
||||
|
||||
it("returns true for sonnet-4-6 (hyphen)", () => {
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-sonnet-4-6-20250514-v1:0"),
|
||||
);
|
||||
});
|
||||
|
||||
// --- RED: the following FAIL because opus-4-7 / sonnet-4-7 / haiku-4-5 are missing ---
|
||||
|
||||
it("[#4392] returns true for opus-4-7 (hyphen, Bedrock ARN style)", () => {
|
||||
// FAILS: supportsAdaptiveThinking does not include 'opus-4-7'
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-opus-4-7-20250514-v1:0"),
|
||||
"opus-4-7 should support adaptive thinking (bug #4392)",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4392] returns true for opus-4-7 (dot separator)", () => {
|
||||
// FAILS: supportsAdaptiveThinking does not include 'opus-4.7'
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-opus-4.7-20250514-v1:0"),
|
||||
"opus-4.7 (dot) should support adaptive thinking (bug #4392)",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4352] returns true for sonnet-4-7 (hyphen)", () => {
|
||||
// FAILS: supportsAdaptiveThinking does not include 'sonnet-4-7'
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-sonnet-4-7-20250514-v1:0"),
|
||||
"sonnet-4-7 should support adaptive thinking (bug #4352)",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4352] returns true for haiku-4-5 (hyphen)", () => {
|
||||
// FAILS: supportsAdaptiveThinking does not include 'haiku-4-5'
|
||||
assert.ok(
|
||||
supportsAdaptiveThinking("anthropic.claude-haiku-4-5-20250514-v1:0"),
|
||||
"haiku-4-5 should support adaptive thinking (bug #4352)",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildAdditionalModelRequestFields — adaptive thinking output for opus-4-7
|
||||
// Tests go through the public API surface to validate end-to-end behaviour.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("buildAdditionalModelRequestFields — Bug #4392: opus-4-7 must use adaptive thinking", () => {
|
||||
const options: BedrockOptions = { reasoning: "high" };
|
||||
|
||||
it("[#4392] opus-4-7 Bedrock ARN → thinking.type === 'adaptive' (not budget_tokens)", () => {
|
||||
const model = makeModel("anthropic.claude-opus-4-7-20250514-v1:0");
|
||||
const fields = buildAdditionalModelRequestFields(model, options);
|
||||
// FAILS: because supportsAdaptiveThinking returns false for opus-4-7,
|
||||
// the function returns { thinking: { type: "enabled", budget_tokens: ... } }
|
||||
assert.equal(
|
||||
fields?.thinking?.type,
|
||||
"adaptive",
|
||||
"opus-4-7 should produce thinking.type='adaptive', not budget_tokens",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4392] opus-4-7 dot separator → thinking.type === 'adaptive'", () => {
|
||||
const model = makeModel("anthropic.claude-opus-4.7-20250514-v1:0");
|
||||
const fields = buildAdditionalModelRequestFields(model, options);
|
||||
assert.equal(
|
||||
fields?.thinking?.type,
|
||||
"adaptive",
|
||||
"opus-4.7 (dot) should produce thinking.type='adaptive'",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4352] sonnet-4-7 → thinking.type === 'adaptive'", () => {
|
||||
const model = makeModel("anthropic.claude-sonnet-4-7-20250514-v1:0");
|
||||
const fields = buildAdditionalModelRequestFields(model, options);
|
||||
assert.equal(
|
||||
fields?.thinking?.type,
|
||||
"adaptive",
|
||||
"sonnet-4-7 should produce thinking.type='adaptive'",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4352] haiku-4-5 → thinking.type === 'adaptive'", () => {
|
||||
const model = makeModel("anthropic.claude-haiku-4-5-20250514-v1:0");
|
||||
const fields = buildAdditionalModelRequestFields(model, options);
|
||||
assert.equal(
|
||||
fields?.thinking?.type,
|
||||
"adaptive",
|
||||
"haiku-4-5 should produce thinking.type='adaptive'",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// mapThinkingLevelToEffort — RED test for xhigh on opus-4-7
|
||||
// The Bedrock version returns "max" (dead code path at line 411), whereas
|
||||
// the correct value is "xhigh" (as implemented in anthropic-shared.ts).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
describe("mapThinkingLevelToEffort — Bug #4392: opus-4-7 xhigh should return 'xhigh' not 'max'", () => {
|
||||
it("[#4392] maps xhigh → 'xhigh' for opus-4-7 (native xhigh support)", () => {
|
||||
// FAILS: current code returns "max" for opus-4-7 at line 411,
|
||||
// and in any case this code path is unreachable because
|
||||
// supportsAdaptiveThinking returns false for opus-4-7.
|
||||
// After the fix, supportsAdaptiveThinking will return true AND
|
||||
// mapThinkingLevelToEffort must return "xhigh" (not "max").
|
||||
const result = mapThinkingLevelToEffort(
|
||||
"xhigh",
|
||||
"anthropic.claude-opus-4-7-20250514-v1:0",
|
||||
);
|
||||
assert.equal(
|
||||
result,
|
||||
"xhigh",
|
||||
"opus-4-7 supports native xhigh effort — must not be clamped to 'max'",
|
||||
);
|
||||
});
|
||||
|
||||
it("[#4392] maps xhigh → 'max' for opus-4-6 (no native xhigh, clamped)", () => {
|
||||
// This already passes — regression guard.
|
||||
const result = mapThinkingLevelToEffort(
|
||||
"xhigh",
|
||||
"anthropic.claude-opus-4-6-20250514-v1:0",
|
||||
);
|
||||
assert.equal(result, "max");
|
||||
});
|
||||
|
||||
it("maps high → 'high' for opus-4-7 (not affected by bug)", () => {
|
||||
const result = mapThinkingLevelToEffort(
|
||||
"high",
|
||||
"anthropic.claude-opus-4-7-20250514-v1:0",
|
||||
);
|
||||
assert.equal(result, "high");
|
||||
});
|
||||
});
|
||||
|
|
@ -1,975 +0,0 @@
|
|||
import {
|
||||
BedrockRuntimeClient,
|
||||
type BedrockRuntimeClientConfig,
|
||||
StopReason as BedrockStopReason,
|
||||
type Tool as BedrockTool,
|
||||
CachePointType,
|
||||
CacheTTL,
|
||||
type ContentBlock,
|
||||
type ContentBlockDeltaEvent,
|
||||
type ContentBlockStartEvent,
|
||||
type ContentBlockStopEvent,
|
||||
ConversationRole,
|
||||
ConverseStreamCommand,
|
||||
type ConverseStreamMetadataEvent,
|
||||
ImageFormat,
|
||||
type Message,
|
||||
type SystemContentBlock,
|
||||
type ToolChoice,
|
||||
type ToolConfiguration,
|
||||
ToolResultStatus,
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
RequestedThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
adjustMaxTokensForThinking,
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
export interface BedrockOptions extends StreamOptions {
|
||||
region?: string;
|
||||
profile?: string;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
|
||||
reasoning?: RequestedThinkingLevel;
|
||||
/* Custom token budgets per thinking level. Overrides default budgets. */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
|
||||
interleavedThinking?: boolean;
|
||||
}
|
||||
|
||||
type Block = (TextContent | ThinkingContent | ToolCall) & {
|
||||
index?: number;
|
||||
partialJson?: string;
|
||||
};
|
||||
|
||||
export const streamBedrock: StreamFunction<
|
||||
"bedrock-converse-stream",
|
||||
BedrockOptions
|
||||
> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options: BedrockOptions = {},
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
const config: BedrockRuntimeClientConfig = {
|
||||
profile: options.profile,
|
||||
};
|
||||
|
||||
// in Node.js/Bun environment only
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
// Region resolution: explicit option > env vars > SDK default chain.
|
||||
// When AWS_PROFILE is set, we leave region undefined so the SDK can
|
||||
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
|
||||
const explicitRegion =
|
||||
options.region ||
|
||||
process.env.AWS_REGION ||
|
||||
process.env.AWS_DEFAULT_REGION;
|
||||
if (explicitRegion) {
|
||||
config.region = explicitRegion;
|
||||
} else if (!process.env.AWS_PROFILE) {
|
||||
config.region = "us-east-1";
|
||||
}
|
||||
|
||||
// Support proxies that don't need authentication
|
||||
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
|
||||
config.credentials = {
|
||||
accessKeyId: "dummy-access-key",
|
||||
secretAccessKey: "dummy-secret-key",
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.HTTP_PROXY ||
|
||||
process.env.HTTPS_PROXY ||
|
||||
process.env.NO_PROXY ||
|
||||
process.env.http_proxy ||
|
||||
process.env.https_proxy ||
|
||||
process.env.no_proxy
|
||||
) {
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
const proxyAgent = await import("proxy-agent");
|
||||
|
||||
const agent = new proxyAgent.ProxyAgent();
|
||||
|
||||
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
|
||||
// on `http2` module and has no support for http agent.
|
||||
// Use NodeHttpHandler to support http agent.
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
|
||||
httpAgent: agent,
|
||||
httpsAgent: agent,
|
||||
});
|
||||
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
|
||||
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
|
||||
}
|
||||
} else {
|
||||
// Non-Node environment (browser): fall back to us-east-1 since
|
||||
// there's no config file resolution available.
|
||||
config.region = options.region || "us-east-1";
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new BedrockRuntimeClient(config);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
let commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
inferenceConfig: {
|
||||
maxTokens: options.maxTokens,
|
||||
temperature: options.temperature,
|
||||
},
|
||||
toolConfig: convertToolConfig(
|
||||
context.tools,
|
||||
options.toolChoice,
|
||||
model,
|
||||
cacheRetention,
|
||||
),
|
||||
additionalModelRequestFields: buildAdditionalModelRequestFields(
|
||||
model,
|
||||
options,
|
||||
),
|
||||
};
|
||||
const nextCommandInput = await options?.onPayload?.(commandInput, model);
|
||||
if (nextCommandInput !== undefined) {
|
||||
commandInput = nextCommandInput as typeof commandInput;
|
||||
}
|
||||
const command = new ConverseStreamCommand(commandInput);
|
||||
|
||||
const response = await client.send(command, {
|
||||
abortSignal: options.signal,
|
||||
});
|
||||
|
||||
for await (const item of response.stream!) {
|
||||
if (item.messageStart) {
|
||||
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
|
||||
throw new Error(
|
||||
"Unexpected assistant message start but got user message start instead",
|
||||
);
|
||||
}
|
||||
stream.push({ type: "start", partial: output });
|
||||
} else if (item.contentBlockStart) {
|
||||
handleContentBlockStart(
|
||||
item.contentBlockStart,
|
||||
blocks,
|
||||
output,
|
||||
stream,
|
||||
);
|
||||
} else if (item.contentBlockDelta) {
|
||||
handleContentBlockDelta(
|
||||
item.contentBlockDelta,
|
||||
blocks,
|
||||
output,
|
||||
stream,
|
||||
);
|
||||
} else if (item.contentBlockStop) {
|
||||
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
|
||||
} else if (item.messageStop) {
|
||||
output.stopReason = mapStopReason(item.messageStop.stopReason);
|
||||
} else if (item.metadata) {
|
||||
handleMetadata(item.metadata, model, output);
|
||||
} else if (item.internalServerException) {
|
||||
throw new Error(
|
||||
`Internal server error: ${item.internalServerException.message}`,
|
||||
);
|
||||
} else if (item.modelStreamErrorException) {
|
||||
throw new Error(
|
||||
`Model stream error: ${item.modelStreamErrorException.message}`,
|
||||
);
|
||||
} else if (item.validationException) {
|
||||
throw new Error(
|
||||
`Validation error: ${item.validationException.message}`,
|
||||
);
|
||||
} else if (item.throttlingException) {
|
||||
throw new Error(
|
||||
`Throttling error: ${item.throttlingException.message}`,
|
||||
);
|
||||
} else if (item.serviceUnavailableException) {
|
||||
throw new Error(
|
||||
`Service unavailable: ${item.serviceUnavailableException.message}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "error" || output.stopReason === "aborted") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
delete (block as Block).index;
|
||||
delete (block as Block).partialJson;
|
||||
}
|
||||
output.stopReason = options.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleBedrock: StreamFunction<
|
||||
"bedrock-converse-stream",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: undefined,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options.reasoning);
|
||||
|
||||
if (
|
||||
model.id.includes("anthropic.claude") ||
|
||||
model.id.includes("anthropic/claude")
|
||||
) {
|
||||
if (
|
||||
supportsAdaptiveThinking(model.id, model.name) &&
|
||||
isAutoReasoning(options.reasoning)
|
||||
) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
if (supportsAdaptiveThinking(model.id, model.name)) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: effectiveReasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
effectiveReasoning!,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
reasoning: effectiveReasoning,
|
||||
thinkingBudgets: {
|
||||
...(options.thinkingBudgets || {}),
|
||||
[clampReasoning(effectiveReasoning)!]: adjusted.thinkingBudget,
|
||||
},
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: effectiveReasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
};
|
||||
|
||||
function handleContentBlockStart(
|
||||
event: ContentBlockStartEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = event.contentBlockIndex!;
|
||||
const start = event.start;
|
||||
|
||||
if (start?.toolUse) {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: start.toolUse.toolUseId || "",
|
||||
name: start.toolUse.name || "",
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blocks.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockDelta(
|
||||
event: ContentBlockDeltaEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const contentBlockIndex = event.contentBlockIndex!;
|
||||
const delta = event.delta;
|
||||
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
|
||||
let block = blocks[index];
|
||||
|
||||
if (delta?.text !== undefined) {
|
||||
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
|
||||
if (!block) {
|
||||
const newBlock: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
output.content.push(newBlock);
|
||||
index = blocks.length - 1;
|
||||
block = blocks[index];
|
||||
stream.push({ type: "text_start", contentIndex: index, partial: output });
|
||||
}
|
||||
if (block.type === "text") {
|
||||
block.text += delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (delta?.toolUse && block?.type === "toolCall") {
|
||||
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: delta.toolUse.input || "",
|
||||
partial: output,
|
||||
});
|
||||
} else if (delta?.reasoningContent) {
|
||||
let thinkingBlock = block;
|
||||
let thinkingIndex = index;
|
||||
|
||||
if (!thinkingBlock) {
|
||||
const newBlock: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: contentBlockIndex,
|
||||
};
|
||||
output.content.push(newBlock);
|
||||
thinkingIndex = blocks.length - 1;
|
||||
thinkingBlock = blocks[thinkingIndex];
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: thinkingIndex,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (thinkingBlock?.type === "thinking") {
|
||||
if (delta.reasoningContent.text) {
|
||||
thinkingBlock.thinking += delta.reasoningContent.text;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: thinkingIndex,
|
||||
delta: delta.reasoningContent.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
if (delta.reasoningContent.signature) {
|
||||
thinkingBlock.thinkingSignature =
|
||||
(thinkingBlock.thinkingSignature || "") +
|
||||
delta.reasoningContent.signature;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleMetadata(
|
||||
event: ConverseStreamMetadataEvent,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
if (event.usage) {
|
||||
output.usage.input = event.usage.inputTokens || 0;
|
||||
output.usage.output = event.usage.outputTokens || 0;
|
||||
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
|
||||
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
|
||||
output.usage.totalTokens =
|
||||
event.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockStop(
|
||||
event: ContentBlockStopEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
|
||||
const block = blocks[index];
|
||||
if (!block) return;
|
||||
delete (block as Block).index;
|
||||
|
||||
switch (block.type) {
|
||||
case "text":
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
case "toolCall":
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as Block).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks both model ID and model name to support application inference profiles
|
||||
* whose ARNs don't contain the model name.
|
||||
*/
|
||||
function getModelMatchCandidates(
|
||||
modelId: string,
|
||||
modelName?: string,
|
||||
): string[] {
|
||||
const values = modelName ? [modelId, modelName] : [modelId];
|
||||
return values.flatMap((value) => {
|
||||
const lower = value.toLowerCase();
|
||||
return [lower, lower.replace(/[\s_.:]+/g, "-")];
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports adaptive thinking (Opus 4.6/4.7, Sonnet 4.6/4.7, Haiku 4.5).
|
||||
* @internal exported for testing only
|
||||
*/
|
||||
export function supportsAdaptiveThinking(
|
||||
modelId: string,
|
||||
modelName?: string,
|
||||
): boolean {
|
||||
const candidates = getModelMatchCandidates(modelId, modelName);
|
||||
return candidates.some(
|
||||
(s) =>
|
||||
s.includes("opus-4-6") ||
|
||||
s.includes("opus-4-7") ||
|
||||
s.includes("sonnet-4-6") ||
|
||||
s.includes("sonnet-4-7") ||
|
||||
s.includes("haiku-4-5"),
|
||||
);
|
||||
}
|
||||
|
||||
/** @internal exported for testing only */
|
||||
export function mapThinkingLevelToEffort(
|
||||
level: SimpleStreamOptions["reasoning"],
|
||||
modelId: string,
|
||||
modelName?: string,
|
||||
): "low" | "medium" | "high" | "xhigh" | "max" {
|
||||
const candidates = getModelMatchCandidates(modelId, modelName);
|
||||
switch (level) {
|
||||
case "auto":
|
||||
return "medium";
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
if (candidates.some((s) => s.includes("opus-4-7"))) return "xhigh";
|
||||
if (candidates.some((s) => s.includes("opus-4-6"))) return "max";
|
||||
return "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports prompt caching.
|
||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||
*/
|
||||
function supportsPromptCaching(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
): boolean {
|
||||
if (model.cost.cacheRead || model.cost.cacheWrite) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const candidates = getModelMatchCandidates(model.id, model.name);
|
||||
const hasClaudeRef = candidates.some((s) => s.includes("claude"));
|
||||
if (!hasClaudeRef) {
|
||||
return false;
|
||||
}
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
if (candidates.some((s) => s.includes("-4-"))) return true;
|
||||
// Claude 3.7 Sonnet
|
||||
if (candidates.some((s) => s.includes("claude-3-7-sonnet"))) return true;
|
||||
// Claude 3.5 Haiku
|
||||
if (candidates.some((s) => s.includes("claude-3-5-haiku"))) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports thinking signatures in reasoningContent.
|
||||
* Only Anthropic Claude models support the signature field.
|
||||
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
|
||||
* "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
*/
|
||||
function supportsThinkingSignature(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
): boolean {
|
||||
const id = model.id.toLowerCase();
|
||||
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
|
||||
}
|
||||
|
||||
function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [
|
||||
{ text: sanitizeSurrogates(systemPrompt) },
|
||||
];
|
||||
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
blocks.push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function normalizeToolCallId(id: string): string {
|
||||
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
"bedrock-converse-stream",
|
||||
);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const m = transformedMessages[i];
|
||||
|
||||
switch (m.role) {
|
||||
case "user":
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content:
|
||||
typeof m.content === "string"
|
||||
? [{ text: sanitizeSurrogates(m.content) }]
|
||||
: m.content.map((c) => {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
return { text: sanitizeSurrogates(c.text) };
|
||||
case "image":
|
||||
return { image: createImageBlock(c.mimeType, c.data) };
|
||||
default:
|
||||
throw new Error("Unknown user content type");
|
||||
}
|
||||
}),
|
||||
});
|
||||
break;
|
||||
case "assistant": {
|
||||
// Skip assistant messages with empty content (e.g., from aborted requests)
|
||||
// Bedrock rejects messages with empty content arrays
|
||||
if (m.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const contentBlocks: ContentBlock[] = [];
|
||||
for (const c of m.content) {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
// Skip empty text blocks
|
||||
if (c.text.trim().length === 0) continue;
|
||||
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
|
||||
break;
|
||||
case "toolCall":
|
||||
contentBlocks.push({
|
||||
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
// Skip empty thinking blocks
|
||||
if (c.thinking.trim().length === 0) continue;
|
||||
// Only Anthropic models support the signature field in reasoningText.
|
||||
// For other models, we omit the signature to avoid errors like:
|
||||
// "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
if (supportsThinkingSignature(model)) {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: {
|
||||
text: sanitizeSurrogates(c.thinking),
|
||||
signature: c.thinkingSignature,
|
||||
},
|
||||
},
|
||||
});
|
||||
} else {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: { text: sanitizeSurrogates(c.thinking) },
|
||||
},
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unknown assistant content type");
|
||||
}
|
||||
}
|
||||
// Skip if all content blocks were filtered out
|
||||
if (contentBlocks.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.push({
|
||||
role: ConversationRole.ASSISTANT,
|
||||
content: contentBlocks,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "toolResult": {
|
||||
// Collect all consecutive toolResult messages into a single user message
|
||||
// Bedrock requires all tool results to be in one message
|
||||
const toolResults: ContentBlock.ToolResultMember[] = [];
|
||||
|
||||
// Add current tool result with all content blocks combined
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: m.toolCallId,
|
||||
content: m.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: m.isError
|
||||
? ToolResultStatus.ERROR
|
||||
: ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult"
|
||||
) {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: nextMsg.toolCallId,
|
||||
content: nextMsg.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: nextMsg.isError
|
||||
? ToolResultStatus.ERROR
|
||||
: ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content: toolResults,
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error("Unknown message role");
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (
|
||||
cacheRetention !== "none" &&
|
||||
supportsPromptCaching(model) &&
|
||||
result.length > 0
|
||||
) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function convertToolConfig(
|
||||
tools: Tool[] | undefined,
|
||||
toolChoice: BedrockOptions["toolChoice"],
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): ToolConfiguration | undefined {
|
||||
if (!tools?.length || toolChoice === "none") return undefined;
|
||||
|
||||
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
|
||||
toolSpec: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
inputSchema: { json: tool.parameters },
|
||||
},
|
||||
}));
|
||||
|
||||
// Add cachePoint after last tool for supported models
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
bedrockTools.push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
} as any);
|
||||
}
|
||||
|
||||
let bedrockToolChoice: ToolChoice | undefined;
|
||||
switch (toolChoice) {
|
||||
case "auto":
|
||||
bedrockToolChoice = { auto: {} };
|
||||
break;
|
||||
case "any":
|
||||
bedrockToolChoice = { any: {} };
|
||||
break;
|
||||
default:
|
||||
if (toolChoice?.type === "tool") {
|
||||
bedrockToolChoice = { tool: { name: toolChoice.name } };
|
||||
}
|
||||
}
|
||||
|
||||
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
|
||||
}
|
||||
|
||||
function mapStopReason(reason: string | undefined): StopReason {
|
||||
switch (reason) {
|
||||
case BedrockStopReason.END_TURN:
|
||||
case BedrockStopReason.STOP_SEQUENCE:
|
||||
return "stop";
|
||||
case BedrockStopReason.MAX_TOKENS:
|
||||
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
|
||||
return "length";
|
||||
case BedrockStopReason.TOOL_USE:
|
||||
return "toolUse";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
/** @internal exported for testing only */
|
||||
export function buildAdditionalModelRequestFields(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
options: BedrockOptions,
|
||||
): Record<string, any> | undefined {
|
||||
if (!options.reasoning || !model.reasoning) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (
|
||||
model.id.includes("anthropic.claude") ||
|
||||
model.id.includes("anthropic/claude")
|
||||
) {
|
||||
const result: Record<string, any> = supportsAdaptiveThinking(
|
||||
model.id,
|
||||
model.name,
|
||||
)
|
||||
? options.reasoning === "auto"
|
||||
? {
|
||||
thinking: { type: "adaptive" },
|
||||
}
|
||||
: {
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: {
|
||||
effort: mapThinkingLevelToEffort(
|
||||
options.reasoning,
|
||||
model.id,
|
||||
model.name,
|
||||
),
|
||||
},
|
||||
}
|
||||
: (() => {
|
||||
const defaultBudgets: Record<ThinkingLevel, number> = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
|
||||
};
|
||||
|
||||
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
|
||||
const normalizedReasoning =
|
||||
options.reasoning === "auto" ? "medium" : options.reasoning;
|
||||
const level =
|
||||
normalizedReasoning === "xhigh" ? "high" : normalizedReasoning;
|
||||
const budget =
|
||||
options.thinkingBudgets?.[level] ??
|
||||
defaultBudgets[normalizedReasoning];
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budget_tokens: budget,
|
||||
},
|
||||
};
|
||||
})();
|
||||
|
||||
if (
|
||||
!supportsAdaptiveThinking(model.id, model.name) &&
|
||||
(options.interleavedThinking ?? true)
|
||||
) {
|
||||
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function createImageBlock(mimeType: string, data: string) {
|
||||
let format: ImageFormat;
|
||||
switch (mimeType) {
|
||||
case "image/jpeg":
|
||||
case "image/jpg":
|
||||
format = ImageFormat.JPEG;
|
||||
break;
|
||||
case "image/png":
|
||||
format = ImageFormat.PNG;
|
||||
break;
|
||||
case "image/gif":
|
||||
format = ImageFormat.GIF;
|
||||
break;
|
||||
case "image/webp":
|
||||
format = ImageFormat.WEBP;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown image type: ${mimeType}`);
|
||||
}
|
||||
|
||||
const binaryString = atob(data);
|
||||
const bytes = new Uint8Array(binaryString.length);
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
|
||||
return { source: { bytes }, format };
|
||||
}
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { dirname, join } from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import { test } from "vitest";
|
||||
|
||||
import {
|
||||
resolveAnthropicBaseUrl,
|
||||
usesAnthropicBearerAuth,
|
||||
} from "./anthropic.js";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
test("usesAnthropicBearerAuth covers Bearer-only Anthropic-compatible providers (#3783)", () => {
|
||||
assert.equal(usesAnthropicBearerAuth("alibaba-coding-plan"), true);
|
||||
assert.equal(usesAnthropicBearerAuth("minimax"), true);
|
||||
assert.equal(usesAnthropicBearerAuth("minimax-cn"), true);
|
||||
assert.equal(usesAnthropicBearerAuth("longcat"), true);
|
||||
assert.equal(usesAnthropicBearerAuth("xiaomi"), true);
|
||||
assert.equal(usesAnthropicBearerAuth("anthropic"), false);
|
||||
});
|
||||
|
||||
test("createClient routes Bearer-auth providers through authToken (#3783)", () => {
|
||||
const source = readFileSync(
|
||||
join(__dirname, "..", "..", "src", "providers", "anthropic.ts"),
|
||||
"utf-8",
|
||||
);
|
||||
assert.ok(
|
||||
source.includes(
|
||||
"const usesBearerAuth = usesAnthropicBearerAuth(model.provider);",
|
||||
),
|
||||
"createClient should derive auth mode from usesAnthropicBearerAuth",
|
||||
);
|
||||
assert.ok(
|
||||
source.includes("apiKey: usesBearerAuth ? null : apiKey"),
|
||||
"Bearer-auth providers should skip x-api-key auth",
|
||||
);
|
||||
assert.ok(
|
||||
source.includes("authToken: usesBearerAuth ? apiKey : undefined"),
|
||||
"Bearer-auth providers should send authToken instead",
|
||||
);
|
||||
});
|
||||
|
||||
// Minimal model stub — only the field resolveAnthropicBaseUrl cares about.
|
||||
const stubModel = { baseUrl: "https://api.anthropic.com" } as Parameters<
|
||||
typeof resolveAnthropicBaseUrl
|
||||
>[0];
|
||||
|
||||
test("resolveAnthropicBaseUrl returns model.baseUrl when ANTHROPIC_BASE_URL is unset (#4140)", () => {
|
||||
const saved = process.env.ANTHROPIC_BASE_URL;
|
||||
try {
|
||||
delete process.env.ANTHROPIC_BASE_URL;
|
||||
assert.equal(
|
||||
resolveAnthropicBaseUrl(stubModel),
|
||||
"https://api.anthropic.com",
|
||||
);
|
||||
} finally {
|
||||
if (saved === undefined) delete process.env.ANTHROPIC_BASE_URL;
|
||||
else process.env.ANTHROPIC_BASE_URL = saved;
|
||||
}
|
||||
});
|
||||
|
||||
test("resolveAnthropicBaseUrl prefers ANTHROPIC_BASE_URL over model.baseUrl (#4140)", () => {
|
||||
const saved = process.env.ANTHROPIC_BASE_URL;
|
||||
try {
|
||||
process.env.ANTHROPIC_BASE_URL = "https://proxy.example.com";
|
||||
assert.equal(
|
||||
resolveAnthropicBaseUrl(stubModel),
|
||||
"https://proxy.example.com",
|
||||
);
|
||||
} finally {
|
||||
if (saved === undefined) delete process.env.ANTHROPIC_BASE_URL;
|
||||
else process.env.ANTHROPIC_BASE_URL = saved;
|
||||
}
|
||||
});
|
||||
|
||||
test("resolveAnthropicBaseUrl ignores whitespace-only ANTHROPIC_BASE_URL (#4140)", () => {
|
||||
const saved = process.env.ANTHROPIC_BASE_URL;
|
||||
try {
|
||||
process.env.ANTHROPIC_BASE_URL = " ";
|
||||
assert.equal(
|
||||
resolveAnthropicBaseUrl(stubModel),
|
||||
"https://api.anthropic.com",
|
||||
);
|
||||
} finally {
|
||||
if (saved === undefined) delete process.env.ANTHROPIC_BASE_URL;
|
||||
else process.env.ANTHROPIC_BASE_URL = saved;
|
||||
}
|
||||
});
|
||||
|
||||
test("createClient uses resolveAnthropicBaseUrl for all auth paths (#4140)", () => {
|
||||
const source = readFileSync(
|
||||
join(__dirname, "..", "..", "src", "providers", "anthropic.ts"),
|
||||
"utf-8",
|
||||
);
|
||||
const directUsages = (source.match(/baseURL:\s*model\.baseUrl/g) ?? [])
|
||||
.length;
|
||||
assert.equal(
|
||||
directUsages,
|
||||
0,
|
||||
"createClient must not use model.baseUrl directly — use resolveAnthropicBaseUrl(model)",
|
||||
);
|
||||
assert.ok(
|
||||
source.includes("baseURL: resolveAnthropicBaseUrl(model)"),
|
||||
"all createClient branches should pass baseURL through resolveAnthropicBaseUrl",
|
||||
);
|
||||
});
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { convertTools, mapStopReason } from "./anthropic-shared.js";
|
||||
|
||||
const makeTool = (name: string) =>
|
||||
({
|
||||
name,
|
||||
description: `desc for ${name}`,
|
||||
parameters: {
|
||||
type: "object" as const,
|
||||
properties: { arg: { type: "string" } },
|
||||
required: ["arg"],
|
||||
},
|
||||
}) as any;
|
||||
|
||||
describe("convertTools cache_control", () => {
|
||||
it("adds cache_control to the last tool when cacheControl is provided", () => {
|
||||
const tools = [makeTool("Read"), makeTool("Write"), makeTool("Edit")];
|
||||
const cacheControl = { type: "ephemeral" as const };
|
||||
const result = convertTools(tools, false, cacheControl);
|
||||
|
||||
assert.equal(result.length, 3);
|
||||
assert.equal((result[0] as any).cache_control, undefined);
|
||||
assert.equal((result[1] as any).cache_control, undefined);
|
||||
assert.deepEqual((result[2] as any).cache_control, { type: "ephemeral" });
|
||||
});
|
||||
|
||||
it("does not add cache_control when cacheControl is undefined", () => {
|
||||
const tools = [makeTool("Read"), makeTool("Write")];
|
||||
const result = convertTools(tools, false);
|
||||
|
||||
for (const tool of result) {
|
||||
assert.equal((tool as any).cache_control, undefined);
|
||||
}
|
||||
});
|
||||
|
||||
it("handles empty tools array without error", () => {
|
||||
const result = convertTools([], false, { type: "ephemeral" });
|
||||
assert.equal(result.length, 0);
|
||||
});
|
||||
|
||||
it("passes through ttl when provided", () => {
|
||||
const tools = [makeTool("Read")];
|
||||
const cacheControl = { type: "ephemeral" as const, ttl: "1h" as const };
|
||||
const result = convertTools(tools, false, cacheControl);
|
||||
|
||||
assert.deepEqual((result[0] as any).cache_control, {
|
||||
type: "ephemeral",
|
||||
ttl: "1h",
|
||||
});
|
||||
});
|
||||
|
||||
it("single tool gets cache_control", () => {
|
||||
const tools = [makeTool("Read")];
|
||||
const result = convertTools(tools, false, { type: "ephemeral" });
|
||||
|
||||
assert.equal(result.length, 1);
|
||||
assert.deepEqual((result[0] as any).cache_control, { type: "ephemeral" });
|
||||
});
|
||||
});
|
||||
|
||||
describe("mapStopReason", () => {
|
||||
it("maps end_turn to stop", () => {
|
||||
assert.equal(mapStopReason("end_turn"), "stop");
|
||||
});
|
||||
|
||||
it("maps max_tokens to length", () => {
|
||||
assert.equal(mapStopReason("max_tokens"), "length");
|
||||
});
|
||||
|
||||
it("maps tool_use to toolUse", () => {
|
||||
assert.equal(mapStopReason("tool_use"), "toolUse");
|
||||
});
|
||||
|
||||
it("maps pause_turn to pauseTurn (not stop)", () => {
|
||||
// pause_turn means the server paused a long-running turn (e.g. native
|
||||
// web search hit its iteration limit). Mapping it to "stop" causes the
|
||||
// agent loop to exit, leaving an incomplete server_tool_use block in
|
||||
// history which triggers a 400 on the next request.
|
||||
assert.equal(mapStopReason("pause_turn"), "pauseTurn");
|
||||
});
|
||||
|
||||
it("throws on unknown stop reason", () => {
|
||||
assert.throws(() => mapStopReason("bogus"), /Unhandled stop reason/);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,937 +0,0 @@
|
|||
/**
|
||||
* Shared utilities for Anthropic providers (direct API and Vertex AI).
|
||||
*/
|
||||
import type Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
CacheControlEphemeral,
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
RawMessageStreamEvent,
|
||||
ServerToolUseBlockParam,
|
||||
WebSearchToolResultBlockParam,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
ServerToolUseContent,
|
||||
StopReason,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
WebSearchResultContent,
|
||||
} from "../types.js";
|
||||
|
||||
/** API types that use the Anthropic Messages protocol */
|
||||
export type AnthropicApi = "anthropic-messages" | "anthropic-vertex";
|
||||
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseAnthropicSSE } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import {
|
||||
hasXmlParameterTags,
|
||||
repairToolJson,
|
||||
} from "../utils/repair-tool-json.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
export type AnthropicEffort = "low" | "medium" | "high" | "max";
|
||||
|
||||
export interface AnthropicOptions extends StreamOptions {
|
||||
thinkingEnabled?: boolean;
|
||||
thinkingBudgetTokens?: number;
|
||||
effort?: AnthropicEffort;
|
||||
interleavedThinking?: boolean;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
const claudeCodeTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Bash",
|
||||
"Grep",
|
||||
"Glob",
|
||||
"AskUserQuestion",
|
||||
"EnterPlanMode",
|
||||
"ExitPlanMode",
|
||||
"KillShell",
|
||||
"NotebookEdit",
|
||||
"Skill",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"TodoWrite",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
];
|
||||
|
||||
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
|
||||
|
||||
export const toClaudeCodeName = (name: string) =>
|
||||
ccToolLookup.get(name.toLowerCase()) ?? name;
|
||||
export const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
|
||||
if (tools && tools.length > 0) {
|
||||
const lowerName = name.toLowerCase();
|
||||
const matchedTool = tools.find(
|
||||
(tool) => tool.name.toLowerCase() === lowerName,
|
||||
);
|
||||
if (matchedTool) return matchedTool.name;
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
export function getCacheControl(
|
||||
baseUrl: string,
|
||||
cacheRetention?: CacheRetention,
|
||||
): {
|
||||
retention: CacheRetention;
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" };
|
||||
} {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
const ttl =
|
||||
retention === "long" && baseUrl.includes("api.anthropic.com")
|
||||
? "1h"
|
||||
: undefined;
|
||||
return {
|
||||
retention,
|
||||
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||
};
|
||||
}
|
||||
|
||||
export function convertContentBlocks(content: (TextContent | ImageContent)[]):
|
||||
| string
|
||||
| Array<
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "image";
|
||||
source: {
|
||||
type: "base64";
|
||||
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
> {
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (!hasImages) {
|
||||
return sanitizeSurrogates(
|
||||
content.map((c) => (c as TextContent).text).join("\n"),
|
||||
);
|
||||
}
|
||||
|
||||
const blocks = content.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: sanitizeSurrogates(block.text),
|
||||
};
|
||||
}
|
||||
return {
|
||||
type: "image" as const,
|
||||
source: {
|
||||
type: "base64" as const,
|
||||
media_type: block.mimeType as
|
||||
| "image/jpeg"
|
||||
| "image/png"
|
||||
| "image/gif"
|
||||
| "image/webp",
|
||||
data: block.data,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const hasText = blocks.some((b) => b.type === "text");
|
||||
if (!hasText) {
|
||||
blocks.unshift({
|
||||
type: "text" as const,
|
||||
text: "(see attached image)",
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
export function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6") ||
|
||||
modelId.includes("sonnet-4-7") ||
|
||||
modelId.includes("sonnet-4.7") ||
|
||||
modelId.includes("haiku-4-5") ||
|
||||
modelId.includes("haiku-4.5")
|
||||
);
|
||||
}
|
||||
|
||||
export function mapThinkingLevelToEffort(
|
||||
level: string | undefined,
|
||||
modelId: string,
|
||||
): AnthropicEffort {
|
||||
switch (level) {
|
||||
case "auto":
|
||||
return "medium";
|
||||
case "minimal":
|
||||
return "low";
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6")
|
||||
? "max"
|
||||
: "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
export function isTransientNetworkError(error: unknown): boolean {
|
||||
if (!(error instanceof Error)) return false;
|
||||
const msg = error.message.toLowerCase();
|
||||
const code = (error as NodeJS.ErrnoException).code;
|
||||
return (
|
||||
code === "ECONNRESET" ||
|
||||
code === "EPIPE" ||
|
||||
code === "ETIMEDOUT" ||
|
||||
code === "ENOTFOUND" ||
|
||||
code === "EAI_AGAIN" ||
|
||||
msg.includes("connector_closed") ||
|
||||
msg.includes("socket hang up") ||
|
||||
msg.includes("network") ||
|
||||
(msg.includes("connection") && msg.includes("closed")) ||
|
||||
msg.includes("fetch failed")
|
||||
);
|
||||
}
|
||||
|
||||
export function extractRetryAfterMs(
|
||||
headers: Headers | { get(name: string): string | null },
|
||||
_errorText = "",
|
||||
): number | undefined {
|
||||
const normalizeDelay = (ms: number): number | undefined =>
|
||||
ms > 0 ? Math.ceil(ms + 1000) : undefined;
|
||||
|
||||
const retryAfter = headers.get("retry-after");
|
||||
if (retryAfter) {
|
||||
const seconds = Number(retryAfter);
|
||||
if (Number.isFinite(seconds)) {
|
||||
const delay = normalizeDelay(seconds * 1000);
|
||||
if (delay !== undefined) return delay;
|
||||
}
|
||||
const asDate = new Date(retryAfter).getTime();
|
||||
if (!Number.isNaN(asDate)) {
|
||||
const delay = normalizeDelay(asDate - Date.now());
|
||||
if (delay !== undefined) return delay;
|
||||
}
|
||||
}
|
||||
|
||||
for (const header of [
|
||||
"x-ratelimit-reset-requests",
|
||||
"x-ratelimit-reset-tokens",
|
||||
]) {
|
||||
const value = headers.get(header);
|
||||
if (value) {
|
||||
const resetSeconds = Number(value);
|
||||
if (Number.isFinite(resetSeconds)) {
|
||||
const delay = normalizeDelay(resetSeconds * 1000 - Date.now());
|
||||
if (delay !== undefined) return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export function normalizeToolCallId(id: string): string {
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
}
|
||||
|
||||
export function convertMessages(
|
||||
messages: Message[],
|
||||
model: Model<AnthropicApi>,
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
"anthropic-messages",
|
||||
);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as
|
||||
| "image/jpeg"
|
||||
| "image/png"
|
||||
| "image/gif"
|
||||
| "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
let filteredBlocks = !model?.input.includes("image")
|
||||
? blocks.filter((b) => b.type !== "image")
|
||||
: blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
if (block.redacted) {
|
||||
blocks.push({
|
||||
type: "redacted_thinking",
|
||||
data: block.thinkingSignature!,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
if (
|
||||
!block.thinkingSignature ||
|
||||
block.thinkingSignature.trim().length === 0
|
||||
) {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
} else {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: sanitizeSurrogates(block.thinking),
|
||||
signature: block.thinkingSignature,
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
|
||||
input: block.arguments ?? {},
|
||||
});
|
||||
} else if (block.type === "serverToolUse") {
|
||||
blocks.push({
|
||||
type: "server_tool_use",
|
||||
id: block.id,
|
||||
name: block.name as ServerToolUseBlockParam["name"],
|
||||
input: block.input ?? {},
|
||||
} as ServerToolUseBlockParam);
|
||||
} else if (block.type === "webSearchResult") {
|
||||
blocks.push({
|
||||
type: "web_search_tool_result",
|
||||
tool_use_id: block.toolUseId,
|
||||
content: block.content,
|
||||
} as WebSearchToolResultBlockParam);
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
const toolResults: ContentBlockParam[] = [];
|
||||
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: convertContentBlocks(msg.content),
|
||||
is_error: msg.isError,
|
||||
});
|
||||
|
||||
let j = i + 1;
|
||||
while (
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult"
|
||||
) {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: nextMsg.toolCallId,
|
||||
content: convertContentBlocks(nextMsg.content),
|
||||
is_error: nextMsg.isError,
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
|
||||
params.push({
|
||||
role: "user",
|
||||
content: toolResults,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (cacheControl && params.length > 0) {
|
||||
const lastMessage = params[params.length - 1];
|
||||
if (lastMessage.role === "user") {
|
||||
if (Array.isArray(lastMessage.content)) {
|
||||
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
|
||||
if (
|
||||
lastBlock &&
|
||||
(lastBlock.type === "text" ||
|
||||
lastBlock.type === "image" ||
|
||||
lastBlock.type === "tool_result")
|
||||
) {
|
||||
// TextBlockParam, ImageBlockParam, and ToolResultBlockParam all
|
||||
// carry cache_control?: CacheControlEphemeral | null — the type
|
||||
// guard above narrows to exactly those three variants.
|
||||
(
|
||||
lastBlock as { cache_control?: CacheControlEphemeral | null }
|
||||
).cache_control = cacheControl;
|
||||
}
|
||||
} else if (typeof lastMessage.content === "string") {
|
||||
lastMessage.content = [
|
||||
{
|
||||
type: "text",
|
||||
text: lastMessage.content,
|
||||
cache_control: cacheControl,
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
export function convertTools(
|
||||
tools: Tool[],
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
const result: Anthropic.Messages.Tool[] = tools.map((tool) => {
|
||||
// TSchema extends SchemaOptions which carries [prop: string]: any,
|
||||
// so .properties and .required are accessible without a cast.
|
||||
const jsonSchema = tool.parameters;
|
||||
|
||||
return {
|
||||
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: jsonSchema.properties || {},
|
||||
required: (jsonSchema.required as string[] | undefined) || [],
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// Add cache breakpoint to last tool — covers entire tool block.
|
||||
// Anthropic.Messages.Tool carries cache_control?: CacheControlEphemeral | null.
|
||||
if (cacheControl && result.length > 0) {
|
||||
result[result.length - 1].cache_control = cacheControl;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function buildParams(
|
||||
model: Model<AnthropicApi>,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(
|
||||
model.baseUrl,
|
||||
options?.cacheRetention,
|
||||
);
|
||||
const apiModelId = model.id.replace(/\[.*\]$/, "");
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: apiModelId,
|
||||
messages: convertMessages(
|
||||
context.messages,
|
||||
model,
|
||||
isOAuthToken,
|
||||
cacheControl,
|
||||
),
|
||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
params.tools = convertTools(context.tools, isOAuthToken, cacheControl);
|
||||
}
|
||||
|
||||
if (options?.thinkingEnabled && model.reasoning) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
params.thinking = { type: "adaptive" };
|
||||
if (options.effort) {
|
||||
params.output_config = { effort: options.effort };
|
||||
}
|
||||
} else if (model.capabilities?.thinkingNoBudget) {
|
||||
// Provider accepts {"type":"enabled"} without budget_tokens — model manages depth.
|
||||
// The Anthropic SDK type requires budget_tokens but the kimi-coding API does not,
|
||||
// so we bypass the SDK constraint via unknown to avoid falsely promising budget_tokens.
|
||||
(params as unknown as Record<string, unknown>).thinking = {
|
||||
type: "enabled",
|
||||
};
|
||||
} else {
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.metadata) {
|
||||
const userId = options.metadata.user_id;
|
||||
if (typeof userId === "string") {
|
||||
params.metadata = { user_id: userId };
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
export function mapStopReason(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "error";
|
||||
case "pause_turn":
|
||||
return "pauseTurn";
|
||||
case "stop_sequence":
|
||||
return "stop";
|
||||
case "sensitive":
|
||||
return "error";
|
||||
default:
|
||||
throw new Error(`Unhandled stop reason: ${reason}`);
|
||||
}
|
||||
}
|
||||
|
||||
export interface StreamAnthropicArgs {
|
||||
client: Anthropic;
|
||||
model: Model<AnthropicApi>;
|
||||
context: Context;
|
||||
isOAuthToken: boolean;
|
||||
options?: AnthropicOptions;
|
||||
AnthropicSdkClass?: typeof Anthropic;
|
||||
}
|
||||
|
||||
export function processAnthropicStream(
|
||||
stream: AssistantMessageEventStream,
|
||||
args: StreamAnthropicArgs,
|
||||
): void {
|
||||
const { client, model, context, isOAuthToken, options, AnthropicSdkClass } =
|
||||
args;
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
let params = buildParams(model, context, isOAuthToken, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as MessageCreateParamsStreaming;
|
||||
}
|
||||
const apiPromise = client.messages.create(
|
||||
{ ...params, stream: true },
|
||||
{ signal: options?.signal },
|
||||
);
|
||||
const response = await apiPromise.asResponse();
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (
|
||||
| ThinkingContent
|
||||
| TextContent
|
||||
| (ToolCall & { partialJson: string })
|
||||
| ServerToolUseContent
|
||||
| WebSearchResultContent
|
||||
) & { index: number };
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
for await (const rawEvent of parseAnthropicSSE(
|
||||
response,
|
||||
options?.signal,
|
||||
)) {
|
||||
const event = rawEvent as RawMessageStreamEvent;
|
||||
if (event.type === "message_start") {
|
||||
output.usage.input = event.message.usage.input_tokens || 0;
|
||||
output.usage.output = event.message.usage.output_tokens || 0;
|
||||
output.usage.cacheRead =
|
||||
event.message.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite =
|
||||
event.message.usage.cache_creation_input_tokens || 0;
|
||||
output.usage.totalTokens =
|
||||
output.usage.input +
|
||||
output.usage.output +
|
||||
output.usage.cacheRead +
|
||||
output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
} else if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
const block: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "redacted_thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "[Reasoning redacted]",
|
||||
thinkingSignature: event.content_block.data,
|
||||
redacted: true,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuthToken
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments:
|
||||
(event.content_block.input as Record<string, any>) ?? {},
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "server_tool_use") {
|
||||
const serverBlock = event.content_block;
|
||||
const block: Block = {
|
||||
type: "serverToolUse",
|
||||
id: serverBlock.id,
|
||||
name: serverBlock.name,
|
||||
input: serverBlock.input,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "server_tool_use",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
} else if (event.content_block.type === "web_search_tool_result") {
|
||||
const resultBlock = event.content_block;
|
||||
const block: Block = {
|
||||
type: "webSearchResult",
|
||||
toolUseId: resultBlock.tool_use_id,
|
||||
content: resultBlock.content,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({
|
||||
type: "web_search_result",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "text") {
|
||||
block.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "toolCall") {
|
||||
block.partialJson += event.delta.partial_json;
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.partial_json,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinkingSignature = block.thinkingSignature || "";
|
||||
block.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block) {
|
||||
// `index` is an internal bookkeeping field added at block creation
|
||||
// and must be stripped before the block is exposed to callers.
|
||||
delete (block as { index?: number }).index;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
// Try strict parse first; if it fails, attempt YAML bullet
|
||||
// repair (#2660) before falling back to the lenient streaming
|
||||
// parser which silently swallows errors.
|
||||
const raw = block.partialJson ?? "";
|
||||
const rawForParse = hasXmlParameterTags(raw)
|
||||
? repairToolJson(raw)
|
||||
: raw;
|
||||
let parsed: Record<string, any> | undefined;
|
||||
try {
|
||||
parsed = JSON.parse(rawForParse);
|
||||
} catch {
|
||||
try {
|
||||
parsed = JSON.parse(repairToolJson(rawForParse));
|
||||
} catch {
|
||||
// Fall through to streaming parser
|
||||
}
|
||||
}
|
||||
block.arguments = parsed ?? parseStreamingJson(block.partialJson);
|
||||
// `partialJson` is an internal streaming field that must not
|
||||
// appear on the final ToolCall exposed to callers.
|
||||
delete (block as { partialJson?: string }).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
if (event.usage.input_tokens != null) {
|
||||
output.usage.input = event.usage.input_tokens;
|
||||
}
|
||||
if (event.usage.output_tokens != null) {
|
||||
output.usage.output = event.usage.output_tokens;
|
||||
}
|
||||
if (event.usage.cache_read_input_tokens != null) {
|
||||
output.usage.cacheRead = event.usage.cache_read_input_tokens;
|
||||
}
|
||||
if (event.usage.cache_creation_input_tokens != null) {
|
||||
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
|
||||
}
|
||||
output.usage.totalTokens =
|
||||
output.usage.input +
|
||||
output.usage.output +
|
||||
output.usage.cacheRead +
|
||||
output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content)
|
||||
delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
if (model.provider === "alibaba-coding-plan") {
|
||||
output.errorMessage = `[alibaba-coding-plan] ${output.errorMessage}`;
|
||||
}
|
||||
if (
|
||||
AnthropicSdkClass &&
|
||||
error instanceof AnthropicSdkClass.APIError &&
|
||||
error.headers
|
||||
) {
|
||||
const retryAfterMs = extractRetryAfterMs(error.headers, error.message);
|
||||
if (retryAfterMs !== undefined) {
|
||||
output.retryAfterMs = retryAfterMs;
|
||||
}
|
||||
}
|
||||
if (isTransientNetworkError(error)) {
|
||||
output.retryAfterMs = output.retryAfterMs ?? 5000;
|
||||
}
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
|
@ -1,161 +0,0 @@
|
|||
// Lazy-loaded: Anthropic Vertex SDK is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Anthropic Vertex models.
|
||||
import type Anthropic from "@anthropic-ai/sdk";
|
||||
import type { AnthropicVertex } from "@anthropic-ai/vertex-sdk";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import type {
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
type AnthropicOptions,
|
||||
mapThinkingLevelToEffort,
|
||||
processAnthropicStream,
|
||||
supportsAdaptiveThinking,
|
||||
} from "./anthropic-shared.js";
|
||||
import {
|
||||
adjustMaxTokensForThinking,
|
||||
buildBaseOptions,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
let _AnthropicVertexClass: typeof AnthropicVertex | undefined;
|
||||
let _AnthropicSdkClass: typeof Anthropic | undefined;
|
||||
|
||||
async function getAnthropicVertexClass(): Promise<typeof AnthropicVertex> {
|
||||
if (!_AnthropicVertexClass) {
|
||||
const mod = await import("@anthropic-ai/vertex-sdk");
|
||||
_AnthropicVertexClass = mod.AnthropicVertex;
|
||||
}
|
||||
return _AnthropicVertexClass;
|
||||
}
|
||||
|
||||
async function getAnthropicSdkClass(): Promise<typeof Anthropic> {
|
||||
if (!_AnthropicSdkClass) {
|
||||
const mod = await import("@anthropic-ai/sdk");
|
||||
_AnthropicSdkClass = mod.default;
|
||||
}
|
||||
return _AnthropicSdkClass;
|
||||
}
|
||||
|
||||
function resolveProjectId(): string {
|
||||
const projectId =
|
||||
process.env.ANTHROPIC_VERTEX_PROJECT_ID ||
|
||||
process.env.GOOGLE_CLOUD_PROJECT ||
|
||||
process.env.GCLOUD_PROJECT;
|
||||
if (!projectId) {
|
||||
throw new Error(
|
||||
"Anthropic Vertex requires a project ID. Set ANTHROPIC_VERTEX_PROJECT_ID, GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT.",
|
||||
);
|
||||
}
|
||||
return projectId;
|
||||
}
|
||||
|
||||
function resolveRegion(): string {
|
||||
return (
|
||||
process.env.CLOUD_ML_REGION ||
|
||||
process.env.GOOGLE_CLOUD_LOCATION ||
|
||||
"us-central1"
|
||||
);
|
||||
}
|
||||
|
||||
async function createVertexClient(): Promise<AnthropicVertex> {
|
||||
const AnthropicVertexClass = await getAnthropicVertexClass();
|
||||
const projectId = resolveProjectId();
|
||||
const region = resolveRegion();
|
||||
|
||||
return new AnthropicVertexClass({
|
||||
projectId,
|
||||
region,
|
||||
});
|
||||
}
|
||||
|
||||
export const streamAnthropicVertex: StreamFunction<
|
||||
"anthropic-vertex",
|
||||
AnthropicOptions
|
||||
> = (
|
||||
model: Model<"anthropic-vertex">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const client = await createVertexClient();
|
||||
const AnthropicSdk = await getAnthropicSdkClass();
|
||||
|
||||
processAnthropicStream(stream, {
|
||||
client: client as unknown as Anthropic,
|
||||
model,
|
||||
context,
|
||||
isOAuthToken: false,
|
||||
options,
|
||||
AnthropicSdkClass: AnthropicSdk,
|
||||
});
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAnthropicVertex: StreamFunction<
|
||||
"anthropic-vertex",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"anthropic-vertex">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(
|
||||
`No API key found for provider: ${model.provider}. Set ANTHROPIC_VERTEX_PROJECT_ID to use Claude on Vertex AI.`,
|
||||
);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamAnthropicVertex(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: false,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
if (
|
||||
isAutoReasoning(options.reasoning) &&
|
||||
(supportsAdaptiveThinking(model.id) || model.capabilities?.thinkingNoBudget)
|
||||
) {
|
||||
return streamAnthropicVertex(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options.reasoning)!;
|
||||
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
const effort = mapThinkingLevelToEffort(effectiveReasoning, model.id);
|
||||
return streamAnthropicVertex(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
effort,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
effectiveReasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamAnthropicVertex(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: adjusted.thinkingBudget,
|
||||
} satisfies AnthropicOptions);
|
||||
};
|
||||
|
|
@ -1,263 +0,0 @@
|
|||
// Lazy-loaded: Anthropic SDK (~500ms) is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Anthropic models.
|
||||
import type Anthropic from "@anthropic-ai/sdk";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import type {
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
type AnthropicEffort,
|
||||
type AnthropicOptions,
|
||||
extractRetryAfterMs,
|
||||
mapThinkingLevelToEffort,
|
||||
processAnthropicStream,
|
||||
supportsAdaptiveThinking,
|
||||
} from "./anthropic-shared.js";
|
||||
import {
|
||||
buildCopilotDynamicHeaders,
|
||||
hasCopilotVisionInput,
|
||||
} from "./github-copilot-headers.js";
|
||||
import {
|
||||
adjustMaxTokensForThinking,
|
||||
buildBaseOptions,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
// Re-export types used by other modules
|
||||
export type { AnthropicEffort, AnthropicOptions };
|
||||
export { extractRetryAfterMs };
|
||||
|
||||
/**
|
||||
* Resolve the base URL for Anthropic API requests.
|
||||
*
|
||||
* Resolution order:
|
||||
* 1. ANTHROPIC_BASE_URL environment variable (if set and non-empty after trim)
|
||||
* 2. model.baseUrl (from the model definition)
|
||||
*
|
||||
* This allows routing traffic through custom proxy endpoints (e.g. OpusMax,
|
||||
* local mirrors, corporate gateways) without modifying model definitions.
|
||||
*/
|
||||
export function resolveAnthropicBaseUrl(
|
||||
model: Model<"anthropic-messages">,
|
||||
): string {
|
||||
const envBaseUrl = process.env.ANTHROPIC_BASE_URL?.trim();
|
||||
if (envBaseUrl) {
|
||||
return envBaseUrl;
|
||||
}
|
||||
return model.baseUrl;
|
||||
}
|
||||
|
||||
let _AnthropicClass: typeof Anthropic | undefined;
|
||||
async function getAnthropicClass(): Promise<typeof Anthropic> {
|
||||
if (!_AnthropicClass) {
|
||||
const mod = await import("@anthropic-ai/sdk");
|
||||
_AnthropicClass = mod.default;
|
||||
}
|
||||
return _AnthropicClass;
|
||||
}
|
||||
|
||||
function mergeHeaders(
|
||||
...headerSources: (Record<string, string> | undefined)[]
|
||||
): Record<string, string> {
|
||||
const merged: Record<string, string> = {};
|
||||
for (const headers of headerSources) {
|
||||
if (headers) {
|
||||
Object.assign(merged, headers);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
export function usesAnthropicBearerAuth(
|
||||
provider: Model<"anthropic-messages">["provider"],
|
||||
): boolean {
|
||||
return (
|
||||
provider === "alibaba-coding-plan" ||
|
||||
provider === "minimax" ||
|
||||
provider === "minimax-cn" ||
|
||||
provider === "longcat" ||
|
||||
provider === "xiaomi"
|
||||
);
|
||||
}
|
||||
|
||||
async function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
interleavedThinking: boolean,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
dynamicHeaders?: Record<string, string>,
|
||||
): Promise<{ client: Anthropic; isOAuthToken: boolean }> {
|
||||
const AnthropicClass = await getAnthropicClass();
|
||||
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
|
||||
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
|
||||
const needsInterleavedBeta =
|
||||
interleavedThinking && !supportsAdaptiveThinking(model.id);
|
||||
|
||||
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
|
||||
if (model.provider === "github-copilot") {
|
||||
const betaFeatures: string[] = [];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
const client = new AnthropicClass({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: resolveAnthropicBaseUrl(model),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
...(betaFeatures.length > 0
|
||||
? { "anthropic-beta": betaFeatures.join(",") }
|
||||
: {}),
|
||||
},
|
||||
model.headers,
|
||||
dynamicHeaders,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
// Skip beta headers for providers that don't support them (e.g., Alibaba Coding Plan)
|
||||
const skipBetaHeaders = model.provider === "alibaba-coding-plan";
|
||||
const betaFeatures = skipBetaHeaders
|
||||
? []
|
||||
: ["fine-grained-tool-streaming-2025-05-14"];
|
||||
if (needsInterleavedBeta && !skipBetaHeaders) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
// API key auth (Anthropic OAuth removed per TOS compliance — use API keys or Claude CLI)
|
||||
// Some Anthropic-compatible providers require Bearer auth instead of x-api-key.
|
||||
const usesBearerAuth = usesAnthropicBearerAuth(model.provider);
|
||||
const client = new AnthropicClass({
|
||||
apiKey: usesBearerAuth ? null : apiKey,
|
||||
authToken: usesBearerAuth ? apiKey : undefined,
|
||||
baseURL: resolveAnthropicBaseUrl(model),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
...(betaFeatures.length > 0
|
||||
? { "anthropic-beta": betaFeatures.join(",") }
|
||||
: {}),
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
export const streamAnthropic: StreamFunction<
|
||||
"anthropic-messages",
|
||||
AnthropicOptions
|
||||
> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
|
||||
|
||||
let copilotDynamicHeaders: Record<string, string> | undefined;
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
copilotDynamicHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
}
|
||||
|
||||
const { client, isOAuthToken: isOAuth } = await createClient(
|
||||
model,
|
||||
apiKey,
|
||||
options?.interleavedThinking ?? true,
|
||||
options?.headers,
|
||||
copilotDynamicHeaders,
|
||||
);
|
||||
|
||||
processAnthropicStream(stream, {
|
||||
client,
|
||||
model,
|
||||
context,
|
||||
isOAuthToken: isOAuth,
|
||||
options,
|
||||
AnthropicSdkClass: _AnthropicClass,
|
||||
});
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAnthropic: StreamFunction<
|
||||
"anthropic-messages",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: false,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
if (
|
||||
isAutoReasoning(options.reasoning) &&
|
||||
(supportsAdaptiveThinking(model.id) || model.capabilities?.thinkingNoBudget)
|
||||
) {
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options.reasoning)!;
|
||||
|
||||
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
|
||||
// For older models: use budget-based thinking
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
const effort = mapThinkingLevelToEffort(effectiveReasoning, model.id);
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
effort,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
effectiveReasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: adjusted.thinkingBudget,
|
||||
} satisfies AnthropicOptions);
|
||||
};
|
||||
|
|
@ -1,318 +0,0 @@
|
|||
// Lazy-loaded: OpenAI SDK (AzureOpenAI) is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Azure OpenAI models.
|
||||
import type { AzureOpenAI } from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
convertResponsesMessages,
|
||||
convertResponsesTools,
|
||||
processResponsesStream,
|
||||
} from "./openai-responses-shared.js";
|
||||
import {
|
||||
assertStreamSuccess,
|
||||
buildInitialOutput,
|
||||
clampReasoningForModel,
|
||||
finalizeStream,
|
||||
handleStreamError,
|
||||
} from "./openai-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
let _AzureOpenAIClass: typeof AzureOpenAI | undefined;
|
||||
async function getAzureOpenAIClass(): Promise<typeof AzureOpenAI> {
|
||||
if (!_AzureOpenAIClass) {
|
||||
const mod = await import("openai");
|
||||
_AzureOpenAIClass = mod.AzureOpenAI;
|
||||
}
|
||||
return _AzureOpenAIClass;
|
||||
}
|
||||
|
||||
const DEFAULT_AZURE_API_VERSION = "v1";
|
||||
const AZURE_TOOL_CALL_PROVIDERS = new Set([
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
"azure-openai-responses",
|
||||
]);
|
||||
|
||||
function parseDeploymentNameMap(
|
||||
value: string | undefined,
|
||||
): Map<string, string> {
|
||||
const map = new Map<string, string>();
|
||||
if (!value) return map;
|
||||
for (const entry of value.split(",")) {
|
||||
const trimmed = entry.trim();
|
||||
if (!trimmed) continue;
|
||||
const [modelId, deploymentName] = trimmed.split("=", 2);
|
||||
if (!modelId || !deploymentName) continue;
|
||||
map.set(modelId.trim(), deploymentName.trim());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
function resolveDeploymentName(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): string {
|
||||
if (options?.azureDeploymentName) {
|
||||
return options.azureDeploymentName;
|
||||
}
|
||||
const mappedDeployment = parseDeploymentNameMap(
|
||||
process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP,
|
||||
).get(model.id);
|
||||
return mappedDeployment || model.id;
|
||||
}
|
||||
|
||||
// Azure OpenAI Responses-specific options
|
||||
export interface AzureOpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
azureApiVersion?: string;
|
||||
azureResourceName?: string;
|
||||
azureBaseUrl?: string;
|
||||
azureDeploymentName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Azure OpenAI Responses API
|
||||
*/
|
||||
export const streamAzureOpenAIResponses: StreamFunction<
|
||||
"azure-openai-responses",
|
||||
AzureOpenAIResponsesOptions
|
||||
> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const deploymentName = resolveDeploymentName(model, options);
|
||||
const output = buildInitialOutput(model);
|
||||
|
||||
try {
|
||||
// Create Azure OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = await createClient(model, apiKey, options);
|
||||
let params = buildParams(model, context, options, deploymentName);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model);
|
||||
|
||||
assertStreamSuccess(output, options?.signal);
|
||||
finalizeStream(stream, output);
|
||||
} catch (error) {
|
||||
handleStreamError(stream, output, error, options?.signal);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAzureOpenAIResponses: StreamFunction<
|
||||
"azure-openai-responses",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options?.reasoning);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? effectiveReasoning
|
||||
: clampReasoning(effectiveReasoning);
|
||||
|
||||
return streamAzureOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies AzureOpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function normalizeAzureBaseUrl(baseUrl: string): string {
|
||||
return baseUrl.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
function buildDefaultBaseUrl(resourceName: string): string {
|
||||
return `https://${resourceName}.openai.azure.com/openai/v1`;
|
||||
}
|
||||
|
||||
function isCognitiveServicesDomain(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
return hostname.endsWith(".cognitiveservices.azure.com");
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeCognitiveServicesUrl(url: string): string {
|
||||
// Azure Cognitive Services endpoints use /openai/deployments/{deployment}/chat/completions
|
||||
// We need to normalize to the OpenAI-compatible base path
|
||||
if (url.includes("/openai/deployments/")) {
|
||||
return url.split("/openai/deployments/")[0]!;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
function resolveAzureConfig(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): { baseUrl: string; apiVersion: string } {
|
||||
const apiVersion =
|
||||
options?.azureApiVersion ||
|
||||
process.env.AZURE_OPENAI_API_VERSION ||
|
||||
DEFAULT_AZURE_API_VERSION;
|
||||
|
||||
const baseUrl =
|
||||
options?.azureBaseUrl?.trim() ||
|
||||
process.env.AZURE_OPENAI_BASE_URL?.trim() ||
|
||||
undefined;
|
||||
const resourceName =
|
||||
options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
|
||||
|
||||
let resolvedBaseUrl = baseUrl;
|
||||
|
||||
if (!resolvedBaseUrl && resourceName) {
|
||||
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl && model.baseUrl) {
|
||||
resolvedBaseUrl = model.baseUrl;
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl) {
|
||||
throw new Error(
|
||||
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
|
||||
);
|
||||
}
|
||||
|
||||
// Normalize Cognitive Services endpoints (e.g., .cognitiveservices.azure.com)
|
||||
if (isCognitiveServicesDomain(resolvedBaseUrl)) {
|
||||
resolvedBaseUrl = normalizeCognitiveServicesUrl(resolvedBaseUrl);
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
|
||||
apiVersion,
|
||||
};
|
||||
}
|
||||
|
||||
async function createClient(
|
||||
model: Model<"azure-openai-responses">,
|
||||
apiKey: string,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.AZURE_OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.AZURE_OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
|
||||
if (options?.headers) {
|
||||
Object.assign(headers, options.headers);
|
||||
}
|
||||
|
||||
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
|
||||
const AzureOpenAIClass = await getAzureOpenAIClass();
|
||||
|
||||
return new AzureOpenAIClass({
|
||||
apiKey,
|
||||
apiVersion,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
baseURL: baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options: AzureOpenAIResponsesOptions | undefined,
|
||||
deploymentName: string,
|
||||
) {
|
||||
const messages = convertResponsesMessages(
|
||||
model,
|
||||
context,
|
||||
AZURE_TOOL_CALL_PROVIDERS,
|
||||
);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: deploymentName,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: options?.sessionId,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
const effort = clampReasoningForModel(
|
||||
model.name,
|
||||
options?.reasoningEffort || "medium",
|
||||
) as typeof options.reasoningEffort;
|
||||
params.reasoning = {
|
||||
effort: effort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
if (model.name.toLowerCase().startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
|
@ -1,429 +0,0 @@
|
|||
import type { ChildProcessWithoutNullStreams } from "node:child_process";
|
||||
import type * as NodeReadline from "node:readline";
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_CHILD_PROCESS_SPECIFIER = "node:" + "child_process";
|
||||
const NODE_READLINE_SPECIFIER = "node:" + "readline";
|
||||
|
||||
type RequestId = number;
|
||||
type JsonObject = Record<string, unknown>;
|
||||
|
||||
interface JsonRpcError {
|
||||
code: number;
|
||||
message: string;
|
||||
data?: unknown;
|
||||
}
|
||||
|
||||
interface JsonRpcResponse {
|
||||
id: RequestId;
|
||||
result?: unknown;
|
||||
error?: JsonRpcError;
|
||||
}
|
||||
|
||||
interface JsonRpcNotification {
|
||||
method: string;
|
||||
params?: unknown;
|
||||
}
|
||||
|
||||
interface JsonRpcServerRequest extends JsonRpcNotification {
|
||||
id: RequestId;
|
||||
}
|
||||
|
||||
interface PendingRequest {
|
||||
resolve: (value: unknown) => void;
|
||||
reject: (reason: Error) => void;
|
||||
}
|
||||
|
||||
export interface CodexAppServerClientOptions {
|
||||
cwd?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
extraArgs?: string[];
|
||||
clientInfo?: {
|
||||
name: string;
|
||||
title: string;
|
||||
version: string;
|
||||
};
|
||||
}
|
||||
|
||||
export type CodexAppServerNotification = JsonRpcNotification;
|
||||
|
||||
export type CodexAppServerNotificationHandler = (
|
||||
notification: CodexAppServerNotification,
|
||||
) => void;
|
||||
|
||||
const DEFAULT_CLIENT_INFO = {
|
||||
name: "singularity_forge_pi_ai",
|
||||
title: "Singularity Forge pi-ai",
|
||||
version: "0.0.0",
|
||||
};
|
||||
|
||||
let sharedClientPromise: Promise<CodexAppServerClient> | undefined;
|
||||
|
||||
/**
|
||||
* Return the session-wide Codex app-server client. Spawns `codex app-server` lazily and reuses it.
|
||||
*
|
||||
* Purpose: delegate ChatGPT auth and protocol drift to the installed Codex CLI while keeping pi-ai provider calls cheap.
|
||||
* Consumer: openai-codex-responses.ts for every OpenAI Codex provider stream.
|
||||
*/
|
||||
export function getCodexAppServerClient(
|
||||
options?: CodexAppServerClientOptions,
|
||||
): Promise<CodexAppServerClient> {
|
||||
if (!sharedClientPromise) {
|
||||
sharedClientPromise = CodexAppServerClient.connect(options);
|
||||
}
|
||||
return sharedClientPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the session-wide Codex app-server client after the child process exits or is disposed.
|
||||
*
|
||||
* Purpose: allow the next provider call to recover from a crashed or deliberately closed Codex process.
|
||||
* Consumer: CodexAppServerClient lifecycle handlers in this module.
|
||||
*/
|
||||
export function clearCodexAppServerClient(client: CodexAppServerClient): void {
|
||||
if (sharedClientPromise) {
|
||||
sharedClientPromise.then(
|
||||
(current) => {
|
||||
if (current === client) sharedClientPromise = undefined;
|
||||
},
|
||||
() => {
|
||||
sharedClientPromise = undefined;
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* JSON-RPC client for a stdio `codex app-server` child process.
|
||||
*
|
||||
* Purpose: provide a small dependency-free transport that matches Codex's newline-delimited JSON protocol.
|
||||
* Consumer: getCodexAppServerClient and the OpenAI Codex provider adapter.
|
||||
*/
|
||||
export class CodexAppServerClient {
|
||||
private proc: ChildProcessWithoutNullStreams | undefined;
|
||||
private readline: NodeReadline.Interface | undefined;
|
||||
private nextId = 1;
|
||||
private stderr = "";
|
||||
private closed = false;
|
||||
private exitError: Error | undefined;
|
||||
private readonly pending = new Map<RequestId, PendingRequest>();
|
||||
private readonly notificationHandlers =
|
||||
new Set<CodexAppServerNotificationHandler>();
|
||||
|
||||
private constructor(private readonly options: CodexAppServerClientOptions) {}
|
||||
|
||||
/**
|
||||
* Spawn and initialize a Codex app-server process.
|
||||
*
|
||||
* Purpose: complete Codex's required initialize/initialized handshake before any thread or turn RPC.
|
||||
* Consumer: getCodexAppServerClient when creating the shared process.
|
||||
*/
|
||||
static async connect(
|
||||
options: CodexAppServerClientOptions = {},
|
||||
): Promise<CodexAppServerClient> {
|
||||
const client = new CodexAppServerClient(options);
|
||||
await client.initialize();
|
||||
return client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a notification callback and return an unsubscribe function.
|
||||
*
|
||||
* Purpose: let provider streams observe only their own thread/turn notifications without owning the transport.
|
||||
* Consumer: streamOpenAICodexResponses while a turn is active.
|
||||
*/
|
||||
onNotification(handler: CodexAppServerNotificationHandler): () => void {
|
||||
this.notificationHandlers.add(handler);
|
||||
return () => {
|
||||
this.notificationHandlers.delete(handler);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a JSON-RPC request and resolve with the response result.
|
||||
*
|
||||
* Purpose: provide typed call sites for app-server methods while centralizing response/error handling.
|
||||
* Consumer: provider setup, turn start, context injection, and cancellation paths.
|
||||
*/
|
||||
request(
|
||||
method: string,
|
||||
params?: unknown,
|
||||
signal?: AbortSignal,
|
||||
): Promise<unknown> {
|
||||
if (this.closed) {
|
||||
return Promise.reject(
|
||||
this.exitError ?? new Error("codex app-server is closed."),
|
||||
);
|
||||
}
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error("Request was aborted"));
|
||||
}
|
||||
|
||||
const id = this.nextId++;
|
||||
const message =
|
||||
params === undefined ? { id, method } : { id, method, params };
|
||||
|
||||
return new Promise<unknown>((resolve, reject) => {
|
||||
const abort = () => {
|
||||
this.pending.delete(id);
|
||||
reject(new Error("Request was aborted"));
|
||||
};
|
||||
|
||||
this.pending.set(id, {
|
||||
resolve: (value) => {
|
||||
signal?.removeEventListener("abort", abort);
|
||||
resolve(value);
|
||||
},
|
||||
reject: (error) => {
|
||||
signal?.removeEventListener("abort", abort);
|
||||
reject(error);
|
||||
},
|
||||
});
|
||||
|
||||
signal?.addEventListener("abort", abort, { once: true });
|
||||
|
||||
this.send(message).catch((error: unknown) => {
|
||||
this.pending.delete(id);
|
||||
signal?.removeEventListener("abort", abort);
|
||||
reject(error instanceof Error ? error : new Error(String(error)));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a JSON-RPC notification.
|
||||
*
|
||||
* Purpose: acknowledge initialization and support fire-and-forget app-server protocol calls.
|
||||
* Consumer: initialize() for the required `initialized` notification.
|
||||
*/
|
||||
async notify(method: string, params?: unknown): Promise<void> {
|
||||
const message = params === undefined ? { method } : { method, params };
|
||||
await this.send(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Interrupt an active Codex turn.
|
||||
*
|
||||
* Purpose: translate an AbortSignal into Codex's cooperative turn cancellation RPC.
|
||||
* Consumer: openai-codex-responses.ts abort handling.
|
||||
*/
|
||||
async interruptTurn(threadId: string, turnId: string): Promise<void> {
|
||||
await this.request("turn/interrupt", { threadId, turnId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose the child process and reject pending requests.
|
||||
*
|
||||
* Purpose: release the long-running Codex process when the owning session is shutting down.
|
||||
* Consumer: tests, future host lifecycle hooks, and crash recovery.
|
||||
*/
|
||||
async dispose(): Promise<void> {
|
||||
if (this.closed) return;
|
||||
this.closed = true;
|
||||
clearCodexAppServerClient(this);
|
||||
this.readline?.close();
|
||||
if (this.proc && !this.proc.killed) {
|
||||
this.proc.kill("SIGTERM");
|
||||
}
|
||||
this.rejectPending(new Error("codex app-server was disposed."));
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose the client if it has no pending requests and no active notification
|
||||
* handlers. The check is deferred by one event-loop turn so a consumer that is
|
||||
* about to register a new handler wins the race.
|
||||
*
|
||||
* Purpose: allow short-lived processes (smoke tests, one-shot scripts) to exit
|
||||
* cleanly without leaking the codex app-server child process, while keeping the
|
||||
* client alive across back-to-back turns in a long-running session.
|
||||
*/
|
||||
releaseIfIdle(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
setImmediate(() => {
|
||||
if (
|
||||
this.closed ||
|
||||
this.pending.size > 0 ||
|
||||
this.notificationHandlers.size > 0
|
||||
) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
this.dispose().then(
|
||||
() => resolve(),
|
||||
() => resolve(),
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private async initialize(): Promise<void> {
|
||||
const childProcessModule = (await dynamicImport(
|
||||
NODE_CHILD_PROCESS_SPECIFIER,
|
||||
)) as typeof import("node:child_process");
|
||||
const readlineModule = (await dynamicImport(
|
||||
NODE_READLINE_SPECIFIER,
|
||||
)) as typeof import("node:readline");
|
||||
const args = [
|
||||
"app-server",
|
||||
"--listen",
|
||||
"stdio://",
|
||||
...(this.options.extraArgs ?? []),
|
||||
];
|
||||
|
||||
try {
|
||||
this.proc = childProcessModule.spawn("codex", args, {
|
||||
cwd: this.options.cwd ?? process.cwd(),
|
||||
env: this.options.env ?? process.env,
|
||||
stdio: ["pipe", "pipe", "pipe"],
|
||||
shell: process.platform === "win32",
|
||||
windowsHide: true,
|
||||
});
|
||||
} catch (error) {
|
||||
throw this.toSpawnError(error);
|
||||
}
|
||||
|
||||
this.proc.stdout.setEncoding("utf8");
|
||||
this.proc.stderr.setEncoding("utf8");
|
||||
this.proc.stderr.on("data", (chunk: string) => {
|
||||
this.stderr = (this.stderr + chunk).slice(-12000);
|
||||
});
|
||||
this.proc.on("error", (error) => {
|
||||
this.handleExit(this.toSpawnError(error));
|
||||
});
|
||||
this.proc.on("exit", (code, signal) => {
|
||||
if (this.closed) return;
|
||||
const detail = signal ? `signal ${signal}` : `exit ${code ?? "unknown"}`;
|
||||
this.handleExit(
|
||||
new Error(
|
||||
`codex app-server exited unexpectedly (${detail}).${this.stderrSuffix()}`,
|
||||
),
|
||||
);
|
||||
});
|
||||
|
||||
this.readline = readlineModule.createInterface({ input: this.proc.stdout });
|
||||
this.readline.on("line", (line) => this.handleLine(line));
|
||||
|
||||
await this.request("initialize", {
|
||||
clientInfo: this.options.clientInfo ?? DEFAULT_CLIENT_INFO,
|
||||
capabilities: { experimentalApi: true },
|
||||
});
|
||||
await this.notify("initialized");
|
||||
}
|
||||
|
||||
private async send(message: JsonObject): Promise<void> {
|
||||
if (!this.proc?.stdin || this.closed) {
|
||||
throw (
|
||||
this.exitError ?? new Error("codex app-server stdin is not available.")
|
||||
);
|
||||
}
|
||||
const line = `${JSON.stringify(message)}\n`;
|
||||
if (this.proc.stdin.write(line)) return;
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const stdin = this.proc?.stdin;
|
||||
if (!stdin) {
|
||||
reject(new Error("codex app-server stdin is not available."));
|
||||
return;
|
||||
}
|
||||
const onDrain = () => {
|
||||
cleanup();
|
||||
resolve();
|
||||
};
|
||||
const onError = (error: Error) => {
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
const cleanup = () => {
|
||||
stdin.off("drain", onDrain);
|
||||
stdin.off("error", onError);
|
||||
};
|
||||
stdin.once("drain", onDrain);
|
||||
stdin.once("error", onError);
|
||||
});
|
||||
}
|
||||
|
||||
private handleLine(line: string): void {
|
||||
let message: unknown;
|
||||
try {
|
||||
message = JSON.parse(line);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
if (!message || typeof message !== "object") return;
|
||||
const object = message as JsonObject;
|
||||
|
||||
if (typeof object.id === "number" && !("method" in object)) {
|
||||
this.handleResponse(object as unknown as JsonRpcResponse);
|
||||
return;
|
||||
}
|
||||
if (typeof object.method === "string" && typeof object.id === "number") {
|
||||
this.handleServerRequest(object as unknown as JsonRpcServerRequest);
|
||||
return;
|
||||
}
|
||||
if (typeof object.method === "string") {
|
||||
this.handleNotification({ method: object.method, params: object.params });
|
||||
}
|
||||
}
|
||||
|
||||
private handleResponse(response: JsonRpcResponse): void {
|
||||
const pending = this.pending.get(response.id);
|
||||
if (!pending) return;
|
||||
this.pending.delete(response.id);
|
||||
if (response.error) {
|
||||
pending.reject(
|
||||
new Error(`${response.error.message} (code ${response.error.code})`),
|
||||
);
|
||||
return;
|
||||
}
|
||||
pending.resolve(response.result);
|
||||
}
|
||||
|
||||
private handleNotification(notification: CodexAppServerNotification): void {
|
||||
for (const handler of this.notificationHandlers) {
|
||||
handler(notification);
|
||||
}
|
||||
}
|
||||
|
||||
private handleServerRequest(request: JsonRpcServerRequest): void {
|
||||
this.send({
|
||||
id: request.id,
|
||||
error: {
|
||||
code: -32601,
|
||||
message: `Unsupported codex app-server request: ${request.method}`,
|
||||
},
|
||||
}).catch(() => {});
|
||||
}
|
||||
|
||||
private handleExit(error?: Error): void {
|
||||
if (this.closed && !error) return;
|
||||
this.closed = true;
|
||||
this.exitError = error ?? new Error("codex app-server connection closed.");
|
||||
clearCodexAppServerClient(this);
|
||||
this.readline?.close();
|
||||
this.rejectPending(this.exitError);
|
||||
}
|
||||
|
||||
private rejectPending(error: Error): void {
|
||||
for (const pending of this.pending.values()) {
|
||||
pending.reject(error);
|
||||
}
|
||||
this.pending.clear();
|
||||
}
|
||||
|
||||
private toSpawnError(error: unknown): Error {
|
||||
const err = error instanceof Error ? error : new Error(String(error));
|
||||
const nodeError = err as Error & { code?: string };
|
||||
if (nodeError.code === "ENOENT") {
|
||||
return new Error(
|
||||
"Codex CLI was not found. Install the OpenAI Codex CLI and make sure `codex` is on PATH, then run `codex login` if needed.",
|
||||
);
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
private stderrSuffix(): string {
|
||||
const trimmed = this.stderr.trim();
|
||||
return trimmed ? ` stderr: ${trimmed}` : "";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
import type { Message } from "../types.js";
|
||||
|
||||
// Copilot expects X-Initiator to indicate whether the request is user-initiated
|
||||
// or agent-initiated (e.g. follow-up after assistant/tool messages).
|
||||
function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
|
||||
const last = messages[messages.length - 1];
|
||||
return last && last.role !== "user" ? "agent" : "user";
|
||||
}
|
||||
|
||||
// Copilot requires Copilot-Vision-Request header when sending images
|
||||
export function hasCopilotVisionInput(messages: Message[]): boolean {
|
||||
return messages.some((msg) => {
|
||||
if (msg.role === "user" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
export function buildCopilotDynamicHeaders(params: {
|
||||
messages: Message[];
|
||||
hasImages: boolean;
|
||||
}): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"X-Initiator": inferCopilotInitiator(params.messages),
|
||||
"Openai-Intent": "conversation-edits",
|
||||
};
|
||||
|
||||
if (params.hasImages) {
|
||||
headers["Copilot-Vision-Request"] = "true";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
# Re-platforming `google-gemini-cli` onto `@google/gemini-cli-core`
|
||||
|
||||
**Status:** Dependency installed (2026-04-19). Refactor pending next iteration.
|
||||
|
||||
## Goal
|
||||
|
||||
Replace the handwritten `fetch()` transport in `google-gemini-cli.ts` with calls
|
||||
into `@google/gemini-cli-core`'s `CodeAssistServer` so requests to
|
||||
`cloudcode-pa.googleapis.com` are byte-for-byte indistinguishable from the
|
||||
official `gemini` CLI. Upside: free OAuth quota treatment, automatic inheritance
|
||||
of upstream improvements, no reverse-engineered User-Agent / Client-Metadata
|
||||
drift.
|
||||
|
||||
## Scope
|
||||
|
||||
**In-scope**
|
||||
- `provider: "google-gemini-cli"` stream paths in `google-gemini-cli.ts`
|
||||
(functions `streamGoogleGeminiCli` and `streamSimpleGoogleGeminiCli`).
|
||||
|
||||
**Out-of-scope (keep handwritten)**
|
||||
- `provider: "google-antigravity"` — different sandbox endpoints
|
||||
(`daily-cloudcode-pa.sandbox.googleapis.com`), different auth contract
|
||||
(Antigravity IDE-scoped), different User-Agent requirements. cli-core
|
||||
does not target these endpoints.
|
||||
- `provider: "google"` (API key) and `provider: "google-vertex"` — unrelated
|
||||
transports, stay on `@google/genai` directly.
|
||||
|
||||
## API mapping (cli-core 0.38.2)
|
||||
|
||||
| Today (handwritten) | After (cli-core) |
|
||||
|------------------------------------------------------------|------------------------------------------------------------------------|
|
||||
| `fetch(CLOUD_CODE_ASSIST_ENDPOINT + ":streamGenerateContent?alt=sse", …)` | `await server.generateContentStream(req, promptId, role)` returns `AsyncGenerator<GenerateContentResponse>` |
|
||||
| Manual SSE body parsing (`response.body.getReader()` + `TextDecoder`) | cli-core yields already-parsed `GenerateContentResponse` chunks |
|
||||
| Custom retry loop (429/5xx with backoff, endpoint cascade) | cli-core has internal retry in `requestStreamingPost` |
|
||||
| Header assembly (`User-Agent`, `X-Goog-Api-Client`, `Client-Metadata`) | cli-core sets its own correct headers; just pass `httpOptions.headers` for extras |
|
||||
| OAuth token carried in SF `apiKey` as `{ token, projectId }` JSON | Either keep (build `OAuth2Client` + set credentials) OR let cli-core load from `~/.gemini/oauth_creds.json` via `getOauthClient()` |
|
||||
|
||||
Relevant cli-core exports:
|
||||
|
||||
```ts
|
||||
import { CodeAssistServer, CODE_ASSIST_ENDPOINT, type HttpOptions } from "@google/gemini-cli-core";
|
||||
import { getOauthClient } from "@google/gemini-cli-core/dist/src/code_assist/oauth2.js";
|
||||
import { AuthType } from "@google/gemini-cli-core";
|
||||
import type { GenerateContentParameters, GenerateContentResponse } from "@google/genai";
|
||||
```
|
||||
|
||||
## Two integration strategies
|
||||
|
||||
### Strategy A: Transport-only (incremental, lower risk)
|
||||
|
||||
Keep SF's existing auth storage (`apiKey` JSON blob with `{ token, projectId }`).
|
||||
At each request:
|
||||
|
||||
```ts
|
||||
import { OAuth2Client } from "google-auth-library";
|
||||
import { CodeAssistServer } from "@google/gemini-cli-core";
|
||||
|
||||
const authClient = new OAuth2Client();
|
||||
authClient.setCredentials({ access_token: token });
|
||||
const server = new CodeAssistServer(authClient, projectId, {
|
||||
headers: { /* extras if any */ },
|
||||
});
|
||||
|
||||
for await (const chunk of await server.generateContentStream(req, promptId, "USER")) {
|
||||
// feed chunk into existing AssistantMessageEventStream adapter
|
||||
}
|
||||
```
|
||||
|
||||
Pros: no SF auth-layer changes, minimal blast radius.
|
||||
Cons: SF still does OAuth refresh manually; cli-core's auto-refresh benefit lost.
|
||||
|
||||
### Strategy B: Full cli-core auth (target state)
|
||||
|
||||
Drop the `apiKey` unpacking for `google-gemini-cli`. At provider init:
|
||||
|
||||
```ts
|
||||
const authClient = await getOauthClient(AuthType.LOGIN_WITH_GOOGLE, config);
|
||||
const server = new CodeAssistServer(authClient, projectId);
|
||||
```
|
||||
|
||||
cli-core reads `~/.gemini/oauth_creds.json` (migrated to keychain on newer
|
||||
installs), refreshes tokens, writes back. SF's `/login` flow for this provider
|
||||
becomes "let cli-core own the login flow" instead of reimplementing Google OAuth in SF.
|
||||
|
||||
Pros: full integration benefit, SF drops ~80 lines of auth management.
|
||||
Cons: breaks existing SF auth storage path for this provider; users must
|
||||
re-authenticate via `gemini` CLI at least once.
|
||||
|
||||
Recommendation: **A first** (one commit, verifiable), **B second** as a
|
||||
follow-up once A is stable.
|
||||
|
||||
## Implementation checklist (Strategy A)
|
||||
|
||||
1. Add factory helper `buildCodeAssistServer(token, projectId)` that constructs
|
||||
`OAuth2Client` + `CodeAssistServer`. Put it alongside the existing helpers
|
||||
near the top of `google-gemini-cli.ts`.
|
||||
2. In `streamGoogleGeminiCli` (line 320): branch on `model.provider`. When
|
||||
`"google-gemini-cli"`, use the new helper and replace the `fetch()` block
|
||||
(lines ~392-450) with `server.generateContentStream()` consumption. When
|
||||
`"google-antigravity"`, keep the existing codepath unchanged.
|
||||
3. Convert cli-core's `GenerateContentResponse` chunks to SSE-equivalent
|
||||
processing via the existing `processStreamChunk` helper (or inline the
|
||||
minimal equivalent — cli-core already parses the JSON).
|
||||
4. Remove `GEMINI_CLI_HEADERS` constant (cli-core sets its own).
|
||||
5. Keep `ANTIGRAVITY_*` constants for the antigravity path.
|
||||
6. Update `streamSimpleGoogleGeminiCli` similarly.
|
||||
7. Tests:
|
||||
- Replace `global.fetch` mocks targeting `cloudcode-pa.googleapis.com` with
|
||||
`CodeAssistServer` prototype mocks (`generateContentStream` returns a
|
||||
mocked AsyncGenerator).
|
||||
- Keep antigravity tests unchanged (still fetch-based).
|
||||
8. Live smoke test against a `gemini-*` model in dr-repo or a scratch project,
|
||||
confirm OAuth flow works, streaming response arrives, cost is reported.
|
||||
|
||||
## Retry semantics
|
||||
|
||||
cli-core's internal retry on `requestStreamingPost` handles 429/5xx with
|
||||
exponential backoff and consults `Retry-After` headers. That subsumes the
|
||||
existing `MAX_RETRIES` / `BASE_DELAY_MS` loop in SF for this provider.
|
||||
Keep the loop for antigravity (different endpoint, different quirks).
|
||||
|
||||
`extractRetryDelay` and `isRetryableError` helpers become antigravity-only.
|
||||
|
||||
## Why this matters (recap)
|
||||
|
||||
- **Free OAuth quota**: Google subsidises the official CLI's free tier. Our
|
||||
requests blending in byte-for-byte preserves access.
|
||||
- **Bot-detection resilience**: User-Agent / Client-Metadata drift risk goes
|
||||
to zero — cli-core is the authoritative client.
|
||||
- **Upstream improvements**: new tool formats, grounding, session caching,
|
||||
quota displays ship via `npm update @google/gemini-cli-core`.
|
||||
- **Our proxy becomes "the CLI, programmable"**: identical upstream behavior,
|
||||
hookable local endpoint for any OpenAI-compatible tool.
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, test, vi } from "vitest";
|
||||
import type { Context, Model } from "../types.js";
|
||||
|
||||
const geminiCliCore = vi.hoisted(() => ({
|
||||
retryError: undefined as Error | undefined,
|
||||
retryOptions: undefined as Record<string, unknown> | undefined,
|
||||
helperArgs: undefined as Record<string, unknown> | undefined,
|
||||
}));
|
||||
|
||||
vi.mock("@google/gemini-cli-core", () => ({
|
||||
CodeAssistServer: class {
|
||||
async generateContentStream(): Promise<AsyncGenerator<unknown>> {
|
||||
return (async function* emptyStream() {})();
|
||||
}
|
||||
},
|
||||
retryWithBackoff: vi.fn(
|
||||
async (_fn: unknown, options: Record<string, unknown>) => {
|
||||
geminiCliCore.retryOptions = options;
|
||||
throw geminiCliCore.retryError ?? new Error("quota exhausted");
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("@singularity-forge/google-gemini-cli-provider", () => ({
|
||||
createGeminiCliContentGenerator: vi.fn(
|
||||
async (args: Record<string, unknown>) => {
|
||||
geminiCliCore.helperArgs = args;
|
||||
return {
|
||||
async generateContentStream(): Promise<AsyncGenerator<unknown>> {
|
||||
return (async function* emptyStream() {})();
|
||||
},
|
||||
};
|
||||
},
|
||||
),
|
||||
}));
|
||||
|
||||
import { streamGoogleGeminiCli } from "./google-gemini-cli.js";
|
||||
|
||||
function makeModel(): Model<"google-gemini-cli"> {
|
||||
return {
|
||||
id: "gemini-3-flash-preview",
|
||||
name: "Gemini 3 Flash Preview",
|
||||
api: "google-gemini-cli",
|
||||
provider: "google-gemini-cli",
|
||||
baseUrl: "",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 1_000_000,
|
||||
maxTokens: 8192,
|
||||
};
|
||||
}
|
||||
|
||||
function makeContext(): Context {
|
||||
return {
|
||||
messages: [{ role: "user", content: "hello", timestamp: 0 }],
|
||||
};
|
||||
}
|
||||
|
||||
describe("google-gemini-cli provider retry ownership", () => {
|
||||
test("google_gemini_cli_when_quota_resets_soon_returns_error_to_caller_without_cli_retry_loop", async () => {
|
||||
geminiCliCore.retryOptions = undefined;
|
||||
geminiCliCore.retryError = Object.assign(
|
||||
new Error(
|
||||
"You have exhausted your capacity on this model. Your quota will reset after 54s.",
|
||||
),
|
||||
{ retryDelayMs: 54_000 },
|
||||
);
|
||||
|
||||
const stream = streamGoogleGeminiCli(makeModel(), makeContext());
|
||||
const result = await stream.result();
|
||||
|
||||
const retryOptions = geminiCliCore.retryOptions as
|
||||
| { maxAttempts?: unknown }
|
||||
| undefined;
|
||||
assert.equal(retryOptions?.maxAttempts, 1);
|
||||
assert.equal(geminiCliCore.helperArgs?.modelId, "gemini-3-flash-preview");
|
||||
assert.equal(result.stopReason, "error");
|
||||
assert.match(result.errorMessage ?? "", /exhausted your capacity/i);
|
||||
assert.equal(result.retryAfterMs, 54_000);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,638 +0,0 @@
|
|||
/**
|
||||
* Google Gemini CLI provider.
|
||||
*
|
||||
* Delegates auth, project discovery, and the Code Assist transport setup to
|
||||
* the dedicated google-gemini-cli-provider package.
|
||||
* Request retry/fallback stays in the caller so SF can move to the next model.
|
||||
*/
|
||||
|
||||
import { retryWithBackoff } from "@google/gemini-cli-core";
|
||||
import type {
|
||||
Content,
|
||||
GenerateContentParameters,
|
||||
ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReasonString,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
import { createGeminiCliContentGenerator } from "@singularity-forge/google-gemini-cli-provider";
|
||||
|
||||
/**
|
||||
* Thinking level for Gemini 3 models.
|
||||
*
|
||||
* Gemini 3 Pro supports LOW/HIGH; Gemini 3 Flash supports MINIMAL/LOW/MEDIUM/HIGH.
|
||||
* These are the wire format values for `ThinkingConfig.thinkingLevel` sent to cli-core's
|
||||
* `CodeAssistServer.generateContentStream()`.
|
||||
*/
|
||||
export type GoogleThinkingLevel =
|
||||
| "THINKING_LEVEL_UNSPECIFIED"
|
||||
| "MINIMAL"
|
||||
| "LOW"
|
||||
| "MEDIUM"
|
||||
| "HIGH";
|
||||
|
||||
/**
|
||||
* Options for `streamGoogleGeminiCli()`.
|
||||
*
|
||||
* Delegates auth to the helper package (reads ~/.gemini/oauth_creds.json via
|
||||
* Gemini CLI Core's transport setup);
|
||||
* `projectId` is auto-discovered and not used by this provider (apiKey is ignored).
|
||||
* Thinking is configured separately from base `StreamOptions` because Gemini 2 and 3
|
||||
* models use incompatible enum formats (budgetTokens vs. level).
|
||||
*/
|
||||
export interface GoogleGeminiCliOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
/**
|
||||
* Thinking/reasoning configuration.
|
||||
* - Gemini 2.x models: use `budgetTokens` to set the thinking budget
|
||||
* - Gemini 3 models (gemini-3-pro-*, gemini-3-flash-*): use `level` instead
|
||||
*
|
||||
* When using `streamSimple`, this is handled automatically based on the model.
|
||||
*/
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
/** Thinking budget in tokens. Use for Gemini 2.x models. */
|
||||
budgetTokens?: number;
|
||||
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
projectId?: string;
|
||||
}
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
function parseDurationMs(value: string): number | undefined {
|
||||
const match = value.match(/(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?/i);
|
||||
if (!match || !match[0]) return undefined;
|
||||
const hours = Number(match[1] ?? 0);
|
||||
const minutes = Number(match[2] ?? 0);
|
||||
const seconds = Number(match[3] ?? 0);
|
||||
const totalMs = ((hours * 60 + minutes) * 60 + seconds) * 1000;
|
||||
return totalMs > 0 ? totalMs : undefined;
|
||||
}
|
||||
|
||||
function extractRetryAfterMs(error: unknown): number | undefined {
|
||||
if (typeof error === "object" && error !== null && "retryDelayMs" in error) {
|
||||
const retryDelayMs = (error as { retryDelayMs?: unknown }).retryDelayMs;
|
||||
if (
|
||||
typeof retryDelayMs === "number" &&
|
||||
Number.isFinite(retryDelayMs) &&
|
||||
retryDelayMs > 0
|
||||
) {
|
||||
return retryDelayMs;
|
||||
}
|
||||
}
|
||||
const message =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
const resetMatch = message.match(
|
||||
/(?:quota will reset|reset) after ([0-9hms]+)/i,
|
||||
);
|
||||
return resetMatch?.[1] ? parseDurationMs(resetMatch[1]) : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is a Gemini 3 Pro variant (gemini-3*-pro).
|
||||
* Used to determine which thinking config enum to use (thinkingLevel vs. budgetTokens).
|
||||
*/
|
||||
function isGemini3ProModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-pro/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is a Gemini 3 Flash variant (gemini-3*-flash).
|
||||
* Used to determine which thinking config enum to use (thinkingLevel vs. budgetTokens).
|
||||
*/
|
||||
function isGemini3FlashModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-flash/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is any Gemini 3 variant (Pro or Flash).
|
||||
* Determines whether to use thinkingLevel enum (Gemini 3) vs. budgetTokens (Gemini 2.x).
|
||||
*/
|
||||
function isGemini3Model(modelId: string): boolean {
|
||||
return isGemini3ProModel(modelId) || isGemini3FlashModel(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion from Google Gemini via the helper package and cli-core transport.
|
||||
*
|
||||
* The helper package owns the OAuth/bootstrap path against `@google/gemini-cli-core`, including
|
||||
* `~/.gemini/oauth_creds.json` and Gemini Code Assist project discovery. `apiKey` is ignored.
|
||||
* Casting the request as `any` works around the fact that cli-core bundles its own nested
|
||||
* `@google/genai` copy (nominal type split at packaging time; runtime shapes are byte-identical).
|
||||
* Returns a real-time stream emitting start, delta, end, and error events that accumulate into
|
||||
* an `AssistantMessage`.
|
||||
*/
|
||||
export const streamGoogleGeminiCli: StreamFunction<
|
||||
"google-gemini-cli",
|
||||
GoogleGeminiCliOptions
|
||||
> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: GoogleGeminiCliOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-gemini-cli" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
let req = buildRequest(model, context, options);
|
||||
const nextReq = await options?.onPayload?.(req, model);
|
||||
if (nextReq !== undefined) {
|
||||
req = nextReq as GenerateContentParameters;
|
||||
}
|
||||
// cli-core handles auth + project discovery through the helper package.
|
||||
const server = await createGeminiCliContentGenerator({
|
||||
modelId: req.model,
|
||||
});
|
||||
const promptId = `pi-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`;
|
||||
// Cast through `any` — cli-core bundles its own nested @google/genai copy,
|
||||
// so TypeScript sees two structurally-identical-but-distinct Content types.
|
||||
// The runtime shapes are byte-identical; the nominal split is a packaging
|
||||
// artefact.
|
||||
const streamGen = await retryWithBackoff(
|
||||
() => server.generateContentStream(req as any, promptId, "USER" as any),
|
||||
{
|
||||
// SF owns cross-model fallback. Let cli-core classify quota errors,
|
||||
// but do not let it hold the turn through its 10-attempt retry loop.
|
||||
maxAttempts: 1,
|
||||
signal: options?.signal,
|
||||
},
|
||||
);
|
||||
|
||||
let started = false;
|
||||
const ensureStarted = () => {
|
||||
if (!started) {
|
||||
stream.push({ type: "start", partial: output });
|
||||
started = true;
|
||||
}
|
||||
};
|
||||
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
for await (const chunk of streamGen) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
const candidate = chunk?.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
// Text / thinking block handling
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: undefined,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Tool-call part
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some(
|
||||
(b) => b.type === "toolCall" && b.id === providedId,
|
||||
);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments:
|
||||
(part.functionCall.args as Record<string, unknown>) ?? {},
|
||||
...(part.thoughtSignature && {
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
}),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
ensureStarted();
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReasonString(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk?.usageMetadata) {
|
||||
const promptTokens = chunk.usageMetadata.promptTokenCount || 0;
|
||||
const cacheReadTokens =
|
||||
chunk.usageMetadata.cachedContentTokenCount || 0;
|
||||
output.usage = {
|
||||
input: promptTokens - cacheReadTokens,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) +
|
||||
(chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: cacheReadTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
// Close any open text/thinking block after stream ends
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
const retryAfterMs = extractRetryAfterMs(error);
|
||||
if (retryAfterMs !== undefined) {
|
||||
output.retryAfterMs = retryAfterMs;
|
||||
}
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Simplified stream wrapper that auto-configures thinking based on model and reasoning level.
|
||||
*
|
||||
* Reasoning intent is resolved via `buildBaseOptions()` and the `reasoning` flag in `SimpleStreamOptions`.
|
||||
* For Gemini 3 models, uses the thinkingLevel enum (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash).
|
||||
* For Gemini 2.x, maps the requested level to token budgets (default: minimal=1K, low=2K, medium=8K, high=16K).
|
||||
* Auth is still handled by cli-core (apiKey is ignored). Returns the same `AssistantMessageEventStream`
|
||||
* as `streamGoogleGeminiCli()` after delegating with appropriate `thinking` config.
|
||||
*/
|
||||
export const streamSimpleGoogleGeminiCli: StreamFunction<
|
||||
"google-gemini-cli",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
// cli-core sources auth from ~/.gemini/ — apiKey not required.
|
||||
const base = buildBaseOptions(model, options, options?.apiKey ?? "");
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
if (isAutoReasoning(options.reasoning)) {
|
||||
if (isGemini3Model(model.id)) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: "THINKING_LEVEL_UNSPECIFIED",
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: -1,
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(
|
||||
resolveReasoningLevel(model, options.reasoning),
|
||||
)!;
|
||||
if (isGemini3Model(model.id)) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGeminiCliThinkingLevel(effort, model.id),
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...options.thinkingBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
let thinkingBudget = budgets[effort]!;
|
||||
const maxTokens = Math.min(
|
||||
(base.maxTokens || 0) + thinkingBudget,
|
||||
model.maxTokens,
|
||||
);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
maxTokens,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: thinkingBudget,
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
};
|
||||
|
||||
/**
|
||||
* Build a `GenerateContentParameters` payload for cli-core's `CodeAssistServer.generateContentStream()`.
|
||||
*
|
||||
* This is the raw genai Content/Config shape (`@google/genai`), not the legacy Cloud Code Assist envelope.
|
||||
* cli-core wraps it with project, requestId, User-Agent, and retry logic; we only provide content/tools/config.
|
||||
* Unlike the old path, we do NOT need to set `project` or `requestId` — cli-core infers project from `setupUser()`.
|
||||
* Returns the exact shape the server's `generateContentStream()` method expects (casting through `any`
|
||||
* at the call site handles the vendored `@google/genai` type split).
|
||||
*/
|
||||
function buildRequest(
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options: GoogleGeminiCliOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const config: NonNullable<GenerateContentParameters["config"]> = {};
|
||||
if (options.temperature !== undefined)
|
||||
config.temperature = options.temperature;
|
||||
if (options.maxTokens !== undefined)
|
||||
config.maxOutputTokens = options.maxTokens;
|
||||
|
||||
// Thinking config
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
|
||||
if (options.thinking.level !== undefined) {
|
||||
thinkingConfig.thinkingLevel = options.thinking
|
||||
.level as ThinkingConfig["thinkingLevel"];
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (context.systemPrompt) {
|
||||
config.systemInstruction = {
|
||||
parts: [{ text: sanitizeSurrogates(context.systemPrompt) }],
|
||||
} as Content;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
// Claude models historically needed the legacy `parameters` field, but
|
||||
// Claude via gemini-cli is no longer supported (Antigravity was the
|
||||
// only path). Keep the useParameters=false default.
|
||||
const useParameters = false;
|
||||
config.tools = convertTools(context.tools, useParameters) as NonNullable<
|
||||
GenerateContentParameters["config"]
|
||||
>["tools"];
|
||||
if (options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
/**
|
||||
* Map a normalized thinking level (minimal/low/medium/high) to the Gemini 3 wire format.
|
||||
*
|
||||
* Gemini 3 Pro only supports LOW/HIGH (maps minimal/low -> LOW, medium/high -> HIGH).
|
||||
* Gemini 3 Flash supports all four (MINIMAL/LOW/MEDIUM/HIGH one-to-one).
|
||||
* Used when `options.thinking.level` is set for Gemini 3 models.
|
||||
*/
|
||||
function getGeminiCliThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
modelId: string,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { sanitizeSchemaForGoogle } from "./google-shared.js";
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// sanitizeSchemaForGoogle
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("sanitizeSchemaForGoogle", () => {
|
||||
it("passes through primitives unchanged", () => {
|
||||
assert.equal(sanitizeSchemaForGoogle(null), null);
|
||||
assert.equal(sanitizeSchemaForGoogle(42), 42);
|
||||
assert.equal(sanitizeSchemaForGoogle("hello"), "hello");
|
||||
assert.equal(sanitizeSchemaForGoogle(true), true);
|
||||
});
|
||||
|
||||
it("passes through a valid schema with no banned fields", () => {
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
name: { type: "string" },
|
||||
age: { type: "number" },
|
||||
},
|
||||
required: ["name"],
|
||||
};
|
||||
assert.deepEqual(sanitizeSchemaForGoogle(schema), schema);
|
||||
});
|
||||
|
||||
it("removes top-level patternProperties", () => {
|
||||
const schema = {
|
||||
type: "object",
|
||||
patternProperties: { "^S_": { type: "string" } },
|
||||
properties: { foo: { type: "string" } },
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as Record<string, unknown>;
|
||||
assert.ok(!("patternProperties" in result));
|
||||
assert.deepEqual(result.properties, { foo: { type: "string" } });
|
||||
});
|
||||
|
||||
it("removes nested patternProperties", () => {
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
nested: {
|
||||
type: "object",
|
||||
patternProperties: { ".*": { type: "string" } },
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
assert.ok(!("patternProperties" in result.properties.nested));
|
||||
});
|
||||
|
||||
it("converts top-level const to enum", () => {
|
||||
const schema = { const: "fixed-value" };
|
||||
const result = sanitizeSchemaForGoogle(schema) as Record<string, unknown>;
|
||||
assert.deepEqual(result.enum, ["fixed-value"]);
|
||||
assert.ok(!("const" in result));
|
||||
});
|
||||
|
||||
it("converts const to enum inside anyOf", () => {
|
||||
const schema = {
|
||||
anyOf: [{ const: "a" }, { const: "b" }, { type: "string" }],
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
assert.deepEqual(result.anyOf[0], { enum: ["a"] });
|
||||
assert.deepEqual(result.anyOf[1], { enum: ["b"] });
|
||||
assert.deepEqual(result.anyOf[2], { type: "string" });
|
||||
});
|
||||
|
||||
it("converts const to enum inside oneOf", () => {
|
||||
const schema = {
|
||||
oneOf: [{ const: "x" }, { const: "y" }],
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
assert.deepEqual(result.oneOf[0], { enum: ["x"] });
|
||||
assert.deepEqual(result.oneOf[1], { enum: ["y"] });
|
||||
});
|
||||
|
||||
it("recursively sanitizes deeply nested schemas", () => {
|
||||
const schema = {
|
||||
type: "object",
|
||||
properties: {
|
||||
level1: {
|
||||
type: "object",
|
||||
properties: {
|
||||
level2: {
|
||||
anyOf: [{ const: "deep" }, { type: "null" }],
|
||||
patternProperties: { ".*": { type: "string" } },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
const level2 = result.properties.level1.properties.level2;
|
||||
assert.deepEqual(level2.anyOf[0], { enum: ["deep"] });
|
||||
assert.ok(!("patternProperties" in level2));
|
||||
});
|
||||
|
||||
it("sanitizes items in array schemas", () => {
|
||||
const schema = {
|
||||
type: "array",
|
||||
items: {
|
||||
anyOf: [{ const: "foo" }, { type: "string" }],
|
||||
},
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
assert.deepEqual(result.items.anyOf[0], { enum: ["foo"] });
|
||||
});
|
||||
|
||||
it("sanitizes arrays of schemas", () => {
|
||||
const input = [{ const: "a" }, { const: "b" }];
|
||||
const result = sanitizeSchemaForGoogle(input) as any[];
|
||||
assert.deepEqual(result[0], { enum: ["a"] });
|
||||
assert.deepEqual(result[1], { enum: ["b"] });
|
||||
});
|
||||
|
||||
it("preserves non-string const values unchanged", () => {
|
||||
// Only string const values are converted; number const is passed through
|
||||
const schema = { const: 42 };
|
||||
const result = sanitizeSchemaForGoogle(schema) as Record<string, unknown>;
|
||||
assert.equal(result.const, 42);
|
||||
assert.ok(!("enum" in result));
|
||||
});
|
||||
|
||||
it("sanitizes additionalProperties", () => {
|
||||
const schema = {
|
||||
type: "object",
|
||||
additionalProperties: {
|
||||
patternProperties: { "^x-": { type: "string" } },
|
||||
},
|
||||
};
|
||||
const result = sanitizeSchemaForGoogle(schema) as any;
|
||||
assert.ok(!("patternProperties" in result.additionalProperties));
|
||||
});
|
||||
});
|
||||
|
|
@ -1,423 +0,0 @@
|
|||
/**
|
||||
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Content,
|
||||
FinishReason,
|
||||
FunctionCallingConfigMode,
|
||||
type Part,
|
||||
} from "@google/genai";
|
||||
import type {
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
Tool,
|
||||
} from "../types.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
type GoogleApiType =
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex";
|
||||
|
||||
/**
|
||||
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
|
||||
*
|
||||
* Protocol note (Gemini / Vertex AI thought signatures):
|
||||
* - `thought: true` is the definitive marker for thinking content (thought summaries).
|
||||
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
|
||||
* used to preserve reasoning context across multi-turn interactions.
|
||||
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
|
||||
* indicate the part itself is thinking content.
|
||||
* - For non-functionCall responses, the signature appears on the last part for context replay.
|
||||
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
|
||||
* do not merge/move signatures across parts.
|
||||
*
|
||||
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
*/
|
||||
export function isThinkingPart(
|
||||
part: Pick<Part, "thought" | "thoughtSignature">,
|
||||
): boolean {
|
||||
return part.thought === true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retain thought signatures during streaming.
|
||||
*
|
||||
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
|
||||
* This helper preserves the last non-empty signature for the current block.
|
||||
*
|
||||
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
|
||||
* a signature from being overwritten with `undefined` within the same streamed block.
|
||||
*/
|
||||
export function retainThoughtSignature(
|
||||
existing: string | undefined,
|
||||
incoming: string | undefined,
|
||||
): string | undefined {
|
||||
if (typeof incoming === "string" && incoming.length > 0) return incoming;
|
||||
return existing;
|
||||
}
|
||||
|
||||
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
|
||||
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
|
||||
|
||||
// Sentinel value that tells the Gemini API to skip thought signature validation.
|
||||
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
|
||||
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
|
||||
|
||||
function isValidThoughtSignature(signature: string | undefined): boolean {
|
||||
if (!signature) return false;
|
||||
if (signature.length % 4 !== 0) return false;
|
||||
return base64SignaturePattern.test(signature);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only keep signatures from the same provider/model and with valid base64.
|
||||
*/
|
||||
function resolveThoughtSignature(
|
||||
isSameProviderAndModel: boolean,
|
||||
signature: string | undefined,
|
||||
): string | undefined {
|
||||
return isSameProviderAndModel && isValidThoughtSignature(signature)
|
||||
? signature
|
||||
: undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
|
||||
*/
|
||||
function requiresToolCallId(modelId: string): boolean {
|
||||
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert internal messages to Gemini Content[] format.
|
||||
*/
|
||||
export function convertMessages<T extends GoogleApiType>(
|
||||
model: Model<T>,
|
||||
context: Context,
|
||||
): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!requiresToolCallId(model.id)) return id;
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
"google-generative-ai",
|
||||
);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: sanitizeSurrogates(item.text) };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image")
|
||||
? parts.filter((p) => p.text !== undefined)
|
||||
: parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
// Check if message is from same provider and model - only then keep thinking blocks
|
||||
const isSameProviderAndModel =
|
||||
msg.provider === model.provider && msg.model === model.id;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
// Skip empty text blocks - they can cause issues with some models
|
||||
if (!block.text || block.text.trim() === "") continue;
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.textSignature,
|
||||
);
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.text),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Skip empty thinking blocks
|
||||
if (!block.thinking || block.thinking.trim() === "") continue;
|
||||
// Only keep as thinking block if same provider AND same model
|
||||
// Otherwise convert to plain text (no tags to avoid model mimicking them)
|
||||
if (isSameProviderAndModel) {
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.thinkingSignature,
|
||||
);
|
||||
parts.push({
|
||||
thought: true,
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
const thoughtSignature = resolveThoughtSignature(
|
||||
isSameProviderAndModel,
|
||||
block.thoughtSignature,
|
||||
);
|
||||
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
|
||||
// Use the skip_thought_signature_validator sentinel for unsigned function calls
|
||||
// (e.g. replayed from providers without thought signatures).
|
||||
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
|
||||
const effectiveSignature =
|
||||
thoughtSignature ||
|
||||
(isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
|
||||
const part: Part = {
|
||||
functionCall: {
|
||||
name: block.name,
|
||||
args: block.arguments ?? {},
|
||||
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
|
||||
},
|
||||
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
|
||||
};
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textContent = msg.content.filter(
|
||||
(c): c is TextContent => c.type === "text",
|
||||
);
|
||||
const textResult = textContent.map((c) => c.text).join("\n");
|
||||
const imageContent = model.input.includes("image")
|
||||
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
||||
: [];
|
||||
|
||||
const hasText = textResult.length > 0;
|
||||
const hasImages = imageContent.length > 0;
|
||||
|
||||
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
|
||||
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
|
||||
// Older models don't support this, so we put images in a separate user message.
|
||||
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
|
||||
|
||||
// Use "output" key for success, "error" key for errors as per SDK documentation
|
||||
const responseValue = hasText
|
||||
? sanitizeSurrogates(textResult)
|
||||
: hasImages
|
||||
? "(see attached image)"
|
||||
: "";
|
||||
|
||||
const imageParts: Part[] = imageContent.map((imageBlock) => ({
|
||||
inlineData: {
|
||||
mimeType: imageBlock.mimeType,
|
||||
data: imageBlock.data,
|
||||
},
|
||||
}));
|
||||
|
||||
const includeId = requiresToolCallId(model.id);
|
||||
const functionResponsePart: Part = {
|
||||
functionResponse: {
|
||||
name: msg.toolName,
|
||||
response: msg.isError
|
||||
? { error: responseValue }
|
||||
: { output: responseValue },
|
||||
// Nest images inside functionResponse.parts for Gemini 3
|
||||
...(hasImages &&
|
||||
supportsMultimodalFunctionResponse && { parts: imageParts }),
|
||||
...(includeId ? { id: msg.toolCallId } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
// Cloud Code Assist API requires all function responses to be in a single user turn.
|
||||
// Check if the last content is already a user turn with function responses and merge.
|
||||
const lastContent = contents[contents.length - 1];
|
||||
if (
|
||||
lastContent?.role === "user" &&
|
||||
lastContent.parts?.some((p) => p.functionResponse)
|
||||
) {
|
||||
lastContent.parts.push(functionResponsePart);
|
||||
} else {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [functionResponsePart],
|
||||
});
|
||||
}
|
||||
|
||||
// For older models, add images in a separate user message
|
||||
if (hasImages && !supportsMultimodalFunctionResponse) {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: "Tool result image:" }, ...imageParts],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize a JSON Schema for Google's function declarations API.
|
||||
* Google's API rejects `patternProperties` and `const` fields which are valid in JSON Schema.
|
||||
*
|
||||
* This function recursively:
|
||||
* - Removes all `patternProperties` fields
|
||||
* - Converts `const: "value"` to `enum: ["value"]` in anyOf/oneOf blocks
|
||||
*
|
||||
* Needed because Google Cloud Code Assist (google-gemini-cli provider) uses a
|
||||
* restricted subset of JSON Schema and rejects patternProperties / const.
|
||||
*/
|
||||
export function sanitizeSchemaForGoogle(schema: unknown): unknown {
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return schema;
|
||||
}
|
||||
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.map((item) => sanitizeSchemaForGoogle(item));
|
||||
}
|
||||
|
||||
const obj = schema as Record<string, unknown>;
|
||||
const sanitized: Record<string, unknown> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
// Skip patternProperties entirely — not supported by Google's API
|
||||
if (key === "patternProperties") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert const to enum — Google's API rejects the const keyword
|
||||
if (key === "const" && typeof value === "string") {
|
||||
sanitized.enum = [value];
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recursively sanitize all nested objects and arrays
|
||||
if (typeof value === "object") {
|
||||
sanitized[key] = sanitizeSchemaForGoogle(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return sanitized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert tools to Gemini function declarations format.
|
||||
*
|
||||
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
|
||||
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
|
||||
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
|
||||
* models, where the API translates `parameters` into Anthropic's `input_schema`.
|
||||
*
|
||||
* The schema is automatically sanitized to remove fields not supported by Google's
|
||||
* function declarations API (patternProperties, const converted to enum, etc.).
|
||||
*/
|
||||
export function convertTools(
|
||||
tools: Tool[],
|
||||
useParameters = false,
|
||||
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
|
||||
if (tools.length === 0) return undefined;
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
...(useParameters
|
||||
? { parameters: sanitizeSchemaForGoogle(tool.parameters) }
|
||||
: { parametersJsonSchema: sanitizeSchemaForGoogle(tool.parameters) }),
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Map tool choice string to Gemini FunctionCallingConfigMode.
|
||||
*/
|
||||
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Gemini FinishReason to our StopReason.
|
||||
*/
|
||||
export function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.IMAGE_PROHIBITED_CONTENT:
|
||||
case FinishReason.IMAGE_RECITATION:
|
||||
case FinishReason.IMAGE_OTHER:
|
||||
case FinishReason.RECITATION:
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
case FinishReason.NO_IMAGE:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map string finish reason to our StopReason (for raw API responses).
|
||||
*/
|
||||
export function mapStopReasonString(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,582 +0,0 @@
|
|||
// Lazy-loaded: Google GenAI SDK is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Google Vertex models.
|
||||
import type {
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
ThinkingLevel as PiThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
let _GoogleVertexClass: typeof GoogleGenAI | undefined;
|
||||
async function getGoogleVertexClass(): Promise<typeof GoogleGenAI> {
|
||||
if (!_GoogleVertexClass) {
|
||||
const mod = await import("@google/genai");
|
||||
_GoogleVertexClass = mod.GoogleGenAI;
|
||||
}
|
||||
return _GoogleVertexClass;
|
||||
}
|
||||
|
||||
export interface GoogleVertexOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
project?: string;
|
||||
location?: string;
|
||||
}
|
||||
|
||||
const API_VERSION = "v1";
|
||||
|
||||
// ThinkingLevel is a string enum where each value equals its key name.
|
||||
// Using string literals avoids importing the SDK at module load time.
|
||||
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, string> = {
|
||||
THINKING_LEVEL_UNSPECIFIED: "THINKING_LEVEL_UNSPECIFIED",
|
||||
MINIMAL: "MINIMAL",
|
||||
LOW: "LOW",
|
||||
MEDIUM: "MEDIUM",
|
||||
HIGH: "HIGH",
|
||||
};
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogleVertex: StreamFunction<
|
||||
"google-vertex",
|
||||
GoogleVertexOptions
|
||||
> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: GoogleVertexOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-vertex" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const project = resolveProject(options);
|
||||
const location = resolveLocation(options);
|
||||
const client = await createClient(
|
||||
model,
|
||||
project,
|
||||
location,
|
||||
options?.headers,
|
||||
);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: undefined,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some(
|
||||
(b) => b.type === "toolCall" && b.id === providedId,
|
||||
);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments:
|
||||
(part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && {
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
}),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) +
|
||||
(chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleVertex: StreamFunction<
|
||||
"google-vertex",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
if (isAutoReasoning(options.reasoning)) {
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: "THINKING_LEVEL_UNSPECIFIED",
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: -1,
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(
|
||||
resolveReasoningLevel(model, options.reasoning),
|
||||
)!;
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, geminiModel),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(
|
||||
geminiModel,
|
||||
effort,
|
||||
options.thinkingBudgets,
|
||||
),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
};
|
||||
|
||||
async function createClient(
|
||||
model: Model<"google-vertex">,
|
||||
project: string,
|
||||
location: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): Promise<GoogleGenAI> {
|
||||
const httpOptions: { headers?: Record<string, string> } = {};
|
||||
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
||||
const GoogleGenAIClass = await getGoogleVertexClass();
|
||||
|
||||
return new GoogleGenAIClass({
|
||||
vertexai: true,
|
||||
project,
|
||||
location,
|
||||
apiVersion: API_VERSION,
|
||||
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function resolveProject(options?: GoogleVertexOptions): string {
|
||||
const project =
|
||||
options?.project ||
|
||||
process.env.GOOGLE_CLOUD_PROJECT ||
|
||||
process.env.GCLOUD_PROJECT;
|
||||
if (!project) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
||||
);
|
||||
}
|
||||
return project;
|
||||
}
|
||||
|
||||
function resolveLocation(options?: GoogleVertexOptions): string {
|
||||
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
||||
if (!location) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.",
|
||||
);
|
||||
}
|
||||
return location;
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options: GoogleVertexOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && {
|
||||
systemInstruction: sanitizeSurrogates(context.systemPrompt),
|
||||
}),
|
||||
...(context.tools &&
|
||||
context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast safe: string values match ThinkingLevel enum values exactly
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[
|
||||
options.thinking.level
|
||||
] as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1,545 +0,0 @@
|
|||
// Lazy-loaded: Google GenAI SDK (~186ms) is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Google models.
|
||||
import type {
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
|
||||
let _GoogleGenAIClass: typeof GoogleGenAI | undefined;
|
||||
async function getGoogleGenAIClass(): Promise<typeof GoogleGenAI> {
|
||||
if (!_GoogleGenAIClass) {
|
||||
const mod = await import("@google/genai");
|
||||
_GoogleGenAIClass = mod.GoogleGenAI;
|
||||
}
|
||||
return _GoogleGenAIClass;
|
||||
}
|
||||
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
export interface GoogleOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
}
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogle: StreamFunction<
|
||||
"google-generative-ai",
|
||||
GoogleOptions
|
||||
> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = await createClient(model, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: undefined,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Generate unique ID if not provided or if it's a duplicate
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some(
|
||||
(b) => b.type === "toolCall" && b.id === providedId,
|
||||
);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments:
|
||||
(part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && {
|
||||
thoughtSignature: part.thoughtSignature,
|
||||
}),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) +
|
||||
(chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogle: StreamFunction<
|
||||
"google-generative-ai",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
if (isAutoReasoning(options.reasoning)) {
|
||||
const googleModel = model as Model<"google-generative-ai">;
|
||||
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: "THINKING_LEVEL_UNSPECIFIED",
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: -1,
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(
|
||||
resolveReasoningLevel(model, options.reasoning),
|
||||
)!;
|
||||
const googleModel = model as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, googleModel),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(
|
||||
googleModel,
|
||||
effort,
|
||||
options.thinkingBudgets,
|
||||
),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
};
|
||||
|
||||
async function createClient(
|
||||
model: Model<"google-generative-ai">,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): Promise<GoogleGenAI> {
|
||||
const httpOptions: {
|
||||
baseUrl?: string;
|
||||
apiVersion?: string;
|
||||
headers?: Record<string, string>;
|
||||
} = {};
|
||||
if (model.baseUrl) {
|
||||
httpOptions.baseUrl = model.baseUrl;
|
||||
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
|
||||
}
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const GoogleGenAIClass = await getGoogleGenAIClass();
|
||||
return new GoogleGenAIClass({
|
||||
apiKey,
|
||||
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && {
|
||||
systemInstruction: sanitizeSurrogates(context.systemPrompt),
|
||||
}),
|
||||
...(context.tools &&
|
||||
context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1,762 +0,0 @@
|
|||
// Lazy-loaded: Mistral SDK (~369ms) is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use Mistral models.
|
||||
import type { Mistral } from "@mistralai/mistralai";
|
||||
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
|
||||
import type {
|
||||
ChatCompletionStreamRequest,
|
||||
ChatCompletionStreamRequestMessage,
|
||||
CompletionEvent,
|
||||
ContentChunk,
|
||||
FunctionTool,
|
||||
} from "@mistralai/mistralai/models/components/index.js";
|
||||
|
||||
let _MistralClass: typeof Mistral | undefined;
|
||||
async function getMistralClass(): Promise<typeof Mistral> {
|
||||
if (!_MistralClass) {
|
||||
const mod = await import("@mistralai/mistralai");
|
||||
_MistralClass = mod.Mistral;
|
||||
}
|
||||
return _MistralClass;
|
||||
}
|
||||
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
|
||||
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
|
||||
|
||||
/**
|
||||
* Provider-specific options for the Mistral API.
|
||||
*/
|
||||
export interface MistralOptions extends StreamOptions {
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "any"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } };
|
||||
promptMode?: "reasoning";
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream responses from Mistral using `chat.stream`.
|
||||
*/
|
||||
export const streamMistral: StreamFunction<
|
||||
"mistral-conversations",
|
||||
MistralOptions
|
||||
> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: MistralOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = createOutput(model);
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
|
||||
const MistralSDK = await getMistralClass();
|
||||
const mistral = new MistralSDK({
|
||||
apiKey,
|
||||
serverURL: model.baseUrl,
|
||||
});
|
||||
|
||||
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
context.messages,
|
||||
model,
|
||||
(id) => normalizeMistralToolCallId(id),
|
||||
"mistral-conversations",
|
||||
);
|
||||
|
||||
let payload = buildChatPayload(
|
||||
model,
|
||||
context,
|
||||
transformedMessages,
|
||||
options,
|
||||
);
|
||||
const nextPayload = await options?.onPayload?.(payload, model);
|
||||
if (nextPayload !== undefined) {
|
||||
payload = nextPayload as ChatCompletionStreamRequest;
|
||||
}
|
||||
const mistralStream = await mistral.chat.stream(
|
||||
payload,
|
||||
buildRequestOptions(model, options),
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
await consumeChatStream(model, output, stream, mistralStream);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = formatMistralError(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
|
||||
*/
|
||||
export const streamSimpleMistral: StreamFunction<
|
||||
"mistral-conversations",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoning = clampReasoning(
|
||||
resolveReasoningLevel(model, options?.reasoning),
|
||||
);
|
||||
|
||||
return streamMistral(model, context, {
|
||||
...base,
|
||||
promptMode: shouldUseMistralReasoningPromptMode(model, reasoning)
|
||||
? "reasoning"
|
||||
: undefined,
|
||||
} satisfies MistralOptions);
|
||||
};
|
||||
|
||||
export function shouldUseMistralReasoningPromptMode(
|
||||
model: Model<"mistral-conversations">,
|
||||
reasoning?: string | null,
|
||||
): boolean {
|
||||
if (!model.reasoning || !reasoning) return false;
|
||||
const id = model.id.toLowerCase();
|
||||
return id.startsWith("magistral");
|
||||
}
|
||||
|
||||
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createMistralToolCallIdNormalizer(): (id: string) => string {
|
||||
const idMap = new Map<string, string>();
|
||||
const reverseMap = new Map<string, string>();
|
||||
|
||||
return (id: string): string => {
|
||||
const existing = idMap.get(id);
|
||||
if (existing) return existing;
|
||||
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
const candidate = deriveMistralToolCallId(id, attempt);
|
||||
const owner = reverseMap.get(candidate);
|
||||
if (!owner || owner === id) {
|
||||
idMap.set(id, candidate);
|
||||
reverseMap.set(candidate, id);
|
||||
return candidate;
|
||||
}
|
||||
attempt++;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function deriveMistralToolCallId(id: string, attempt: number): string {
|
||||
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
|
||||
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH)
|
||||
return normalized;
|
||||
const seedBase = normalized || id;
|
||||
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
|
||||
return shortHash(seed)
|
||||
.replace(/[^a-zA-Z0-9]/g, "")
|
||||
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
|
||||
}
|
||||
|
||||
function formatMistralError(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
|
||||
const statusCode =
|
||||
typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
|
||||
const bodyText =
|
||||
typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
|
||||
if (statusCode !== undefined && bodyText) {
|
||||
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
|
||||
}
|
||||
if (statusCode !== undefined)
|
||||
return `Mistral API error (${statusCode}): ${error.message}`;
|
||||
return error.message;
|
||||
}
|
||||
return safeJsonStringify(error);
|
||||
}
|
||||
|
||||
function truncateErrorText(text: string, maxChars: number): string {
|
||||
if (text.length <= maxChars) return text;
|
||||
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
|
||||
}
|
||||
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
try {
|
||||
const serialized = JSON.stringify(value);
|
||||
return serialized === undefined ? String(value) : serialized;
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
|
||||
function buildRequestOptions(
|
||||
model: Model<"mistral-conversations">,
|
||||
options?: MistralOptions,
|
||||
): RequestOptions {
|
||||
const requestOptions: RequestOptions = {};
|
||||
if (options?.signal) requestOptions.signal = options.signal;
|
||||
requestOptions.retries = { strategy: "none" };
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (model.headers) Object.assign(headers, model.headers);
|
||||
if (options?.headers) Object.assign(headers, options.headers);
|
||||
|
||||
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
|
||||
// Respect explicit caller-provided header values.
|
||||
if (options?.sessionId && !headers["x-affinity"]) {
|
||||
headers["x-affinity"] = options.sessionId;
|
||||
}
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
requestOptions.headers = headers;
|
||||
}
|
||||
|
||||
return requestOptions;
|
||||
}
|
||||
|
||||
function buildChatPayload(
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
messages: Message[],
|
||||
options?: MistralOptions,
|
||||
): ChatCompletionStreamRequest {
|
||||
const payload: ChatCompletionStreamRequest = {
|
||||
model: model.id,
|
||||
stream: true,
|
||||
messages: toChatMessages(messages, model.input.includes("image")),
|
||||
};
|
||||
|
||||
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
|
||||
if (options?.temperature !== undefined)
|
||||
payload.temperature = options.temperature;
|
||||
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
|
||||
if (options?.toolChoice)
|
||||
payload.toolChoice = mapToolChoice(options.toolChoice);
|
||||
if (options?.promptMode) payload.promptMode = options.promptMode as any;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
payload.messages.unshift({
|
||||
role: "system",
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
async function consumeChatStream(
|
||||
model: Model<"mistral-conversations">,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
mistralStream: AsyncIterable<CompletionEvent>,
|
||||
): Promise<void> {
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const toolBlocksByKey = new Map<string, number>();
|
||||
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (!block) return;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
for await (const event of mistralStream) {
|
||||
const chunk = event.data;
|
||||
|
||||
if (chunk.usage) {
|
||||
output.usage.input = chunk.usage.promptTokens || 0;
|
||||
output.usage.output = chunk.usage.completionTokens || 0;
|
||||
output.usage.cacheRead = 0;
|
||||
output.usage.cacheWrite = 0;
|
||||
output.usage.totalTokens =
|
||||
chunk.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finishReason) {
|
||||
output.stopReason = mapChatStopReason(choice.finishReason);
|
||||
}
|
||||
|
||||
const delta = choice.delta;
|
||||
if (delta.content !== null && delta.content !== undefined) {
|
||||
const contentItems =
|
||||
typeof delta.content === "string" ? [delta.content] : delta.content;
|
||||
for (const item of contentItems) {
|
||||
if (typeof item === "string") {
|
||||
const textDelta = sanitizeSurrogates(item);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "thinking") {
|
||||
const deltaText = item.thinking
|
||||
.map((part) => ("text" in part ? part.text : ""))
|
||||
.filter((text) => text.length > 0)
|
||||
.join("");
|
||||
const thinkingDelta = sanitizeSurrogates(deltaText);
|
||||
if (!thinkingDelta) continue;
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.thinking += thinkingDelta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: thinkingDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "text") {
|
||||
const textDelta = sanitizeSurrogates(item.text);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = delta.toolCalls || [];
|
||||
for (const toolCall of toolCalls) {
|
||||
if (currentBlock) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = null;
|
||||
}
|
||||
const callId =
|
||||
toolCall.id && toolCall.id !== "null"
|
||||
? toolCall.id
|
||||
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
|
||||
const key = `${callId}:${toolCall.index || 0}`;
|
||||
const existingIndex = toolBlocksByKey.get(key);
|
||||
let block: (ToolCall & { partialArgs?: string }) | undefined;
|
||||
|
||||
if (existingIndex !== undefined) {
|
||||
const existing = output.content[existingIndex];
|
||||
if (existing?.type === "toolCall") {
|
||||
block = existing as ToolCall & { partialArgs?: string };
|
||||
}
|
||||
}
|
||||
|
||||
if (!block) {
|
||||
block = {
|
||||
type: "toolCall",
|
||||
id: callId,
|
||||
name: toolCall.function.name,
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(block);
|
||||
toolBlocksByKey.set(key, output.content.length - 1);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: output.content.length - 1,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
const argsDelta =
|
||||
typeof toolCall.function.arguments === "string"
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall.function.arguments || {});
|
||||
block.partialArgs = (block.partialArgs || "") + argsDelta;
|
||||
block.arguments = parseStreamingJson<Record<string, unknown>>(
|
||||
block.partialArgs,
|
||||
);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: toolBlocksByKey.get(key)!,
|
||||
delta: argsDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
for (const index of toolBlocksByKey.values()) {
|
||||
const block = output.content[index];
|
||||
if (block.type !== "toolCall") continue;
|
||||
const toolBlock = block as ToolCall & { partialArgs?: string };
|
||||
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(
|
||||
toolBlock.partialArgs,
|
||||
);
|
||||
delete toolBlock.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: toolBlock,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export function sanitizeMistralToolParameters(
|
||||
value: unknown,
|
||||
): Record<string, unknown> {
|
||||
const sanitized = sanitizeJsonSchemaValue(value);
|
||||
if (isPlainRecord(sanitized)) return sanitized;
|
||||
return { type: "object", properties: {} };
|
||||
}
|
||||
|
||||
function sanitizeJsonSchemaValue(value: unknown): unknown {
|
||||
if (value === null) return null;
|
||||
if (Array.isArray(value)) {
|
||||
return value
|
||||
.map((item) => sanitizeJsonSchemaValue(item))
|
||||
.filter((item) => item !== undefined);
|
||||
}
|
||||
if (isPlainRecord(value)) {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const [key, item] of Object.entries(value)) {
|
||||
const sanitized = sanitizeJsonSchemaValue(item);
|
||||
if (sanitized !== undefined) result[key] = sanitized;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
if (
|
||||
typeof value === "string" ||
|
||||
typeof value === "number" ||
|
||||
typeof value === "boolean"
|
||||
) {
|
||||
return value;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function isPlainRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function toFunctionTools(
|
||||
tools: Tool[],
|
||||
): Array<FunctionTool & { type: "function" }> {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: sanitizeMistralToolParameters(tool.parameters),
|
||||
strict: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function toChatMessages(
|
||||
messages: Message[],
|
||||
supportsImages: boolean,
|
||||
): ChatCompletionStreamRequestMessage[] {
|
||||
const result: ChatCompletionStreamRequestMessage[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
|
||||
continue;
|
||||
}
|
||||
const hadImages = msg.content.some((item) => item.type === "image");
|
||||
const content: ContentChunk[] = msg.content
|
||||
.filter((item) => item.type === "text" || supportsImages)
|
||||
.map((item) => {
|
||||
if (item.type === "text")
|
||||
return { type: "text", text: sanitizeSurrogates(item.text) };
|
||||
return {
|
||||
type: "image_url",
|
||||
imageUrl: `data:${item.mimeType};base64,${item.data}`,
|
||||
};
|
||||
});
|
||||
if (content.length > 0) {
|
||||
result.push({ role: "user", content });
|
||||
continue;
|
||||
}
|
||||
if (hadImages && !supportsImages) {
|
||||
result.push({
|
||||
role: "user",
|
||||
content: "(image omitted: model does not support images)",
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
const contentParts: ContentChunk[] = [];
|
||||
const toolCalls: Array<{
|
||||
id: string;
|
||||
type: "function";
|
||||
function: { name: string; arguments: string };
|
||||
}> = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "thinking",
|
||||
thinking: [
|
||||
{ type: "text", text: sanitizeSurrogates(block.thinking) },
|
||||
],
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type !== "toolCall") {
|
||||
continue;
|
||||
}
|
||||
toolCalls.push({
|
||||
id: block.id,
|
||||
type: "function",
|
||||
function: {
|
||||
name: block.name,
|
||||
arguments: JSON.stringify(block.arguments || {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const assistantMessage: ChatCompletionStreamRequestMessage = {
|
||||
role: "assistant",
|
||||
};
|
||||
if (contentParts.length > 0) assistantMessage.content = contentParts;
|
||||
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
|
||||
if (contentParts.length > 0 || toolCalls.length > 0)
|
||||
result.push(assistantMessage);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolContent: ContentChunk[] = [];
|
||||
const textResult = msg.content
|
||||
.filter((part) => part.type === "text")
|
||||
.map((part) =>
|
||||
part.type === "text" ? sanitizeSurrogates(part.text) : "",
|
||||
)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((part) => part.type === "image");
|
||||
const toolText = buildToolResultText(
|
||||
textResult,
|
||||
hasImages,
|
||||
supportsImages,
|
||||
msg.isError,
|
||||
);
|
||||
toolContent.push({ type: "text", text: toolText });
|
||||
for (const part of msg.content) {
|
||||
if (!supportsImages) continue;
|
||||
if (part.type !== "image") continue;
|
||||
toolContent.push({
|
||||
type: "image_url",
|
||||
imageUrl: `data:${part.mimeType};base64,${part.data}`,
|
||||
});
|
||||
}
|
||||
result.push({
|
||||
role: "tool",
|
||||
toolCallId: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
content: toolContent,
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildToolResultText(
|
||||
text: string,
|
||||
hasImages: boolean,
|
||||
supportsImages: boolean,
|
||||
isError: boolean,
|
||||
): string {
|
||||
const trimmed = text.trim();
|
||||
const errorPrefix = isError ? "[tool error] " : "";
|
||||
|
||||
if (trimmed.length > 0) {
|
||||
const imageSuffix =
|
||||
hasImages && !supportsImages
|
||||
? "\n[tool image omitted: model does not support images]"
|
||||
: "";
|
||||
return `${errorPrefix}${trimmed}${imageSuffix}`;
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
if (supportsImages) {
|
||||
return isError
|
||||
? "[tool error] (see attached image)"
|
||||
: "(see attached image)";
|
||||
}
|
||||
return isError
|
||||
? "[tool error] (image omitted: model does not support images)"
|
||||
: "(image omitted: model does not support images)";
|
||||
}
|
||||
|
||||
return isError ? "[tool error] (no tool output)" : "(no tool output)";
|
||||
}
|
||||
|
||||
function mapToolChoice(
|
||||
choice: MistralOptions["toolChoice"],
|
||||
):
|
||||
| "auto"
|
||||
| "none"
|
||||
| "any"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } }
|
||||
| undefined {
|
||||
if (!choice) return undefined;
|
||||
if (
|
||||
choice === "auto" ||
|
||||
choice === "none" ||
|
||||
choice === "any" ||
|
||||
choice === "required"
|
||||
) {
|
||||
return choice as any;
|
||||
}
|
||||
return {
|
||||
type: "function",
|
||||
function: { name: choice.function.name },
|
||||
};
|
||||
}
|
||||
|
||||
function mapChatStopReason(reason: string | null): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
case "model_length":
|
||||
return "length";
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "error":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,672 +0,0 @@
|
|||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import {
|
||||
type CodexAppServerNotification,
|
||||
getCodexAppServerClient,
|
||||
} from "./codex-app-server-client.js";
|
||||
import { convertResponsesMessages } from "./openai-responses-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
export interface OpenAICodexResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "none" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "concise" | "detailed" | "off" | "on" | null;
|
||||
textVerbosity?: "low" | "medium" | "high";
|
||||
network_access?: boolean;
|
||||
web_search?: boolean;
|
||||
}
|
||||
|
||||
type AppServerReasoningSummary = "auto" | "concise" | "detailed" | "none";
|
||||
type JsonObject = Record<string, unknown>;
|
||||
|
||||
interface ThreadStartResponse {
|
||||
thread: { id: string };
|
||||
}
|
||||
|
||||
interface TurnStartResponse {
|
||||
turn: { id: string };
|
||||
}
|
||||
|
||||
interface AppServerItem {
|
||||
type: string;
|
||||
id?: string;
|
||||
text?: string;
|
||||
summary?: string[];
|
||||
content?: string[];
|
||||
server?: string;
|
||||
tool?: string;
|
||||
namespace?: string | null;
|
||||
arguments?: unknown;
|
||||
query?: string;
|
||||
}
|
||||
|
||||
interface TokenUsageBreakdown {
|
||||
totalTokens: number;
|
||||
inputTokens: number;
|
||||
cachedInputTokens: number;
|
||||
outputTokens: number;
|
||||
reasoningOutputTokens: number;
|
||||
}
|
||||
|
||||
const CODEX_TOOL_CALL_PROVIDERS = new Set([
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Stream a Codex turn through the installed `codex app-server`.
|
||||
*
|
||||
* Purpose: reuse Codex CLI's authenticated JSON-RPC backend instead of maintaining a hand-rolled ChatGPT transport.
|
||||
* Consumer: built-in provider registry for the `openai-codex-responses` API.
|
||||
*/
|
||||
export const streamOpenAICodexResponses: StreamFunction<
|
||||
"openai-codex-responses",
|
||||
OpenAICodexResponsesOptions
|
||||
> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = createAssistantMessage(model);
|
||||
let client: Awaited<ReturnType<typeof getCodexAppServerClient>> | undefined;
|
||||
try {
|
||||
client = await getCodexAppServerClient({
|
||||
cwd: process.cwd(),
|
||||
extraArgs: buildProcessConfig(model, options),
|
||||
});
|
||||
const thread = await client.request(
|
||||
"thread/start",
|
||||
buildThreadStartParams(model, context, options),
|
||||
options?.signal,
|
||||
);
|
||||
const threadId = readThreadId(thread);
|
||||
await injectPriorContext(
|
||||
client,
|
||||
threadId,
|
||||
model,
|
||||
context,
|
||||
options?.signal,
|
||||
);
|
||||
|
||||
const turnInput = buildTurnInput(context, model);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let activeTurnId: string | undefined;
|
||||
let cleanupTurnNotifications: (() => void) | undefined;
|
||||
const turnDone = new Promise<void>((resolve, reject) => {
|
||||
const mapper = new CodexNotificationMapper(
|
||||
threadId,
|
||||
output,
|
||||
stream,
|
||||
resolve,
|
||||
reject,
|
||||
);
|
||||
const unsubscribe = client!.onNotification((notification) =>
|
||||
mapper.handle(notification),
|
||||
);
|
||||
const onAbort = () => {
|
||||
if (activeTurnId) {
|
||||
client!.interruptTurn(threadId, activeTurnId).catch(() => {});
|
||||
}
|
||||
reject(new Error("Request was aborted"));
|
||||
};
|
||||
options?.signal?.addEventListener("abort", onAbort, { once: true });
|
||||
mapper.onDispose = () => {
|
||||
unsubscribe();
|
||||
options?.signal?.removeEventListener("abort", onAbort);
|
||||
};
|
||||
cleanupTurnNotifications = mapper.onDispose;
|
||||
});
|
||||
|
||||
let turn: unknown;
|
||||
try {
|
||||
turn = await client.request(
|
||||
"turn/start",
|
||||
buildTurnStartParams(threadId, turnInput, model, options),
|
||||
options?.signal,
|
||||
);
|
||||
} catch (error) {
|
||||
cleanupTurnNotifications?.();
|
||||
throw error;
|
||||
}
|
||||
activeTurnId = readTurnId(turn);
|
||||
await turnDone;
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason === "toolUse" ? "toolUse" : "stop",
|
||||
message: output,
|
||||
});
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end(output);
|
||||
} finally {
|
||||
client?.releaseIfIdle().catch(() => {});
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Stream a simple Codex request while preserving the shared pi-ai option surface.
|
||||
*
|
||||
* Purpose: map simple reasoning options to Codex app-server turn parameters without requiring callers to know app-server details.
|
||||
* Consumer: built-in provider registry `streamSimple` calls.
|
||||
*/
|
||||
export const streamSimpleOpenAICodexResponses: StreamFunction<
|
||||
"openai-codex-responses",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options);
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options?.reasoning);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? effectiveReasoning
|
||||
: clampReasoning(effectiveReasoning);
|
||||
|
||||
return streamOpenAICodexResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAICodexResponsesOptions);
|
||||
};
|
||||
|
||||
class CodexNotificationMapper {
|
||||
private readonly blocksByItemId = new Map<string, number>();
|
||||
private usage: Usage | undefined;
|
||||
onDispose?: () => void;
|
||||
|
||||
constructor(
|
||||
private readonly threadId: string,
|
||||
private readonly output: AssistantMessage,
|
||||
private readonly stream: AssistantMessageEventStream,
|
||||
private readonly resolve: () => void,
|
||||
private readonly reject: (reason: Error) => void,
|
||||
) {}
|
||||
|
||||
handle(notification: CodexAppServerNotification): void {
|
||||
try {
|
||||
const params = asObject(notification.params);
|
||||
const notificationThreadId = readString(params?.threadId);
|
||||
if (
|
||||
notificationThreadId !== undefined &&
|
||||
notificationThreadId !== this.threadId
|
||||
)
|
||||
return;
|
||||
|
||||
if (notification.method === "item/started")
|
||||
this.handleItemStarted(params);
|
||||
else if (notification.method === "item/agent_message/delta")
|
||||
this.handleAgentMessageDelta(params);
|
||||
else if (
|
||||
notification.method === "item/reasoning/text_delta" ||
|
||||
notification.method === "item/reasoning/summary_text_delta"
|
||||
)
|
||||
this.handleReasoningDelta(params);
|
||||
else if (notification.method === "item/completed")
|
||||
this.handleItemCompleted(params);
|
||||
else if (notification.method === "response/item/completed")
|
||||
this.handleRawResponseItemCompleted(params);
|
||||
else if (notification.method === "thread/token_usage/updated")
|
||||
this.handleUsage(params);
|
||||
else if (notification.method === "turn/completed")
|
||||
this.handleTurnCompleted(params);
|
||||
else if (notification.method === "turn/failed")
|
||||
this.reject(new Error(readErrorMessage(params) ?? "Codex turn failed"));
|
||||
} catch (error) {
|
||||
this.dispose();
|
||||
this.reject(error instanceof Error ? error : new Error(String(error)));
|
||||
}
|
||||
}
|
||||
|
||||
private handleItemStarted(params: JsonObject | undefined): void {
|
||||
const item = asObject(params?.item) as unknown as AppServerItem | undefined;
|
||||
if (!item?.id) return;
|
||||
if (item.type === "agentMessage") this.startText(item.id);
|
||||
else if (item.type === "reasoning") this.startThinking(item.id);
|
||||
}
|
||||
|
||||
private handleAgentMessageDelta(params: JsonObject | undefined): void {
|
||||
const itemId = readString(params?.itemId);
|
||||
const delta = readString(params?.delta);
|
||||
if (!itemId || delta === undefined) return;
|
||||
const index = this.startText(itemId);
|
||||
const block = this.output.content[index];
|
||||
if (block?.type !== "text") return;
|
||||
block.text += delta;
|
||||
this.stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta,
|
||||
partial: this.output,
|
||||
});
|
||||
}
|
||||
|
||||
private handleReasoningDelta(params: JsonObject | undefined): void {
|
||||
const itemId = readString(params?.itemId);
|
||||
const delta = readString(params?.delta);
|
||||
if (!itemId || delta === undefined) return;
|
||||
const index = this.startThinking(itemId);
|
||||
const block = this.output.content[index];
|
||||
if (block?.type !== "thinking") return;
|
||||
block.thinking += delta;
|
||||
this.stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta,
|
||||
partial: this.output,
|
||||
});
|
||||
}
|
||||
|
||||
private handleItemCompleted(params: JsonObject | undefined): void {
|
||||
const item = asObject(params?.item) as unknown as AppServerItem | undefined;
|
||||
if (!item?.id) return;
|
||||
if (item.type === "agentMessage") this.endText(item.id, item.text ?? "");
|
||||
else if (item.type === "reasoning")
|
||||
this.endThinking(
|
||||
item.id,
|
||||
[...(item.summary ?? []), ...(item.content ?? [])].join("\n\n"),
|
||||
);
|
||||
else if (
|
||||
item.type === "dynamicToolCall" ||
|
||||
item.type === "mcpToolCall" ||
|
||||
item.type === "webSearch"
|
||||
) {
|
||||
this.emitToolCall(item);
|
||||
}
|
||||
}
|
||||
|
||||
private handleRawResponseItemCompleted(params: JsonObject | undefined): void {
|
||||
const item = asObject(params?.item);
|
||||
if (readString(item?.type) !== "function_call") return;
|
||||
const callId = readString(item?.call_id);
|
||||
const name = readString(item?.name);
|
||||
if (!callId || !name) return;
|
||||
this.emitToolCall({
|
||||
type: "function_call",
|
||||
id: callId,
|
||||
tool: name,
|
||||
arguments: readString(item?.arguments) ?? "{}",
|
||||
});
|
||||
}
|
||||
|
||||
private handleUsage(params: JsonObject | undefined): void {
|
||||
const tokenUsage = asObject(params?.tokenUsage);
|
||||
const last = asObject(tokenUsage?.last) as unknown as
|
||||
| TokenUsageBreakdown
|
||||
| undefined;
|
||||
if (!last) return;
|
||||
this.usage = {
|
||||
input: Math.max(0, last.inputTokens - last.cachedInputTokens),
|
||||
output: last.outputTokens,
|
||||
cacheRead: last.cachedInputTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: last.totalTokens,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
private handleTurnCompleted(params: JsonObject | undefined): void {
|
||||
const turn = asObject(params?.turn);
|
||||
const status = readString(turn?.status);
|
||||
if (this.usage) this.output.usage = this.usage;
|
||||
this.output.stopReason =
|
||||
status === "interrupted"
|
||||
? "aborted"
|
||||
: status === "failed"
|
||||
? "error"
|
||||
: this.output.stopReason;
|
||||
this.dispose();
|
||||
if (this.output.stopReason === "aborted") {
|
||||
this.reject(new Error("Request was aborted"));
|
||||
return;
|
||||
}
|
||||
if (this.output.stopReason === "error") {
|
||||
this.reject(new Error(readErrorMessage(turn) ?? "Codex turn failed"));
|
||||
return;
|
||||
}
|
||||
this.resolve();
|
||||
}
|
||||
|
||||
private startText(itemId: string): number {
|
||||
const existing = this.blocksByItemId.get(itemId);
|
||||
if (existing !== undefined) return existing;
|
||||
const index = this.output.content.push({ type: "text", text: "" }) - 1;
|
||||
this.blocksByItemId.set(itemId, index);
|
||||
this.stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: index,
|
||||
partial: this.output,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
private endText(itemId: string, text: string): void {
|
||||
const index = this.startText(itemId);
|
||||
const block = this.output.content[index];
|
||||
if (block?.type !== "text") return;
|
||||
if (text) block.text = text;
|
||||
block.textSignature = itemId;
|
||||
this.stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: this.output,
|
||||
});
|
||||
}
|
||||
|
||||
private startThinking(itemId: string): number {
|
||||
const existing = this.blocksByItemId.get(itemId);
|
||||
if (existing !== undefined) return existing;
|
||||
const index =
|
||||
this.output.content.push({ type: "thinking", thinking: "" }) - 1;
|
||||
this.blocksByItemId.set(itemId, index);
|
||||
this.stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: index,
|
||||
partial: this.output,
|
||||
});
|
||||
return index;
|
||||
}
|
||||
|
||||
private endThinking(itemId: string, thinking: string): void {
|
||||
const index = this.startThinking(itemId);
|
||||
const block = this.output.content[index];
|
||||
if (block?.type !== "thinking") return;
|
||||
if (thinking) block.thinking = thinking;
|
||||
block.thinkingSignature = itemId;
|
||||
this.stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: this.output,
|
||||
});
|
||||
}
|
||||
|
||||
private emitToolCall(item: AppServerItem): void {
|
||||
const name =
|
||||
item.type === "webSearch"
|
||||
? "web_search"
|
||||
: [item.namespace, item.server, item.tool].filter(Boolean).join(".");
|
||||
if (!name) return;
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: item.id ?? `${name}-${Date.now()}`,
|
||||
name,
|
||||
arguments: normalizeArguments(
|
||||
item.type === "webSearch"
|
||||
? { query: item.query ?? "" }
|
||||
: item.arguments,
|
||||
),
|
||||
};
|
||||
const index = this.output.content.push(toolCall) - 1;
|
||||
this.stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: index,
|
||||
partial: this.output,
|
||||
});
|
||||
this.stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall,
|
||||
partial: this.output,
|
||||
});
|
||||
this.output.stopReason = "toolUse";
|
||||
}
|
||||
|
||||
private dispose(): void {
|
||||
this.onDispose?.();
|
||||
this.onDispose = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function createAssistantMessage(
|
||||
model: Model<"openai-codex-responses">,
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "openai-codex-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function buildThreadStartParams(
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): JsonObject {
|
||||
const params: JsonObject = {
|
||||
model: model.id,
|
||||
cwd: process.cwd(),
|
||||
baseInstructions: context.systemPrompt ?? null,
|
||||
approvalPolicy: "never",
|
||||
sandbox: "workspace-write",
|
||||
experimentalRawEvents: true,
|
||||
persistExtendedHistory: true,
|
||||
};
|
||||
const config = buildConfig(model, options);
|
||||
if (Object.keys(config).length > 0) params.config = config;
|
||||
return params;
|
||||
}
|
||||
|
||||
function buildTurnStartParams(
|
||||
threadId: string,
|
||||
input: JsonObject[],
|
||||
model: Model<"openai-codex-responses">,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): JsonObject {
|
||||
return {
|
||||
threadId,
|
||||
input,
|
||||
cwd: process.cwd(),
|
||||
model: model.id,
|
||||
effort: options?.reasoningEffort
|
||||
? clampReasoningEffort(model.id, options.reasoningEffort)
|
||||
: null,
|
||||
summary: normalizeReasoningSummary(options?.reasoningSummary),
|
||||
};
|
||||
}
|
||||
|
||||
function buildProcessConfig(
|
||||
model: Model<"openai-codex-responses">,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): string[] {
|
||||
const config = buildConfig(model, options);
|
||||
return Object.entries(config).flatMap(([key, value]) => [
|
||||
"-c",
|
||||
`${key}=${JSON.stringify(value)}`,
|
||||
]);
|
||||
}
|
||||
|
||||
function buildConfig(
|
||||
model: Model<"openai-codex-responses">,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): JsonObject {
|
||||
const config: JsonObject = { model: model.id };
|
||||
if (options?.reasoningEffort)
|
||||
config.model_reasoning_effort = clampReasoningEffort(
|
||||
model.id,
|
||||
options.reasoningEffort,
|
||||
);
|
||||
if (typeof options?.network_access === "boolean")
|
||||
config.network_access = options.network_access;
|
||||
if (typeof options?.web_search === "boolean")
|
||||
config.web_search = options.web_search;
|
||||
return config;
|
||||
}
|
||||
|
||||
async function injectPriorContext(
|
||||
client: Awaited<ReturnType<typeof getCodexAppServerClient>>,
|
||||
threadId: string,
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
const lastUserIndex = findLastUserMessageIndex(context);
|
||||
if (lastUserIndex <= 0) return;
|
||||
const priorContext = {
|
||||
...context,
|
||||
messages: context.messages.slice(0, lastUserIndex),
|
||||
};
|
||||
const items = convertResponsesMessages(
|
||||
model,
|
||||
priorContext,
|
||||
CODEX_TOOL_CALL_PROVIDERS,
|
||||
{ includeSystemPrompt: false },
|
||||
);
|
||||
if (items.length === 0) return;
|
||||
await client.request("thread/inject_items", { threadId, items }, signal);
|
||||
}
|
||||
|
||||
function buildTurnInput(
|
||||
context: Context,
|
||||
model: Model<"openai-codex-responses">,
|
||||
): JsonObject[] {
|
||||
const lastUserIndex = findLastUserMessageIndex(context);
|
||||
const message =
|
||||
lastUserIndex >= 0 ? context.messages[lastUserIndex] : undefined;
|
||||
if (!message || message.role !== "user") {
|
||||
return [{ type: "text", text: "", text_elements: [] }];
|
||||
}
|
||||
if (typeof message.content === "string") {
|
||||
return [{ type: "text", text: message.content, text_elements: [] }];
|
||||
}
|
||||
const input: JsonObject[] = [];
|
||||
for (const block of message.content) {
|
||||
if (block.type === "text") {
|
||||
input.push({ type: "text", text: block.text, text_elements: [] });
|
||||
} else if (model.input.includes("image")) {
|
||||
input.push(imageBlockToUserInput(block));
|
||||
}
|
||||
}
|
||||
return input.length > 0
|
||||
? input
|
||||
: [{ type: "text", text: "", text_elements: [] }];
|
||||
}
|
||||
|
||||
function imageBlockToUserInput(block: ImageContent): JsonObject {
|
||||
return { type: "image", url: `data:${block.mimeType};base64,${block.data}` };
|
||||
}
|
||||
|
||||
function findLastUserMessageIndex(context: Context): number {
|
||||
for (let i = context.messages.length - 1; i >= 0; i--) {
|
||||
if (context.messages[i]?.role === "user") return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
function clampReasoningEffort(modelId: string, effort: string): string {
|
||||
const id = modelId.includes("/") ? modelId.split("/").pop()! : modelId;
|
||||
if (
|
||||
(id.startsWith("gpt-5.2") ||
|
||||
id.startsWith("gpt-5.3") ||
|
||||
id.startsWith("gpt-5.4")) &&
|
||||
effort === "minimal"
|
||||
)
|
||||
return "low";
|
||||
if (id === "gpt-5.1" && effort === "xhigh") return "high";
|
||||
if (id === "gpt-5.1-codex-mini")
|
||||
return effort === "high" || effort === "xhigh" ? "high" : "medium";
|
||||
return effort;
|
||||
}
|
||||
|
||||
function normalizeReasoningSummary(
|
||||
value: OpenAICodexResponsesOptions["reasoningSummary"],
|
||||
): AppServerReasoningSummary | null {
|
||||
if (value === "off") return "none";
|
||||
if (value === "on") return "auto";
|
||||
return value ?? null;
|
||||
}
|
||||
|
||||
function readThreadId(value: unknown): string {
|
||||
const response = value as ThreadStartResponse;
|
||||
if (!response.thread?.id)
|
||||
throw new Error(
|
||||
"Codex app-server thread/start response did not include thread.id",
|
||||
);
|
||||
return response.thread.id;
|
||||
}
|
||||
|
||||
function readTurnId(value: unknown): string {
|
||||
const response = value as TurnStartResponse;
|
||||
if (!response.turn?.id)
|
||||
throw new Error(
|
||||
"Codex app-server turn/start response did not include turn.id",
|
||||
);
|
||||
return response.turn.id;
|
||||
}
|
||||
|
||||
function normalizeArguments(value: unknown): Record<string, unknown> {
|
||||
if (typeof value === "string") return parseStreamingJson(value);
|
||||
if (value && typeof value === "object" && !Array.isArray(value))
|
||||
return value as Record<string, unknown>;
|
||||
return {};
|
||||
}
|
||||
|
||||
function asObject(value: unknown): JsonObject | undefined {
|
||||
return value && typeof value === "object" && !Array.isArray(value)
|
||||
? (value as JsonObject)
|
||||
: undefined;
|
||||
}
|
||||
|
||||
function readString(value: unknown): string | undefined {
|
||||
return typeof value === "string" ? value : undefined;
|
||||
}
|
||||
|
||||
function readErrorMessage(value: unknown): string | undefined {
|
||||
const object = asObject(value);
|
||||
const _error = asObject(object?.error);
|
||||
return readNestedCodexErrorMessage(object) ?? readString(object?.message);
|
||||
}
|
||||
|
||||
function readNestedCodexErrorMessage(
|
||||
event: JsonObject | undefined,
|
||||
): string | undefined {
|
||||
const errorObj = asObject(event?.error);
|
||||
const message = readString(errorObj?.message);
|
||||
const type = readString(errorObj?.type);
|
||||
if (message && type) return `${type}: ${message}`;
|
||||
return message ?? type;
|
||||
}
|
||||
|
|
@ -1,960 +0,0 @@
|
|||
// Lazy-loaded: OpenAI SDK is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use OpenAI models.
|
||||
import type OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionReasoningEffort,
|
||||
ChatCompletionToolMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import type { FunctionParameters } from "openai/resources/shared.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost, supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletionsCompat,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
assertStreamSuccess,
|
||||
buildInitialOutput,
|
||||
createOpenAIClient,
|
||||
finalizeStream,
|
||||
handleStreamError,
|
||||
} from "./openai-shared.js";
|
||||
import { sanitizeToolCallArgumentsForSerialization } from "./sanitize-tool-arguments.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Check if conversation messages contain tool calls or tool results.
|
||||
* This is needed because Anthropic (via proxy) requires the tools param
|
||||
* to be present when messages include tool_calls or tool role messages.
|
||||
*/
|
||||
function hasToolHistory(messages: Message[]): boolean {
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "toolResult") {
|
||||
return true;
|
||||
}
|
||||
if (msg.role === "assistant") {
|
||||
if (msg.content.some((block) => block.type === "toolCall")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export interface OpenAICompletionsOptions extends StreamOptions {
|
||||
toolChoice?:
|
||||
| "auto"
|
||||
| "none"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
}
|
||||
|
||||
export const streamOpenAICompletions: StreamFunction<
|
||||
"openai-completions",
|
||||
OpenAICompletionsOptions
|
||||
> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = buildInitialOutput(model);
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const isZai =
|
||||
model.provider === "zai" || model.baseUrl.includes("api.z.ai");
|
||||
const client = await createOpenAIClient(model, context, apiKey, {
|
||||
optionsHeaders: options?.headers,
|
||||
extraClientOptions: isZai
|
||||
? { timeout: 100_000, maxRetries: 4 }
|
||||
: undefined,
|
||||
});
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params =
|
||||
nextParams as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.chat.completions.create(params, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock:
|
||||
| TextContent
|
||||
| ThinkingContent
|
||||
| (ToolCall & { partialArgs?: string })
|
||||
| null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (block) {
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialArgs);
|
||||
delete block.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for await (const chunk of openaiStream) {
|
||||
if (!chunk || typeof chunk !== "object") continue;
|
||||
|
||||
if (chunk.usage) {
|
||||
const cachedTokens =
|
||||
chunk.usage.prompt_tokens_details?.cached_tokens || 0;
|
||||
const reasoningTokens =
|
||||
chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
|
||||
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
|
||||
const outputTokens =
|
||||
(chunk.usage.completion_tokens || 0) + reasoningTokens;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
// Compute totalTokens ourselves since we add reasoning_tokens to output
|
||||
// and some providers (e.g., Groq) don't include them in total_tokens
|
||||
totalTokens: input + outputTokens + cachedTokens,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = Array.isArray(chunk.choices)
|
||||
? chunk.choices[0]
|
||||
: undefined;
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp),
|
||||
// or reasoning (other openai compatible endpoints)
|
||||
// Use the first non-empty reasoning field to avoid duplication
|
||||
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
|
||||
// SDK-divergence: reasoning_content / reasoning / reasoning_text are vendor extensions
|
||||
// not present on ChatCompletionChunk.Choice.Delta in the official OpenAI SDK.
|
||||
const deltaExt = choice.delta as unknown as Record<string, unknown>;
|
||||
const reasoningFields = [
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"reasoning_text",
|
||||
];
|
||||
let foundReasoningField: string | null = null;
|
||||
for (const field of reasoningFields) {
|
||||
if (
|
||||
deltaExt[field] !== null &&
|
||||
deltaExt[field] !== undefined &&
|
||||
(deltaExt[field] as string).length > 0
|
||||
) {
|
||||
if (!foundReasoningField) {
|
||||
foundReasoningField = field;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (foundReasoningField) {
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: foundReasoningField,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = deltaExt[foundReasoningField] as string;
|
||||
currentBlock.thinking += delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
name: toolCall.function?.name || "",
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name)
|
||||
currentBlock.name = toolCall.function.name;
|
||||
let delta = "";
|
||||
if (toolCall.function?.arguments) {
|
||||
delta = toolCall.function.arguments;
|
||||
currentBlock.partialArgs += toolCall.function.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(
|
||||
currentBlock.partialArgs,
|
||||
);
|
||||
}
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SDK-divergence: reasoning_details is a vendor extension (OpenAI Responses API via
|
||||
// completions-compat path) not present on ChatCompletionChunk.Choice.Delta.
|
||||
const reasoningDetails = deltaExt.reasoning_details;
|
||||
if (reasoningDetails && Array.isArray(reasoningDetails)) {
|
||||
for (const detail of reasoningDetails) {
|
||||
if (
|
||||
detail.type === "reasoning.encrypted" &&
|
||||
detail.id &&
|
||||
detail.data
|
||||
) {
|
||||
const matchingToolCall = output.content.find(
|
||||
(b) => b.type === "toolCall" && b.id === detail.id,
|
||||
) as ToolCall | undefined;
|
||||
if (matchingToolCall) {
|
||||
matchingToolCall.thoughtSignature = JSON.stringify(detail);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
assertStreamSuccess(output, options?.signal);
|
||||
finalizeStream(stream, output);
|
||||
} catch (error) {
|
||||
// Some providers via OpenRouter give additional information in this field.
|
||||
// SDK-divergence: APIError.error is typed as Object | undefined; the nested
|
||||
// metadata.raw field is an OpenRouter-specific extension not in the SDK type.
|
||||
const rawMetadata = (
|
||||
error as unknown as { error?: { metadata?: { raw?: string } } }
|
||||
)?.error?.metadata?.raw;
|
||||
handleStreamError(stream, output, error, options?.signal, rawMetadata);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICompletions: StreamFunction<
|
||||
"openai-completions",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const effectiveReasoning = resolveReasoningLevel(model, options?.reasoning);
|
||||
const reasoningEffort = supportsXhigh(model)
|
||||
? effectiveReasoning
|
||||
: clampReasoning(effectiveReasoning);
|
||||
const toolChoice = (options as OpenAICompletionsOptions | undefined)
|
||||
?.toolChoice;
|
||||
|
||||
return streamOpenAICompletions(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
toolChoice,
|
||||
} satisfies OpenAICompletionsOptions);
|
||||
};
|
||||
|
||||
function buildParams(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
) {
|
||||
const compat = getCompat(model);
|
||||
const messages = convertMessages(model, context, compat);
|
||||
maybeAddOpenRouterAnthropicCacheControl(model, messages);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (compat.supportsUsageInStreaming !== false) {
|
||||
params.stream_options = { include_usage: true };
|
||||
}
|
||||
|
||||
if (compat.supportsStore) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
if (compat.maxTokensField === "max_tokens") {
|
||||
// max_tokens is a deprecated but valid field on ChatCompletionCreateParamsBase,
|
||||
// kept for providers (e.g. chutes.ai) that reject max_completion_tokens.
|
||||
params.max_tokens = options.maxTokens;
|
||||
} else {
|
||||
params.max_completion_tokens = options.maxTokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
params.tools = convertTools(context.tools, compat);
|
||||
maybeAddOpenRouterAnthropicToolCacheControl(model, params.tools);
|
||||
} else if (hasToolHistory(context.messages)) {
|
||||
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
|
||||
params.tools = [];
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
// For vendor-specific fields not in ChatCompletionCreateParamsBase, use a typed extension object.
|
||||
const paramsExt = params as unknown as Record<string, unknown>;
|
||||
|
||||
if (
|
||||
(compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") &&
|
||||
model.reasoning
|
||||
) {
|
||||
// SDK-divergence: enable_thinking is a Z.ai / Qwen vendor extension not in the OpenAI SDK type.
|
||||
paramsExt.enable_thinking = !!options?.reasoningEffort;
|
||||
} else if (
|
||||
options?.reasoningEffort &&
|
||||
model.reasoning &&
|
||||
compat.supportsReasoningEffort
|
||||
) {
|
||||
// reasoning_effort is in ChatCompletionCreateParamsBase, but mapReasoningEffort returns a
|
||||
// plain string (from a provider-specific map) which may not match the SDK's ReasoningEffort
|
||||
// literal union — cast to the SDK type to satisfy the checker.
|
||||
params.reasoning_effort = mapReasoningEffort(
|
||||
options.reasoningEffort,
|
||||
compat.reasoningEffortMap,
|
||||
) as ChatCompletionReasoningEffort;
|
||||
}
|
||||
|
||||
// OpenRouter provider routing preferences
|
||||
if (
|
||||
model.baseUrl.includes("openrouter.ai") &&
|
||||
model.compat?.openRouterRouting
|
||||
) {
|
||||
// SDK-divergence: provider routing is an OpenRouter vendor extension not in the OpenAI SDK type.
|
||||
paramsExt.provider = model.compat.openRouterRouting;
|
||||
}
|
||||
|
||||
// Vercel AI Gateway provider routing preferences
|
||||
if (
|
||||
model.baseUrl.includes("ai-gateway.vercel.sh") &&
|
||||
model.compat?.vercelGatewayRouting
|
||||
) {
|
||||
const routing = model.compat.vercelGatewayRouting;
|
||||
if (routing.only || routing.order) {
|
||||
const gatewayOptions: Record<string, string[]> = {};
|
||||
if (routing.only) gatewayOptions.only = routing.only;
|
||||
if (routing.order) gatewayOptions.order = routing.order;
|
||||
// SDK-divergence: providerOptions is a Vercel AI Gateway vendor extension not in the OpenAI SDK type.
|
||||
paramsExt.providerOptions = { gateway: gatewayOptions };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function maybeAddOpenRouterAnthropicToolCacheControl(
|
||||
model: Model<"openai-completions">,
|
||||
tools: OpenAI.Chat.Completions.ChatCompletionTool[] | undefined,
|
||||
): void {
|
||||
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/"))
|
||||
return;
|
||||
if (!tools?.length) return;
|
||||
|
||||
const lastTool = tools[tools.length - 1];
|
||||
if ("function" in lastTool) {
|
||||
Object.assign(lastTool.function, { cache_control: { type: "ephemeral" } });
|
||||
}
|
||||
}
|
||||
|
||||
function mapReasoningEffort(
|
||||
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
|
||||
reasoningEffortMap: Partial<
|
||||
Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>
|
||||
>,
|
||||
): string {
|
||||
return reasoningEffortMap[effort] ?? effort;
|
||||
}
|
||||
|
||||
function maybeAddOpenRouterAnthropicCacheControl(
|
||||
model: Model<"openai-completions">,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/"))
|
||||
return;
|
||||
|
||||
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
|
||||
// on the last user/assistant message (walking backwards until we find text content).
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i];
|
||||
if (msg.role !== "user" && msg.role !== "assistant") continue;
|
||||
|
||||
const content = msg.content;
|
||||
if (typeof content === "string") {
|
||||
msg.content = [
|
||||
Object.assign(
|
||||
{ type: "text" as const, text: content },
|
||||
{ cache_control: { type: "ephemeral" } },
|
||||
),
|
||||
];
|
||||
return;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) continue;
|
||||
|
||||
// Find last text part and add cache_control
|
||||
for (let j = content.length - 1; j >= 0; j--) {
|
||||
const part = content[j];
|
||||
if (part?.type === "text") {
|
||||
Object.assign(part, { cache_control: { type: "ephemeral" } });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertMessages(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
// Handle pipe-separated IDs from OpenAI Responses API
|
||||
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
|
||||
// These come from providers like github-copilot, openai-codex, opencode
|
||||
// Extract just the call_id part and normalize it
|
||||
if (id.includes("|")) {
|
||||
const [callId] = id.split("|");
|
||||
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
|
||||
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
|
||||
}
|
||||
|
||||
if (model.provider === "openai")
|
||||
return id.length > 40 ? id.slice(0, 40) : id;
|
||||
return id;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
context.messages,
|
||||
model,
|
||||
(id) => normalizeToolCallId(id),
|
||||
"openai-completions",
|
||||
);
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({
|
||||
role: role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let lastRole: string | null = null;
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
// Some providers don't allow user messages directly after tool results
|
||||
// Insert a synthetic assistant message to bridge the gap
|
||||
if (
|
||||
compat.requiresAssistantAfterToolResult &&
|
||||
lastRole === "toolResult" &&
|
||||
msg.role === "user"
|
||||
) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map(
|
||||
(item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
},
|
||||
);
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Some providers don't accept null content, use empty string instead
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: compat.requiresAssistantAfterToolResult ? "" : null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter(
|
||||
(b) => b.type === "text",
|
||||
) as TextContent[];
|
||||
// Filter out empty text blocks to avoid API validation errors
|
||||
const nonEmptyTextBlocks = textBlocks.filter(
|
||||
(b) => b.text && b.text.trim().length > 0,
|
||||
);
|
||||
if (nonEmptyTextBlocks.length > 0) {
|
||||
// GitHub Copilot requires assistant content as a string, not an array.
|
||||
// Sending as array causes Claude models to re-answer all previous prompts.
|
||||
if (model.provider === "github-copilot") {
|
||||
assistantMsg.content = nonEmptyTextBlocks
|
||||
.map((b) => sanitizeSurrogates(b.text))
|
||||
.join("");
|
||||
} else {
|
||||
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
|
||||
return { type: "text", text: sanitizeSurrogates(b.text) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle thinking blocks
|
||||
const thinkingBlocks = msg.content.filter(
|
||||
(b) => b.type === "thinking",
|
||||
) as ThinkingContent[];
|
||||
// Filter out empty thinking blocks to avoid API validation errors
|
||||
const nonEmptyThinkingBlocks = thinkingBlocks.filter(
|
||||
(b) => b.thinking && b.thinking.trim().length > 0,
|
||||
);
|
||||
if (nonEmptyThinkingBlocks.length > 0) {
|
||||
if (compat.requiresThinkingAsText) {
|
||||
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
|
||||
const thinkingText = nonEmptyThinkingBlocks
|
||||
.map((b) => b.thinking)
|
||||
.join("\n\n");
|
||||
const textContent = assistantMsg.content as Array<{
|
||||
type: "text";
|
||||
text: string;
|
||||
}> | null;
|
||||
if (textContent) {
|
||||
textContent.unshift({ type: "text", text: thinkingText });
|
||||
} else {
|
||||
assistantMsg.content = [{ type: "text", text: thinkingText }];
|
||||
}
|
||||
} else {
|
||||
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
|
||||
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
// SDK-divergence: llama.cpp / gpt-oss return a dynamic per-field name for the
|
||||
// reasoning content (e.g. "reasoning_content"). The field is not in
|
||||
// ChatCompletionAssistantMessageParam, so we use a typed extension object.
|
||||
(assistantMsg as unknown as Record<string, unknown>)[signature] =
|
||||
nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter(
|
||||
(b) => b.type === "toolCall",
|
||||
) as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(
|
||||
sanitizeToolCallArgumentsForSerialization(tc.arguments),
|
||||
),
|
||||
},
|
||||
}));
|
||||
const reasoningDetails = toolCalls
|
||||
.filter((tc) => tc.thoughtSignature)
|
||||
.map((tc) => {
|
||||
try {
|
||||
return JSON.parse(tc.thoughtSignature!);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
if (reasoningDetails.length > 0) {
|
||||
// SDK-divergence: reasoning_details is a vendor extension (gpt-oss / OpenAI Responses
|
||||
// compat path) not present on ChatCompletionAssistantMessageParam.
|
||||
(
|
||||
assistantMsg as unknown as Record<string, unknown>
|
||||
).reasoning_details = reasoningDetails;
|
||||
}
|
||||
}
|
||||
// Skip assistant messages that have no content and no tool calls.
|
||||
// Some providers require "either content or tool_calls, but not none".
|
||||
// Other providers also don't accept empty assistant messages.
|
||||
// This handles aborted assistant responses that got no content.
|
||||
const content = assistantMsg.content;
|
||||
const hasContent =
|
||||
content !== null &&
|
||||
content !== undefined &&
|
||||
(typeof content === "string" ? content.length > 0 : content.length > 0);
|
||||
if (!hasContent && !assistantMsg.tool_calls) {
|
||||
continue;
|
||||
}
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const imageBlocks: Array<{
|
||||
type: "image_url";
|
||||
image_url: { url: string };
|
||||
}> = [];
|
||||
let j = i;
|
||||
|
||||
for (
|
||||
;
|
||||
j < transformedMessages.length &&
|
||||
transformedMessages[j].role === "toolResult";
|
||||
j++
|
||||
) {
|
||||
const toolMsg = transformedMessages[j] as ToolResultMessage;
|
||||
|
||||
// Extract text and image content
|
||||
const textResult = toolMsg.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as TextContent).text)
|
||||
.join("\n");
|
||||
const hasImages = toolMsg.content.some((c) => c.type === "image");
|
||||
|
||||
// Always send tool result with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
// Some providers require the 'name' field in tool results
|
||||
const toolResultMsg: ChatCompletionToolMessageParam = {
|
||||
role: "tool",
|
||||
content: sanitizeSurrogates(
|
||||
hasText ? textResult : "(see attached image)",
|
||||
),
|
||||
tool_call_id: toolMsg.toolCallId,
|
||||
};
|
||||
if (compat.requiresToolResultName && toolMsg.toolName) {
|
||||
// SDK-divergence: the `name` field on tool results is required by some providers
|
||||
// (e.g., Mistral) but is not part of the ChatCompletionToolMessageParam type.
|
||||
(toolResultMsg as unknown as Record<string, unknown>).name =
|
||||
toolMsg.toolName;
|
||||
}
|
||||
params.push(toolResultMsg);
|
||||
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
for (const block of toolMsg.content) {
|
||||
if (block.type === "image") {
|
||||
const imageBlock = block as ImageContent;
|
||||
imageBlocks.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${imageBlock.mimeType};base64,${imageBlock.data}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
|
||||
if (imageBlocks.length > 0) {
|
||||
if (compat.requiresAssistantAfterToolResult) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
},
|
||||
...imageBlocks,
|
||||
],
|
||||
});
|
||||
lastRole = "user";
|
||||
} else {
|
||||
lastRole = "toolResult";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
lastRole = msg.role;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: Tool[],
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as unknown as FunctionParameters, // TypeBox TSchema is a valid JSON Schema object
|
||||
// Only include strict if provider supports it. Some reject unknown fields.
|
||||
...(compat.supportsStrictMode !== false && { strict: false }),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(
|
||||
reason: ChatCompletionChunk.Choice["finish_reason"],
|
||||
): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "error";
|
||||
default:
|
||||
// Third-party and community models (e.g. Qwen GGUF quants) may emit
|
||||
// non-standard finish_reason values like "eos_token", "eos", or
|
||||
// "end_of_turn". The OpenAI spec defines finish_reason as a string,
|
||||
// so we treat unrecognized values as a normal stop rather than
|
||||
// throwing — which would abort in-flight tool calls (#863).
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect compatibility settings from provider and baseUrl for known providers.
|
||||
* Provider takes precedence over URL-based detection since it's explicitly configured.
|
||||
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
|
||||
*/
|
||||
function detectCompat(
|
||||
model: Model<"openai-completions">,
|
||||
): Required<OpenAICompletionsCompat> {
|
||||
const provider = model.provider;
|
||||
const baseUrl = model.baseUrl;
|
||||
|
||||
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
|
||||
|
||||
const isNonStandard =
|
||||
provider === "cerebras" ||
|
||||
baseUrl.includes("cerebras.ai") ||
|
||||
provider === "xai" ||
|
||||
baseUrl.includes("api.x.ai") ||
|
||||
baseUrl.includes("chutes.ai") ||
|
||||
baseUrl.includes("deepseek.com") ||
|
||||
isZai ||
|
||||
provider === "opencode" ||
|
||||
baseUrl.includes("opencode.ai");
|
||||
|
||||
const useMaxTokens = baseUrl.includes("chutes.ai");
|
||||
|
||||
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
|
||||
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
|
||||
|
||||
const reasoningEffortMap =
|
||||
isGroq && model.id === "qwen/qwen3-32b"
|
||||
? {
|
||||
minimal: "default",
|
||||
low: "default",
|
||||
medium: "default",
|
||||
high: "default",
|
||||
xhigh: "default",
|
||||
}
|
||||
: {};
|
||||
return {
|
||||
supportsStore: !isNonStandard,
|
||||
supportsDeveloperRole: !isNonStandard,
|
||||
supportsReasoningEffort: !isGrok && !isZai,
|
||||
reasoningEffortMap,
|
||||
supportsUsageInStreaming: true,
|
||||
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
|
||||
requiresToolResultName: false,
|
||||
requiresAssistantAfterToolResult: false,
|
||||
requiresThinkingAsText: false,
|
||||
thinkingFormat: isZai ? "zai" : "openai",
|
||||
openRouterRouting: {},
|
||||
vercelGatewayRouting: {},
|
||||
supportsStrictMode: true,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get resolved compatibility settings for a model.
|
||||
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
|
||||
*/
|
||||
function getCompat(
|
||||
model: Model<"openai-completions">,
|
||||
): Required<OpenAICompletionsCompat> {
|
||||
const detected = detectCompat(model);
|
||||
if (!model.compat) return detected;
|
||||
|
||||
return {
|
||||
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
|
||||
supportsDeveloperRole:
|
||||
model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
|
||||
supportsReasoningEffort:
|
||||
model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
|
||||
reasoningEffortMap:
|
||||
model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
|
||||
supportsUsageInStreaming:
|
||||
model.compat.supportsUsageInStreaming ??
|
||||
detected.supportsUsageInStreaming,
|
||||
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
|
||||
requiresToolResultName:
|
||||
model.compat.requiresToolResultName ?? detected.requiresToolResultName,
|
||||
requiresAssistantAfterToolResult:
|
||||
model.compat.requiresAssistantAfterToolResult ??
|
||||
detected.requiresAssistantAfterToolResult,
|
||||
requiresThinkingAsText:
|
||||
model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
|
||||
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
|
||||
openRouterRouting: model.compat.openRouterRouting ?? {},
|
||||
vercelGatewayRouting:
|
||||
model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
|
||||
supportsStrictMode:
|
||||
model.compat.supportsStrictMode ?? detected.supportsStrictMode,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,586 +0,0 @@
|
|||
import type OpenAI from "openai";
|
||||
import type {
|
||||
Tool as OpenAITool,
|
||||
ResponseCreateParamsStreaming,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInput,
|
||||
ResponseInputContent,
|
||||
ResponseInputImage,
|
||||
ResponseInputText,
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
TextSignatureV1,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { sanitizeToolCallArgumentsForSerialization } from "./sanitize-tool-arguments.js";
|
||||
import { transformMessagesWithReport } from "./transform-messages.js";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
function encodeTextSignatureV1(
|
||||
id: string,
|
||||
phase?: TextSignatureV1["phase"],
|
||||
): string {
|
||||
const payload: TextSignatureV1 = { v: 1, id };
|
||||
if (phase) payload.phase = phase;
|
||||
return JSON.stringify(payload);
|
||||
}
|
||||
|
||||
function parseTextSignature(
|
||||
signature: string | undefined,
|
||||
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
|
||||
if (!signature) return undefined;
|
||||
if (signature.startsWith("{")) {
|
||||
try {
|
||||
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
|
||||
if (parsed.v === 1 && typeof parsed.id === "string") {
|
||||
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
|
||||
return { id: parsed.id, phase: parsed.phase };
|
||||
}
|
||||
return { id: parsed.id };
|
||||
}
|
||||
} catch {
|
||||
// Fall through to legacy plain-string handling.
|
||||
}
|
||||
}
|
||||
return { id: signature };
|
||||
}
|
||||
|
||||
export interface OpenAIResponsesStreamOptions {
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
applyServiceTierPricing?: (
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) => void;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesMessagesOptions {
|
||||
includeSystemPrompt?: boolean;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesToolsOptions {
|
||||
strict?: boolean | null;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Message conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesMessages<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
allowedToolCallProviders: ReadonlySet<string>,
|
||||
options?: ConvertResponsesMessagesOptions,
|
||||
): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!allowedToolCallProviders.has(model.provider)) return id;
|
||||
if (!id.includes("|")) return id;
|
||||
const [callId, itemId] = id.split("|");
|
||||
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
// OpenAI Responses API requires item id to start with "fc"
|
||||
if (!sanitizedItemId.startsWith("fc")) {
|
||||
sanitizedItemId = `fc_${sanitizedItemId}`;
|
||||
}
|
||||
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
|
||||
let normalizedCallId =
|
||||
sanitizedCallId.length > 64
|
||||
? sanitizedCallId.slice(0, 64)
|
||||
: sanitizedCallId;
|
||||
let normalizedItemId =
|
||||
sanitizedItemId.length > 64
|
||||
? sanitizedItemId.slice(0, 64)
|
||||
: sanitizedItemId;
|
||||
normalizedCallId = normalizedCallId.replace(/_+$/, "");
|
||||
normalizedItemId = normalizedItemId.replace(/_+$/, "");
|
||||
return `${normalizedCallId}|${normalizedItemId}`;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessagesWithReport(
|
||||
context.messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
"openai-responses",
|
||||
);
|
||||
|
||||
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
|
||||
if (includeSystemPrompt && context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let msgIndex = 0;
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "input_text", text: sanitizeSurrogates(msg.content) },
|
||||
],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map(
|
||||
(item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ResponseInputText;
|
||||
}
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
},
|
||||
);
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isDifferentModel =
|
||||
assistantMsg.model !== model.id &&
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(
|
||||
block.thinkingSignature,
|
||||
) as ResponseReasoningItem;
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
const parsedSignature = parseTextSignature(textBlock.textSignature);
|
||||
// OpenAI requires id to be max 64 characters
|
||||
let msgId = parsedSignature?.id;
|
||||
if (!msgId) {
|
||||
msgId = `msg_${msgIndex}`;
|
||||
} else if (msgId.length > 64) {
|
||||
msgId = `msg_${shortHash(msgId)}`;
|
||||
}
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "output_text",
|
||||
text: sanitizeSurrogates(textBlock.text),
|
||||
annotations: [],
|
||||
},
|
||||
],
|
||||
status: "completed",
|
||||
id: msgId,
|
||||
phase: parsedSignature?.phase,
|
||||
} satisfies ResponseOutputMessage);
|
||||
} else if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
const [callId, itemIdRaw] = toolCall.id.split("|");
|
||||
let itemId: string | undefined = itemIdRaw;
|
||||
|
||||
// For different-model messages, set id to undefined to avoid pairing validation.
|
||||
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
|
||||
// By omitting the id, we avoid triggering that validation (like cross-provider does).
|
||||
if (isDifferentModel && itemId?.startsWith("fc_")) {
|
||||
itemId = undefined;
|
||||
}
|
||||
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: itemId,
|
||||
call_id: callId,
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(
|
||||
sanitizeToolCallArgumentsForSerialization(toolCall.arguments),
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textResult = msg.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some(
|
||||
(c): c is ImageContent => c.type === "image",
|
||||
);
|
||||
|
||||
// Always send function_call_output with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
const [callId] = msg.toolCallId.split("|");
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output: sanitizeSurrogates(
|
||||
hasText ? textResult : "(see attached image)",
|
||||
),
|
||||
});
|
||||
|
||||
// If there are images and model supports them, send a follow-up user message with images
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
const contentParts: ResponseInputContent[] = [];
|
||||
|
||||
// Add text prefix
|
||||
contentParts.push({
|
||||
type: "input_text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
} satisfies ResponseInputText);
|
||||
|
||||
// Add images
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "image") {
|
||||
contentParts.push({
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${block.mimeType};base64,${block.data}`,
|
||||
} satisfies ResponseInputImage);
|
||||
}
|
||||
}
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: contentParts,
|
||||
});
|
||||
}
|
||||
}
|
||||
msgIndex++;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesTools(
|
||||
tools: Tool[],
|
||||
options?: ConvertResponsesToolsOptions,
|
||||
): OpenAITool[] {
|
||||
const strict = options?.strict === undefined ? false : options.strict;
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
strict,
|
||||
}));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Stream processing
|
||||
// =============================================================================
|
||||
|
||||
export async function processResponsesStream<TApi extends Api>(
|
||||
openaiStream: AsyncIterable<ResponseStreamEvent>,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<TApi>,
|
||||
options?: OpenAIResponsesStreamOptions,
|
||||
): Promise<void> {
|
||||
let currentItem:
|
||||
| ResponseReasoningItem
|
||||
| ResponseOutputMessage
|
||||
| ResponseFunctionToolCall
|
||||
| null = null;
|
||||
let currentBlock:
|
||||
| ThinkingContent
|
||||
| TextContent
|
||||
| (ToolCall & { partialJson: string })
|
||||
| null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
for await (const event of openaiStream) {
|
||||
if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else if (item.type === "message") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "text_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else if (item.type === "function_call") {
|
||||
currentItem = item;
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: {},
|
||||
partialJson: item.arguments || "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.added") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (
|
||||
currentItem?.type === "reasoning" &&
|
||||
currentBlock?.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (
|
||||
currentItem?.type === "reasoning" &&
|
||||
currentBlock?.type === "thinking"
|
||||
) {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.content_part.added") {
|
||||
if (currentItem?.type === "message") {
|
||||
currentItem.content = currentItem.content || [];
|
||||
// Filter out ReasoningText, only accept output_text and refusal
|
||||
if (
|
||||
event.part.type === "output_text" ||
|
||||
event.part.type === "refusal"
|
||||
) {
|
||||
currentItem.content.push(event.part);
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.delta") {
|
||||
if (
|
||||
currentItem?.type === "function_call" &&
|
||||
currentBlock?.type === "toolCall"
|
||||
) {
|
||||
currentBlock.partialJson += event.delta;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.done") {
|
||||
if (
|
||||
currentItem?.type === "function_call" &&
|
||||
currentBlock?.type === "toolCall"
|
||||
) {
|
||||
currentBlock.partialJson = event.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
}
|
||||
} else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentBlock.thinking =
|
||||
item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock?.type === "text") {
|
||||
currentBlock.text = item.content
|
||||
.map((c) => (c.type === "output_text" ? c.text : c.refusal))
|
||||
.join("");
|
||||
currentBlock.textSignature = encodeTextSignatureV1(
|
||||
item.id,
|
||||
item.phase ?? undefined,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const args =
|
||||
currentBlock?.type === "toolCall" && currentBlock.partialJson
|
||||
? parseStreamingJson(currentBlock.partialJson)
|
||||
: parseStreamingJson(item.arguments || "{}");
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: args,
|
||||
};
|
||||
|
||||
currentBlock = null;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.completed") {
|
||||
const response = event.response;
|
||||
if (response?.usage) {
|
||||
const cachedTokens =
|
||||
response.usage.input_tokens_details?.cached_tokens || 0;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
|
||||
input: (response.usage.input_tokens || 0) - cachedTokens,
|
||||
output: response.usage.output_tokens || 0,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: response.usage.total_tokens || 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(model, output.usage);
|
||||
if (options?.applyServiceTierPricing) {
|
||||
const serviceTier = response?.service_tier ?? options.serviceTier;
|
||||
options.applyServiceTierPricing(output.usage, serviceTier);
|
||||
}
|
||||
// Map status to stop reason
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (
|
||||
output.content.some((b) => b.type === "toolCall") &&
|
||||
output.stopReason === "stop"
|
||||
) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
} else if (event.type === "error") {
|
||||
throw new Error(
|
||||
`Error Code ${event.code}: ${event.message}` || "Unknown error",
|
||||
);
|
||||
} else if (event.type === "response.failed") {
|
||||
throw new Error("Unknown error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(
|
||||
status: OpenAI.Responses.ResponseStatus | undefined,
|
||||
): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,267 +0,0 @@
|
|||
// Lazy-loaded: OpenAI SDK is imported on first use, not at startup.
|
||||
// This avoids penalizing users who don't use OpenAI models.
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
convertResponsesMessages,
|
||||
convertResponsesTools,
|
||||
processResponsesStream,
|
||||
} from "./openai-responses-shared.js";
|
||||
import {
|
||||
assertStreamSuccess,
|
||||
buildInitialOutput,
|
||||
clampReasoningForModel,
|
||||
createOpenAIClient,
|
||||
finalizeStream,
|
||||
handleStreamError,
|
||||
} from "./openai-shared.js";
|
||||
import {
|
||||
buildBaseOptions,
|
||||
clampReasoning,
|
||||
isAutoReasoning,
|
||||
resolveReasoningLevel,
|
||||
} from "./simple-options.js";
|
||||
|
||||
const OPENAI_TOOL_CALL_PROVIDERS = new Set([
|
||||
"openai",
|
||||
"openai-codex",
|
||||
"opencode",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(
|
||||
cacheRetention?: CacheRetention,
|
||||
): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
process.env.PI_CACHE_RETENTION === "long"
|
||||
) {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on cacheRetention and base URL.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
*/
|
||||
function getPromptCacheRetention(
|
||||
baseUrl: string,
|
||||
cacheRetention: CacheRetention,
|
||||
): "24h" | undefined {
|
||||
if (cacheRetention !== "long") {
|
||||
return undefined;
|
||||
}
|
||||
if (baseUrl.includes("api.openai.com")) {
|
||||
return "24h";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends StreamOptions {
|
||||
/** "auto" means no effort constraint — model decides its own reasoning depth (GPT-5+). */
|
||||
reasoningEffort?: "auto" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: StreamFunction<
|
||||
"openai-responses",
|
||||
OpenAIResponsesOptions
|
||||
> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output = buildInitialOutput(model);
|
||||
|
||||
try {
|
||||
// Create OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = await createOpenAIClient(model, context, apiKey, {
|
||||
optionsHeaders: options?.headers,
|
||||
});
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model, {
|
||||
serviceTier: options?.serviceTier,
|
||||
applyServiceTierPricing,
|
||||
});
|
||||
|
||||
assertStreamSuccess(output, options?.signal);
|
||||
finalizeStream(stream, output);
|
||||
} catch (error) {
|
||||
handleStreamError(stream, output, error, options?.signal);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAIResponses: StreamFunction<
|
||||
"openai-responses",
|
||||
SimpleStreamOptions
|
||||
> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort: OpenAIResponsesOptions["reasoningEffort"] =
|
||||
isAutoReasoning(options?.reasoning)
|
||||
? "auto"
|
||||
: supportsXhigh(model)
|
||||
? resolveReasoningLevel(model, options?.reasoning)
|
||||
: clampReasoning(resolveReasoningLevel(model, options?.reasoning));
|
||||
|
||||
return streamOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function buildParams(
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
) {
|
||||
const messages = convertResponsesMessages(
|
||||
model,
|
||||
context,
|
||||
OPENAI_TOOL_CALL_PROVIDERS,
|
||||
);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key:
|
||||
cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(
|
||||
model.baseUrl,
|
||||
cacheRetention,
|
||||
),
|
||||
store: false,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (options?.serviceTier !== undefined) {
|
||||
params.service_tier = options.serviceTier;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
if (options?.reasoningEffort === "auto") {
|
||||
// Let the model decide its own reasoning depth — no effort constraint.
|
||||
// GPT-5+ will reason as much as it judges necessary, same as
|
||||
// THINKING_LEVEL_UNSPECIFIED for Gemini 2.5.
|
||||
params.reasoning = { summary: options?.reasoningSummary || "auto" };
|
||||
} else if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
const effort = clampReasoningForModel(
|
||||
model.name,
|
||||
options?.reasoningEffort || "medium",
|
||||
) as typeof options.reasoningEffort;
|
||||
params.reasoning = {
|
||||
effort: effort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
} else {
|
||||
if (model.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function getServiceTierCostMultiplier(
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
): number {
|
||||
switch (serviceTier) {
|
||||
case "flex":
|
||||
return 0.5;
|
||||
case "priority":
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
function applyServiceTierPricing(
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) {
|
||||
const multiplier = getServiceTierCostMultiplier(serviceTier);
|
||||
if (multiplier === 1) return;
|
||||
|
||||
usage.cost.input *= multiplier;
|
||||
usage.cost.output *= multiplier;
|
||||
usage.cost.cacheRead *= multiplier;
|
||||
usage.cost.cacheWrite *= multiplier;
|
||||
usage.cost.total =
|
||||
usage.cost.input +
|
||||
usage.cost.output +
|
||||
usage.cost.cacheRead +
|
||||
usage.cost.cacheWrite;
|
||||
}
|
||||
|
|
@ -1,215 +0,0 @@
|
|||
/**
|
||||
* Shared utilities for OpenAI Completions and Responses providers.
|
||||
*
|
||||
* This module consolidates code that is identical (or near-identical) across
|
||||
* openai-completions.ts and openai-responses.ts to reduce duplication while
|
||||
* preserving the subtle behavioural differences of each provider.
|
||||
*/
|
||||
|
||||
import type OpenAI from "openai";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
StopReason,
|
||||
} from "../types.js";
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import {
|
||||
buildCopilotDynamicHeaders,
|
||||
hasCopilotVisionInput,
|
||||
} from "./github-copilot-headers.js";
|
||||
|
||||
// =============================================================================
|
||||
// Lazy SDK loading
|
||||
// =============================================================================
|
||||
|
||||
let _openAIClass: typeof OpenAI | undefined;
|
||||
|
||||
/**
|
||||
* Lazy-load the OpenAI SDK default export.
|
||||
* Shared between Completions and Responses providers so the module is only
|
||||
* imported once regardless of which provider is used first.
|
||||
*/
|
||||
export async function getOpenAIClass(): Promise<typeof OpenAI> {
|
||||
if (!_openAIClass) {
|
||||
const mod = await import("openai");
|
||||
_openAIClass = mod.default;
|
||||
}
|
||||
return _openAIClass;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Client creation
|
||||
// =============================================================================
|
||||
|
||||
export interface CreateClientOptions {
|
||||
/** Extra headers from the options bag (merged last, can override defaults). */
|
||||
optionsHeaders?: Record<string, string>;
|
||||
/** Provider-specific client constructor options (e.g. timeout, maxRetries for Z.ai). */
|
||||
extraClientOptions?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an OpenAI SDK client instance.
|
||||
*
|
||||
* Handles:
|
||||
* - API key resolution (explicit > env)
|
||||
* - GitHub Copilot dynamic headers
|
||||
* - Options header merging
|
||||
* - Lazy SDK loading
|
||||
*/
|
||||
export async function createOpenAIClient<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
apiKey: string | undefined,
|
||||
options?: CreateClientOptions,
|
||||
): Promise<OpenAI> {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (options?.optionsHeaders) {
|
||||
Object.assign(headers, options.optionsHeaders);
|
||||
}
|
||||
|
||||
const OpenAIClass = await getOpenAIClass();
|
||||
return new OpenAIClass({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
...options?.extraClientOptions,
|
||||
});
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Initial output construction
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Build the initial AssistantMessage output object used by all OpenAI stream
|
||||
* handlers. Every field is initialised to its zero/default value.
|
||||
*/
|
||||
export function buildInitialOutput<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Stream lifecycle helpers
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Shared post-stream checks. Call after the provider-specific stream loop
|
||||
* finishes successfully (before pushing the "done" event).
|
||||
*
|
||||
* Throws if the request was aborted or the output indicates an error.
|
||||
*/
|
||||
export function assertStreamSuccess(
|
||||
output: AssistantMessage,
|
||||
signal?: AbortSignal,
|
||||
): void {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit the "done" event and close the stream.
|
||||
*/
|
||||
export function finalizeStream(
|
||||
stream: AssistantMessageEventStream,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason as Extract<
|
||||
StopReason,
|
||||
"stop" | "length" | "toolUse"
|
||||
>,
|
||||
message: output,
|
||||
});
|
||||
stream.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle an error during streaming.
|
||||
*
|
||||
* Cleans up any leftover `index` properties on content blocks, sets the
|
||||
* appropriate stop reason and error message, then emits the "error" event.
|
||||
*/
|
||||
export function handleStreamError(
|
||||
stream: AssistantMessageEventStream,
|
||||
output: AssistantMessage,
|
||||
error: unknown,
|
||||
signal?: AbortSignal,
|
||||
/** Extra error metadata to append (e.g. OpenRouter raw metadata). */
|
||||
extraMessage?: string,
|
||||
): void {
|
||||
for (const block of output.content)
|
||||
delete (block as { index?: number }).index;
|
||||
output.stopReason = signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
if (extraMessage) output.errorMessage += `\n${extraMessage}`;
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Reasoning helpers
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Clamp reasoning effort for models that don't support all levels.
|
||||
* gpt-5.x models don't support "minimal" -- map to "low".
|
||||
*
|
||||
* Used by both openai-responses.ts and azure-openai-responses.ts.
|
||||
*/
|
||||
export function clampReasoningForModel(
|
||||
modelName: string,
|
||||
effort: string,
|
||||
): string {
|
||||
const name = modelName.includes("/")
|
||||
? modelName.split("/").pop()!
|
||||
: modelName;
|
||||
if (name.startsWith("gpt-5") && effort === "minimal") return "low";
|
||||
return effort;
|
||||
}
|
||||
|
|
@ -1,280 +0,0 @@
|
|||
// SF — Provider Capabilities Registry Tests (ADR-005 Phase 1)
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { describe, test } from "vitest";
|
||||
import {
|
||||
sanitizeMistralToolParameters,
|
||||
shouldUseMistralReasoningPromptMode,
|
||||
} from "./mistral.js";
|
||||
import {
|
||||
getProviderCapabilities,
|
||||
getRegisteredApis,
|
||||
getUnsupportedFeatures,
|
||||
mergeCapabilityOverrides,
|
||||
PROVIDER_CAPABILITIES,
|
||||
} from "./provider-capabilities.js";
|
||||
|
||||
// ─── Registry Completeness ──────────────────────────────────────────────────
|
||||
|
||||
describe("PROVIDER_CAPABILITIES registry", () => {
|
||||
const EXPECTED_APIS = [
|
||||
"anthropic-messages",
|
||||
"anthropic-vertex",
|
||||
"openai-responses",
|
||||
"azure-openai-responses",
|
||||
"openai-codex-responses",
|
||||
"openai-completions",
|
||||
"google-generative-ai",
|
||||
"google-gemini-cli",
|
||||
"google-vertex",
|
||||
"mistral-conversations",
|
||||
"bedrock-converse-stream",
|
||||
"ollama-chat",
|
||||
];
|
||||
|
||||
test("covers all expected API providers", () => {
|
||||
for (const api of EXPECTED_APIS) {
|
||||
assert.ok(
|
||||
PROVIDER_CAPABILITIES[api],
|
||||
`Missing capability entry for API: ${api}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test("getRegisteredApis returns all entries", () => {
|
||||
const registered = getRegisteredApis();
|
||||
for (const api of EXPECTED_APIS) {
|
||||
assert.ok(registered.includes(api), `getRegisteredApis missing: ${api}`);
|
||||
}
|
||||
});
|
||||
|
||||
test("all entries have required fields", () => {
|
||||
for (const [api, caps] of Object.entries(PROVIDER_CAPABILITIES)) {
|
||||
assert.equal(typeof caps.toolCalling, "boolean", `${api}.toolCalling`);
|
||||
assert.equal(typeof caps.maxTools, "number", `${api}.maxTools`);
|
||||
assert.equal(
|
||||
typeof caps.imageToolResults,
|
||||
"boolean",
|
||||
`${api}.imageToolResults`,
|
||||
);
|
||||
assert.equal(
|
||||
typeof caps.structuredOutput,
|
||||
"boolean",
|
||||
`${api}.structuredOutput`,
|
||||
);
|
||||
assert.ok(caps.toolCallIdFormat, `${api}.toolCallIdFormat`);
|
||||
assert.equal(
|
||||
typeof caps.toolCallIdFormat.maxLength,
|
||||
"number",
|
||||
`${api}.toolCallIdFormat.maxLength`,
|
||||
);
|
||||
assert.ok(
|
||||
caps.toolCallIdFormat.allowedChars instanceof RegExp,
|
||||
`${api}.toolCallIdFormat.allowedChars`,
|
||||
);
|
||||
assert.ok(
|
||||
["full", "text-only", "none"].includes(caps.thinkingPersistence),
|
||||
`${api}.thinkingPersistence is "${caps.thinkingPersistence}"`,
|
||||
);
|
||||
assert.ok(
|
||||
Array.isArray(caps.unsupportedSchemaFeatures),
|
||||
`${api}.unsupportedSchemaFeatures`,
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Provider-specific Values ───────────────────────────────────────────────
|
||||
|
||||
describe("provider-specific capabilities", () => {
|
||||
test("Anthropic supports full thinking persistence", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["anthropic-messages"].thinkingPersistence,
|
||||
"full",
|
||||
);
|
||||
});
|
||||
|
||||
test("Anthropic supports image tool results", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["anthropic-messages"].imageToolResults,
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
test("Anthropic tool call ID is 64 chars max", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["anthropic-messages"].toolCallIdFormat.maxLength,
|
||||
64,
|
||||
);
|
||||
});
|
||||
|
||||
test("Mistral tool call ID is 9 chars max", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["mistral-conversations"].toolCallIdFormat.maxLength,
|
||||
9,
|
||||
);
|
||||
});
|
||||
|
||||
test("Mistral has no thinking persistence", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["mistral-conversations"].thinkingPersistence,
|
||||
"none",
|
||||
);
|
||||
});
|
||||
|
||||
test("Mistral reasoning prompt mode is limited to Magistral models", () => {
|
||||
const baseModel = {
|
||||
id: "mistral-small-latest",
|
||||
reasoning: true,
|
||||
} as any;
|
||||
|
||||
assert.equal(
|
||||
shouldUseMistralReasoningPromptMode(baseModel, "medium"),
|
||||
false,
|
||||
);
|
||||
assert.equal(
|
||||
shouldUseMistralReasoningPromptMode(
|
||||
{ ...baseModel, id: "magistral-medium-latest" },
|
||||
"medium",
|
||||
),
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
test("Mistral tool schema drops TypeBox symbol metadata", () => {
|
||||
const kind = Symbol("TypeBox.Kind");
|
||||
const schema = {
|
||||
type: "object",
|
||||
required: ["path"],
|
||||
properties: {
|
||||
path: {
|
||||
type: "string",
|
||||
[kind]: "String",
|
||||
},
|
||||
},
|
||||
[kind]: "Object",
|
||||
};
|
||||
|
||||
const sanitized = sanitizeMistralToolParameters(schema);
|
||||
|
||||
assert.deepEqual(Object.getOwnPropertySymbols(sanitized), []);
|
||||
assert.deepEqual(
|
||||
Object.getOwnPropertySymbols((sanitized.properties as any).path),
|
||||
[],
|
||||
);
|
||||
assert.deepEqual(sanitized, {
|
||||
type: "object",
|
||||
required: ["path"],
|
||||
properties: {
|
||||
path: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test("Google does not support patternProperties", () => {
|
||||
assert.ok(
|
||||
PROVIDER_CAPABILITIES[
|
||||
"google-generative-ai"
|
||||
].unsupportedSchemaFeatures.includes("patternProperties"),
|
||||
);
|
||||
});
|
||||
|
||||
test("Google does not support const", () => {
|
||||
assert.ok(
|
||||
PROVIDER_CAPABILITIES[
|
||||
"google-generative-ai"
|
||||
].unsupportedSchemaFeatures.includes("const"),
|
||||
);
|
||||
});
|
||||
|
||||
test("OpenAI Responses does not support image tool results", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["openai-responses"].imageToolResults,
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
test("OpenAI Responses has text-only thinking persistence", () => {
|
||||
assert.equal(
|
||||
PROVIDER_CAPABILITIES["openai-responses"].thinkingPersistence,
|
||||
"text-only",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getProviderCapabilities ────────────────────────────────────────────────
|
||||
|
||||
describe("getProviderCapabilities", () => {
|
||||
test("returns known provider capabilities", () => {
|
||||
const caps = getProviderCapabilities("anthropic-messages");
|
||||
assert.equal(caps.toolCalling, true);
|
||||
assert.equal(caps.thinkingPersistence, "full");
|
||||
});
|
||||
|
||||
test("returns permissive defaults for unknown providers", () => {
|
||||
const caps = getProviderCapabilities("unknown-provider-xyz");
|
||||
assert.equal(caps.toolCalling, true);
|
||||
assert.equal(caps.imageToolResults, true);
|
||||
assert.deepEqual(caps.unsupportedSchemaFeatures, []);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getUnsupportedFeatures ─────────────────────────────────────────────────
|
||||
|
||||
describe("getUnsupportedFeatures", () => {
|
||||
test("returns unsupported features for Google", () => {
|
||||
const unsupported = getUnsupportedFeatures("google-generative-ai", [
|
||||
"patternProperties",
|
||||
"const",
|
||||
]);
|
||||
assert.deepEqual(unsupported, ["patternProperties", "const"]);
|
||||
});
|
||||
|
||||
test("returns empty for Anthropic with any features", () => {
|
||||
const unsupported = getUnsupportedFeatures("anthropic-messages", [
|
||||
"patternProperties",
|
||||
"const",
|
||||
]);
|
||||
assert.deepEqual(unsupported, []);
|
||||
});
|
||||
|
||||
test("returns empty for unknown provider", () => {
|
||||
const unsupported = getUnsupportedFeatures("unknown-xyz", [
|
||||
"patternProperties",
|
||||
]);
|
||||
assert.deepEqual(unsupported, []);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── mergeCapabilityOverrides ───────────────────────────────────────────────
|
||||
|
||||
describe("mergeCapabilityOverrides", () => {
|
||||
test("overrides individual fields", () => {
|
||||
const merged = mergeCapabilityOverrides("openai-responses", {
|
||||
imageToolResults: true,
|
||||
});
|
||||
assert.equal(merged.imageToolResults, true);
|
||||
// Non-overridden fields preserved
|
||||
assert.equal(merged.toolCalling, true);
|
||||
assert.equal(merged.thinkingPersistence, "text-only");
|
||||
});
|
||||
|
||||
test("deep-merges toolCallIdFormat", () => {
|
||||
const merged = mergeCapabilityOverrides("anthropic-messages", {
|
||||
toolCallIdFormat: { maxLength: 128 },
|
||||
});
|
||||
assert.equal(merged.toolCallIdFormat.maxLength, 128);
|
||||
// allowedChars preserved from base
|
||||
assert.ok(merged.toolCallIdFormat.allowedChars instanceof RegExp);
|
||||
});
|
||||
|
||||
test("uses permissive defaults for unknown provider", () => {
|
||||
const merged = mergeCapabilityOverrides("unknown-xyz", {
|
||||
imageToolResults: false,
|
||||
});
|
||||
assert.equal(merged.imageToolResults, false);
|
||||
assert.equal(merged.toolCalling, true); // from default
|
||||
});
|
||||
});
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
// SF — Provider Capabilities Registry (ADR-005 Phase 1)
|
||||
// Declarative registry of what each API provider supports, consolidating
|
||||
// scattered knowledge from *-shared.ts files into a queryable data structure.
|
||||
|
||||
// ─── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Declarative capability profile for an API provider.
|
||||
* Used by the model router to filter incompatible models and by the tool
|
||||
* system to adjust tool sets per provider.
|
||||
*/
|
||||
export interface ProviderCapabilities {
|
||||
/** Whether models from this provider support tool/function calling */
|
||||
toolCalling: boolean;
|
||||
/** Maximum number of tools the provider handles well (0 = unlimited) */
|
||||
maxTools: number;
|
||||
/** Whether tool results can contain images */
|
||||
imageToolResults: boolean;
|
||||
/** Whether the provider supports structured JSON output */
|
||||
structuredOutput: boolean;
|
||||
/** Tool call ID format constraints */
|
||||
toolCallIdFormat: {
|
||||
maxLength: number;
|
||||
allowedChars: RegExp;
|
||||
};
|
||||
/** Whether thinking/reasoning blocks are preserved cross-turn */
|
||||
thinkingPersistence: "full" | "text-only" | "none";
|
||||
/** Schema features NOT supported (tools using these get filtered) */
|
||||
unsupportedSchemaFeatures: string[];
|
||||
}
|
||||
|
||||
// ─── Registry ───────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Built-in provider capability profiles.
|
||||
*
|
||||
* Sources (consolidated from scattered *-shared.ts files):
|
||||
* - anthropic-shared.ts: normalizeToolCallId (64-char, [a-zA-Z0-9_-])
|
||||
* - openai-responses-shared.ts: ID normalization (64-char, fc_ prefix), image-in-tool-result workaround
|
||||
* - google-shared.ts: sanitizeSchemaForGoogle (patternProperties, const), requiresToolCallId
|
||||
* - mistral.ts: MISTRAL_TOOL_CALL_ID_LENGTH = 9
|
||||
* - amazon-bedrock.ts: normalizeToolCallId (64-char, [a-zA-Z0-9_-])
|
||||
*/
|
||||
export const PROVIDER_CAPABILITIES: Record<string, ProviderCapabilities> = {
|
||||
"anthropic-messages": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "full",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"anthropic-vertex": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "full",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"openai-responses": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false, // images sent as separate user message, not in tool result
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 512, allowedChars: /^.+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"azure-openai-responses": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 512, allowedChars: /^.+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"openai-codex-responses": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"openai-completions": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"google-generative-ai": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: ["patternProperties", "const"],
|
||||
},
|
||||
"google-gemini-cli": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: ["patternProperties", "const"],
|
||||
},
|
||||
"google-vertex": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: ["patternProperties", "const"],
|
||||
},
|
||||
"mistral-conversations": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 9, allowedChars: /^[a-zA-Z0-9]+$/ },
|
||||
thinkingPersistence: "none",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"bedrock-converse-stream": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true, // Bedrock supports image content blocks in tool results
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
"ollama-chat": {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: false,
|
||||
structuredOutput: false,
|
||||
toolCallIdFormat: { maxLength: 64, allowedChars: /^[a-zA-Z0-9_-]+$/ },
|
||||
thinkingPersistence: "none",
|
||||
unsupportedSchemaFeatures: [],
|
||||
},
|
||||
};
|
||||
|
||||
// ─── Default (permissive) profile for unknown providers ─────────────────────
|
||||
|
||||
const DEFAULT_CAPABILITIES: ProviderCapabilities = {
|
||||
toolCalling: true,
|
||||
maxTools: 0,
|
||||
imageToolResults: true,
|
||||
structuredOutput: true,
|
||||
toolCallIdFormat: { maxLength: 512, allowedChars: /^.+$/ },
|
||||
thinkingPersistence: "text-only",
|
||||
unsupportedSchemaFeatures: [],
|
||||
};
|
||||
|
||||
// ─── Public API ─────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Get capabilities for a provider API. Returns a permissive default for
|
||||
* unknown providers (preserving existing behavior per ADR-005 principle 5).
|
||||
*/
|
||||
export function getProviderCapabilities(api: string): ProviderCapabilities {
|
||||
return PROVIDER_CAPABILITIES[api] ?? DEFAULT_CAPABILITIES;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider supports all required schema features.
|
||||
* Returns the list of unsupported features (empty if all supported).
|
||||
*/
|
||||
export function getUnsupportedFeatures(
|
||||
api: string,
|
||||
requiredFeatures: string[],
|
||||
): string[] {
|
||||
const caps = getProviderCapabilities(api);
|
||||
return requiredFeatures.filter((f) =>
|
||||
caps.unsupportedSchemaFeatures.includes(f),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deep-merge user-provided capability overrides with built-in defaults.
|
||||
* Partial overrides merge with the built-in profile for the given API.
|
||||
*/
|
||||
export function mergeCapabilityOverrides(
|
||||
api: string,
|
||||
overrides: Partial<Omit<ProviderCapabilities, "toolCallIdFormat">> & {
|
||||
toolCallIdFormat?: Partial<ProviderCapabilities["toolCallIdFormat"]>;
|
||||
},
|
||||
): ProviderCapabilities {
|
||||
const base = getProviderCapabilities(api);
|
||||
return {
|
||||
...base,
|
||||
...overrides,
|
||||
toolCallIdFormat: overrides.toolCallIdFormat
|
||||
? { ...base.toolCallIdFormat, ...overrides.toolCallIdFormat }
|
||||
: base.toolCallIdFormat,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered API names in the capability registry.
|
||||
* Used by lint rules to verify all providers in register-builtins.ts
|
||||
* have corresponding capability entries.
|
||||
*/
|
||||
export function getRegisteredApis(): string[] {
|
||||
return Object.keys(PROVIDER_CAPABILITIES);
|
||||
}
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import type { BedrockOptions } from "./amazon-bedrock.js";
|
||||
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
|
||||
import {
|
||||
streamAnthropicVertex,
|
||||
streamSimpleAnthropicVertex,
|
||||
} from "./anthropic-vertex.js";
|
||||
import {
|
||||
streamAzureOpenAIResponses,
|
||||
streamSimpleAzureOpenAIResponses,
|
||||
} from "./azure-openai-responses.js";
|
||||
import { streamGoogle, streamSimpleGoogle } from "./google.js";
|
||||
import {
|
||||
streamGoogleGeminiCli,
|
||||
streamSimpleGoogleGeminiCli,
|
||||
} from "./google-gemini-cli.js";
|
||||
import {
|
||||
streamGoogleVertex,
|
||||
streamSimpleGoogleVertex,
|
||||
} from "./google-vertex.js";
|
||||
import { streamMistral, streamSimpleMistral } from "./mistral.js";
|
||||
import {
|
||||
streamOpenAICodexResponses,
|
||||
streamSimpleOpenAICodexResponses,
|
||||
} from "./openai-codex-responses.js";
|
||||
import {
|
||||
streamOpenAICompletions,
|
||||
streamSimpleOpenAICompletions,
|
||||
} from "./openai-completions.js";
|
||||
import {
|
||||
streamOpenAIResponses,
|
||||
streamSimpleOpenAIResponses,
|
||||
} from "./openai-responses.js";
|
||||
|
||||
interface BedrockProviderModule {
|
||||
streamBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: BedrockOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
streamSimpleBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
}
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
|
||||
|
||||
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
|
||||
|
||||
export function setBedrockProviderModule(module: BedrockProviderModule): void {
|
||||
bedrockProviderModuleOverride = module;
|
||||
}
|
||||
|
||||
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
|
||||
if (bedrockProviderModuleOverride) {
|
||||
return bedrockProviderModuleOverride;
|
||||
}
|
||||
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
|
||||
return module as BedrockProviderModule;
|
||||
}
|
||||
|
||||
function forwardStream(
|
||||
target: AssistantMessageEventStream,
|
||||
source: AsyncIterable<AssistantMessageEvent>,
|
||||
): void {
|
||||
(async () => {
|
||||
for await (const event of source) {
|
||||
target.push(event);
|
||||
}
|
||||
target.end();
|
||||
})();
|
||||
}
|
||||
|
||||
function createLazyLoadErrorMessage(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
error: unknown,
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream",
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function streamBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: BedrockOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
function streamSimpleBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamSimpleBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
function registerBuiltInApiProviders(): void {
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "mistral-conversations",
|
||||
stream: streamMistral,
|
||||
streamSimple: streamSimpleMistral,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "anthropic-vertex",
|
||||
stream: streamAnthropicVertex,
|
||||
streamSimple: streamSimpleAnthropicVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrockLazy,
|
||||
streamSimple: streamSimpleBedrockLazy,
|
||||
});
|
||||
}
|
||||
|
||||
export function resetApiProviders(): void {
|
||||
clearApiProviders();
|
||||
registerBuiltInApiProviders();
|
||||
}
|
||||
|
||||
registerBuiltInApiProviders();
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
import { shortHash } from "../utils/hash.js";
|
||||
|
||||
const MAX_TOOL_ARGUMENT_KEY_LENGTH = 256;
|
||||
const LONG_KEY_PREFIX = "tool_arg_";
|
||||
|
||||
function isObject(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function clampKey(base: string, maxLength: number): string {
|
||||
return base.length <= maxLength ? base : base.slice(0, maxLength);
|
||||
}
|
||||
|
||||
function makeSafeKey(
|
||||
key: string,
|
||||
maxLength: number,
|
||||
usedKeys: Set<string>,
|
||||
seen: Map<string, string>,
|
||||
): string {
|
||||
if (key.length <= maxLength && !usedKeys.has(key)) {
|
||||
return key;
|
||||
}
|
||||
|
||||
if (usedKeys.has(key)) {
|
||||
const base = `${LONG_KEY_PREFIX}${shortHash(key)}`;
|
||||
const safeBase = clampKey(base, maxLength);
|
||||
let next = 0;
|
||||
let candidate = safeBase;
|
||||
while (usedKeys.has(candidate)) {
|
||||
candidate = clampKey(`${safeBase}_${next}`, maxLength);
|
||||
next += 1;
|
||||
}
|
||||
seen.set(key, candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
const existing = seen.get(key);
|
||||
if (existing) {
|
||||
let next = 0;
|
||||
let candidate = existing;
|
||||
while (usedKeys.has(candidate)) {
|
||||
candidate = clampKey(`${existing}_${next}`, maxLength);
|
||||
next += 1;
|
||||
}
|
||||
return candidate;
|
||||
}
|
||||
|
||||
const base = `${LONG_KEY_PREFIX}${shortHash(key)}`;
|
||||
const safeBase = clampKey(base, maxLength);
|
||||
let next = 0;
|
||||
let candidate = safeBase;
|
||||
while (usedKeys.has(candidate)) {
|
||||
candidate = clampKey(`${safeBase}_${next}`, maxLength);
|
||||
next += 1;
|
||||
}
|
||||
seen.set(key, candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
export function sanitizeToolCallArgumentsForSerialization(
|
||||
args: unknown,
|
||||
maxKeyLength = MAX_TOOL_ARGUMENT_KEY_LENGTH,
|
||||
): unknown {
|
||||
if (isObject(args)) {
|
||||
const output: Record<string, unknown> = {};
|
||||
const usedKeys = new Set<string>();
|
||||
const replacements = new Map<string, string>();
|
||||
|
||||
for (const [key, value] of Object.entries(args)) {
|
||||
const safeKey = makeSafeKey(key, maxKeyLength, usedKeys, replacements);
|
||||
output[safeKey] = sanitizeToolCallArgumentsForSerialization(
|
||||
value,
|
||||
maxKeyLength,
|
||||
);
|
||||
usedKeys.add(safeKey);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
if (Array.isArray(args)) {
|
||||
return args.map((entry) =>
|
||||
sanitizeToolCallArgumentsForSerialization(entry, maxKeyLength),
|
||||
);
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import type { Model } from "../types.js";
|
||||
import { isAutoReasoning, resolveReasoningLevel } from "./simple-options.js";
|
||||
|
||||
function createModel(overrides: Partial<Model<any>> = {}): Model<any> {
|
||||
return {
|
||||
id: "test-model",
|
||||
name: "Test Model",
|
||||
provider: "openai",
|
||||
api: "openai-responses",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
contextWindow: 128_000,
|
||||
maxTokens: 16_384,
|
||||
input: ["text"],
|
||||
reasoning: true,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
},
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("simple-options reasoning helpers", () => {
|
||||
it("recognizes auto reasoning requests", () => {
|
||||
assert.equal(isAutoReasoning("auto"), true);
|
||||
assert.equal(isAutoReasoning("medium"), false);
|
||||
assert.equal(isAutoReasoning(undefined), false);
|
||||
});
|
||||
|
||||
it("maps auto to medium for reasoning-capable models", () => {
|
||||
assert.equal(resolveReasoningLevel(createModel(), "auto"), "medium");
|
||||
});
|
||||
|
||||
it("maps auto to undefined for models without reasoning support", () => {
|
||||
assert.equal(
|
||||
resolveReasoningLevel(createModel({ reasoning: false }), "auto"),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("passes through explicit reasoning levels unchanged", () => {
|
||||
assert.equal(resolveReasoningLevel(createModel(), "xhigh"), "xhigh");
|
||||
});
|
||||
});
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
import type {
|
||||
Api,
|
||||
Model,
|
||||
RequestedThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
ThinkingBudgets,
|
||||
ThinkingLevel,
|
||||
} from "../types.js";
|
||||
|
||||
export function buildBaseOptions(
|
||||
model: Model<Api>,
|
||||
options?: SimpleStreamOptions,
|
||||
apiKey?: string,
|
||||
): StreamOptions {
|
||||
return {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
||||
signal: options?.signal,
|
||||
apiKey: apiKey || options?.apiKey,
|
||||
cacheRetention: options?.cacheRetention,
|
||||
sessionId: options?.sessionId,
|
||||
headers: options?.headers,
|
||||
onPayload: options?.onPayload,
|
||||
maxRetryDelayMs: options?.maxRetryDelayMs,
|
||||
metadata: options?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
export function clampReasoning(
|
||||
effort: ThinkingLevel | undefined,
|
||||
): Exclude<ThinkingLevel, "xhigh"> | undefined {
|
||||
return effort === "xhigh" ? "high" : effort;
|
||||
}
|
||||
|
||||
export function isAutoReasoning(
|
||||
effort: RequestedThinkingLevel | undefined,
|
||||
): effort is Extract<RequestedThinkingLevel, "auto"> {
|
||||
return effort === "auto";
|
||||
}
|
||||
|
||||
export function resolveReasoningLevel(
|
||||
model: Model<Api>,
|
||||
effort: RequestedThinkingLevel | undefined,
|
||||
): ThinkingLevel | undefined {
|
||||
if (!effort || effort === "auto") {
|
||||
if (!model.reasoning) return undefined;
|
||||
return "medium";
|
||||
}
|
||||
return effort;
|
||||
}
|
||||
|
||||
export function adjustMaxTokensForThinking(
|
||||
baseMaxTokens: number,
|
||||
modelMaxTokens: number,
|
||||
reasoningLevel: ThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): { maxTokens: number; thinkingBudget: number } {
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...customBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
const level = clampReasoning(reasoningLevel)!;
|
||||
let thinkingBudget = budgets[level]!;
|
||||
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return { maxTokens, thinkingBudget };
|
||||
}
|
||||
|
|
@ -1,229 +0,0 @@
|
|||
// SF — ProviderSwitchReport Tests (ADR-005 Phase 3)
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { describe, test } from "vitest";
|
||||
import type { AssistantMessage, Message, Model, ToolCall } from "../types.js";
|
||||
import {
|
||||
createEmptyReport,
|
||||
hasTransformations,
|
||||
transformMessages,
|
||||
} from "./transform-messages.js";
|
||||
|
||||
// ─── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
function makeModel(overrides: Partial<Model<any>> = {}): Model<any> {
|
||||
return {
|
||||
id: "claude-sonnet-4-6",
|
||||
name: "Claude Sonnet 4.6",
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
baseUrl: "",
|
||||
reasoning: false,
|
||||
input: ["text", "image"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
...overrides,
|
||||
} as Model<any>;
|
||||
}
|
||||
|
||||
function makeAssistantMsg(
|
||||
overrides: Partial<AssistantMessage> = {},
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4-6",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── createEmptyReport / hasTransformations ─────────────────────────────────
|
||||
|
||||
describe("createEmptyReport", () => {
|
||||
test("creates report with zero counters", () => {
|
||||
const report = createEmptyReport("anthropic-messages", "openai-responses");
|
||||
assert.equal(report.fromApi, "anthropic-messages");
|
||||
assert.equal(report.toApi, "openai-responses");
|
||||
assert.equal(report.thinkingBlocksDropped, 0);
|
||||
assert.equal(report.thinkingBlocksDowngraded, 0);
|
||||
assert.equal(report.toolCallIdsRemapped, 0);
|
||||
assert.equal(report.syntheticToolResultsInserted, 0);
|
||||
assert.equal(report.thoughtSignaturesDropped, 0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasTransformations", () => {
|
||||
test("returns false for empty report", () => {
|
||||
const report = createEmptyReport("a", "b");
|
||||
assert.equal(hasTransformations(report), false);
|
||||
});
|
||||
|
||||
test("returns true when any counter is non-zero", () => {
|
||||
const report = createEmptyReport("a", "b");
|
||||
report.thinkingBlocksDropped = 1;
|
||||
assert.equal(hasTransformations(report), true);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Report Tracking in transformMessages ───────────────────────────────────
|
||||
|
||||
describe("transformMessages with report tracking", () => {
|
||||
test("tracks thinking blocks dropped for redacted cross-model", () => {
|
||||
const model = makeModel({
|
||||
id: "gpt-5",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
});
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
content: [
|
||||
{ type: "thinking", thinking: "", redacted: true },
|
||||
{ type: "text", text: "Hello" },
|
||||
],
|
||||
}),
|
||||
];
|
||||
const report = createEmptyReport("anthropic-messages", "openai-responses");
|
||||
transformMessages(messages, model, undefined, report);
|
||||
assert.equal(report.thinkingBlocksDropped, 1);
|
||||
});
|
||||
|
||||
test("tracks thinking blocks downgraded to plain text", () => {
|
||||
const model = makeModel({
|
||||
id: "gpt-5",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
});
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
content: [
|
||||
{ type: "thinking", thinking: "Let me think about this..." },
|
||||
{ type: "text", text: "Here is my answer" },
|
||||
],
|
||||
}),
|
||||
];
|
||||
const report = createEmptyReport("anthropic-messages", "openai-responses");
|
||||
transformMessages(messages, model, undefined, report);
|
||||
assert.equal(report.thinkingBlocksDowngraded, 1);
|
||||
});
|
||||
|
||||
test("tracks tool call IDs remapped", () => {
|
||||
const model = makeModel({
|
||||
id: "claude-sonnet-4-6",
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
});
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: "original-long-id-that-needs-normalization|with-special-chars",
|
||||
name: "bash",
|
||||
arguments: { command: "ls" },
|
||||
};
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
provider: "openai",
|
||||
api: "openai-responses",
|
||||
model: "gpt-5",
|
||||
content: [toolCall],
|
||||
}),
|
||||
];
|
||||
const normalizer = (id: string) =>
|
||||
id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
const report = createEmptyReport("openai-responses", "anthropic-messages");
|
||||
transformMessages(messages, model, normalizer, report);
|
||||
assert.equal(report.toolCallIdsRemapped, 1);
|
||||
});
|
||||
|
||||
test("tracks thought signatures dropped", () => {
|
||||
const model = makeModel({
|
||||
id: "claude-sonnet-4-6",
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
});
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: "tc_001",
|
||||
name: "bash",
|
||||
arguments: { command: "ls" },
|
||||
thoughtSignature: "some-opaque-signature",
|
||||
};
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
provider: "google",
|
||||
api: "google-generative-ai",
|
||||
model: "gemini-2.5-pro",
|
||||
content: [toolCall],
|
||||
}),
|
||||
];
|
||||
const report = createEmptyReport(
|
||||
"google-generative-ai",
|
||||
"anthropic-messages",
|
||||
);
|
||||
transformMessages(messages, model, undefined, report);
|
||||
assert.equal(report.thoughtSignaturesDropped, 1);
|
||||
});
|
||||
|
||||
test("tracks synthetic tool results inserted", () => {
|
||||
const model = makeModel();
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: "tc_orphan",
|
||||
name: "bash",
|
||||
arguments: { command: "ls" },
|
||||
};
|
||||
// Assistant message with tool call followed by another assistant (no tool result)
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
content: [toolCall, { type: "text", text: "Using bash" }],
|
||||
}),
|
||||
makeAssistantMsg({ content: [{ type: "text", text: "Next message" }] }),
|
||||
];
|
||||
const report = createEmptyReport(
|
||||
"anthropic-messages",
|
||||
"anthropic-messages",
|
||||
);
|
||||
transformMessages(messages, model, undefined, report);
|
||||
assert.equal(report.syntheticToolResultsInserted, 1);
|
||||
});
|
||||
|
||||
test("does not count transformations for same-model messages", () => {
|
||||
const model = makeModel();
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({
|
||||
content: [
|
||||
{ type: "thinking", thinking: "Let me think..." },
|
||||
{ type: "text", text: "Answer" },
|
||||
],
|
||||
}),
|
||||
];
|
||||
const report = createEmptyReport(
|
||||
"anthropic-messages",
|
||||
"anthropic-messages",
|
||||
);
|
||||
transformMessages(messages, model, undefined, report);
|
||||
assert.equal(report.thinkingBlocksDowngraded, 0);
|
||||
assert.equal(report.thinkingBlocksDropped, 0);
|
||||
});
|
||||
|
||||
test("works without report parameter (backward compatible)", () => {
|
||||
const model = makeModel();
|
||||
const messages: Message[] = [
|
||||
makeAssistantMsg({ content: [{ type: "text", text: "Hello" }] }),
|
||||
];
|
||||
// Should not throw
|
||||
const result = transformMessages(messages, model);
|
||||
assert.ok(Array.isArray(result));
|
||||
});
|
||||
});
|
||||
|
|
@ -1,307 +0,0 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Message,
|
||||
Model,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
|
||||
/**
|
||||
* Report of context transformations during a cross-provider switch (ADR-005 Phase 3).
|
||||
* Tracks what was lost or downgraded when replaying conversation history to a different provider.
|
||||
*/
|
||||
export interface ProviderSwitchReport {
|
||||
/** API of the messages being transformed from */
|
||||
fromApi: string;
|
||||
/** API of the target model */
|
||||
toApi: string;
|
||||
/** Number of thinking blocks completely dropped (redacted/encrypted, cross-model) */
|
||||
thinkingBlocksDropped: number;
|
||||
/** Number of thinking blocks downgraded from structured to plain text */
|
||||
thinkingBlocksDowngraded: number;
|
||||
/** Number of tool call IDs that were remapped/normalized */
|
||||
toolCallIdsRemapped: number;
|
||||
/** Number of synthetic tool results inserted for orphaned tool calls */
|
||||
syntheticToolResultsInserted: number;
|
||||
/** Number of thought signatures dropped (Google-specific opaque context) */
|
||||
thoughtSignaturesDropped: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an empty provider switch report.
|
||||
*/
|
||||
export function createEmptyReport(
|
||||
fromApi: string,
|
||||
toApi: string,
|
||||
): ProviderSwitchReport {
|
||||
return {
|
||||
fromApi,
|
||||
toApi,
|
||||
thinkingBlocksDropped: 0,
|
||||
thinkingBlocksDowngraded: 0,
|
||||
toolCallIdsRemapped: 0,
|
||||
syntheticToolResultsInserted: 0,
|
||||
thoughtSignaturesDropped: 0,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider switch report has any non-zero transformations.
|
||||
*/
|
||||
export function hasTransformations(report: ProviderSwitchReport): boolean {
|
||||
return (
|
||||
report.thinkingBlocksDropped > 0 ||
|
||||
report.thinkingBlocksDowngraded > 0 ||
|
||||
report.toolCallIdsRemapped > 0 ||
|
||||
report.syntheticToolResultsInserted > 0 ||
|
||||
report.thoughtSignaturesDropped > 0
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a report, run transformMessages, and log if non-empty.
|
||||
* Convenience wrapper for provider adapters (ADR-005).
|
||||
*/
|
||||
export function transformMessagesWithReport<TApi extends Api>(
|
||||
messages: Message[],
|
||||
model: Model<TApi>,
|
||||
normalizeToolCallId?: (
|
||||
id: string,
|
||||
model: Model<TApi>,
|
||||
source: AssistantMessage,
|
||||
) => string,
|
||||
sourceApi?: string,
|
||||
): Message[] {
|
||||
const report = createEmptyReport(sourceApi ?? "unknown", model.api);
|
||||
const result = transformMessages(
|
||||
messages,
|
||||
model,
|
||||
normalizeToolCallId,
|
||||
report,
|
||||
);
|
||||
if (hasTransformations(report)) {
|
||||
logProviderSwitchReport(report);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Log a non-empty ProviderSwitchReport as a debug-level warning. */
|
||||
function logProviderSwitchReport(report: ProviderSwitchReport): void {
|
||||
const parts: string[] = [
|
||||
`Provider switch ${report.fromApi} → ${report.toApi}:`,
|
||||
];
|
||||
if (report.thinkingBlocksDropped > 0)
|
||||
parts.push(`${report.thinkingBlocksDropped} thinking blocks dropped`);
|
||||
if (report.thinkingBlocksDowngraded > 0)
|
||||
parts.push(`${report.thinkingBlocksDowngraded} thinking blocks downgraded`);
|
||||
if (report.toolCallIdsRemapped > 0)
|
||||
parts.push(`${report.toolCallIdsRemapped} tool call IDs remapped`);
|
||||
if (report.syntheticToolResultsInserted > 0)
|
||||
parts.push(
|
||||
`${report.syntheticToolResultsInserted} synthetic tool results inserted`,
|
||||
);
|
||||
if (report.thoughtSignaturesDropped > 0)
|
||||
parts.push(`${report.thoughtSignaturesDropped} thought signatures dropped`);
|
||||
// Use process.stderr for debug output — this is observable in verbose/debug modes
|
||||
// without polluting stdout which may be used for structured output (RPC/MCP).
|
||||
if (process.env.SF_VERBOSE === "1" || process.env.PI_VERBOSE === "1") {
|
||||
process.stderr.write(`[provider-switch] ${parts.join(", ")}\n`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize tool call ID for cross-provider compatibility.
|
||||
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
|
||||
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
|
||||
*/
|
||||
export function transformMessages<TApi extends Api>(
|
||||
messages: Message[],
|
||||
model: Model<TApi>,
|
||||
normalizeToolCallId?: (
|
||||
id: string,
|
||||
model: Model<TApi>,
|
||||
source: AssistantMessage,
|
||||
) => string,
|
||||
report?: ProviderSwitchReport,
|
||||
): Message[] {
|
||||
// Build a map of original tool call IDs to normalized IDs
|
||||
const toolCallIdMap = new Map<string, string>();
|
||||
|
||||
// First pass: transform messages (thinking blocks, tool call ID normalization)
|
||||
const transformed = messages.map((msg) => {
|
||||
// User messages pass through unchanged
|
||||
if (msg.role === "user") {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Handle toolResult messages - normalize toolCallId if we have a mapping
|
||||
if (msg.role === "toolResult") {
|
||||
const normalizedId = toolCallIdMap.get(msg.toolCallId);
|
||||
if (normalizedId && normalizedId !== msg.toolCallId) {
|
||||
return { ...msg, toolCallId: normalizedId };
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Assistant messages need transformation check
|
||||
if (msg.role === "assistant") {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isSameModel =
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api &&
|
||||
assistantMsg.model === model.id;
|
||||
|
||||
const transformedContent = assistantMsg.content.flatMap((block) => {
|
||||
if (block.type === "thinking") {
|
||||
// Redacted thinking is opaque encrypted content, only valid for the same model.
|
||||
// Drop it for cross-model to avoid API errors.
|
||||
if (block.redacted) {
|
||||
if (!isSameModel && report) report.thinkingBlocksDropped++;
|
||||
return isSameModel ? block : [];
|
||||
}
|
||||
// For same model: keep thinking blocks with signatures (needed for replay)
|
||||
// even if the thinking text is empty (OpenAI encrypted reasoning)
|
||||
if (isSameModel && block.thinkingSignature) return block;
|
||||
// Skip empty thinking blocks, convert others to plain text
|
||||
if (!block.thinking || block.thinking.trim() === "") {
|
||||
if (!isSameModel && report) report.thinkingBlocksDropped++;
|
||||
return [];
|
||||
}
|
||||
if (isSameModel) return block;
|
||||
// Downgrade: structured thinking → plain text
|
||||
if (report) report.thinkingBlocksDowngraded++;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.thinking,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "text") {
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.text,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
let normalizedToolCall: ToolCall = toolCall;
|
||||
|
||||
if (!isSameModel && toolCall.thoughtSignature) {
|
||||
normalizedToolCall = { ...toolCall };
|
||||
delete (normalizedToolCall as { thoughtSignature?: string })
|
||||
.thoughtSignature;
|
||||
if (report) report.thoughtSignaturesDropped++;
|
||||
}
|
||||
|
||||
if (!isSameModel && normalizeToolCallId) {
|
||||
const normalizedId = normalizeToolCallId(
|
||||
toolCall.id,
|
||||
model,
|
||||
assistantMsg,
|
||||
);
|
||||
if (normalizedId !== toolCall.id) {
|
||||
toolCallIdMap.set(toolCall.id, normalizedId);
|
||||
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
|
||||
if (report) report.toolCallIdsRemapped++;
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedToolCall;
|
||||
}
|
||||
|
||||
return block;
|
||||
});
|
||||
|
||||
return {
|
||||
...assistantMsg,
|
||||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
|
||||
// Second pass: insert synthetic empty tool results for orphaned tool calls
|
||||
// This preserves thinking signatures and satisfies API requirements
|
||||
const result: Message[] = [];
|
||||
let pendingToolCalls: ToolCall[] = [];
|
||||
let existingToolResultIds = new Set<string>();
|
||||
|
||||
for (let i = 0; i < transformed.length; i++) {
|
||||
const msg = transformed[i];
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
if (report) report.syntheticToolResultsInserted++;
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
// Skip errored/aborted assistant messages entirely.
|
||||
// These are incomplete turns that shouldn't be replayed:
|
||||
// - May have partial content (reasoning without message, incomplete tool calls)
|
||||
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
|
||||
// - The model should retry from the last valid state
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (
|
||||
assistantMsg.stopReason === "error" ||
|
||||
assistantMsg.stopReason === "aborted"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track tool calls from this assistant message
|
||||
const toolCalls = assistantMsg.content.filter(
|
||||
(b) => b.type === "toolCall",
|
||||
) as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
pendingToolCalls = toolCalls;
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
existingToolResultIds.add(msg.toolCallId);
|
||||
result.push(msg);
|
||||
} else if (msg.role === "user") {
|
||||
// User message interrupts tool flow - insert synthetic results for orphaned calls
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
if (report) report.syntheticToolResultsInserted++;
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
result.push(msg);
|
||||
} else {
|
||||
result.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import "./providers/register-builtins.js";
|
||||
|
||||
import { getApiProvider } from "./api-registry.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
ProviderStreamOptions,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export { getEnvApiKey } from "./env-api-keys.js";
|
||||
|
||||
function resolveApiProvider(api: Api) {
|
||||
const provider = getApiProvider(api);
|
||||
if (!provider) {
|
||||
throw new Error(`No API provider registered for api: ${api}`);
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.stream(model, context, options as StreamOptions);
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.streamSimple(model, context, options);
|
||||
}
|
||||
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
|
|
@ -1,465 +0,0 @@
|
|||
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type KnownApi =
|
||||
| "openai-completions"
|
||||
| "mistral-conversations"
|
||||
| "openai-responses"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex-responses"
|
||||
| "anthropic-messages"
|
||||
| "anthropic-vertex"
|
||||
| "bedrock-converse-stream"
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex"
|
||||
| "ollama-chat";
|
||||
|
||||
export type Api = KnownApi | (string & {});
|
||||
|
||||
export type KnownProvider =
|
||||
| "amazon-bedrock"
|
||||
| "anthropic"
|
||||
| "anthropic-vertex"
|
||||
| "google"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex"
|
||||
| "openai"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex"
|
||||
| "github-copilot"
|
||||
| "xai"
|
||||
| "groq"
|
||||
| "cerebras"
|
||||
| "openrouter"
|
||||
| "vercel-ai-gateway"
|
||||
| "zai"
|
||||
| "mistral"
|
||||
| "minimax"
|
||||
| "minimax-cn"
|
||||
| "huggingface"
|
||||
| "opencode"
|
||||
| "opencode-go"
|
||||
| "kimi-coding"
|
||||
| "xiaomi"
|
||||
| "xiaomi-token-plan-ams"
|
||||
| "xiaomi-token-plan-sgp"
|
||||
| "xiaomi-token-plan-cn"
|
||||
| "alibaba-coding-plan"
|
||||
| "alibaba-dashscope"
|
||||
| "ollama"
|
||||
| "ollama-cloud"
|
||||
| "longcat";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
||||
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
export type RequestedThinkingLevel = "auto" | ThinkingLevel;
|
||||
|
||||
/** Token budgets for each thinking level (token-based providers only) */
|
||||
export interface ThinkingBudgets {
|
||||
minimal?: number;
|
||||
low?: number;
|
||||
medium?: number;
|
||||
high?: number;
|
||||
}
|
||||
|
||||
// Base options all providers share
|
||||
export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export type Transport = "sse" | "websocket" | "auto";
|
||||
|
||||
export interface StreamOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
* Providers that do not support this option ignore it.
|
||||
*/
|
||||
transport?: Transport;
|
||||
/**
|
||||
* Prompt cache retention preference. Providers map this to their supported values.
|
||||
* Default: "short".
|
||||
*/
|
||||
cacheRetention?: CacheRetention;
|
||||
/**
|
||||
* Optional session identifier for providers that support session-based caching.
|
||||
* Providers can use this to enable prompt caching, request routing, or other
|
||||
* session-aware features. Ignored by providers that don't support it.
|
||||
*/
|
||||
sessionId?: string;
|
||||
/**
|
||||
* Optional callback for inspecting or replacing provider payloads before sending.
|
||||
* Return undefined to keep the payload unchanged.
|
||||
*/
|
||||
onPayload?: (
|
||||
payload: unknown,
|
||||
model: Model<Api>,
|
||||
) => unknown | undefined | Promise<unknown | undefined>;
|
||||
/**
|
||||
* Optional custom HTTP headers to include in API requests.
|
||||
* Merged with provider defaults; can override default headers.
|
||||
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately
|
||||
* with an error containing the requested delay, allowing higher-level retry logic
|
||||
* to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
/**
|
||||
* Optional metadata to include in API requests.
|
||||
* Providers extract the fields they understand and ignore the rest.
|
||||
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
|
||||
*/
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
|
||||
|
||||
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
||||
export interface SimpleStreamOptions extends StreamOptions {
|
||||
reasoning?: RequestedThinkingLevel;
|
||||
/** Custom token budgets for thinking levels (token-based providers only) */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
}
|
||||
|
||||
// Generic StreamFunction with typed options
|
||||
export type StreamFunction<
|
||||
TApi extends Api = Api,
|
||||
TOptions extends StreamOptions = StreamOptions,
|
||||
> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: TOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface TextSignatureV1 {
|
||||
v: 1;
|
||||
id: string;
|
||||
phase?: "commentary" | "final_answer";
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
|
||||
}
|
||||
|
||||
export interface ThinkingContent {
|
||||
type: "thinking";
|
||||
thinking: string;
|
||||
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
|
||||
/** When true, the thinking content was redacted by safety filters. The opaque
|
||||
* encrypted payload is stored in `thinkingSignature` so it can be passed back
|
||||
* to the API for multi-turn continuity. */
|
||||
redacted?: boolean;
|
||||
}
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image";
|
||||
data: string; // base64 encoded image data
|
||||
mimeType: string; // e.g., "image/jpeg", "image/png"
|
||||
}
|
||||
|
||||
export interface ToolCall {
|
||||
type: "toolCall";
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
|
||||
}
|
||||
|
||||
/** Server-side tool use (e.g., Anthropic native web search). Executed by the API, not the client. */
|
||||
export interface ServerToolUseContent {
|
||||
type: "serverToolUse";
|
||||
id: string;
|
||||
name: string; // e.g., "web_search"
|
||||
input: unknown;
|
||||
}
|
||||
|
||||
/** Result of a server-side tool execution, paired with a ServerToolUseContent by toolUseId. */
|
||||
export interface WebSearchResultContent {
|
||||
type: "webSearchResult";
|
||||
toolUseId: string;
|
||||
/** Search results or error from the server. Opaque — stored for API replay. */
|
||||
content: unknown;
|
||||
}
|
||||
|
||||
export interface Usage {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
}
|
||||
|
||||
export type StopReason =
|
||||
| "stop"
|
||||
| "length"
|
||||
| "toolUse"
|
||||
| "pauseTurn"
|
||||
| "error"
|
||||
| "aborted";
|
||||
|
||||
export interface UserMessage {
|
||||
role: "user";
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface AssistantMessage {
|
||||
role: "assistant";
|
||||
content: (
|
||||
| TextContent
|
||||
| ThinkingContent
|
||||
| ToolCall
|
||||
| ServerToolUseContent
|
||||
| WebSearchResultContent
|
||||
)[];
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
model: string;
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
errorMessage?: string;
|
||||
/** Server-requested retry delay in milliseconds (from Retry-After or rate limit headers). */
|
||||
retryAfterMs?: number;
|
||||
/** Provider inference performance metrics (e.g. tokens/sec from local models). */
|
||||
inferenceMetrics?: InferenceMetrics;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
/** Inference performance metrics reported by providers that support it (e.g. Ollama). */
|
||||
export interface InferenceMetrics {
|
||||
/** Tokens generated per second during eval phase. */
|
||||
tokensPerSecond: number;
|
||||
/** Wall-clock duration of the full request in milliseconds. */
|
||||
totalDurationMs: number;
|
||||
/** Duration of the eval (generation) phase in milliseconds. */
|
||||
evalDurationMs: number;
|
||||
/** Duration of the prompt eval phase in milliseconds. */
|
||||
promptEvalDurationMs: number;
|
||||
}
|
||||
|
||||
export interface ToolResultMessage<TDetails = any> {
|
||||
role: "toolResult";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
content: (TextContent | ImageContent)[]; // Supports text and images
|
||||
details?: TDetails;
|
||||
isError: boolean;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
|
||||
|
||||
import type { TSchema } from "@sinclair/typebox";
|
||||
|
||||
export interface Tool<TParameters extends TSchema = TSchema> {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: TParameters;
|
||||
}
|
||||
|
||||
export interface Context {
|
||||
systemPrompt?: string;
|
||||
messages: Message[];
|
||||
tools?: Tool[];
|
||||
}
|
||||
|
||||
export type AssistantMessageEvent =
|
||||
| { type: "start"; partial: AssistantMessage }
|
||||
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "text_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "text_end";
|
||||
contentIndex: number;
|
||||
content: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "thinking_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "thinking_end";
|
||||
contentIndex: number;
|
||||
content: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "toolcall_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "toolcall_end";
|
||||
contentIndex: number;
|
||||
toolCall: ToolCall;
|
||||
partial: AssistantMessage;
|
||||
malformedArguments?: boolean;
|
||||
}
|
||||
| { type: "server_tool_use"; contentIndex: number; partial: AssistantMessage }
|
||||
| {
|
||||
type: "web_search_result";
|
||||
contentIndex: number;
|
||||
partial: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse" | "pauseTurn">;
|
||||
message: AssistantMessage;
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
error: AssistantMessage;
|
||||
};
|
||||
|
||||
/**
|
||||
* Compatibility settings for OpenAI-compatible completions APIs.
|
||||
* Use this to override URL-based auto-detection for custom providers.
|
||||
*/
|
||||
export interface OpenAICompletionsCompat {
|
||||
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
|
||||
supportsStore?: boolean;
|
||||
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
|
||||
supportsDeveloperRole?: boolean;
|
||||
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
|
||||
supportsReasoningEffort?: boolean;
|
||||
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
|
||||
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
|
||||
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
|
||||
supportsUsageInStreaming?: boolean;
|
||||
/** Which field to use for max tokens. Default: auto-detected from URL. */
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
|
||||
requiresToolResultName?: boolean;
|
||||
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
|
||||
requiresAssistantAfterToolResult?: boolean;
|
||||
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
|
||||
requiresThinkingAsText?: boolean;
|
||||
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
|
||||
thinkingFormat?: "openai" | "zai" | "qwen";
|
||||
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
|
||||
openRouterRouting?: OpenRouterRouting;
|
||||
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
|
||||
vercelGatewayRouting?: VercelGatewayRouting;
|
||||
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
|
||||
supportsStrictMode?: boolean;
|
||||
}
|
||||
|
||||
/** Compatibility settings for OpenAI Responses APIs. */
|
||||
export type OpenAIResponsesCompat = Record<keyof any, never>;
|
||||
|
||||
/**
|
||||
* OpenRouter provider routing preferences.
|
||||
* Controls which upstream providers OpenRouter routes requests to.
|
||||
* @see https://openrouter.ai/docs/provider-routing
|
||||
*/
|
||||
export interface OpenRouterRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway routing preferences.
|
||||
* Controls which upstream providers the gateway routes requests to.
|
||||
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
|
||||
*/
|
||||
export interface VercelGatewayRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider-agnostic capability declarations for a model.
|
||||
*
|
||||
* These fields allow models to self-declare supported features so that call
|
||||
* sites can read from metadata rather than pattern-matching on model IDs or
|
||||
* provider names. Add fields here as new cross-provider capabilities emerge.
|
||||
*/
|
||||
export interface ModelCapabilities {
|
||||
/** Whether the model supports xhigh thinking level. */
|
||||
supportsXhigh?: boolean;
|
||||
/**
|
||||
* Whether tool call IDs must be included and normalised in tool results for
|
||||
* this model. Relevant for models deployed cross-provider (e.g. Claude or
|
||||
* GPT variants via Google APIs) where the host API imposes stricter ID rules.
|
||||
*/
|
||||
requiresToolCallId?: boolean;
|
||||
/** Whether OpenAI-style service tiers (priority/flex) apply to this model. */
|
||||
supportsServiceTier?: boolean;
|
||||
/**
|
||||
* Approximate characters per token for this model.
|
||||
* Used as a fallback when an accurate tokenizer is unavailable.
|
||||
* If omitted, the provider-level default is used.
|
||||
*/
|
||||
charsPerToken?: number;
|
||||
/**
|
||||
* Whether this model's Anthropic-compatible thinking API accepts {"type":"enabled"}
|
||||
* without a budget_tokens field. When true, reasoning:"auto" sends no budget
|
||||
* and lets the model decide its own reasoning depth (e.g. Kimi via kimi-coding).
|
||||
*/
|
||||
thinkingNoBudget?: boolean;
|
||||
}
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number; // $/million tokens
|
||||
output: number; // $/million tokens
|
||||
cacheRead: number; // $/million tokens
|
||||
cacheWrite: number; // $/million tokens
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
|
||||
compat?: TApi extends "openai-completions"
|
||||
? OpenAICompletionsCompat
|
||||
: TApi extends "openai-responses"
|
||||
? OpenAIResponsesCompat
|
||||
: never;
|
||||
/**
|
||||
* Provider-agnostic capability declarations for this model.
|
||||
* Read these fields instead of pattern-matching on model IDs or provider names.
|
||||
*/
|
||||
capabilities?: ModelCapabilities;
|
||||
/** Opaque provider-specific options. Cast to the appropriate type in the provider's stream handler. */
|
||||
providerOptions?: Record<string, unknown>;
|
||||
}
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { parseAnthropicSSE } from "./event-stream.js";
|
||||
|
||||
function createMockResponse(chunks: string[]): Response {
|
||||
let index = 0;
|
||||
const encoder = new TextEncoder();
|
||||
const stream = new ReadableStream<Uint8Array>({
|
||||
pull(controller) {
|
||||
if (index < chunks.length) {
|
||||
controller.enqueue(encoder.encode(chunks[index++]));
|
||||
} else {
|
||||
controller.close();
|
||||
}
|
||||
},
|
||||
});
|
||||
return new Response(stream);
|
||||
}
|
||||
|
||||
describe("parseAnthropicSSE", () => {
|
||||
it("yields parsed JSON for known Anthropic events", async () => {
|
||||
const sse =
|
||||
"event: message_start\n" +
|
||||
'data: {"type":"message_start","message":{"id":"msg_1","role":"assistant","content":[],"model":"claude-3","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}\n' +
|
||||
"\n" +
|
||||
"event: content_block_start\n" +
|
||||
'data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}\n' +
|
||||
"\n" +
|
||||
"event: content_block_delta\n" +
|
||||
'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}\n' +
|
||||
"\n" +
|
||||
"event: content_block_stop\n" +
|
||||
'data: {"type":"content_block_stop","index":0}\n' +
|
||||
"\n" +
|
||||
"event: message_delta\n" +
|
||||
'data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":10,"output_tokens":1,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}\n' +
|
||||
"\n" +
|
||||
"event: message_stop\n" +
|
||||
'data: {"type":"message_stop"}\n' +
|
||||
"\n";
|
||||
|
||||
const response = createMockResponse([sse]);
|
||||
const events: unknown[] = [];
|
||||
for await (const event of parseAnthropicSSE(response)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert.equal(events.length, 6);
|
||||
assert.equal((events[0] as any).type, "message_start");
|
||||
assert.equal((events[1] as any).type, "content_block_start");
|
||||
assert.equal((events[2] as any).type, "content_block_delta");
|
||||
assert.equal((events[3] as any).type, "content_block_stop");
|
||||
assert.equal((events[4] as any).type, "message_delta");
|
||||
assert.equal((events[5] as any).type, "message_stop");
|
||||
});
|
||||
|
||||
it("silently drops unknown events (e.g. OpenAI-style done)", async () => {
|
||||
const sse =
|
||||
"event: message_start\n" +
|
||||
'data: {"type":"message_start","message":{"id":"msg_1","role":"assistant","content":[],"model":"claude-3","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}\n' +
|
||||
"\n" +
|
||||
"event: done\n" +
|
||||
"data: [DONE]\n" +
|
||||
"\n" +
|
||||
"event: content_block_start\n" +
|
||||
'data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}\n' +
|
||||
"\n";
|
||||
|
||||
const response = createMockResponse([sse]);
|
||||
const events: unknown[] = [];
|
||||
for await (const event of parseAnthropicSSE(response)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert.equal(events.length, 2);
|
||||
assert.equal((events[0] as any).type, "message_start");
|
||||
assert.equal((events[1] as any).type, "content_block_start");
|
||||
});
|
||||
|
||||
it("ignores ping events", async () => {
|
||||
const sse =
|
||||
"event: message_start\n" +
|
||||
'data: {"type":"message_start","message":{"id":"msg_1","role":"assistant","content":[],"model":"claude-3","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}\n' +
|
||||
"\n" +
|
||||
"event: ping\n" +
|
||||
"data: {}\n" +
|
||||
"\n" +
|
||||
"event: message_stop\n" +
|
||||
'data: {"type":"message_stop"}\n' +
|
||||
"\n";
|
||||
|
||||
const response = createMockResponse([sse]);
|
||||
const events: unknown[] = [];
|
||||
for await (const event of parseAnthropicSSE(response)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert.equal(events.length, 2);
|
||||
assert.equal((events[0] as any).type, "message_start");
|
||||
assert.equal((events[1] as any).type, "message_stop");
|
||||
});
|
||||
|
||||
it("handles chunked SSE data across multiple reads", async () => {
|
||||
const chunks = [
|
||||
"event: message_start\n",
|
||||
'data: {"type":"message_start","message":{"id":"msg_1","role":"assistant","content":[],"model":"claude-3","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}\n\n',
|
||||
"event: message_stop\n",
|
||||
'data: {"type":"message_stop"}\n\n',
|
||||
];
|
||||
|
||||
const response = createMockResponse(chunks);
|
||||
const events: unknown[] = [];
|
||||
for await (const event of parseAnthropicSSE(response)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert.equal(events.length, 2);
|
||||
assert.equal((events[0] as any).type, "message_start");
|
||||
assert.equal((events[1] as any).type, "message_stop");
|
||||
});
|
||||
|
||||
it("handles comment lines", async () => {
|
||||
const sse =
|
||||
": comment line\n" +
|
||||
"event: message_start\n" +
|
||||
'data: {"type":"message_start","message":{"id":"msg_1","role":"assistant","content":[],"model":"claude-3","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}\n' +
|
||||
"\n";
|
||||
|
||||
const response = createMockResponse([sse]);
|
||||
const events: unknown[] = [];
|
||||
for await (const event of parseAnthropicSSE(response)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert.equal(events.length, 1);
|
||||
assert.equal((events[0] as any).type, "message_start");
|
||||
});
|
||||
});
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
|
||||
|
||||
/** Known Anthropic SSE event types that we handle. Unknown events are silently dropped. */
|
||||
const KNOWN_ANTHROPIC_EVENTS = new Set([
|
||||
"message_start",
|
||||
"message_delta",
|
||||
"message_stop",
|
||||
"content_block_start",
|
||||
"content_block_delta",
|
||||
"content_block_stop",
|
||||
"ping",
|
||||
"error",
|
||||
]);
|
||||
|
||||
/**
|
||||
* Parse a raw SSE (Server-Sent Events) stream response into JSON events.
|
||||
*
|
||||
* Purpose: give us full control over SSE parsing so that non-Anthropic events
|
||||
* (e.g. OpenAI-style "done" events injected by proxies) are silently dropped
|
||||
* instead of corrupting the stream.
|
||||
*
|
||||
* Consumer: processAnthropicStream in anthropic-shared.ts.
|
||||
*/
|
||||
export async function* parseAnthropicSSE(
|
||||
response: Response,
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<unknown, void, unknown> {
|
||||
if (!response.body) {
|
||||
throw new Error("Attempted to iterate over a response with no body");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
let eventName: string | null = null;
|
||||
let dataLines: string[] = [];
|
||||
|
||||
function flushEvent(): unknown | undefined {
|
||||
if (eventName === null && dataLines.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
const data = dataLines.join("\n");
|
||||
const name = eventName ?? "";
|
||||
eventName = null;
|
||||
dataLines = [];
|
||||
|
||||
if (name === "ping") {
|
||||
return undefined;
|
||||
}
|
||||
if (name === "error") {
|
||||
try {
|
||||
return JSON.parse(data);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
if (!KNOWN_ANTHROPIC_EVENTS.has(name)) {
|
||||
// Silently drop unknown events (e.g. OpenAI-style "done" from proxies)
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
return JSON.parse(data);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function processLine(line: string): unknown | undefined {
|
||||
const trimmed = line.trim();
|
||||
if (trimmed === "") {
|
||||
// Empty line means end of an SSE event
|
||||
return flushEvent();
|
||||
}
|
||||
|
||||
if (trimmed.startsWith(":")) {
|
||||
// Comment line, ignore
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const colonIndex = trimmed.indexOf(":");
|
||||
if (colonIndex === -1) return undefined;
|
||||
|
||||
const field = trimmed.slice(0, colonIndex);
|
||||
const value = trimmed.slice(colonIndex + 1).trimStart();
|
||||
|
||||
if (field === "event") {
|
||||
eventName = value;
|
||||
} else if (field === "data") {
|
||||
dataLines.push(value);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let newlineIndex: number;
|
||||
while ((newlineIndex = buffer.indexOf("\n")) !== -1) {
|
||||
const line = buffer.slice(0, newlineIndex);
|
||||
buffer = buffer.slice(newlineIndex + 1);
|
||||
const event = processLine(line);
|
||||
if (event !== undefined) {
|
||||
yield event;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining buffer as a final line
|
||||
if (buffer.length > 0) {
|
||||
const event = processLine(buffer);
|
||||
if (event !== undefined) {
|
||||
yield event;
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any pending event
|
||||
const event = flushEvent();
|
||||
if (event !== undefined) {
|
||||
yield event;
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
||||
// Generic event stream class for async iteration
|
||||
export class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
private queue: T[] = [];
|
||||
private waiting: ((value: IteratorResult<T>) => void)[] = [];
|
||||
private done = false;
|
||||
private finalResultPromise: Promise<R>;
|
||||
private resolveFinalResult!: (result: R) => void;
|
||||
|
||||
constructor(
|
||||
private isComplete: (event: T) => boolean,
|
||||
private extractResult: (event: T) => R,
|
||||
) {
|
||||
this.finalResultPromise = new Promise((resolve) => {
|
||||
this.resolveFinalResult = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: T): void {
|
||||
if (this.done) return;
|
||||
|
||||
if (this.isComplete(event)) {
|
||||
this.done = true;
|
||||
this.resolveFinalResult(this.extractResult(event));
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: event, done: false });
|
||||
} else {
|
||||
this.queue.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
end(result?: R): void {
|
||||
this.done = true;
|
||||
if (result !== undefined) {
|
||||
this.resolveFinalResult(result);
|
||||
}
|
||||
// Notify all waiting consumers that we're done
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as any, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (true) {
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
} else if (this.done) {
|
||||
return;
|
||||
} else {
|
||||
const result = await new Promise<IteratorResult<T>>((resolve) =>
|
||||
this.waiting.push(resolve),
|
||||
);
|
||||
if (result.done) return;
|
||||
yield result.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result(): Promise<R> {
|
||||
return this.finalResultPromise;
|
||||
}
|
||||
}
|
||||
|
||||
export class AssistantMessageEventStream extends EventStream<
|
||||
AssistantMessageEvent,
|
||||
AssistantMessage
|
||||
> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") {
|
||||
return event.message;
|
||||
} else if (event.type === "error") {
|
||||
return event.error;
|
||||
}
|
||||
throw new Error("Unexpected event type for final result");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory function for AssistantMessageEventStream (for use by package consumers). */
|
||||
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
|
||||
return new AssistantMessageEventStream();
|
||||
}
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
/** Fast deterministic hash to shorten long strings */
|
||||
export function shortHash(str: string): string {
|
||||
let h1 = 0xdeadbeef;
|
||||
let h2 = 0x41c6ce57;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
const ch = str.charCodeAt(i);
|
||||
h1 = Math.imul(h1 ^ ch, 2654435761);
|
||||
h2 = Math.imul(h2 ^ ch, 1597334677);
|
||||
}
|
||||
h1 =
|
||||
Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^
|
||||
Math.imul(h2 ^ (h2 >>> 13), 3266489909);
|
||||
h2 =
|
||||
Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^
|
||||
Math.imul(h1 ^ (h1 >>> 13), 3266489909);
|
||||
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
|
||||
}
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
import { parseStreamingJson as nativeParseStreamingJson } from "@singularity-forge/native";
|
||||
import {
|
||||
hasXmlParameterTags,
|
||||
hasYamlBulletLists,
|
||||
repairToolJsonWithReport,
|
||||
} from "./repair-tool-json.js";
|
||||
|
||||
/**
|
||||
* Attempts to parse potentially incomplete JSON during streaming.
|
||||
* Always returns a valid object, even if the JSON is incomplete.
|
||||
*
|
||||
* Uses the native Rust streaming JSON parser for performance.
|
||||
* Falls back to YAML bullet-list repair when the native parser
|
||||
* returns an empty object from input that contains YAML-style
|
||||
* bullet lists copied from template formatting (#2660).
|
||||
*
|
||||
* @param partialJson The partial JSON string from streaming
|
||||
* @returns Parsed object or empty object if parsing fails
|
||||
*/
|
||||
export function parseStreamingJson<T = any>(
|
||||
partialJson: string | undefined,
|
||||
): T {
|
||||
if (!partialJson || partialJson.trim() === "") {
|
||||
return {} as T;
|
||||
}
|
||||
if (looksLikeIncompleteObjectValue(partialJson)) {
|
||||
return {} as T;
|
||||
}
|
||||
|
||||
// Fast path: try native streaming parser first
|
||||
const result = nativeParseStreamingJson<T>(partialJson);
|
||||
|
||||
// XML parameter tags can be trapped inside otherwise valid JSON strings,
|
||||
// so run repair before trusting the native parse result.
|
||||
if (hasXmlParameterTags(partialJson)) {
|
||||
try {
|
||||
return JSON.parse(repairToolJsonWithReport(partialJson).output) as T;
|
||||
} catch {
|
||||
// Fall through to the native parser result on incomplete partials
|
||||
}
|
||||
}
|
||||
|
||||
// If the native parser returned a non-empty result, use it.
|
||||
// Only attempt repair when the result is empty AND the input looks like a
|
||||
// complete malformed object or YAML-shaped map. This avoids inventing
|
||||
// values for ordinary incomplete streaming chunks.
|
||||
if (
|
||||
result &&
|
||||
typeof result === "object" &&
|
||||
Object.keys(result as object).length === 0 &&
|
||||
shouldAttemptRepair(partialJson)
|
||||
) {
|
||||
try {
|
||||
return JSON.parse(repairToolJsonWithReport(partialJson).output) as T;
|
||||
} catch {
|
||||
// Repair failed — return the empty object from native parser
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function looksLikeIncompleteObjectValue(input: string): boolean {
|
||||
const trimmed = input.trim();
|
||||
return trimmed.startsWith("{") && /:\s*$/.test(trimmed);
|
||||
}
|
||||
|
||||
function shouldAttemptRepair(input: string): boolean {
|
||||
if (hasXmlParameterTags(input) || hasYamlBulletLists(input)) return true;
|
||||
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed) return false;
|
||||
if (
|
||||
(trimmed.startsWith("{") && trimmed.endsWith("}")) ||
|
||||
(trimmed.startsWith("[") && trimmed.endsWith("]"))
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Full YAML map/list tool arguments from weaker models. Require a newline
|
||||
// so normal prose with a colon does not get parsed as a scalar/map.
|
||||
return (
|
||||
/^[A-Za-z_][A-Za-z0-9_-]*\s*:/m.test(trimmed) && trimmed.includes("\n")
|
||||
);
|
||||
}
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { test } from "vitest";
|
||||
|
||||
import type { Api, Model } from "../../types.js";
|
||||
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||
import type { OAuthCredentials } from "./index.js";
|
||||
|
||||
function makeModel(provider: string, id: string): Model<Api> {
|
||||
return {
|
||||
id,
|
||||
name: id,
|
||||
api: "openai-completions",
|
||||
provider,
|
||||
baseUrl: `${provider}:`,
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16384,
|
||||
};
|
||||
}
|
||||
|
||||
function makeCredentials(
|
||||
overrides: Partial<
|
||||
OAuthCredentials & {
|
||||
modelLimits?: Record<
|
||||
string,
|
||||
{ contextWindow: number; maxTokens: number }
|
||||
>;
|
||||
}
|
||||
> = {},
|
||||
) {
|
||||
return {
|
||||
type: "oauth" as const,
|
||||
access: "copilot-token",
|
||||
refresh: "refresh-token",
|
||||
expires: Date.now() + 60_000,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
test("githubCopilotOAuthProvider.modifyModels filters unavailable copilot models (#3849)", () => {
|
||||
const models = [
|
||||
makeModel("github-copilot", "gpt-5"),
|
||||
makeModel("github-copilot", "claude-sonnet-4"),
|
||||
makeModel("openai", "gpt-4.1"),
|
||||
];
|
||||
|
||||
assert.ok(
|
||||
githubCopilotOAuthProvider.modifyModels,
|
||||
"github copilot provider should expose modifyModels",
|
||||
);
|
||||
const modified = githubCopilotOAuthProvider.modifyModels(
|
||||
models,
|
||||
makeCredentials({
|
||||
modelLimits: {
|
||||
"gpt-5": { contextWindow: 256000, maxTokens: 32000 },
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
assert.deepEqual(
|
||||
modified.map((model) => `${model.provider}/${model.id}`),
|
||||
["github-copilot/gpt-5", "openai/gpt-4.1"],
|
||||
);
|
||||
|
||||
const copilotModel = modified.find(
|
||||
(model) => model.provider === "github-copilot" && model.id === "gpt-5",
|
||||
);
|
||||
assert.ok(copilotModel, "available copilot model should remain");
|
||||
assert.equal(copilotModel.contextWindow, 256000);
|
||||
assert.equal(copilotModel.maxTokens, 32000);
|
||||
assert.match(copilotModel.baseUrl, /githubcopilot\.com/);
|
||||
});
|
||||
|
||||
test("githubCopilotOAuthProvider.modifyModels keeps all copilot models when limits are unavailable", () => {
|
||||
const models = [
|
||||
makeModel("github-copilot", "gpt-5"),
|
||||
makeModel("github-copilot", "claude-sonnet-4"),
|
||||
];
|
||||
|
||||
assert.ok(
|
||||
githubCopilotOAuthProvider.modifyModels,
|
||||
"github copilot provider should expose modifyModels",
|
||||
);
|
||||
const modified = githubCopilotOAuthProvider.modifyModels(
|
||||
models,
|
||||
makeCredentials(),
|
||||
);
|
||||
|
||||
assert.equal(
|
||||
modified.length,
|
||||
2,
|
||||
"lack of limits should not hide every copilot model",
|
||||
);
|
||||
assert.ok(modified.every((model) => model.provider === "github-copilot"));
|
||||
assert.ok(
|
||||
modified.every((model) => model.baseUrl.includes("githubcopilot.com")),
|
||||
);
|
||||
});
|
||||
|
|
@ -1,613 +0,0 @@
|
|||
/**
|
||||
* GitHub Copilot OAuth flow
|
||||
*
|
||||
* UPSTREAM AUDIT (2026-05-02): STAY HAND-ROLLED
|
||||
*
|
||||
* Candidate: @octokit/auth-oauth-device (v8.0.3)
|
||||
* Coverage: device-code initiation + authorization_pending/slow_down polling — the
|
||||
* ~120 LOC in startDeviceFlow + pollForGitHubAccessToken only.
|
||||
* Why we're not delegating:
|
||||
* 1. AbortSignal cancellation — the library has no signal/abort support; our
|
||||
* abortableSleep + signal checks are load-bearing for the login-cancel UX.
|
||||
* 2. 74% of this file is Copilot-proprietary with no upstream equivalent:
|
||||
* - copilot_internal/v2/token exchange (refreshGitHubCopilotToken)
|
||||
* - proxy-ep token parsing / base-URL derivation (getBaseUrlFromToken)
|
||||
* - Model policy enablement (enableAllGitHubCopilotModels)
|
||||
* - Model limits fetch (fetchCopilotModelLimits)
|
||||
* - Enterprise domain normalization (normalizeDomain / getUrls)
|
||||
* 3. The library would add a dependency + lose abort support for a ~26% LOC
|
||||
* reduction in a 460-line file — not worth it.
|
||||
*
|
||||
* Re-audit trigger: if @octokit/auth-oauth-device adds AbortSignal support AND
|
||||
* the Copilot-specific surface area shrinks (e.g., models API becomes public SDK).
|
||||
*
|
||||
* UPSTREAM AUDIT (2026-05-02): opencode-copilot-auth + three other candidates — STAY HAND-ROLLED
|
||||
*
|
||||
* Four packages inspected (full source read for each):
|
||||
*
|
||||
* 1. opencode-copilot-auth@0.0.12 (thdxr / ironbay.co; 13.6 kB unpacked; 0 runtime deps;
|
||||
* latest version 2026-01-11; single maintainer; pre-1.0 with 11 versions since Aug 2025)
|
||||
* Source: /tmp/package/index.mjs (extracted from tarball)
|
||||
* Exports: one named export — CopilotAuthPlugin({ client }) — an opencode plugin factory.
|
||||
* ALL logic is inlined inside that single async function and is NOT separately callable:
|
||||
* - authorize() [lines 210-299]: device-code initiation + polling loop. Covers our
|
||||
* startDeviceFlow + pollForGitHubAccessToken. BUT: (a) infinite while(true) loop with
|
||||
* no expiry deadline (our pollForGitHubAccessToken has Date.now() < deadline guard);
|
||||
* (b) no AbortSignal support — cannot be cancelled; (c) missing slow_down interval
|
||||
* back-off (our file handles slow_down by adding 5 s to intervalMs); (d) returns
|
||||
* { type:"success", refresh, access:"", expires:0 } — defers the copilot-token
|
||||
* exchange to the loader, not the authorize step.
|
||||
* - loader() [lines 45-165]: inline copilot_internal/v2/token refresh + expiry check.
|
||||
* Covers our refreshGitHubCopilotToken. BUT: immediately calls client.auth.set() to
|
||||
* persist — the storage call cannot be skipped; there is no way to get the token
|
||||
* without also writing it through opencode's auth API.
|
||||
* - Does NOT cover: proxy-ep URL parsing (hardcodes api.githubcopilot.com), model
|
||||
* policy enablement, fetchCopilotModelLimits, or AbortSignal cancellation.
|
||||
* Net coverage of our 480 LOC: ~30% (device-code dance + token exchange) — below the
|
||||
* 60% threshold AND the usable subset requires surgery to detach from client.auth.set().
|
||||
* Risk factors: pre-1.0, single maintainer, Proprietary license (not MIT/Apache).
|
||||
*
|
||||
* 2. copilot-api@0.7.0 (ericc-ch / echristian; 171.5 kB unpacked; 11 deps)
|
||||
* Source: dist/main.js — starts with #!/usr/bin/env node shebang.
|
||||
* It is a CLI proxy server (hono + srvx), not a library. Exports: `export { }` (empty).
|
||||
* All auth code (device flow, copilot_internal/v2/token, refresh loop) is internal to
|
||||
* the CLI's start command and not callable as a module. Zero reuse possible.
|
||||
*
|
||||
* 3. @github/copilot-language-server@1.480.0 (official GitHub; 134 MB unpacked)
|
||||
* dist/api/types.d.ts exports only ContextProviderApiV1 (a VS Code extension host API
|
||||
* for injecting prompt context). The package ships a single language-server binary;
|
||||
* no OAuth functions are exported or accessible programmatically.
|
||||
*
|
||||
* 4. @octokit/auth-oauth-device@8.0.3 — already audited above (26% coverage, no AbortSignal).
|
||||
*
|
||||
* Conclusion: No candidate reaches the 60% coverage bar. The closest (opencode-copilot-auth)
|
||||
* covers ~30% and cannot be used without its storage side-effect. STAY HAND-ROLLED.
|
||||
*
|
||||
* CREDS-FILE AUDIT (2026-05-02): STAY HAND-ROLLED — device-code dance still needed
|
||||
*
|
||||
* Investigated five candidate file sources for a "consume existing token" fast-path:
|
||||
* 1. ~/.copilot/session-state/*.jsonl — conversation history (type/data/id/timestamp),
|
||||
* no OAuth tokens of any kind.
|
||||
* 2. ~/.config/github-copilot/hosts.json and apps.json (Neovim plugin pattern) —
|
||||
* neither file nor directory exists on this machine.
|
||||
* 3. ~/.config/gh/hosts.yml — contains a gho_* token (40 chars) scoped to
|
||||
* repo/gist/read:org etc.; copilot_internal scope absent. Exchange against
|
||||
* GET api.github.com/copilot_internal/v2/token returns HTTP 404 "Not Found".
|
||||
* 4. ~/.maschine/copilot-token.json — OUR OWN app's cache (singularity/machine
|
||||
* CopilotSubscriptionProvider). Has { githubToken, copilotToken, expiresAt,
|
||||
* refreshIn }. The stored githubToken IS a Copilot-scoped gho_* (40 chars)
|
||||
* and DOES exchange successfully (HTTP 200, fresh 353-char proxy-ep token).
|
||||
* However: (a) that file is written by singularity/machine after device-flow
|
||||
* login there, not by a third-party tool we can rely on; (b) it was last
|
||||
* written 2025-12-30, so on a fresh machine it won't exist; (c) consuming it
|
||||
* here would create cross-app token sharing with no clear ownership boundary.
|
||||
* 5. opencode-copilot-auth@0.0.9 in ~/.bun/install/cache — a bun plugin that
|
||||
* also does the device-code dance and stores state via opencode's auth.set()
|
||||
* API; not a plain filesystem file we can read.
|
||||
*
|
||||
* Conclusion: No third-party-written creds file exists on this machine that carries
|
||||
* a Copilot-scoped token. The gh CLI token lacks the required Copilot scope.
|
||||
* The device-code dance (startDeviceFlow + pollForGitHubAccessToken) is the only
|
||||
* way to obtain a fresh Copilot-authorized github token for a new login.
|
||||
*
|
||||
* Future: if the user installs the Neovim Copilot plugin or VS Code's Copilot
|
||||
* extension writes ~/.config/github-copilot/apps.json (format:
|
||||
* { "github.com:Iv1.b507a08c87ecfe98": { "oauth_token": "gho_..." } }),
|
||||
* we could consume that as a fast-path that skips the device-code dance entirely
|
||||
* and goes straight to refreshGitHubCopilotToken(). Worth adding then.
|
||||
*/
|
||||
|
||||
import { getModels } from "../../models.js";
|
||||
import type { Api, Model } from "../../types.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
type CopilotCredentials = OAuthCredentials & {
|
||||
enterpriseUrl?: string;
|
||||
/** Model limits from the /models API, keyed by model ID */
|
||||
modelLimits?: Record<string, { contextWindow: number; maxTokens: number }>;
|
||||
};
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
||||
const COPILOT_HEADERS = {
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
"Editor-Version": "vscode/1.107.0",
|
||||
"Editor-Plugin-Version": "copilot-chat/0.35.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
} as const;
|
||||
|
||||
type DeviceCodeResponse = {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
interval: number;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
type DeviceTokenSuccessResponse = {
|
||||
access_token: string;
|
||||
token_type?: string;
|
||||
scope?: string;
|
||||
};
|
||||
|
||||
type DeviceTokenErrorResponse = {
|
||||
error: string;
|
||||
error_description?: string;
|
||||
interval?: number;
|
||||
};
|
||||
|
||||
export function normalizeDomain(input: string): string | null {
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const url = trimmed.includes("://")
|
||||
? new URL(trimmed)
|
||||
: new URL(`https://${trimmed}`);
|
||||
return url.hostname;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function getUrls(domain: string): {
|
||||
deviceCodeUrl: string;
|
||||
accessTokenUrl: string;
|
||||
copilotTokenUrl: string;
|
||||
} {
|
||||
return {
|
||||
deviceCodeUrl: `https://${domain}/login/device/code`,
|
||||
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
|
||||
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the proxy-ep from a Copilot token and convert to API base URL.
|
||||
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
||||
* Returns API URL like https://api.individual.githubcopilot.com
|
||||
*/
|
||||
function getBaseUrlFromToken(token: string): string | null {
|
||||
const match = token.match(/proxy-ep=([^;]+)/);
|
||||
if (!match) return null;
|
||||
const proxyHost = match[1];
|
||||
// Convert proxy.xxx to api.xxx
|
||||
const apiHost = proxyHost.replace(/^proxy\./, "api.");
|
||||
return `https://${apiHost}`;
|
||||
}
|
||||
|
||||
export function getGitHubCopilotBaseUrl(
|
||||
token?: string,
|
||||
enterpriseDomain?: string,
|
||||
): string {
|
||||
// If we have a token, extract the base URL from proxy-ep
|
||||
if (token) {
|
||||
const urlFromToken = getBaseUrlFromToken(token);
|
||||
if (urlFromToken) return urlFromToken;
|
||||
}
|
||||
// Fallback for enterprise or if token parsing fails
|
||||
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
|
||||
return "https://api.individual.githubcopilot.com";
|
||||
}
|
||||
|
||||
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
|
||||
const response = await fetch(url, {
|
||||
...init,
|
||||
signal: init.signal ?? AbortSignal.timeout(30_000),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`${response.status} ${response.statusText}: ${text}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
|
||||
const urls = getUrls(domain);
|
||||
const data = await fetchJson(urls.deviceCodeUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
scope: "read:user",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!data || typeof data !== "object") {
|
||||
throw new Error("Invalid device code response");
|
||||
}
|
||||
|
||||
const deviceCode = (data as Record<string, unknown>).device_code;
|
||||
const userCode = (data as Record<string, unknown>).user_code;
|
||||
const verificationUri = (data as Record<string, unknown>).verification_uri;
|
||||
const interval = (data as Record<string, unknown>).interval;
|
||||
const expiresIn = (data as Record<string, unknown>).expires_in;
|
||||
|
||||
if (
|
||||
typeof deviceCode !== "string" ||
|
||||
typeof userCode !== "string" ||
|
||||
typeof verificationUri !== "string" ||
|
||||
typeof interval !== "number" ||
|
||||
typeof expiresIn !== "number"
|
||||
) {
|
||||
throw new Error("Invalid device code response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
device_code: deviceCode,
|
||||
user_code: userCode,
|
||||
verification_uri: verificationUri,
|
||||
interval,
|
||||
expires_in: expiresIn,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep that can be interrupted by an AbortSignal
|
||||
*/
|
||||
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Login cancelled"));
|
||||
return;
|
||||
}
|
||||
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
|
||||
signal?.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Login cancelled"));
|
||||
},
|
||||
{ once: true },
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
async function pollForGitHubAccessToken(
|
||||
domain: string,
|
||||
deviceCode: string,
|
||||
intervalSeconds: number,
|
||||
expiresIn: number,
|
||||
signal?: AbortSignal,
|
||||
) {
|
||||
const urls = getUrls(domain);
|
||||
const deadline = Date.now() + expiresIn * 1000;
|
||||
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
|
||||
|
||||
while (Date.now() < deadline) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const raw = await fetchJson(urls.accessTokenUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
device_code: deviceCode,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}),
|
||||
});
|
||||
|
||||
if (
|
||||
raw &&
|
||||
typeof raw === "object" &&
|
||||
typeof (raw as DeviceTokenSuccessResponse).access_token === "string"
|
||||
) {
|
||||
return (raw as DeviceTokenSuccessResponse).access_token;
|
||||
}
|
||||
|
||||
if (
|
||||
raw &&
|
||||
typeof raw === "object" &&
|
||||
typeof (raw as DeviceTokenErrorResponse).error === "string"
|
||||
) {
|
||||
const err = (raw as DeviceTokenErrorResponse).error;
|
||||
if (err === "authorization_pending") {
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (err === "slow_down") {
|
||||
intervalMs += 5000;
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Error(`Device flow failed: ${err}`);
|
||||
}
|
||||
|
||||
await abortableSleep(intervalMs, signal);
|
||||
}
|
||||
|
||||
throw new Error("Device flow timed out");
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh GitHub Copilot token
|
||||
*/
|
||||
export async function refreshGitHubCopilotToken(
|
||||
refreshToken: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
const urls = getUrls(domain);
|
||||
|
||||
const raw = await fetchJson(urls.copilotTokenUrl, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${refreshToken}`,
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
});
|
||||
|
||||
if (!raw || typeof raw !== "object") {
|
||||
throw new Error("Invalid Copilot token response");
|
||||
}
|
||||
|
||||
const token = (raw as Record<string, unknown>).token;
|
||||
const expiresAt = (raw as Record<string, unknown>).expires_at;
|
||||
|
||||
if (typeof token !== "string" || typeof expiresAt !== "number") {
|
||||
throw new Error("Invalid Copilot token response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: refreshToken,
|
||||
access: token,
|
||||
expires: expiresAt * 1000 - 5 * 60 * 1000,
|
||||
enterpriseUrl: enterpriseDomain,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a model for the user's GitHub Copilot account.
|
||||
* This is required for some models (like Claude, Grok) before they can be used.
|
||||
*/
|
||||
async function enableGitHubCopilotModel(
|
||||
token: string,
|
||||
modelId: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<boolean> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
const url = `${baseUrl}/models/${modelId}/policy`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...COPILOT_HEADERS,
|
||||
"openai-intent": "chat-policy",
|
||||
"x-interaction-type": "chat-policy",
|
||||
},
|
||||
body: JSON.stringify({ state: "enabled" }),
|
||||
signal: AbortSignal.timeout(30_000),
|
||||
});
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable all known GitHub Copilot models that may require policy acceptance.
|
||||
* Called after successful login to ensure all models are available.
|
||||
*/
|
||||
async function enableAllGitHubCopilotModels(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
onProgress?: (model: string, success: boolean) => void,
|
||||
): Promise<void> {
|
||||
const models = getModels("github-copilot");
|
||||
await Promise.all(
|
||||
models.map(async (model) => {
|
||||
const success = await enableGitHubCopilotModel(
|
||||
token,
|
||||
model.id,
|
||||
enterpriseDomain,
|
||||
);
|
||||
onProgress?.(model.id, success);
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
async function fetchCopilotModelLimits(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<Record<string, { contextWindow: number; maxTokens: number }>> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
try {
|
||||
const response = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
"X-GitHub-Api-Version": "2025-05-01",
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
signal: AbortSignal.timeout(30_000),
|
||||
});
|
||||
if (!response.ok) return {};
|
||||
const data = (await response.json()) as {
|
||||
data?: Array<{
|
||||
id: string;
|
||||
capabilities?: {
|
||||
limits?: {
|
||||
max_context_window_tokens?: number;
|
||||
max_output_tokens?: number;
|
||||
};
|
||||
};
|
||||
}>;
|
||||
};
|
||||
const limits: Record<string, { contextWindow: number; maxTokens: number }> =
|
||||
{};
|
||||
for (const m of data.data || []) {
|
||||
const ctx = m.capabilities?.limits?.max_context_window_tokens;
|
||||
const out = m.capabilities?.limits?.max_output_tokens;
|
||||
if (
|
||||
typeof ctx === "number" &&
|
||||
typeof out === "number" &&
|
||||
ctx > 0 &&
|
||||
out > 0 &&
|
||||
Number.isFinite(ctx) &&
|
||||
Number.isFinite(out)
|
||||
) {
|
||||
limits[m.id] = { contextWindow: ctx, maxTokens: out };
|
||||
}
|
||||
}
|
||||
return limits;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with GitHub Copilot OAuth (device code flow)
|
||||
*
|
||||
* @param options.onAuth - Callback with URL and optional instructions (user code)
|
||||
* @param options.onPrompt - Callback to prompt user for input
|
||||
* @param options.onProgress - Optional progress callback
|
||||
* @param options.signal - Optional AbortSignal for cancellation
|
||||
*/
|
||||
export async function loginGitHubCopilot(options: {
|
||||
onAuth: (url: string, instructions?: string) => void;
|
||||
onPrompt: (prompt: {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
}) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const input = await options.onPrompt({
|
||||
message: "GitHub Enterprise URL/domain (blank for github.com)",
|
||||
placeholder: "company.ghe.com",
|
||||
allowEmpty: true,
|
||||
});
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const trimmed = input.trim();
|
||||
const enterpriseDomain = normalizeDomain(input);
|
||||
if (trimmed && !enterpriseDomain) {
|
||||
throw new Error("Invalid GitHub Enterprise URL/domain");
|
||||
}
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
|
||||
const device = await startDeviceFlow(domain);
|
||||
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
|
||||
|
||||
const githubAccessToken = await pollForGitHubAccessToken(
|
||||
domain,
|
||||
device.device_code,
|
||||
device.interval,
|
||||
device.expires_in,
|
||||
options.signal,
|
||||
);
|
||||
const credentials = await refreshGitHubCopilotToken(
|
||||
githubAccessToken,
|
||||
enterpriseDomain ?? undefined,
|
||||
);
|
||||
|
||||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(
|
||||
credentials.access,
|
||||
enterpriseDomain ?? undefined,
|
||||
);
|
||||
|
||||
// Fetch real model limits from the Copilot API
|
||||
options.onProgress?.("Fetching model limits...");
|
||||
const modelLimits = await fetchCopilotModelLimits(
|
||||
credentials.access,
|
||||
enterpriseDomain ?? undefined,
|
||||
);
|
||||
if (Object.keys(modelLimits).length > 0) {
|
||||
(credentials as CopilotCredentials).modelLimits = modelLimits;
|
||||
}
|
||||
|
||||
return credentials;
|
||||
}
|
||||
|
||||
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
||||
id: "github-copilot",
|
||||
name: "GitHub Copilot",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
signal: callbacks.signal,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
const refreshed = await refreshGitHubCopilotToken(
|
||||
creds.refresh,
|
||||
creds.enterpriseUrl,
|
||||
);
|
||||
try {
|
||||
const modelLimits = await fetchCopilotModelLimits(
|
||||
refreshed.access,
|
||||
creds.enterpriseUrl,
|
||||
);
|
||||
if (Object.keys(modelLimits).length > 0) {
|
||||
(refreshed as CopilotCredentials).modelLimits = modelLimits;
|
||||
}
|
||||
} catch {
|
||||
// Model limits fetch is best-effort; don't block token refresh
|
||||
}
|
||||
return refreshed;
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
|
||||
modifyModels(
|
||||
models: Model<Api>[],
|
||||
credentials: OAuthCredentials,
|
||||
): Model<Api>[] {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
const domain = creds.enterpriseUrl
|
||||
? (normalizeDomain(creds.enterpriseUrl) ?? undefined)
|
||||
: undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||
const limits = creds.modelLimits;
|
||||
const availableModelIds = limits ? new Set(Object.keys(limits)) : null;
|
||||
const shouldFilterByAvailability =
|
||||
!!availableModelIds && availableModelIds.size > 0;
|
||||
return models.flatMap((m) => {
|
||||
if (m.provider !== "github-copilot") return m;
|
||||
if (shouldFilterByAvailability && !availableModelIds.has(m.id)) return [];
|
||||
const modelLimits = limits?.[m.id];
|
||||
return {
|
||||
...m,
|
||||
baseUrl,
|
||||
...(modelLimits && {
|
||||
contextWindow: modelLimits.contextWindow,
|
||||
maxTokens: modelLimits.maxTokens,
|
||||
}),
|
||||
};
|
||||
});
|
||||
},
|
||||
};
|
||||
|
|
@ -1,152 +0,0 @@
|
|||
/**
|
||||
* OAuth credential management for AI providers.
|
||||
*
|
||||
* This module handles login, token refresh, and credential storage
|
||||
* for OAuth-based providers:
|
||||
* - GitHub Copilot
|
||||
*
|
||||
* Note: Anthropic OAuth was removed per TOS compliance (see docs/user-docs/claude-code-auth-compliance.md).
|
||||
* Use API keys or the local Claude Code CLI for Anthropic access.
|
||||
*
|
||||
* Note: Google Cloud Code Assist (google-gemini-cli) is not handled here.
|
||||
* The provider delegates to @google/gemini-cli-core, which reads
|
||||
* ~/.gemini/oauth_creds.json when present and owns any login flow it needs.
|
||||
* SF uses cli-core directly and does not spawn a separate provider CLI process.
|
||||
*
|
||||
* Note: OpenAI Codex (ChatGPT) is not handled here via OAuth flows.
|
||||
* The real `codex` CLI writes auth state to ~/.codex/auth.json after login.
|
||||
* We read that file directly — no PKCE, no callback server in our code.
|
||||
* Users authenticate with: codex auth login
|
||||
*/
|
||||
|
||||
// GitHub Copilot
|
||||
export {
|
||||
getGitHubCopilotBaseUrl,
|
||||
githubCopilotOAuthProvider,
|
||||
loginGitHubCopilot,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
} from "./github-copilot.js";
|
||||
// OpenAI Codex — shim provider (login defers to real `codex` CLI)
|
||||
export { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||
|
||||
export * from "./types.js";
|
||||
|
||||
// ============================================================================
|
||||
// Provider Registry
|
||||
// ============================================================================
|
||||
|
||||
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||
import { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||
import type {
|
||||
OAuthCredentials,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInterface,
|
||||
} from "./types.js";
|
||||
|
||||
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
|
||||
githubCopilotOAuthProvider,
|
||||
openaiCodexOAuthProvider,
|
||||
];
|
||||
|
||||
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
|
||||
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
/**
|
||||
* Get an OAuth provider by ID.
|
||||
*
|
||||
* Returns the provider if registered (built-in or custom), otherwise undefined.
|
||||
*/
|
||||
export function getOAuthProvider(
|
||||
id: OAuthProviderId,
|
||||
): OAuthProviderInterface | undefined {
|
||||
return oauthProviderRegistry.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom OAuth provider.
|
||||
*
|
||||
* Custom providers override built-ins with the same ID during the session.
|
||||
* Use `resetOAuthProviders` to restore built-ins.
|
||||
*/
|
||||
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister an OAuth provider.
|
||||
*
|
||||
* If the provider is built-in, restores the built-in implementation.
|
||||
* Custom providers are removed completely.
|
||||
*/
|
||||
export function unregisterOAuthProvider(id: string): void {
|
||||
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find(
|
||||
(provider) => provider.id === id,
|
||||
);
|
||||
if (builtInProvider) {
|
||||
oauthProviderRegistry.set(id, builtInProvider);
|
||||
return;
|
||||
}
|
||||
oauthProviderRegistry.delete(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset OAuth providers to built-ins.
|
||||
*
|
||||
* Clears custom providers and restores only GitHub Copilot and OpenAI Codex.
|
||||
*/
|
||||
export function resetOAuthProviders(): void {
|
||||
oauthProviderRegistry.clear();
|
||||
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers.
|
||||
*
|
||||
* Returns both built-in and custom providers currently in the registry.
|
||||
*/
|
||||
export function getOAuthProviders(): OAuthProviderInterface[] {
|
||||
return Array.from(oauthProviderRegistry.values());
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// High-level API (uses provider registry)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Get API key for a provider from OAuth credentials, refreshing if expired.
|
||||
*
|
||||
* Returns the API key along with updated credentials (if refreshed), or null if no credentials exist.
|
||||
* Throws if the provider is unknown or token refresh fails.
|
||||
*/
|
||||
export async function getOAuthApiKey(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: Record<string, OAuthCredentials>,
|
||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
let creds = credentials[providerId];
|
||||
if (!creds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Refresh if expired
|
||||
if (Date.now() >= creds.expires) {
|
||||
try {
|
||||
creds = await provider.refreshToken(creds);
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to refresh OAuth token for ${providerId}`, {
|
||||
cause: error,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const apiKey = provider.getApiKey(creds);
|
||||
return { newCredentials: creds, apiKey };
|
||||
}
|
||||
|
|
@ -1,238 +0,0 @@
|
|||
/**
|
||||
* OpenAI Codex auth helper — reads ~/.codex/auth.json
|
||||
*
|
||||
* The real `codex` CLI writes its auth state to ~/.codex/auth.json after the
|
||||
* user authenticates. We simply read that file and, if the token is stale,
|
||||
* refresh it against OpenAI's token endpoint.
|
||||
*
|
||||
* No PKCE flow, no callback server, no browser dance in our code.
|
||||
* Users authenticate with the real `codex` CLI; we just consume its output.
|
||||
*
|
||||
* File shape (verified against ~/.codex/auth.json):
|
||||
* {
|
||||
* "auth_mode": "chatgpt" | "apikey", // lowercase
|
||||
* "OPENAI_API_KEY": string | null,
|
||||
* "tokens": {
|
||||
* "id_token": string,
|
||||
* "access_token": string,
|
||||
* "refresh_token": string,
|
||||
* "account_id": string
|
||||
* },
|
||||
* "last_refresh": string // ISO timestamp
|
||||
* }
|
||||
*/
|
||||
|
||||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _os: typeof import("node:os") | null = null;
|
||||
let _fs: typeof import("node:fs") | null = null;
|
||||
if (
|
||||
typeof process !== "undefined" &&
|
||||
(process.versions?.node || process.versions?.bun)
|
||||
) {
|
||||
import("node:os").then((m) => {
|
||||
_os = m;
|
||||
});
|
||||
import("node:fs").then((m) => {
|
||||
_fs = m;
|
||||
});
|
||||
}
|
||||
|
||||
import type { OAuthCredentials, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const TOKEN_URL = "https://auth.openai.com/oauth/token";
|
||||
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
|
||||
// Refresh threshold: 1 hour (conservative; the real codex CLI uses a similar window)
|
||||
const REFRESH_THRESHOLD_MS = 60 * 60 * 1000;
|
||||
|
||||
// ============================================================================
|
||||
// ~/.codex/auth.json types
|
||||
// ============================================================================
|
||||
|
||||
interface CodexAuthFile {
|
||||
auth_mode?: string;
|
||||
OPENAI_API_KEY?: string | null;
|
||||
tokens?: {
|
||||
id_token?: string;
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
account_id?: string;
|
||||
};
|
||||
last_refresh?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// File reader
|
||||
// ============================================================================
|
||||
|
||||
function getCodexAuthPath(): string {
|
||||
if (!_os) throw new Error("node:os not available");
|
||||
return `${_os.homedir()}/.codex/auth.json`;
|
||||
}
|
||||
|
||||
function readCodexAuthFile(): CodexAuthFile {
|
||||
if (!_fs) {
|
||||
throw new Error(
|
||||
"OpenAI Codex auth is only available in Node.js environments",
|
||||
);
|
||||
}
|
||||
const authPath = getCodexAuthPath();
|
||||
if (!_fs.existsSync(authPath)) {
|
||||
throw new Error(
|
||||
`~/.codex/auth.json not found.\n\n` +
|
||||
`Authenticate with the real \`codex\` CLI first:\n` +
|
||||
` codex auth login\n\n` +
|
||||
`Then re-run your command.`,
|
||||
);
|
||||
}
|
||||
const raw = _fs.readFileSync(authPath, "utf-8");
|
||||
return JSON.parse(raw) as CodexAuthFile;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Token refresh
|
||||
// ============================================================================
|
||||
|
||||
async function refreshCodexToken(
|
||||
refreshToken: string,
|
||||
): Promise<{ access_token: string; refresh_token: string }> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "refresh_token",
|
||||
refresh_token: refreshToken,
|
||||
client_id: CLIENT_ID,
|
||||
}),
|
||||
signal: AbortSignal.timeout(30_000),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
throw new Error(
|
||||
`[openai-codex] Token refresh failed: ${response.status} ${text}`,
|
||||
);
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
if (!json.access_token || !json.refresh_token) {
|
||||
throw new Error(
|
||||
"[openai-codex] Token refresh response missing access_token or refresh_token",
|
||||
);
|
||||
}
|
||||
|
||||
return { access_token: json.access_token, refresh_token: json.refresh_token };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Read ~/.codex/auth.json and return the active access token.
|
||||
*
|
||||
* - For auth_mode "apikey": returns OPENAI_API_KEY directly.
|
||||
* - For auth_mode "chatgpt": returns tokens.access_token, refreshing first
|
||||
* if last_refresh is more than REFRESH_THRESHOLD_MS ago.
|
||||
*
|
||||
* Throws a clear error if the file is missing or malformed.
|
||||
*/
|
||||
export async function getCodexAccessToken(): Promise<string> {
|
||||
const auth = readCodexAuthFile();
|
||||
const mode = (auth.auth_mode ?? "").toLowerCase();
|
||||
|
||||
if (mode === "apikey") {
|
||||
const key = auth.OPENAI_API_KEY;
|
||||
if (!key) {
|
||||
throw new Error(
|
||||
`~/.codex/auth.json has auth_mode "apikey" but OPENAI_API_KEY is empty.\n` +
|
||||
`Re-authenticate with: codex auth login`,
|
||||
);
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
// Default: ChatGPT OAuth
|
||||
const tokens = auth.tokens;
|
||||
if (!tokens?.access_token || !tokens?.refresh_token) {
|
||||
throw new Error(
|
||||
`~/.codex/auth.json is missing OAuth tokens.\n` +
|
||||
`Re-authenticate with: codex auth login`,
|
||||
);
|
||||
}
|
||||
|
||||
// Refresh if stale
|
||||
const lastRefresh = auth.last_refresh
|
||||
? new Date(auth.last_refresh).getTime()
|
||||
: 0;
|
||||
const isStale =
|
||||
!lastRefresh || Date.now() - lastRefresh > REFRESH_THRESHOLD_MS;
|
||||
|
||||
if (isStale) {
|
||||
try {
|
||||
const refreshed = await refreshCodexToken(tokens.refresh_token);
|
||||
return refreshed.access_token;
|
||||
} catch (err) {
|
||||
// If refresh fails, fall back to the stored access_token (it may still work)
|
||||
console.warn(
|
||||
`[openai-codex] Token refresh failed, using stored token: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return tokens.access_token;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read account_id from ~/.codex/auth.json (required as a request header).
|
||||
* Falls back to extracting from the access_token JWT payload if not stored.
|
||||
*/
|
||||
export function getCodexAccountId(): string {
|
||||
const auth = readCodexAuthFile();
|
||||
const accountId = auth.tokens?.account_id;
|
||||
if (accountId) return accountId;
|
||||
throw new Error(
|
||||
`~/.codex/auth.json is missing tokens.account_id.\n` +
|
||||
`Re-authenticate with: codex auth login`,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuthProviderInterface shim — kept so oauth/index.ts registry and
|
||||
* auth-storage OAuth refresh flow continue to compile.
|
||||
*
|
||||
* login() is a no-op: users authenticate with the real `codex` CLI.
|
||||
* getApiKey() reads the access token from ~/.codex/auth.json.
|
||||
* refreshToken() is a no-op: getCodexAccessToken() refreshes inline.
|
||||
*/
|
||||
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
|
||||
id: "openai-codex",
|
||||
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
||||
usesCallbackServer: false,
|
||||
|
||||
async login(): Promise<OAuthCredentials> {
|
||||
throw new Error(
|
||||
`OpenAI Codex login is handled by the real \`codex\` CLI.\n` +
|
||||
`Run: codex auth login\n\n` +
|
||||
`Then use pi normally — it reads ~/.codex/auth.json automatically.`,
|
||||
);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
// Inline refresh via getCodexAccessToken(); return same shape
|
||||
return credentials;
|
||||
},
|
||||
|
||||
getApiKey(_credentials: OAuthCredentials): string {
|
||||
// Synchronous fallback — callers that need the real token should await
|
||||
// getCodexAccessToken() directly. This path is used for legacy callers.
|
||||
const auth = readCodexAuthFile();
|
||||
const mode = (auth.auth_mode ?? "").toLowerCase();
|
||||
if (mode === "apikey") return auth.OPENAI_API_KEY ?? "";
|
||||
return auth.tokens?.access_token ?? "";
|
||||
},
|
||||
};
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* PKCE utilities using Web Crypto API.
|
||||
* Works in both Node.js 20+ and browsers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Encode bytes as base64url string.
|
||||
*/
|
||||
function base64urlEncode(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
for (const byte of bytes) {
|
||||
binary += String.fromCharCode(byte);
|
||||
}
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge.
|
||||
* Uses Web Crypto API for cross-platform compatibility.
|
||||
*/
|
||||
export async function generatePKCE(): Promise<{
|
||||
verifier: string;
|
||||
challenge: string;
|
||||
}> {
|
||||
// Generate random verifier
|
||||
const verifierBytes = new Uint8Array(32);
|
||||
crypto.getRandomValues(verifierBytes);
|
||||
const verifier = base64urlEncode(verifierBytes);
|
||||
|
||||
// Compute SHA-256 challenge
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
import type { Api, Model } from "../../types.js";
|
||||
|
||||
export type OAuthCredentials = {
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
export type OAuthProviderId = string;
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthLoginCallbacks {
|
||||
onAuth: (info: OAuthAuthInfo) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OAuthProviderInterface {
|
||||
readonly id: OAuthProviderId;
|
||||
readonly name: string;
|
||||
|
||||
/** Run the login flow, return credentials to persist */
|
||||
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||
|
||||
/** Whether login uses a local callback server and supports manual code input. */
|
||||
usesCallbackServer?: boolean;
|
||||
|
||||
/** Refresh expired credentials, return updated credentials to persist */
|
||||
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||
|
||||
/** Convert credentials to API key string for the provider */
|
||||
getApiKey(credentials: OAuthCredentials): string;
|
||||
|
||||
/** Optional: modify models for this provider (e.g., update baseUrl) */
|
||||
modifyModels?(
|
||||
models: Model<Api>[],
|
||||
credentials: OAuthCredentials,
|
||||
): Model<Api>[];
|
||||
}
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
import type { AssistantMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Regex patterns to detect context overflow errors from different providers.
|
||||
*
|
||||
* These patterns match error messages returned when the input exceeds
|
||||
* the model's context window.
|
||||
*
|
||||
* Provider-specific patterns (with example error messages):
|
||||
*
|
||||
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
|
||||
* - OpenAI: "Your input exceeds the context window of this model"
|
||||
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
|
||||
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
|
||||
* - Groq: "Please reduce the length of the messages or completion"
|
||||
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
|
||||
* - llama.cpp: "the request exceeds the available context size, try increasing it"
|
||||
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
|
||||
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
|
||||
* - MiniMax: "invalid params, context window exceeds limit"
|
||||
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
|
||||
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
|
||||
* - Ollama: Silently truncates input - not detectable via error message
|
||||
*/
|
||||
const OVERFLOW_PATTERNS = [
|
||||
/prompt is too long/i, // Anthropic
|
||||
/input is too long for requested model/i, // Amazon Bedrock
|
||||
/exceeds the context window/i, // OpenAI (Completions & Responses API)
|
||||
/input token count.*exceeds the maximum/i, // Google (Gemini)
|
||||
/maximum prompt length is \d+/i, // xAI (Grok)
|
||||
/reduce the length of the messages/i, // Groq
|
||||
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
|
||||
/exceeds the limit of \d+/i, // GitHub Copilot
|
||||
/exceeds the available context size/i, // llama.cpp server
|
||||
/greater than the context length/i, // LM Studio
|
||||
/context window exceeds limit/i, // MiniMax
|
||||
/exceeded model token limit/i, // Kimi For Coding
|
||||
/too large for model with \d+ maximum context length/i, // Mistral
|
||||
/model_context_window_exceeded/i, // z.ai non-standard finish_reason surfaced as error text
|
||||
/context[_ ]length[_ ]exceeded/i, // Generic fallback
|
||||
/too many tokens/i, // Generic fallback
|
||||
/token limit exceeded/i, // Generic fallback
|
||||
];
|
||||
|
||||
/**
|
||||
* Check if an assistant message represents a context overflow error.
|
||||
*
|
||||
* This handles two cases:
|
||||
* 1. Error-based overflow: Most providers return stopReason "error" with a
|
||||
* specific error message pattern.
|
||||
* 2. Silent overflow: Some providers accept overflow requests and return
|
||||
* successfully. For these, we check if usage.input exceeds the context window.
|
||||
*
|
||||
* ## Reliability by Provider
|
||||
*
|
||||
* **Reliable detection (returns error with detectable message):**
|
||||
* - Anthropic: "prompt is too long: X tokens > Y maximum"
|
||||
* - OpenAI (Completions & Responses): "exceeds the context window"
|
||||
* - Google Gemini: "input token count exceeds the maximum"
|
||||
* - xAI (Grok): "maximum prompt length is X but request contains Y"
|
||||
* - Groq: "reduce the length of the messages"
|
||||
* - Cerebras: 400/413 status code (no body)
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - OpenRouter (all backends): "maximum context length is X tokens"
|
||||
* - llama.cpp: "exceeds the available context size"
|
||||
* - LM Studio: "greater than the context length"
|
||||
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
|
||||
*
|
||||
* **Unreliable detection:**
|
||||
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
|
||||
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
|
||||
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
|
||||
* The response will have usage.input < expected, but we don't know the expected value.
|
||||
*
|
||||
* ## Custom Providers
|
||||
*
|
||||
* If you've added custom models via settings.json, this function may not detect
|
||||
* overflow errors from those providers. To add support:
|
||||
*
|
||||
* 1. Send a request that exceeds the model's context window
|
||||
* 2. Check the errorMessage in the response
|
||||
* 3. Create a regex pattern that matches the error
|
||||
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
|
||||
* check the errorMessage yourself before calling this function
|
||||
*
|
||||
* @param message - The assistant message to check
|
||||
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
|
||||
* @returns true if the message indicates a context overflow
|
||||
*/
|
||||
export function isContextOverflow(
|
||||
message: AssistantMessage,
|
||||
contextWindow?: number,
|
||||
): boolean {
|
||||
// Case 1: Check error message patterns
|
||||
if (message.stopReason === "error" && message.errorMessage) {
|
||||
// Check known patterns
|
||||
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cerebras returns 400/413 with no body for context overflow
|
||||
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
|
||||
if (
|
||||
/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers surface overflow as assistant text while putting a generic
|
||||
// classifier value in errorMessage (e.g. claude-code: errorMessage="success",
|
||||
// text="Prompt is too long"). Check rendered text as a fallback.
|
||||
if (message.stopReason === "error") {
|
||||
const assistantText = message.content
|
||||
.filter((block) => block.type === "text")
|
||||
.map((block) => block.text)
|
||||
.join("\n");
|
||||
if (assistantText && OVERFLOW_PATTERNS.some((p) => p.test(assistantText))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
|
||||
if (contextWindow && message.stopReason === "stop") {
|
||||
const inputTokens = message.usage.input + message.usage.cacheRead;
|
||||
if (inputTokens > contextWindow) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1,340 +0,0 @@
|
|||
/**
|
||||
* Repair malformed JSON in LLM tool-call arguments.
|
||||
*
|
||||
* LLMs sometimes copy YAML template formatting into JSON tool arguments,
|
||||
* producing patterns like:
|
||||
*
|
||||
* "keyDecisions": - Used Web Notification API...,
|
||||
* "keyFiles": - src-tauri/src/lib.rs — Extended...
|
||||
*
|
||||
* instead of valid JSON arrays:
|
||||
*
|
||||
* "keyDecisions": ["Used Web Notification API..."],
|
||||
* "keyFiles": ["src-tauri/src/lib.rs — Extended..."]
|
||||
*
|
||||
* This module detects and repairs such patterns before JSON.parse is called.
|
||||
*
|
||||
* @see https://github.com/singularity-forge/sf-run/issues/2660
|
||||
*/
|
||||
|
||||
import { jsonrepair } from "jsonrepair";
|
||||
import { parse as parseYaml } from "yaml";
|
||||
|
||||
export const TOOL_JSON_REPAIR_PIPELINE_VERSION = 1;
|
||||
|
||||
export interface ToolJsonRepairReport {
|
||||
version: number;
|
||||
input: string;
|
||||
output: string;
|
||||
changed: boolean;
|
||||
repairs: string[];
|
||||
parseable: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect whether a JSON string contains YAML-style bullet-list values
|
||||
* (i.e. `"key": - item` instead of `"key": ["item"]`).
|
||||
*/
|
||||
export function hasYamlBulletLists(json: string): boolean {
|
||||
// Match: "key": followed by whitespace then a dash-space pattern (YAML bullet)
|
||||
// The negative lookahead excludes negative numbers (e.g. "key": -1)
|
||||
return /"\s*:\s*-\s+(?!\d)/.test(json);
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect whether a JSON string contains XML parameter tags
|
||||
* (i.e. `<parameter name="X">value</parameter>`).
|
||||
*
|
||||
* Some models mix XML tool-call syntax into JSON string values,
|
||||
* producing hybrid output that fails JSON.parse.
|
||||
*
|
||||
* @see https://github.com/singularity-forge/sf-run/issues/3403
|
||||
*/
|
||||
export function hasXmlParameterTags(json: string): boolean {
|
||||
return /<\/?parameter[\s>]/.test(json);
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect whether a JSON string contains truncated numeric values
|
||||
* (e.g. `"exitCode": -,` or `"durationMs": ,`).
|
||||
*
|
||||
* Smaller models sometimes emit incomplete numbers when the value
|
||||
* is cut off mid-generation.
|
||||
*
|
||||
* @see https://github.com/singularity-forge/sf-run/issues/3464
|
||||
*/
|
||||
export function hasTruncatedNumbers(json: string): boolean {
|
||||
// Match: colon, optional whitespace, then a comma or } without a value
|
||||
// Or: colon, optional whitespace, bare minus sign followed by comma/}
|
||||
return /:\s*,/.test(json) || /:\s*-\s*[,}]/.test(json);
|
||||
}
|
||||
|
||||
type XmlParameterBlock = {
|
||||
name: string;
|
||||
value: unknown;
|
||||
};
|
||||
|
||||
const xmlParameterBlockPattern =
|
||||
/<parameter\s+name="([^"]+)"\s*>([\s\S]*?)<\/parameter>/g;
|
||||
|
||||
function parseXmlParameterValue(raw: string): unknown {
|
||||
const trimmed = raw.trim();
|
||||
if (trimmed === "") return "";
|
||||
try {
|
||||
return JSON.parse(trimmed);
|
||||
} catch {
|
||||
return trimmed;
|
||||
}
|
||||
}
|
||||
|
||||
function extractXmlParameterBlocks(text: string): XmlParameterBlock[] {
|
||||
const blocks: XmlParameterBlock[] = [];
|
||||
for (const match of text.matchAll(xmlParameterBlockPattern)) {
|
||||
blocks.push({
|
||||
name: match[1],
|
||||
value: parseXmlParameterValue(match[2] ?? ""),
|
||||
});
|
||||
}
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function trimLeakedXmlTail(fieldName: string, value: string): string {
|
||||
let cut = value.length;
|
||||
const parameterIndex = value.indexOf("<parameter");
|
||||
if (parameterIndex >= 0) cut = Math.min(cut, parameterIndex);
|
||||
|
||||
const closingTagIndex = value.indexOf(`</${fieldName}>`);
|
||||
if (closingTagIndex >= 0) cut = Math.min(cut, closingTagIndex);
|
||||
|
||||
return value.slice(0, cut).trimEnd();
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip XML `<parameter>` tags from a JSON string, leaving only the
|
||||
* text content. This handles the case where the LLM mixes XML
|
||||
* tool-call format into JSON string values.
|
||||
*/
|
||||
function stripXmlParameterTags(json: string): string {
|
||||
// Remove opening tags: <parameter name="X">
|
||||
let cleaned = json.replace(/<parameter\s+name="[^"]*"\s*>/g, "");
|
||||
// Remove closing tags: </parameter>
|
||||
cleaned = cleaned.replace(/<\/parameter>/g, "");
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
function promoteXmlParametersToTopLevel(json: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(json) as Record<string, unknown>;
|
||||
if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) {
|
||||
return stripXmlParameterTags(json);
|
||||
}
|
||||
|
||||
let changed = false;
|
||||
for (const [fieldName, value] of Object.entries(parsed)) {
|
||||
if (typeof value !== "string" || !hasXmlParameterTags(value)) continue;
|
||||
|
||||
const blocks = extractXmlParameterBlocks(value);
|
||||
if (blocks.length === 0) continue;
|
||||
|
||||
parsed[fieldName] = trimLeakedXmlTail(fieldName, value);
|
||||
for (const block of blocks) {
|
||||
if (!(block.name in parsed)) {
|
||||
parsed[block.name] = block.value;
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed ? JSON.stringify(parsed) : stripXmlParameterTags(json);
|
||||
} catch {
|
||||
return stripXmlParameterTags(json);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace truncated numeric values with 0.
|
||||
* Handles: `"key": ,` → `"key": 0,` and `"key": -,` → `"key": 0,`
|
||||
*/
|
||||
function repairTruncatedNumbers(json: string): string {
|
||||
// Bare comma after colon (missing value entirely)
|
||||
let repaired = json.replace(/:\s*,/g, ": 0,");
|
||||
// Bare minus sign followed by comma or closing brace
|
||||
repaired = repaired.replace(/:\s*-\s*([,}])/g, ": 0$1");
|
||||
return repaired;
|
||||
}
|
||||
|
||||
function isParseableJson(json: string): boolean {
|
||||
try {
|
||||
JSON.parse(json);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function repairWithJsonRepair(json: string): string {
|
||||
try {
|
||||
const repaired = jsonrepair(json);
|
||||
return isParseableJson(repaired) ? repaired : json;
|
||||
} catch {
|
||||
return json;
|
||||
}
|
||||
}
|
||||
|
||||
function repairWithYaml(json: string): string {
|
||||
try {
|
||||
const parsed = parseYaml(json);
|
||||
if (
|
||||
parsed === null ||
|
||||
typeof parsed !== "object" ||
|
||||
parsed instanceof Date
|
||||
) {
|
||||
return json;
|
||||
}
|
||||
const repaired = JSON.stringify(parsed);
|
||||
return isParseableJson(repaired) ? repaired : json;
|
||||
} catch {
|
||||
return json;
|
||||
}
|
||||
}
|
||||
|
||||
function applyGenericRepairs(json: string): {
|
||||
output: string;
|
||||
repairs: string[];
|
||||
} {
|
||||
if (isParseableJson(json)) return { output: json, repairs: [] };
|
||||
|
||||
if (looksLikeYamlObject(json)) {
|
||||
const yamlRepaired = repairWithYaml(json);
|
||||
if (yamlRepaired !== json) {
|
||||
return { output: yamlRepaired, repairs: ["yaml"] };
|
||||
}
|
||||
}
|
||||
|
||||
const jsonRepaired = repairWithJsonRepair(json);
|
||||
if (jsonRepaired !== json) {
|
||||
return { output: jsonRepaired, repairs: ["jsonrepair"] };
|
||||
}
|
||||
|
||||
const yamlRepaired = repairWithYaml(json);
|
||||
if (yamlRepaired !== json) {
|
||||
return { output: yamlRepaired, repairs: ["yaml"] };
|
||||
}
|
||||
|
||||
return { output: json, repairs: [] };
|
||||
}
|
||||
|
||||
function looksLikeYamlObject(input: string): boolean {
|
||||
const trimmed = input.trim();
|
||||
return (
|
||||
/^[A-Za-z_][A-Za-z0-9_-]*\s*:/m.test(trimmed) && trimmed.includes("\n")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to repair malformed JSON in LLM tool-call arguments.
|
||||
*
|
||||
* Handles three categories of malformation:
|
||||
*
|
||||
* 1. **YAML bullet lists** (#2660): `"key": - item1\n - item2` → `"key": ["item1", "item2"]`
|
||||
* 2. **XML parameter tags** (#3403): `<parameter name="X">value</parameter>` → stripped to content
|
||||
* 3. **Truncated numbers** (#3464): `"exitCode": -,` → `"exitCode": 0,`
|
||||
*
|
||||
* Returns the original string unchanged if no patterns are detected
|
||||
* or if the repair itself would produce invalid JSON.
|
||||
*/
|
||||
export function repairToolJsonWithReport(json: string): ToolJsonRepairReport {
|
||||
let repaired = json;
|
||||
const repairs: string[] = [];
|
||||
|
||||
// Phase 1: Strip XML parameter tags
|
||||
if (hasXmlParameterTags(repaired)) {
|
||||
repaired = promoteXmlParametersToTopLevel(repaired);
|
||||
repairs.push("xml-parameter-tags");
|
||||
}
|
||||
|
||||
// Phase 2: Repair truncated numbers
|
||||
if (hasTruncatedNumbers(repaired)) {
|
||||
repaired = repairTruncatedNumbers(repaired);
|
||||
repairs.push("truncated-numbers");
|
||||
}
|
||||
|
||||
// Phase 3: Repair YAML bullet lists
|
||||
if (!hasYamlBulletLists(repaired)) {
|
||||
const generic = applyGenericRepairs(repaired);
|
||||
repairs.push(...generic.repairs);
|
||||
const output = generic.output;
|
||||
return {
|
||||
version: TOOL_JSON_REPAIR_PIPELINE_VERSION,
|
||||
input: json,
|
||||
output,
|
||||
changed: output !== json,
|
||||
repairs,
|
||||
parseable: isParseableJson(output),
|
||||
};
|
||||
}
|
||||
repairs.push("yaml-bullet-lists");
|
||||
|
||||
// Strategy: find each `"key": - item1\n - item2\n - item3` region and
|
||||
// wrap items in a JSON array.
|
||||
//
|
||||
// We work on the raw string because the JSON is not parseable yet.
|
||||
// The pattern we target:
|
||||
// "someKey":\s*- item text (possibly multiline)
|
||||
// optionally followed by more `- item` lines
|
||||
// terminated by the next `"key":` or `}` or end of string.
|
||||
|
||||
// Match a key followed by YAML-style bullet list.
|
||||
// Capture: (1) the key portion including colon, (2) the bullet-list body,
|
||||
// (3) the separator (comma or empty) before the next key/bracket.
|
||||
// The bullet list body ends at the next `"key":` or `}` or `]` or end of string.
|
||||
const keyBulletPattern =
|
||||
/("(?:[^"\\]|\\.)*"\s*:\s*)(- .+?)(,?\s*)(?="(?:[^"\\]|\\.)*"\s*:|[}\]]|$)/gs;
|
||||
|
||||
repaired = repaired.replace(
|
||||
keyBulletPattern,
|
||||
(_match, keyPart: string, bulletBody: string, separator: string) => {
|
||||
// Split the bullet body into individual items on `- ` boundaries.
|
||||
// Items may contain embedded newlines for multi-line values.
|
||||
const items = bulletBody
|
||||
.split(/\n?\s*- /)
|
||||
.filter((s) => s.trim().length > 0)
|
||||
.map((s) => s.replace(/,\s*$/, "").trim());
|
||||
|
||||
// JSON-encode each item as a string, then wrap in an array.
|
||||
const jsonArray =
|
||||
"[" + items.map((item) => JSON.stringify(item)).join(", ") + "]";
|
||||
|
||||
// Re-emit the separator (comma) so the next key is properly delimited
|
||||
const sep = separator.trim()
|
||||
? separator
|
||||
: /^\s*"/.test(separator + "x")
|
||||
? ", "
|
||||
: "";
|
||||
return keyPart + jsonArray + sep;
|
||||
},
|
||||
);
|
||||
|
||||
// Strip trailing commas before } or ] (common in repaired JSON)
|
||||
repaired = repaired.replace(/,(\s*[}\]])/g, "$1");
|
||||
|
||||
// Final phase: general-purpose repair for common JSON-ish model output:
|
||||
// unquoted keys, single quotes, trailing commas, missing quotes, etc.
|
||||
// This runs after SF-specific repairs so battle-tested generic repair
|
||||
// handles broad syntax cleanup without weakening known field semantics.
|
||||
const generic = applyGenericRepairs(repaired);
|
||||
repairs.push(...generic.repairs);
|
||||
const output = generic.output;
|
||||
return {
|
||||
version: TOOL_JSON_REPAIR_PIPELINE_VERSION,
|
||||
input: json,
|
||||
output,
|
||||
changed: output !== json,
|
||||
repairs,
|
||||
parseable: isParseableJson(output),
|
||||
};
|
||||
}
|
||||
|
||||
export function repairToolJson(json: string): string {
|
||||
return repairToolJsonWithReport(json).output;
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
/**
|
||||
* Removes unpaired Unicode surrogate characters from a string.
|
||||
*
|
||||
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
|
||||
* or vice versa) cause JSON serialization errors in many API providers.
|
||||
*
|
||||
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
|
||||
* surrogates and will NOT be affected by this function.
|
||||
*
|
||||
* @param text - The text to sanitize
|
||||
* @returns The sanitized text with unpaired surrogates removed
|
||||
*
|
||||
* @example
|
||||
* // Valid emoji (properly paired surrogates) are preserved
|
||||
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
|
||||
*
|
||||
* // Unpaired high surrogate is removed
|
||||
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
|
||||
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
|
||||
*/
|
||||
export function sanitizeSurrogates(text: string): string {
|
||||
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
|
||||
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
|
||||
return text.replace(
|
||||
/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g,
|
||||
"",
|
||||
);
|
||||
}
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, test } from "vitest";
|
||||
import { parseStreamingJson } from "../json-parse.js";
|
||||
|
||||
describe("parseStreamingJson — XML parameter recovery (#3751)", () => {
|
||||
test("promotes XML parameters trapped inside valid JSON string values", () => {
|
||||
const malformed =
|
||||
'{"narrative":"text.</narrative>\\n<parameter name=\\"verification\\">all tests pass</parameter>\\n<parameter name=\\"verificationEvidence\\">[\\"npm test\\"]</parameter>","oneLiner":"done"}';
|
||||
|
||||
const parsed = parseStreamingJson<Record<string, unknown>>(malformed);
|
||||
|
||||
assert.equal(parsed.narrative, "text.");
|
||||
assert.equal(parsed.verification, "all tests pass");
|
||||
assert.deepEqual(parsed.verificationEvidence, ["npm test"]);
|
||||
assert.equal(parsed.oneLiner, "done");
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseStreamingJson — generic malformed tool argument recovery", () => {
|
||||
test("repairs complete JSON-ish objects with unquoted keys", () => {
|
||||
const parsed = parseStreamingJson<Record<string, unknown>>(
|
||||
"{title: 'Done', verificationPassed: true,}",
|
||||
);
|
||||
|
||||
assert.equal(parsed.title, "Done");
|
||||
assert.equal(parsed.verificationPassed, true);
|
||||
});
|
||||
|
||||
test("repairs full YAML-shaped object arguments", () => {
|
||||
const parsed = parseStreamingJson<Record<string, unknown>>(
|
||||
"title: Done\nverificationPassed: true\n",
|
||||
);
|
||||
|
||||
assert.equal(parsed.title, "Done");
|
||||
assert.equal(parsed.verificationPassed, true);
|
||||
});
|
||||
|
||||
test("does not repair incomplete streaming chunks into fabricated values", () => {
|
||||
const parsed = parseStreamingJson<Record<string, unknown>>('{"title":');
|
||||
|
||||
assert.deepEqual(parsed, {});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, test } from "vitest";
|
||||
import type { AssistantMessage } from "../../types.js";
|
||||
import { isContextOverflow } from "../overflow.js";
|
||||
|
||||
function makeAssistantMessage(
|
||||
overrides: Partial<AssistantMessage> = {},
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-sonnet-4-6",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
timestamp: Date.now(),
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("isContextOverflow", () => {
|
||||
test("detects overflow from provider errorMessage", () => {
|
||||
const message = makeAssistantMessage({
|
||||
errorMessage: "prompt is too long: 213462 tokens > 200000 maximum",
|
||||
});
|
||||
|
||||
assert.equal(isContextOverflow(message, 200000), true);
|
||||
});
|
||||
|
||||
test("detects claude-code overflow when text contains the error but errorMessage is generic (#3925)", () => {
|
||||
const message = makeAssistantMessage({
|
||||
provider: "claude-code",
|
||||
api: "anthropic-messages",
|
||||
model: "claude-sonnet-4-6",
|
||||
errorMessage: "success",
|
||||
content: [{ type: "text", text: "Prompt is too long" }],
|
||||
});
|
||||
|
||||
assert.equal(isContextOverflow(message, 200000), true);
|
||||
});
|
||||
|
||||
test("does not treat normal non-error text as overflow", () => {
|
||||
const message = makeAssistantMessage({
|
||||
stopReason: "stop",
|
||||
errorMessage: undefined,
|
||||
content: [{ type: "text", text: "Prompt is too long" }],
|
||||
});
|
||||
|
||||
assert.equal(isContextOverflow(message, 200000), false);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,291 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, test } from "vitest";
|
||||
import {
|
||||
hasTruncatedNumbers,
|
||||
hasXmlParameterTags,
|
||||
hasYamlBulletLists,
|
||||
repairToolJson,
|
||||
repairToolJsonWithReport,
|
||||
TOOL_JSON_REPAIR_PIPELINE_VERSION,
|
||||
} from "../repair-tool-json.js";
|
||||
|
||||
describe("repairToolJson — YAML bullet list repair (#2660)", () => {
|
||||
// ── Detection ──────────────────────────────────────────────────────────
|
||||
|
||||
test("hasYamlBulletLists detects YAML-style bullets", () => {
|
||||
assert.equal(
|
||||
hasYamlBulletLists('"keyDecisions": - Used Web Notification API'),
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
test("hasYamlBulletLists ignores negative numbers", () => {
|
||||
assert.equal(
|
||||
hasYamlBulletLists('"offset": -1'),
|
||||
false,
|
||||
"negative number should not be detected as YAML bullet",
|
||||
);
|
||||
});
|
||||
|
||||
test("hasYamlBulletLists returns false for valid JSON", () => {
|
||||
assert.equal(
|
||||
hasYamlBulletLists('{"keyDecisions": ["item1", "item2"]}'),
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
// ── Single bullet item ────────────────────────────────────────────────
|
||||
|
||||
test("repairs single YAML bullet to JSON array", () => {
|
||||
const malformed = '{"keyDecisions": - Used Web Notification API}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.deepEqual(parsed.keyDecisions, ["Used Web Notification API"]);
|
||||
});
|
||||
|
||||
// ── Multiple bullet items (newline-separated) ─────────────────────────
|
||||
|
||||
test("repairs multiple YAML bullets separated by newlines", () => {
|
||||
const malformed =
|
||||
'{"keyDecisions": - Used Web Notification API\n - Chose Tauri over Electron\n - Adopted SQLite for storage, "title": "M005"}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.deepEqual(parsed.keyDecisions, [
|
||||
"Used Web Notification API",
|
||||
"Chose Tauri over Electron",
|
||||
"Adopted SQLite for storage",
|
||||
]);
|
||||
assert.equal(parsed.title, "M005");
|
||||
});
|
||||
|
||||
// ── Multiple fields with YAML bullets ─────────────────────────────────
|
||||
|
||||
test("repairs multiple fields each with YAML bullet lists", () => {
|
||||
const malformed =
|
||||
'{"keyDecisions": - decision one\n - decision two, "keyFiles": - src/lib.rs — Extended menu\n - src/main.ts — Entry point, "title": "done"}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.deepEqual(parsed.keyDecisions, ["decision one", "decision two"]);
|
||||
assert.deepEqual(parsed.keyFiles, [
|
||||
"src/lib.rs \u2014 Extended menu",
|
||||
"src/main.ts \u2014 Entry point",
|
||||
]);
|
||||
assert.equal(parsed.title, "done");
|
||||
});
|
||||
|
||||
// ── Exact reproduction from issue #2660 ───────────────────────────────
|
||||
|
||||
test("repairs the exact malformed JSON from issue #2660", () => {
|
||||
const malformed = `{"milestoneId": "M005", "title": "Native Desktop Polish", "oneLiner": "summary", "narrative": "details", "successCriteriaResults": "all pass", "definitionOfDoneResults": "all done", "requirementOutcomes": "met", "keyDecisions": - Used Web Notification API (new window.Notification()) instead of Tauri sendNotification wrapper, "keyFiles": - src-tauri/src/lib.rs \u2014 Extended menu builder with notification toggle, "lessonsLearned": - Always test notification permissions before sending, "followUps": "none", "deviations": "none", "verificationPassed": true}`;
|
||||
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
|
||||
assert.equal(parsed.milestoneId, "M005");
|
||||
assert.equal(parsed.title, "Native Desktop Polish");
|
||||
assert.ok(
|
||||
Array.isArray(parsed.keyDecisions),
|
||||
"keyDecisions should be an array",
|
||||
);
|
||||
assert.ok(parsed.keyDecisions[0].includes("Web Notification API"));
|
||||
assert.ok(Array.isArray(parsed.keyFiles), "keyFiles should be an array");
|
||||
assert.ok(parsed.keyFiles[0].includes("src-tauri/src/lib.rs"));
|
||||
assert.ok(
|
||||
Array.isArray(parsed.lessonsLearned),
|
||||
"lessonsLearned should be an array",
|
||||
);
|
||||
assert.equal(parsed.verificationPassed, true);
|
||||
});
|
||||
|
||||
// ── Passthrough for valid JSON ────────────────────────────────────────
|
||||
|
||||
test("returns valid JSON unchanged", () => {
|
||||
const valid = '{"keyDecisions": ["item1", "item2"], "count": -5}';
|
||||
const result = repairToolJson(valid);
|
||||
assert.equal(result, valid, "valid JSON should be returned unchanged");
|
||||
});
|
||||
|
||||
// ── Negative numbers are preserved ────────────────────────────────────
|
||||
|
||||
test("does not mangle negative numbers", () => {
|
||||
const valid = '{"offset": -1, "limit": -100}';
|
||||
const result = repairToolJson(valid);
|
||||
assert.equal(result, valid);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// General JSON repair via jsonrepair
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("repairToolJson — general JSON repair via jsonrepair", () => {
|
||||
test("repairs unquoted keys and trailing commas", () => {
|
||||
const malformed = "{title: 'Done', count: 2,}";
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
|
||||
assert.deepEqual(parsed, { title: "Done", count: 2 });
|
||||
});
|
||||
|
||||
test("repairs single-quoted strings", () => {
|
||||
const malformed = "{'milestoneId':'M001','title':'Plan'}";
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
|
||||
assert.deepEqual(parsed, { milestoneId: "M001", title: "Plan" });
|
||||
});
|
||||
|
||||
test("returns a versioned repair report with provenance", () => {
|
||||
const report = repairToolJsonWithReport("{title: 'Done', count: 2,}");
|
||||
|
||||
assert.equal(report.version, TOOL_JSON_REPAIR_PIPELINE_VERSION);
|
||||
assert.equal(report.changed, true);
|
||||
assert.equal(report.parseable, true);
|
||||
assert.ok(report.repairs.includes("jsonrepair"));
|
||||
assert.deepEqual(JSON.parse(report.output), { title: "Done", count: 2 });
|
||||
});
|
||||
});
|
||||
|
||||
describe("repairToolJson — full YAML object fallback", () => {
|
||||
test("repairs YAML-shaped tool arguments to JSON", () => {
|
||||
const malformed = [
|
||||
"title: Done",
|
||||
"keyDecisions:",
|
||||
" - Keep semantic model aliases",
|
||||
" - Prefer strict validation",
|
||||
"verificationPassed: true",
|
||||
].join("\n");
|
||||
const report = repairToolJsonWithReport(malformed);
|
||||
const parsed = JSON.parse(report.output);
|
||||
|
||||
assert.ok(report.repairs.includes("yaml"));
|
||||
assert.deepEqual(parsed, {
|
||||
title: "Done",
|
||||
keyDecisions: ["Keep semantic model aliases", "Prefer strict validation"],
|
||||
verificationPassed: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// XML parameter tag repair (#3403)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("repairToolJson — XML parameter tag stripping (#3403)", () => {
|
||||
test("hasXmlParameterTags detects opening tags", () => {
|
||||
assert.equal(
|
||||
hasXmlParameterTags('<parameter name="narrative">some text</parameter>'),
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
test("hasXmlParameterTags returns false for clean JSON", () => {
|
||||
assert.equal(hasXmlParameterTags('{"narrative": "some text"}'), false);
|
||||
});
|
||||
|
||||
test("strips XML parameter tags from JSON values", () => {
|
||||
const malformed =
|
||||
'{"sliceId": "S03", "narrative": <parameter name="narrative">The slice work</parameter>}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
// After stripping tags, the content should be parseable or at least tag-free
|
||||
assert.ok(
|
||||
!repaired.includes("<parameter"),
|
||||
"should not contain <parameter tags",
|
||||
);
|
||||
assert.ok(
|
||||
!repaired.includes("</parameter>"),
|
||||
"should not contain </parameter> tags",
|
||||
);
|
||||
});
|
||||
|
||||
test("handles mixed XML and JSON content", () => {
|
||||
const malformed =
|
||||
'{"oneLiner": "done", "verification": <parameter name="verification">all tests pass</parameter>}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
assert.ok(!repaired.includes("<parameter"), "XML tags should be stripped");
|
||||
assert.ok(
|
||||
repaired.includes("all tests pass"),
|
||||
"content should be preserved",
|
||||
);
|
||||
});
|
||||
|
||||
test("promotes XML parameters trapped inside valid JSON string values", () => {
|
||||
const malformed =
|
||||
'{"narrative":"text.</narrative>\\n<parameter name=\\"verification\\">all tests pass</parameter>\\n<parameter name=\\"verificationEvidence\\">[\\"npm test\\"]</parameter>","oneLiner":"done"}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
|
||||
assert.equal(parsed.narrative, "text.");
|
||||
assert.equal(parsed.verification, "all tests pass");
|
||||
assert.deepEqual(parsed.verificationEvidence, ["npm test"]);
|
||||
assert.equal(parsed.oneLiner, "done");
|
||||
assert.ok(
|
||||
!parsed.narrative.includes("<parameter"),
|
||||
"narrative should not retain leaked XML",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
// Truncated number repair (#3464)
|
||||
// ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
describe("repairToolJson — truncated number repair (#3464)", () => {
|
||||
test("hasTruncatedNumbers detects bare comma after colon", () => {
|
||||
assert.equal(hasTruncatedNumbers('"exitCode": ,'), true);
|
||||
});
|
||||
|
||||
test("hasTruncatedNumbers detects bare minus before comma", () => {
|
||||
assert.equal(hasTruncatedNumbers('"exitCode": -,'), true);
|
||||
});
|
||||
|
||||
test("hasTruncatedNumbers detects bare minus before closing brace", () => {
|
||||
assert.equal(hasTruncatedNumbers('"durationMs": -}'), true);
|
||||
});
|
||||
|
||||
test("hasTruncatedNumbers returns false for valid numbers", () => {
|
||||
assert.equal(
|
||||
hasTruncatedNumbers('"exitCode": 0, "durationMs": 1234'),
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
test("hasTruncatedNumbers returns false for negative numbers", () => {
|
||||
assert.equal(hasTruncatedNumbers('"exitCode": -1, "offset": -100'), false);
|
||||
});
|
||||
|
||||
test("repairs truncated exitCode with bare comma", () => {
|
||||
const malformed =
|
||||
'{"command": "npm test", "exitCode": , "verdict": "pass", "durationMs": 500}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.equal(parsed.exitCode, 0);
|
||||
assert.equal(parsed.durationMs, 500);
|
||||
});
|
||||
|
||||
test("repairs truncated exitCode with bare minus", () => {
|
||||
const malformed =
|
||||
'{"command": "npm test", "exitCode": -, "verdict": "pass", "durationMs": 1234}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.equal(parsed.exitCode, 0);
|
||||
assert.equal(parsed.verdict, "pass");
|
||||
});
|
||||
|
||||
test("repairs truncated durationMs at end of object", () => {
|
||||
const malformed =
|
||||
'{"command": "npm test", "exitCode": 0, "verdict": "pass", "durationMs": -}';
|
||||
const repaired = repairToolJson(malformed);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.equal(parsed.durationMs, 0);
|
||||
assert.equal(parsed.exitCode, 0);
|
||||
});
|
||||
|
||||
test("does not mangle valid negative numbers", () => {
|
||||
const valid = '{"exitCode": -1, "offset": -100}';
|
||||
const repaired = repairToolJson(valid);
|
||||
const parsed = JSON.parse(repaired);
|
||||
assert.equal(parsed.exitCode, -1);
|
||||
assert.equal(parsed.offset, -100);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
import { type TUnsafe, Type } from "@sinclair/typebox";
|
||||
|
||||
/**
|
||||
* Creates a string enum schema compatible with Google's API and other providers
|
||||
* that don't support anyOf/const patterns.
|
||||
*
|
||||
* @example
|
||||
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
|
||||
* description: "The operation to perform"
|
||||
* });
|
||||
*
|
||||
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
|
||||
*/
|
||||
export function StringEnum<T extends readonly string[]>(
|
||||
values: T,
|
||||
options?: { description?: string; default?: T[number] },
|
||||
): TUnsafe<T[number]> {
|
||||
return Type.Unsafe<T[number]>({
|
||||
type: "string",
|
||||
enum: values as any,
|
||||
...(options?.description && { description: options.description }),
|
||||
...(options?.default && { default: options.default }),
|
||||
});
|
||||
}
|
||||
|
|
@ -1,124 +0,0 @@
|
|||
import AjvModule from "ajv";
|
||||
import addFormatsModule from "ajv-formats";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
||||
|
||||
import type { Tool, ToolCall } from "../types.js";
|
||||
|
||||
type JsonSchemaObject = Record<string, unknown>;
|
||||
|
||||
function isRecord(value: unknown): value is JsonSchemaObject {
|
||||
return value !== null && typeof value === "object" && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function isStringArraySchema(schema: unknown): schema is JsonSchemaObject {
|
||||
if (!isRecord(schema) || schema.type !== "array") return false;
|
||||
const items = schema.items;
|
||||
return isRecord(items) && items.type === "string";
|
||||
}
|
||||
|
||||
function coerceSchemaValue(schema: unknown, value: unknown): unknown {
|
||||
if (!isRecord(schema)) return value;
|
||||
if (isStringArraySchema(schema) && typeof value === "string") {
|
||||
return [value];
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
const items = schema.items;
|
||||
if (!isRecord(items)) return value;
|
||||
return value.map((item) => coerceSchemaValue(items, item));
|
||||
}
|
||||
|
||||
if (!isRecord(value)) return value;
|
||||
|
||||
const properties = schema.properties;
|
||||
if (!isRecord(properties)) return value;
|
||||
|
||||
let next: JsonSchemaObject | null = null;
|
||||
for (const [key, propertySchema] of Object.entries(properties)) {
|
||||
if (!Object.hasOwn(value, key)) continue;
|
||||
const coercedValue = coerceSchemaValue(propertySchema, value[key]);
|
||||
if (coercedValue !== value[key]) {
|
||||
next ??= { ...value };
|
||||
next[key] = coercedValue;
|
||||
}
|
||||
}
|
||||
return next ?? value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps bare strings for JSON-schema fields declared as string arrays before AJV validation.
|
||||
*/
|
||||
export function coerceStringArrays(schema: unknown, params: unknown): unknown {
|
||||
return coerceSchemaValue(schema, params);
|
||||
}
|
||||
|
||||
// Detect if we're in a browser extension environment with strict CSP
|
||||
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
||||
const isBrowserExtension =
|
||||
typeof globalThis !== "undefined" &&
|
||||
(globalThis as any).chrome?.runtime?.id !== undefined;
|
||||
|
||||
// Create a singleton AJV instance with formats (only if not in browser extension)
|
||||
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
|
||||
let ajv: any = null;
|
||||
if (!isBrowserExtension) {
|
||||
try {
|
||||
ajv = new Ajv({
|
||||
allErrors: true,
|
||||
strict: false,
|
||||
coerceTypes: true,
|
||||
});
|
||||
addFormats(ajv);
|
||||
} catch (_e) {
|
||||
// AJV initialization failed (likely CSP restriction)
|
||||
console.warn("AJV validation disabled due to CSP restrictions");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates tool call arguments against the tool's TypeBox schema
|
||||
* @param tool The tool definition with TypeBox schema
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated (and potentially coerced) arguments
|
||||
* @throws Error with formatted message if validation fails
|
||||
*/
|
||||
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
||||
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
|
||||
if (!ajv || isBrowserExtension) {
|
||||
// Trust the LLM's output without validation
|
||||
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
|
||||
return toolCall.arguments;
|
||||
}
|
||||
|
||||
// Compile the schema
|
||||
const validate = ajv.compile(tool.parameters);
|
||||
|
||||
// Clone arguments so AJV can safely mutate for type coercion
|
||||
const args = coerceStringArrays(
|
||||
tool.parameters,
|
||||
structuredClone(toolCall.arguments),
|
||||
);
|
||||
|
||||
// Validate the arguments (AJV mutates args in-place for type coercion)
|
||||
if (validate(args)) {
|
||||
return args;
|
||||
}
|
||||
|
||||
// Format validation errors nicely
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((err: any) => {
|
||||
const path = err.instancePath
|
||||
? err.instancePath.substring(1)
|
||||
: err.params.missingProperty || "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n") || "Unknown validation error";
|
||||
|
||||
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
import { existsSync } from "node:fs";
|
||||
import { homedir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
|
||||
import type { KnownProvider } from "./types.js";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
if (cachedVertexAdcCredentialsExists !== null) {
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
cachedVertexAdcCredentialsExists = gacPath
|
||||
? existsSync(gacPath)
|
||||
: existsSync(
|
||||
join(
|
||||
homedir(),
|
||||
".config",
|
||||
"gcloud",
|
||||
"application_default_credentials.json",
|
||||
),
|
||||
);
|
||||
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Node-only env-key lookup for the standalone web host.
|
||||
*
|
||||
* This intentionally avoids the browser-safe dynamic-import pattern from the
|
||||
* shared pi-ai runtime because the packaged Next standalone server turns that
|
||||
* pattern into a failing "Cannot find module as expression is too dynamic"
|
||||
* runtime branch.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined {
|
||||
if (provider === "github-copilot") {
|
||||
return (
|
||||
process.env.COPILOT_GITHUB_TOKEN ||
|
||||
process.env.GH_TOKEN ||
|
||||
process.env.GITHUB_TOKEN
|
||||
);
|
||||
}
|
||||
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
if (provider === "google-vertex") {
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasProject = !!(
|
||||
process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT
|
||||
);
|
||||
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
// Xiaomi MiMo token-plan providers share a single key; allow legacy fallbacks.
|
||||
if (
|
||||
provider === "xiaomi" ||
|
||||
provider === "xiaomi-token-plan-ams" ||
|
||||
provider === "xiaomi-token-plan-sgp" ||
|
||||
provider === "xiaomi-token-plan-cn"
|
||||
) {
|
||||
return (
|
||||
process.env.XIAOMI_API_KEY ||
|
||||
process.env.XIAOMI_TOKEN_PLAN_API_KEY ||
|
||||
process.env.MIMO_API_KEY
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
provider === "amazon-bedrock" &&
|
||||
(process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE)
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
|
||||
const envMap: Record<string, string | string[]> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
|
||||
google: ["GEMINI_API_KEY", "GOOGLE_GENERATIVE_AI_API_KEY"],
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
mistral: "MISTRAL_API_KEY",
|
||||
minimax: "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
huggingface: "HF_TOKEN",
|
||||
opencode: "OPENCODE_API_KEY",
|
||||
"opencode-go": ["OPENCODE_GO_API_KEY", "OPENCODE_API_KEY"],
|
||||
"kimi-coding": "KIMI_API_KEY",
|
||||
xiaomi: "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-ams": "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-sgp": "XIAOMI_API_KEY",
|
||||
"xiaomi-token-plan-cn": "XIAOMI_API_KEY",
|
||||
"alibaba-coding-plan": "ALIBABA_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
if (Array.isArray(envVar)) {
|
||||
for (const name of envVar) {
|
||||
const value = process.env[name];
|
||||
if (value) return value;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
export {
|
||||
getOAuthProvider,
|
||||
getOAuthProviders,
|
||||
type OAuthAuthInfo,
|
||||
type OAuthCredentials,
|
||||
type OAuthLoginCallbacks,
|
||||
type OAuthPrompt,
|
||||
type OAuthProviderInterface,
|
||||
} from "./oauth.js";
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2024",
|
||||
"module": "Node16",
|
||||
"lib": ["ES2024"],
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"incremental": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"inlineSources": true,
|
||||
"inlineSourceMap": false,
|
||||
"moduleResolution": "Node16",
|
||||
"resolveJsonModule": true,
|
||||
"allowImportingTsExtensions": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true,
|
||||
"useDefineForClassFields": false,
|
||||
"types": ["node"],
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
|
||||
}
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
{
|
||||
"name": "@singularity-forge/pi-coding-agent",
|
||||
"version": "2.75.3",
|
||||
"description": "Coding agent CLI (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"piConfig": {
|
||||
"name": "sf",
|
||||
"configDir": ".sf"
|
||||
},
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json && npm run copy-assets",
|
||||
"copy-assets": "node scripts/copy-assets.cjs"
|
||||
},
|
||||
"dependencies": {
|
||||
"@mariozechner/jiti": "^2.6.2",
|
||||
"@silvia-odwyer/photon-node": "^0.3.4",
|
||||
"chalk": "^5.5.0",
|
||||
"diff": "^9.0.0",
|
||||
"express": "^5.2.1",
|
||||
"extract-zip": "^2.0.1",
|
||||
"file-type": "^21.3.4",
|
||||
"hosted-git-info": "^9.0.3",
|
||||
"ignore": "^7.0.5",
|
||||
"marked": "^18.0.3",
|
||||
"minimatch": "^10.2.5",
|
||||
"proper-lockfile": "^4.1.2",
|
||||
"strip-ansi": "^7.2.0",
|
||||
"undici": "^8.2.0",
|
||||
"yaml": "^2.8.4"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/diff": "^7.0.2",
|
||||
"@types/express": "^4.17.21",
|
||||
"@types/hosted-git-info": "^3.0.5",
|
||||
"@types/proper-lockfile": "^4.1.4"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=26.1.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
const { mkdirSync, cpSync, copyFileSync, readdirSync } = require("node:fs");
|
||||
const { join } = require("node:path");
|
||||
|
||||
/**
|
||||
* Recursive directory copy using copyFileSync — workaround for cpSync failures
|
||||
* on Windows paths containing non-ASCII characters (#1178).
|
||||
*/
|
||||
function safeCpSync(src, dest, options) {
|
||||
try {
|
||||
cpSync(src, dest, options);
|
||||
} catch {
|
||||
if (options && options.recursive) {
|
||||
copyDirRecursive(src, dest, options && options.filter);
|
||||
} else {
|
||||
copyFileSync(src, dest);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function copyDirRecursive(src, dest, filter) {
|
||||
mkdirSync(dest, { recursive: true });
|
||||
for (const entry of readdirSync(src, { withFileTypes: true })) {
|
||||
const srcPath = join(src, entry.name);
|
||||
const destPath = join(dest, entry.name);
|
||||
if (filter && !filter(srcPath)) continue;
|
||||
if (entry.isDirectory()) {
|
||||
copyDirRecursive(srcPath, destPath, filter);
|
||||
} else {
|
||||
copyFileSync(srcPath, destPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Theme assets
|
||||
mkdirSync("dist/modes/interactive/theme", { recursive: true });
|
||||
safeCpSync("src/modes/interactive/theme", "dist/modes/interactive/theme", {
|
||||
recursive: true,
|
||||
filter: (s) => !s.endsWith(".ts"),
|
||||
});
|
||||
|
||||
// Export HTML templates and vendor files
|
||||
mkdirSync("dist/core/export-html/vendor", { recursive: true });
|
||||
safeCpSync(
|
||||
"src/core/export-html/template.html",
|
||||
"dist/core/export-html/template.html",
|
||||
);
|
||||
safeCpSync(
|
||||
"src/core/export-html/template.css",
|
||||
"dist/core/export-html/template.css",
|
||||
);
|
||||
safeCpSync(
|
||||
"src/core/export-html/template.js",
|
||||
"dist/core/export-html/template.js",
|
||||
);
|
||||
safeCpSync("src/core/export-html/vendor", "dist/core/export-html/vendor", {
|
||||
recursive: true,
|
||||
filter: (s) => !s.endsWith(".ts"),
|
||||
});
|
||||
|
||||
// LSP defaults
|
||||
mkdirSync("dist/core/lsp", { recursive: true });
|
||||
safeCpSync("src/core/lsp/defaults.json", "dist/core/lsp/defaults.json");
|
||||
safeCpSync("src/core/lsp/lsp.md", "dist/core/lsp/lsp.md");
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* CLI entry point for the refactored coding agent.
|
||||
* Uses main.ts with AgentSession and new mode modules.
|
||||
*
|
||||
* Test with: npx tsx src/cli-new.ts [args...]
|
||||
*/
|
||||
process.title = "pi";
|
||||
|
||||
import { setBedrockProviderModule } from "@singularity-forge/pi-ai";
|
||||
import { bedrockProviderModule } from "@singularity-forge/pi-ai/bedrock-provider";
|
||||
import { EnvHttpProxyAgent, setGlobalDispatcher } from "undici";
|
||||
import { main } from "./main.js";
|
||||
|
||||
// bodyTimeout/headersTimeout default to 300s in undici; long local-LLM stalls
|
||||
// (e.g. vLLM buffering a large tool call) exceed that and abort the SSE stream
|
||||
// with UND_ERR_BODY_TIMEOUT. Disable both — provider SDKs enforce their own
|
||||
// AbortController-based deadlines via retry.provider.timeoutMs.
|
||||
setGlobalDispatcher(
|
||||
new EnvHttpProxyAgent({ bodyTimeout: 0, headersTimeout: 0 }),
|
||||
);
|
||||
setBedrockProviderModule(bedrockProviderModule);
|
||||
|
||||
main(process.argv.slice(2));
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "vitest";
|
||||
import { parseArgs } from "./args.js";
|
||||
|
||||
describe("parseArgs", () => {
|
||||
it("parses optional-value extension flags with implicit and explicit values", () => {
|
||||
const extensionFlags = new Map([
|
||||
["demo-flag", { type: "string" as const, allowNoValue: true }],
|
||||
]);
|
||||
const defaultFlagArgs = parseArgs(["--demo-flag"], extensionFlags);
|
||||
const explicitFlagArgs = parseArgs(["--demo-flag=8080"], extensionFlags);
|
||||
|
||||
assert.deepEqual(
|
||||
[
|
||||
defaultFlagArgs.unknownFlags.get("demo-flag"),
|
||||
explicitFlagArgs.unknownFlags.get("demo-flag"),
|
||||
],
|
||||
[true, "8080"],
|
||||
);
|
||||
});
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue