feat: dynamic model discovery & provider management UX (#581)
This commit is contained in:
parent
570f6195be
commit
9ed812ed54
25 changed files with 2122 additions and 23 deletions
27
.plans/dynamic-model-discovery.md
Normal file
27
.plans/dynamic-model-discovery.md
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Dynamic Model Discovery
|
||||
|
||||
## Overview
|
||||
Runtime model discovery from provider APIs with caching, TUI management, and CLI flags.
|
||||
|
||||
## Components
|
||||
1. **model-discovery.ts** — Provider adapters (OpenAI, Ollama, OpenRouter, Google) + static adapters
|
||||
2. **discovery-cache.ts** — Disk cache at `{agentDir}/discovery-cache.json` with per-provider TTLs
|
||||
3. **models-json-writer.ts** — Safe read-modify-write for `models.json` with file locking
|
||||
4. **provider-manager.ts** — TUI component for provider management (`/provider` command)
|
||||
5. **model-registry.ts** — Extended with `discoverModels()`, `getAllWithDiscovered()`, cache integration
|
||||
6. **settings-manager.ts** — `modelDiscovery` settings (enabled, providers, ttlMinutes, autoRefreshOnModelSelect)
|
||||
7. **args.ts** — `--discover`, `--add-provider`, `--base-url`, `--discover-models` CLI flags
|
||||
8. **list-models.ts** — Rewritten with `[discovered]` badge support
|
||||
9. **main.ts** — CLI handlers for new flags
|
||||
10. **interactive-mode.ts** — `/provider` command handler
|
||||
11. **preferences.ts** — `updatePreferencesModels()` and `validateModelId()` helpers
|
||||
|
||||
## TTL Strategy
|
||||
- Ollama: 5 min (local, models change often)
|
||||
- OpenAI / Google / OpenRouter: 1 hour
|
||||
- Default: 24 hours
|
||||
|
||||
## Merge Rules
|
||||
- Discovered models never override existing built-in or custom models
|
||||
- Discovered models are appended to the registry with `[discovered]` badge
|
||||
- Background discovery is opt-in via `modelDiscovery.enabled` setting
|
||||
49
.plans/preferences-wizard-completeness.md
Normal file
49
.plans/preferences-wizard-completeness.md
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# Preferences Wizard Completeness
|
||||
|
||||
## Problem
|
||||
The `/gsd prefs wizard` currently only configures 6 of 18+ preference fields. Users must hand-edit YAML for the rest.
|
||||
|
||||
## Current Wizard Coverage
|
||||
1. Models (per phase) ✓
|
||||
2. Auto-supervisor timeouts ✓
|
||||
3. Git main_branch ✓
|
||||
4. Skill discovery mode ✓
|
||||
5. Unique milestone IDs ✓
|
||||
|
||||
## Missing Fields to Add
|
||||
|
||||
### Group 1: Git Settings (expand existing section)
|
||||
- `auto_push` (boolean) — auto-push commits ✓
|
||||
- `push_branches` (boolean) — push milestone branches ✓
|
||||
- `remote` (string) — git remote name ✓
|
||||
- `snapshots` (boolean) — WIP snapshot commits ✓
|
||||
- `pre_merge_check` (boolean | "auto") — pre-merge validation ✓
|
||||
- `commit_type` (select) — conventional commit prefix ✓
|
||||
- `merge_strategy` (select) — squash vs merge ✓
|
||||
- `isolation` (select) — worktree vs branch ✓
|
||||
|
||||
### Group 2: Budget & Cost Control ✓
|
||||
- `budget_ceiling` (number) — dollar limit
|
||||
- `budget_enforcement` (select: warn/pause/halt)
|
||||
- `context_pause_threshold` (number 0-100)
|
||||
|
||||
### Group 3: Notifications ✓
|
||||
- `notifications.enabled` (boolean)
|
||||
- `notifications.on_complete` (boolean)
|
||||
- `notifications.on_error` (boolean)
|
||||
- `notifications.on_budget` (boolean)
|
||||
- `notifications.on_milestone` (boolean)
|
||||
- `notifications.on_attention` (boolean)
|
||||
|
||||
### Group 4: Behavior Toggles ✓
|
||||
- `uat_dispatch` (boolean)
|
||||
|
||||
### Group 5: Update Serialization Order ✓
|
||||
- Added missing keys to `orderedKeys` in `serializePreferencesToFrontmatter()`
|
||||
|
||||
### Group 6: Update Template & Docs ✓
|
||||
- Updated `templates/preferences.md` with new fields
|
||||
- Updated `docs/preferences-reference.md` with budget, notifications, git, hooks
|
||||
|
||||
### Group 7: Tests ✓
|
||||
- Added `preferences-wizard-fields.test.ts` covering all new fields
|
||||
|
|
@ -38,6 +38,11 @@ export interface Args {
|
|||
themes?: string[];
|
||||
noThemes?: boolean;
|
||||
listModels?: string | true;
|
||||
discover?: boolean;
|
||||
addProvider?: string;
|
||||
addProviderBaseUrl?: string;
|
||||
addProviderApiKey?: string;
|
||||
discoverModels?: string | true;
|
||||
offline?: boolean;
|
||||
verbose?: boolean;
|
||||
messages: string[];
|
||||
|
|
@ -150,6 +155,18 @@ export function parseArgs(args: string[], extensionFlags?: Map<string, { type: "
|
|||
} else {
|
||||
result.listModels = true;
|
||||
}
|
||||
} else if (arg === "--discover") {
|
||||
result.discover = true;
|
||||
} else if (arg === "--add-provider" && i + 1 < args.length) {
|
||||
result.addProvider = args[++i];
|
||||
} else if (arg === "--base-url" && i + 1 < args.length) {
|
||||
result.addProviderBaseUrl = args[++i];
|
||||
} else if (arg === "--discover-models") {
|
||||
if (i + 1 < args.length && !args[i + 1].startsWith("-") && !args[i + 1].startsWith("@")) {
|
||||
result.discoverModels = args[++i];
|
||||
} else {
|
||||
result.discoverModels = true;
|
||||
}
|
||||
} else if (arg === "--verbose") {
|
||||
result.verbose = true;
|
||||
} else if (arg === "--offline") {
|
||||
|
|
@ -219,6 +236,10 @@ ${chalk.bold("Options:")}
|
|||
--no-themes Disable theme discovery and loading
|
||||
--export <file> Export session file to HTML and exit
|
||||
--list-models [search] List available models (with optional fuzzy search)
|
||||
--discover Include discovered models in --list-models output
|
||||
--discover-models [provider] Discover models from provider APIs (all or specific)
|
||||
--add-provider <name> Add a provider to models.json (use with --base-url, --api-key)
|
||||
--base-url <url> Base URL for --add-provider
|
||||
--verbose Force verbose startup (overrides quietStartup setting)
|
||||
--offline Disable startup network operations (same as PI_OFFLINE=1)
|
||||
--help, -h Show this help
|
||||
|
|
|
|||
|
|
@ -1,11 +1,18 @@
|
|||
/**
|
||||
* List available models with optional fuzzy search
|
||||
* List available models with optional fuzzy search and discovery support
|
||||
*/
|
||||
|
||||
import type { Api, Model } from "@gsd/pi-ai";
|
||||
import { fuzzyFilter } from "@gsd/pi-tui";
|
||||
import type { ModelRegistry } from "../core/model-registry.js";
|
||||
|
||||
export interface ListModelsOptions {
|
||||
/** Include discovered models in output */
|
||||
discover?: boolean;
|
||||
/** Search pattern for fuzzy filtering */
|
||||
searchPattern?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
|
||||
*/
|
||||
|
|
@ -22,10 +29,48 @@ function formatTokenCount(count: number): string {
|
|||
}
|
||||
|
||||
/**
|
||||
* List available models, optionally filtered by search pattern
|
||||
* Discover models from provider APIs and print results.
|
||||
*/
|
||||
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
|
||||
const models = modelRegistry.getAvailable();
|
||||
export async function discoverAndPrintModels(
|
||||
modelRegistry: ModelRegistry,
|
||||
provider?: string,
|
||||
): Promise<void> {
|
||||
const providers = provider ? [provider] : undefined;
|
||||
|
||||
console.log("Discovering models...");
|
||||
const results = await modelRegistry.discoverModels(providers);
|
||||
|
||||
for (const result of results) {
|
||||
if (result.error) {
|
||||
console.log(` ${result.provider}: error - ${result.error}`);
|
||||
} else {
|
||||
console.log(` ${result.provider}: ${result.models.length} models found`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models, optionally filtered by search pattern.
|
||||
* Accepts either a string (backward compat) or ListModelsOptions.
|
||||
*/
|
||||
export async function listModels(
|
||||
modelRegistry: ModelRegistry,
|
||||
optionsOrSearch?: string | ListModelsOptions,
|
||||
): Promise<void> {
|
||||
const options: ListModelsOptions =
|
||||
typeof optionsOrSearch === "string"
|
||||
? { searchPattern: optionsOrSearch }
|
||||
: optionsOrSearch ?? {};
|
||||
|
||||
// If discover flag is set, run discovery first
|
||||
if (options.discover) {
|
||||
await modelRegistry.discoverModels();
|
||||
}
|
||||
|
||||
// Get models — include discovered if discovery was run
|
||||
const models = options.discover
|
||||
? modelRegistry.getAllWithDiscovered()
|
||||
: modelRegistry.getAvailable();
|
||||
|
||||
if (models.length === 0) {
|
||||
console.log("No models available. Set API keys in environment variables.");
|
||||
|
|
@ -34,12 +79,12 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
|||
|
||||
// Apply fuzzy filter if search pattern provided
|
||||
let filteredModels: Model<Api>[] = models;
|
||||
if (searchPattern) {
|
||||
filteredModels = fuzzyFilter(models, searchPattern, (m) => `${m.provider} ${m.id}`);
|
||||
if (options.searchPattern) {
|
||||
filteredModels = fuzzyFilter(models, options.searchPattern, (m) => `${m.provider} ${m.id}`);
|
||||
}
|
||||
|
||||
if (filteredModels.length === 0) {
|
||||
console.log(`No models matching "${searchPattern}"`);
|
||||
console.log(`No models matching "${options.searchPattern}"`);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -53,15 +98,19 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
|||
});
|
||||
|
||||
// Calculate column widths
|
||||
const rows = filteredModels.map((m) => ({
|
||||
provider: m.provider,
|
||||
model: m.id,
|
||||
name: m.name,
|
||||
context: formatTokenCount(m.contextWindow),
|
||||
maxOut: formatTokenCount(m.maxTokens),
|
||||
thinking: m.reasoning ? "yes" : "no",
|
||||
images: m.input.includes("image") ? "yes" : "no",
|
||||
}));
|
||||
const rows = filteredModels.map((m) => {
|
||||
const isDiscovered = options.discover && modelRegistry.isDiscovered(m);
|
||||
return {
|
||||
provider: m.provider,
|
||||
model: m.id,
|
||||
name: m.name,
|
||||
context: formatTokenCount(m.contextWindow),
|
||||
maxOut: formatTokenCount(m.maxTokens),
|
||||
thinking: m.reasoning ? "yes" : "no",
|
||||
images: m.input.includes("image") ? "yes" : "no",
|
||||
badge: isDiscovered ? "[discovered]" : "",
|
||||
};
|
||||
});
|
||||
|
||||
const headers = {
|
||||
provider: "provider",
|
||||
|
|
@ -71,6 +120,7 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
|||
maxOut: "max-out",
|
||||
thinking: "thinking",
|
||||
images: "images",
|
||||
badge: "",
|
||||
};
|
||||
|
||||
const widths = {
|
||||
|
|
@ -105,7 +155,10 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
|||
row.maxOut.padEnd(widths.maxOut),
|
||||
row.thinking.padEnd(widths.thinking),
|
||||
row.images.padEnd(widths.images),
|
||||
].join(" ");
|
||||
row.badge,
|
||||
]
|
||||
.join(" ")
|
||||
.trimEnd();
|
||||
console.log(line);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
170
packages/pi-coding-agent/src/core/discovery-cache.test.ts
Normal file
170
packages/pi-coding-agent/src/core/discovery-cache.test.ts
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||
|
||||
let testDir: string;
|
||||
let cachePath: string;
|
||||
|
||||
beforeEach(() => {
|
||||
testDir = join(tmpdir(), `discovery-cache-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
cachePath = join(testDir, "discovery-cache.json");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
try {
|
||||
rmSync(testDir, { recursive: true, force: true });
|
||||
} catch {
|
||||
// Cleanup best-effort
|
||||
}
|
||||
});
|
||||
|
||||
// ─── basic operations ────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — basic operations", () => {
|
||||
it("starts with no entries", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
});
|
||||
|
||||
it("stores and retrieves models", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
const models = [{ id: "gpt-4o", name: "GPT-4o" }];
|
||||
cache.set("openai", models);
|
||||
|
||||
const entry = cache.get("openai");
|
||||
assert.ok(entry);
|
||||
assert.deepEqual(entry.models, models);
|
||||
assert.ok(entry.fetchedAt > 0);
|
||||
assert.ok(entry.ttlMs > 0);
|
||||
});
|
||||
|
||||
it("persists to disk and reloads", () => {
|
||||
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||
cache1.set("openai", [{ id: "gpt-4o" }]);
|
||||
|
||||
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||
const entry = cache2.get("openai");
|
||||
assert.ok(entry);
|
||||
assert.equal(entry.models[0].id, "gpt-4o");
|
||||
});
|
||||
|
||||
it("clear removes a specific provider", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache.set("google", [{ id: "gemini-pro" }]);
|
||||
|
||||
cache.clear("openai");
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
assert.ok(cache.get("google"));
|
||||
});
|
||||
|
||||
it("clear without provider removes all entries", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache.set("google", [{ id: "gemini-pro" }]);
|
||||
|
||||
cache.clear();
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
assert.equal(cache.get("google"), undefined);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── staleness ───────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — staleness", () => {
|
||||
it("newly set entries are not stale", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
assert.equal(cache.isStale("openai"), false);
|
||||
});
|
||||
|
||||
it("missing providers are stale", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
assert.equal(cache.isStale("unknown"), true);
|
||||
});
|
||||
|
||||
it("entries with expired TTL are stale", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }], 1); // 1ms TTL
|
||||
|
||||
// Wait for TTL to expire
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < 5) {
|
||||
// busy wait
|
||||
}
|
||||
|
||||
assert.equal(cache.isStale("openai"), true);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getAll ──────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — getAll", () => {
|
||||
it("returns non-stale entries by default", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache.set("stale", [{ id: "old" }], 1);
|
||||
|
||||
// Wait for stale TTL
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < 5) {
|
||||
// busy wait
|
||||
}
|
||||
|
||||
const all = cache.getAll();
|
||||
assert.ok(all.has("openai"));
|
||||
assert.ok(!all.has("stale"));
|
||||
});
|
||||
|
||||
it("returns all entries when includeStale is true", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache.set("stale", [{ id: "old" }], 1);
|
||||
|
||||
// Wait for stale TTL
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < 5) {
|
||||
// busy wait
|
||||
}
|
||||
|
||||
const all = cache.getAll(true);
|
||||
assert.ok(all.has("openai"));
|
||||
assert.ok(all.has("stale"));
|
||||
});
|
||||
});
|
||||
|
||||
// ─── edge cases ──────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — edge cases", () => {
|
||||
it("handles corrupted cache file gracefully", () => {
|
||||
writeFileSync(cachePath, "not valid json", "utf-8");
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
});
|
||||
|
||||
it("handles wrong version gracefully", () => {
|
||||
writeFileSync(cachePath, JSON.stringify({ version: 99, entries: {} }), "utf-8");
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
});
|
||||
|
||||
it("handles missing cache file", () => {
|
||||
const cache = new ModelDiscoveryCache(join(testDir, "nonexistent", "cache.json"));
|
||||
assert.equal(cache.get("openai"), undefined);
|
||||
});
|
||||
|
||||
it("overwrites existing entry for same provider", () => {
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache.set("openai", [{ id: "gpt-4o-mini" }]);
|
||||
|
||||
const entry = cache.get("openai");
|
||||
assert.ok(entry);
|
||||
assert.equal(entry.models.length, 1);
|
||||
assert.equal(entry.models[0].id, "gpt-4o-mini");
|
||||
});
|
||||
});
|
||||
97
packages/pi-coding-agent/src/core/discovery-cache.ts
Normal file
97
packages/pi-coding-agent/src/core/discovery-cache.ts
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Disk-based cache for discovered models.
|
||||
* Stores results at {agentDir}/discovery-cache.json with per-provider TTLs.
|
||||
*/
|
||||
|
||||
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { type DiscoveredModel, getDefaultTTL } from "./model-discovery.js";
|
||||
|
||||
export interface DiscoveryCacheEntry {
|
||||
models: DiscoveredModel[];
|
||||
fetchedAt: number;
|
||||
ttlMs: number;
|
||||
}
|
||||
|
||||
export interface DiscoveryCacheData {
|
||||
version: 1;
|
||||
entries: Record<string, DiscoveryCacheEntry>;
|
||||
}
|
||||
|
||||
export class ModelDiscoveryCache {
|
||||
private data: DiscoveryCacheData;
|
||||
private cachePath: string;
|
||||
|
||||
constructor(cachePath?: string) {
|
||||
this.cachePath = cachePath ?? join(getAgentDir(), "discovery-cache.json");
|
||||
this.data = { version: 1, entries: {} };
|
||||
this.load();
|
||||
}
|
||||
|
||||
get(provider: string): DiscoveryCacheEntry | undefined {
|
||||
const entry = this.data.entries[provider];
|
||||
return entry;
|
||||
}
|
||||
|
||||
set(provider: string, models: DiscoveredModel[], ttlMs?: number): void {
|
||||
this.data.entries[provider] = {
|
||||
models,
|
||||
fetchedAt: Date.now(),
|
||||
ttlMs: ttlMs ?? getDefaultTTL(provider),
|
||||
};
|
||||
this.save();
|
||||
}
|
||||
|
||||
isStale(provider: string): boolean {
|
||||
const entry = this.data.entries[provider];
|
||||
if (!entry) return true;
|
||||
return Date.now() - entry.fetchedAt > entry.ttlMs;
|
||||
}
|
||||
|
||||
clear(provider?: string): void {
|
||||
if (provider) {
|
||||
delete this.data.entries[provider];
|
||||
} else {
|
||||
this.data.entries = {};
|
||||
}
|
||||
this.save();
|
||||
}
|
||||
|
||||
getAll(includeStale = false): Map<string, DiscoveryCacheEntry> {
|
||||
const result = new Map<string, DiscoveryCacheEntry>();
|
||||
for (const [provider, entry] of Object.entries(this.data.entries)) {
|
||||
if (includeStale || !this.isStale(provider)) {
|
||||
result.set(provider, entry);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
load(): void {
|
||||
try {
|
||||
if (existsSync(this.cachePath)) {
|
||||
const content = readFileSync(this.cachePath, "utf-8");
|
||||
const parsed = JSON.parse(content) as DiscoveryCacheData;
|
||||
if (parsed.version === 1 && parsed.entries) {
|
||||
this.data = parsed;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Corrupted or unreadable cache — start fresh
|
||||
this.data = { version: 1, entries: {} };
|
||||
}
|
||||
}
|
||||
|
||||
save(): void {
|
||||
try {
|
||||
const dir = dirname(this.cachePath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
writeFileSync(this.cachePath, JSON.stringify(this.data, null, 2), "utf-8");
|
||||
} catch {
|
||||
// Silently ignore write failures (read-only FS, permissions, etc.)
|
||||
}
|
||||
}
|
||||
}
|
||||
125
packages/pi-coding-agent/src/core/model-discovery.test.ts
Normal file
125
packages/pi-coding-agent/src/core/model-discovery.test.ts
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from "node:test";
|
||||
import {
|
||||
DISCOVERY_TTLS,
|
||||
getDefaultTTL,
|
||||
getDiscoverableProviders,
|
||||
getDiscoveryAdapter,
|
||||
} from "./model-discovery.js";
|
||||
|
||||
// ─── getDiscoveryAdapter ─────────────────────────────────────────────────────
|
||||
|
||||
describe("getDiscoveryAdapter", () => {
|
||||
it("returns an adapter for openai", () => {
|
||||
const adapter = getDiscoveryAdapter("openai");
|
||||
assert.equal(adapter.provider, "openai");
|
||||
assert.equal(adapter.supportsDiscovery, true);
|
||||
});
|
||||
|
||||
it("returns an adapter for ollama", () => {
|
||||
const adapter = getDiscoveryAdapter("ollama");
|
||||
assert.equal(adapter.provider, "ollama");
|
||||
assert.equal(adapter.supportsDiscovery, true);
|
||||
});
|
||||
|
||||
it("returns an adapter for openrouter", () => {
|
||||
const adapter = getDiscoveryAdapter("openrouter");
|
||||
assert.equal(adapter.provider, "openrouter");
|
||||
assert.equal(adapter.supportsDiscovery, true);
|
||||
});
|
||||
|
||||
it("returns an adapter for google", () => {
|
||||
const adapter = getDiscoveryAdapter("google");
|
||||
assert.equal(adapter.provider, "google");
|
||||
assert.equal(adapter.supportsDiscovery, true);
|
||||
});
|
||||
|
||||
it("returns a static adapter for anthropic", () => {
|
||||
const adapter = getDiscoveryAdapter("anthropic");
|
||||
assert.equal(adapter.provider, "anthropic");
|
||||
assert.equal(adapter.supportsDiscovery, false);
|
||||
});
|
||||
|
||||
it("returns a static adapter for bedrock", () => {
|
||||
const adapter = getDiscoveryAdapter("bedrock");
|
||||
assert.equal(adapter.provider, "bedrock");
|
||||
assert.equal(adapter.supportsDiscovery, false);
|
||||
});
|
||||
|
||||
it("returns a static adapter for unknown providers", () => {
|
||||
const adapter = getDiscoveryAdapter("unknown-provider");
|
||||
assert.equal(adapter.provider, "unknown-provider");
|
||||
assert.equal(adapter.supportsDiscovery, false);
|
||||
});
|
||||
|
||||
it("static adapter fetchModels returns empty array", async () => {
|
||||
const adapter = getDiscoveryAdapter("anthropic");
|
||||
const models = await adapter.fetchModels("key");
|
||||
assert.deepEqual(models, []);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getDiscoverableProviders ────────────────────────────────────────────────
|
||||
|
||||
describe("getDiscoverableProviders", () => {
|
||||
it("returns only providers that support discovery", () => {
|
||||
const providers = getDiscoverableProviders();
|
||||
assert.ok(providers.includes("openai"));
|
||||
assert.ok(providers.includes("ollama"));
|
||||
assert.ok(providers.includes("openrouter"));
|
||||
assert.ok(providers.includes("google"));
|
||||
assert.ok(!providers.includes("anthropic"));
|
||||
assert.ok(!providers.includes("bedrock"));
|
||||
});
|
||||
|
||||
it("returns an array of strings", () => {
|
||||
const providers = getDiscoverableProviders();
|
||||
assert.ok(Array.isArray(providers));
|
||||
for (const p of providers) {
|
||||
assert.equal(typeof p, "string");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getDefaultTTL ───────────────────────────────────────────────────────────
|
||||
|
||||
describe("getDefaultTTL", () => {
|
||||
it("returns 5 minutes for ollama", () => {
|
||||
assert.equal(getDefaultTTL("ollama"), 5 * 60 * 1000);
|
||||
});
|
||||
|
||||
it("returns 1 hour for openai", () => {
|
||||
assert.equal(getDefaultTTL("openai"), 60 * 60 * 1000);
|
||||
});
|
||||
|
||||
it("returns 1 hour for google", () => {
|
||||
assert.equal(getDefaultTTL("google"), 60 * 60 * 1000);
|
||||
});
|
||||
|
||||
it("returns 1 hour for openrouter", () => {
|
||||
assert.equal(getDefaultTTL("openrouter"), 60 * 60 * 1000);
|
||||
});
|
||||
|
||||
it("returns 24 hours for unknown providers", () => {
|
||||
assert.equal(getDefaultTTL("some-custom"), 24 * 60 * 60 * 1000);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── DISCOVERY_TTLS ──────────────────────────────────────────────────────────
|
||||
|
||||
describe("DISCOVERY_TTLS", () => {
|
||||
it("has expected keys", () => {
|
||||
assert.ok("ollama" in DISCOVERY_TTLS);
|
||||
assert.ok("openai" in DISCOVERY_TTLS);
|
||||
assert.ok("google" in DISCOVERY_TTLS);
|
||||
assert.ok("openrouter" in DISCOVERY_TTLS);
|
||||
assert.ok("default" in DISCOVERY_TTLS);
|
||||
});
|
||||
|
||||
it("all values are positive numbers", () => {
|
||||
for (const [, value] of Object.entries(DISCOVERY_TTLS)) {
|
||||
assert.equal(typeof value, "number");
|
||||
assert.ok(value > 0);
|
||||
}
|
||||
});
|
||||
});
|
||||
231
packages/pi-coding-agent/src/core/model-discovery.ts
Normal file
231
packages/pi-coding-agent/src/core/model-discovery.ts
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
/**
|
||||
* Provider discovery adapters for runtime model enumeration.
|
||||
* Each adapter implements ProviderDiscoveryAdapter to fetch models from provider APIs.
|
||||
*/
|
||||
|
||||
export interface DiscoveredModel {
|
||||
id: string;
|
||||
name?: string;
|
||||
contextWindow?: number;
|
||||
maxTokens?: number;
|
||||
reasoning?: boolean;
|
||||
input?: ("text" | "image")[];
|
||||
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
}
|
||||
|
||||
export interface DiscoveryResult {
|
||||
provider: string;
|
||||
models: DiscoveredModel[];
|
||||
fetchedAt: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface ProviderDiscoveryAdapter {
|
||||
provider: string;
|
||||
supportsDiscovery: boolean;
|
||||
fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]>;
|
||||
}
|
||||
|
||||
/** Per-provider TTLs in milliseconds */
|
||||
export const DISCOVERY_TTLS: Record<string, number> = {
|
||||
ollama: 5 * 60 * 1000, // 5 minutes (local, models change often)
|
||||
openai: 60 * 60 * 1000, // 1 hour
|
||||
google: 60 * 60 * 1000, // 1 hour
|
||||
openrouter: 60 * 60 * 1000, // 1 hour
|
||||
default: 24 * 60 * 60 * 1000, // 24 hours
|
||||
};
|
||||
|
||||
export function getDefaultTTL(provider: string): number {
|
||||
return DISCOVERY_TTLS[provider] ?? DISCOVERY_TTLS.default;
|
||||
}
|
||||
|
||||
async function fetchWithTimeout(url: string, options: RequestInit = {}, timeoutMs = 5000): Promise<Response> {
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||||
try {
|
||||
return await fetch(url, { ...options, signal: controller.signal });
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── OpenAI Adapter ──────────────────────────────────────────────────────────
|
||||
|
||||
const OPENAI_EXCLUDED_PREFIXES = ["embedding", "tts", "dall-e", "whisper", "text-embedding", "davinci", "babbage"];
|
||||
|
||||
class OpenAIDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||
provider = "openai";
|
||||
supportsDiscovery = true;
|
||||
|
||||
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||
const url = `${baseUrl ?? "https://api.openai.com"}/v1/models`;
|
||||
const response = await fetchWithTimeout(url, {
|
||||
headers: { Authorization: `Bearer ${apiKey}` },
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`OpenAI models API returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as { data: Array<{ id: string; owned_by?: string }> };
|
||||
return data.data
|
||||
.filter((m) => !OPENAI_EXCLUDED_PREFIXES.some((prefix) => m.id.startsWith(prefix)))
|
||||
.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.id,
|
||||
input: ["text" as const, "image" as const],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Ollama Adapter ──────────────────────────────────────────────────────────
|
||||
|
||||
class OllamaDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||
provider = "ollama";
|
||||
supportsDiscovery = true;
|
||||
|
||||
async fetchModels(_apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||
const url = `${baseUrl ?? "http://localhost:11434"}/api/tags`;
|
||||
const response = await fetchWithTimeout(url);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Ollama tags API returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
models: Array<{ name: string; size: number; details?: { parameter_size?: string } }>;
|
||||
};
|
||||
|
||||
return (data.models ?? []).map((m) => ({
|
||||
id: m.name,
|
||||
name: m.name,
|
||||
input: ["text" as const],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── OpenRouter Adapter ──────────────────────────────────────────────────────
|
||||
|
||||
class OpenRouterDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||
provider = "openrouter";
|
||||
supportsDiscovery = true;
|
||||
|
||||
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||
const url = `${baseUrl ?? "https://openrouter.ai"}/api/v1/models`;
|
||||
const response = await fetchWithTimeout(url, {
|
||||
headers: { Authorization: `Bearer ${apiKey}` },
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`OpenRouter models API returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
data: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
context_length?: number;
|
||||
top_provider?: { max_completion_tokens?: number };
|
||||
pricing?: { prompt: string; completion: string };
|
||||
}>;
|
||||
};
|
||||
|
||||
return (data.data ?? []).map((m) => {
|
||||
const cost =
|
||||
m.pricing?.prompt !== undefined && m.pricing?.completion !== undefined
|
||||
? {
|
||||
input: parseFloat(m.pricing.prompt) * 1_000_000,
|
||||
output: parseFloat(m.pricing.completion) * 1_000_000,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
contextWindow: m.context_length,
|
||||
maxTokens: m.top_provider?.max_completion_tokens,
|
||||
cost,
|
||||
input: ["text" as const, "image" as const],
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Google/Gemini Adapter ───────────────────────────────────────────────────
|
||||
|
||||
class GoogleDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||
provider = "google";
|
||||
supportsDiscovery = true;
|
||||
|
||||
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||
const url = `${baseUrl ?? "https://generativelanguage.googleapis.com"}/v1beta/models?key=${apiKey}`;
|
||||
const response = await fetchWithTimeout(url);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Google models API returned ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
models: Array<{
|
||||
name: string;
|
||||
displayName: string;
|
||||
supportedGenerationMethods?: string[];
|
||||
inputTokenLimit?: number;
|
||||
outputTokenLimit?: number;
|
||||
}>;
|
||||
};
|
||||
|
||||
return (data.models ?? [])
|
||||
.filter((m) => m.supportedGenerationMethods?.includes("generateContent"))
|
||||
.map((m) => ({
|
||||
id: m.name.replace("models/", ""),
|
||||
name: m.displayName,
|
||||
contextWindow: m.inputTokenLimit,
|
||||
maxTokens: m.outputTokenLimit,
|
||||
input: ["text" as const, "image" as const],
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Static Adapter (no discovery) ───────────────────────────────────────────
|
||||
|
||||
class StaticDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||
provider: string;
|
||||
supportsDiscovery = false;
|
||||
|
||||
constructor(provider: string) {
|
||||
this.provider = provider;
|
||||
}
|
||||
|
||||
async fetchModels(): Promise<DiscoveredModel[]> {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Registry ────────────────────────────────────────────────────────────────
|
||||
|
||||
const adapters: Record<string, ProviderDiscoveryAdapter> = {
|
||||
openai: new OpenAIDiscoveryAdapter(),
|
||||
ollama: new OllamaDiscoveryAdapter(),
|
||||
openrouter: new OpenRouterDiscoveryAdapter(),
|
||||
google: new GoogleDiscoveryAdapter(),
|
||||
anthropic: new StaticDiscoveryAdapter("anthropic"),
|
||||
bedrock: new StaticDiscoveryAdapter("bedrock"),
|
||||
"azure-openai": new StaticDiscoveryAdapter("azure-openai"),
|
||||
groq: new StaticDiscoveryAdapter("groq"),
|
||||
cerebras: new StaticDiscoveryAdapter("cerebras"),
|
||||
xai: new StaticDiscoveryAdapter("xai"),
|
||||
mistral: new StaticDiscoveryAdapter("mistral"),
|
||||
};
|
||||
|
||||
export function getDiscoveryAdapter(provider: string): ProviderDiscoveryAdapter {
|
||||
return adapters[provider] ?? new StaticDiscoveryAdapter(provider);
|
||||
}
|
||||
|
||||
export function getDiscoverableProviders(): string[] {
|
||||
return Object.entries(adapters)
|
||||
.filter(([, adapter]) => adapter.supportsDiscovery)
|
||||
.map(([name]) => name);
|
||||
}
|
||||
|
|
@ -0,0 +1,135 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||
import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js";
|
||||
|
||||
let testDir: string;
|
||||
|
||||
beforeEach(() => {
|
||||
testDir = join(tmpdir(), `model-registry-discovery-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
try {
|
||||
rmSync(testDir, { recursive: true, force: true });
|
||||
} catch {
|
||||
// Cleanup best-effort
|
||||
}
|
||||
});
|
||||
|
||||
// ─── discovery cache integration ─────────────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — integration with discovery", () => {
|
||||
it("cache respects provider-specific TTLs", () => {
|
||||
const cachePath = join(testDir, "cache.json");
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
|
||||
cache.set("ollama", [{ id: "llama2" }]);
|
||||
const entry = cache.get("ollama");
|
||||
assert.ok(entry);
|
||||
assert.equal(entry.ttlMs, getDefaultTTL("ollama"));
|
||||
});
|
||||
|
||||
it("cache uses custom TTL when provided", () => {
|
||||
const cachePath = join(testDir, "cache.json");
|
||||
const cache = new ModelDiscoveryCache(cachePath);
|
||||
|
||||
cache.set("openai", [{ id: "gpt-4o" }], 999);
|
||||
const entry = cache.get("openai");
|
||||
assert.ok(entry);
|
||||
assert.equal(entry.ttlMs, 999);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── adapter resolution ─────────────────────────────────────────────────────
|
||||
|
||||
describe("Discovery adapter resolution", () => {
|
||||
it("all discoverable providers have adapters", () => {
|
||||
const providers = getDiscoverableProviders();
|
||||
for (const provider of providers) {
|
||||
const adapter = getDiscoveryAdapter(provider);
|
||||
assert.equal(adapter.supportsDiscovery, true, `${provider} should support discovery`);
|
||||
}
|
||||
});
|
||||
|
||||
it("static adapters return empty model lists", async () => {
|
||||
const staticProviders = ["anthropic", "bedrock", "azure-openai", "groq", "cerebras"];
|
||||
for (const provider of staticProviders) {
|
||||
const adapter = getDiscoveryAdapter(provider);
|
||||
assert.equal(adapter.supportsDiscovery, false, `${provider} should not support discovery`);
|
||||
const models = await adapter.fetchModels("dummy-key");
|
||||
assert.deepEqual(models, [], `${provider} should return empty models`);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ─── AuthStorage hasAuth for discovery ───────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — hasAuth for discovery providers", () => {
|
||||
it("returns false for providers without auth", () => {
|
||||
const storage = AuthStorage.inMemory({});
|
||||
assert.equal(storage.hasAuth("openai"), false);
|
||||
assert.equal(storage.hasAuth("ollama"), false);
|
||||
});
|
||||
|
||||
it("returns true for providers with stored keys", () => {
|
||||
const storage = AuthStorage.inMemory({
|
||||
openai: { type: "api_key" as const, key: "sk-test" },
|
||||
});
|
||||
assert.equal(storage.hasAuth("openai"), true);
|
||||
assert.equal(storage.hasAuth("ollama"), false);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── cache persistence across instances ──────────────────────────────────────
|
||||
|
||||
describe("ModelDiscoveryCache — persistence", () => {
|
||||
it("data survives across cache instances", () => {
|
||||
const cachePath = join(testDir, "persist.json");
|
||||
|
||||
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||
cache1.set("openai", [
|
||||
{ id: "gpt-4o", name: "GPT-4o", contextWindow: 128000 },
|
||||
{ id: "gpt-4o-mini", name: "GPT-4o Mini" },
|
||||
]);
|
||||
|
||||
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||
const entry = cache2.get("openai");
|
||||
assert.ok(entry);
|
||||
assert.equal(entry.models.length, 2);
|
||||
assert.equal(entry.models[0].contextWindow, 128000);
|
||||
});
|
||||
|
||||
it("clear persists across instances", () => {
|
||||
const cachePath = join(testDir, "clear.json");
|
||||
|
||||
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||
cache1.set("openai", [{ id: "gpt-4o" }]);
|
||||
cache1.clear("openai");
|
||||
|
||||
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||
assert.equal(cache2.get("openai"), undefined);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── discovery TTL values ────────────────────────────────────────────────────
|
||||
|
||||
describe("Discovery TTL configuration", () => {
|
||||
it("ollama has shortest TTL (local models change often)", () => {
|
||||
const ollamaTTL = getDefaultTTL("ollama");
|
||||
const openaiTTL = getDefaultTTL("openai");
|
||||
assert.ok(ollamaTTL < openaiTTL, "ollama TTL should be shorter than openai");
|
||||
});
|
||||
|
||||
it("unknown providers get default TTL", () => {
|
||||
const customTTL = getDefaultTTL("my-custom-provider");
|
||||
const defaultTTL = getDefaultTTL("default");
|
||||
// Unknown providers should get the same TTL as the explicit "default" key
|
||||
assert.equal(customTTL, defaultTTL);
|
||||
});
|
||||
});
|
||||
|
|
@ -24,6 +24,9 @@ import { existsSync, readFileSync } from "fs";
|
|||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||
import type { DiscoveredModel, DiscoveryResult } from "./model-discovery.js";
|
||||
import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js";
|
||||
import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
|
|
@ -221,6 +224,8 @@ export const clearApiKeyCache = clearConfigValueCache;
|
|||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private discoveredModels: Model<Api>[] = [];
|
||||
private discoveryCache: ModelDiscoveryCache;
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
|
@ -229,6 +234,8 @@ export class ModelRegistry {
|
|||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
||||
) {
|
||||
this.discoveryCache = new ModelDiscoveryCache();
|
||||
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
|
|
@ -666,6 +673,106 @@ export class ModelRegistry {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover models from all providers that support discovery.
|
||||
* Results are cached and merged into the registry (never overrides existing models).
|
||||
*/
|
||||
async discoverModels(providers?: string[]): Promise<DiscoveryResult[]> {
|
||||
const targetProviders = providers ?? getDiscoverableProviders();
|
||||
const results: DiscoveryResult[] = [];
|
||||
|
||||
for (const providerName of targetProviders) {
|
||||
const adapter = getDiscoveryAdapter(providerName);
|
||||
if (!adapter.supportsDiscovery) continue;
|
||||
|
||||
// Skip if cache is still fresh
|
||||
if (!this.discoveryCache.isStale(providerName)) {
|
||||
const cached = this.discoveryCache.get(providerName);
|
||||
if (cached) {
|
||||
results.push({
|
||||
provider: providerName,
|
||||
models: cached.models,
|
||||
fetchedAt: cached.fetchedAt,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const apiKey = await this.authStorage.getApiKey(providerName);
|
||||
if (!apiKey && providerName !== "ollama") continue;
|
||||
|
||||
const models = await adapter.fetchModels(apiKey ?? "", undefined);
|
||||
this.discoveryCache.set(providerName, models);
|
||||
results.push({
|
||||
provider: providerName,
|
||||
models,
|
||||
fetchedAt: Date.now(),
|
||||
});
|
||||
} catch (error) {
|
||||
results.push({
|
||||
provider: providerName,
|
||||
models: [],
|
||||
fetchedAt: Date.now(),
|
||||
error: error instanceof Error ? error.message : String(error),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert and merge discovered models
|
||||
this.discoveredModels = this.convertDiscoveredModels(results);
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models including discovered ones.
|
||||
* Discovered models are appended but never override existing models.
|
||||
*/
|
||||
getAllWithDiscovered(): Model<Api>[] {
|
||||
const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`));
|
||||
const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`));
|
||||
return [...this.models, ...unique];
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model was added via discovery (not built-in or custom).
|
||||
*/
|
||||
isDiscovered(model: Model<Api>): boolean {
|
||||
return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the discovery cache instance.
|
||||
*/
|
||||
getDiscoveryCache(): ModelDiscoveryCache {
|
||||
return this.discoveryCache;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert DiscoveryResult[] into Model<Api>[] with default values.
|
||||
*/
|
||||
private convertDiscoveredModels(results: DiscoveryResult[]): Model<Api>[] {
|
||||
const converted: Model<Api>[] = [];
|
||||
for (const result of results) {
|
||||
if (result.error) continue;
|
||||
for (const dm of result.models) {
|
||||
converted.push({
|
||||
id: dm.id,
|
||||
name: dm.name ?? dm.id,
|
||||
api: "openai" as Api,
|
||||
provider: result.provider,
|
||||
baseUrl: "",
|
||||
reasoning: dm.reasoning ?? false,
|
||||
input: dm.input ?? ["text"],
|
||||
cost: dm.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: dm.contextWindow ?? 128000,
|
||||
maxTokens: dm.maxTokens ?? 16384,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
return converted;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
145
packages/pi-coding-agent/src/core/models-json-writer.test.ts
Normal file
145
packages/pi-coding-agent/src/core/models-json-writer.test.ts
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { existsSync, mkdirSync, readFileSync, rmSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||
import { ModelsJsonWriter } from "./models-json-writer.js";
|
||||
|
||||
let testDir: string;
|
||||
let modelsJsonPath: string;
|
||||
|
||||
beforeEach(() => {
|
||||
testDir = join(tmpdir(), `models-json-writer-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
modelsJsonPath = join(testDir, "models.json");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
try {
|
||||
rmSync(testDir, { recursive: true, force: true });
|
||||
} catch {
|
||||
// Cleanup best-effort
|
||||
}
|
||||
});
|
||||
|
||||
function readModels(): Record<string, unknown> {
|
||||
return JSON.parse(readFileSync(modelsJsonPath, "utf-8"));
|
||||
}
|
||||
|
||||
// ─── addModel ────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelsJsonWriter — addModel", () => {
|
||||
it("creates file and adds model to new provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.addModel("openai", { id: "gpt-4o", name: "GPT-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.ok(config.providers.openai);
|
||||
assert.equal(config.providers.openai.models.length, 1);
|
||||
assert.equal(config.providers.openai.models[0].id, "gpt-4o");
|
||||
});
|
||||
|
||||
it("appends model to existing provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||
writer.addModel("openai", { id: "gpt-4o-mini" });
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.equal(config.providers.openai.models.length, 2);
|
||||
});
|
||||
|
||||
it("replaces model with same id", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.addModel("openai", { id: "gpt-4o", name: "Old" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||
writer.addModel("openai", { id: "gpt-4o", name: "New" });
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.equal(config.providers.openai.models.length, 1);
|
||||
assert.equal(config.providers.openai.models[0].name, "New");
|
||||
});
|
||||
});
|
||||
|
||||
// ─── removeModel ─────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelsJsonWriter — removeModel", () => {
|
||||
it("removes a model from provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||
writer.addModel("openai", { id: "gpt-4o-mini" });
|
||||
|
||||
writer.removeModel("openai", "gpt-4o");
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.equal(config.providers.openai.models.length, 1);
|
||||
assert.equal(config.providers.openai.models[0].id, "gpt-4o-mini");
|
||||
});
|
||||
|
||||
it("removes provider when last model is removed", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||
|
||||
writer.removeModel("openai", "gpt-4o");
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.equal(config.providers.openai, undefined);
|
||||
});
|
||||
|
||||
it("handles removing from nonexistent provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
// Should not throw
|
||||
writer.removeModel("nonexistent", "model-id");
|
||||
});
|
||||
});
|
||||
|
||||
// ─── setProvider / removeProvider ────────────────────────────────────────────
|
||||
|
||||
describe("ModelsJsonWriter — provider operations", () => {
|
||||
it("sets a provider configuration", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.setProvider("custom", {
|
||||
baseUrl: "http://localhost:8080",
|
||||
apiKey: "test-key",
|
||||
api: "openai",
|
||||
models: [{ id: "local-model" }],
|
||||
});
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.ok(config.providers.custom);
|
||||
assert.equal(config.providers.custom.baseUrl, "http://localhost:8080");
|
||||
});
|
||||
|
||||
it("removes a provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.setProvider("custom", { baseUrl: "http://localhost:8080" });
|
||||
writer.removeProvider("custom");
|
||||
|
||||
const config = readModels() as any;
|
||||
assert.equal(config.providers.custom, undefined);
|
||||
});
|
||||
|
||||
it("handles removing nonexistent provider", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.removeProvider("nonexistent");
|
||||
// Should not throw
|
||||
});
|
||||
});
|
||||
|
||||
// ─── listProviders ───────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelsJsonWriter — listProviders", () => {
|
||||
it("returns empty config when file does not exist", () => {
|
||||
const writer = new ModelsJsonWriter(join(testDir, "nonexistent.json"));
|
||||
const config = writer.listProviders();
|
||||
assert.deepEqual(config, { providers: {} });
|
||||
});
|
||||
|
||||
it("returns current provider config", () => {
|
||||
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||
writer.setProvider("openai", { baseUrl: "https://api.openai.com" });
|
||||
writer.setProvider("ollama", { baseUrl: "http://localhost:11434" });
|
||||
|
||||
const config = writer.listProviders();
|
||||
assert.ok(config.providers.openai);
|
||||
assert.ok(config.providers.ollama);
|
||||
});
|
||||
});
|
||||
188
packages/pi-coding-agent/src/core/models-json-writer.ts
Normal file
188
packages/pi-coding-agent/src/core/models-json-writer.ts
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
/**
|
||||
* Safe read-modify-write for models.json with file locking.
|
||||
* Prevents concurrent writes from corrupting the config file.
|
||||
*/
|
||||
|
||||
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import lockfile from "proper-lockfile";
|
||||
import { getAgentDir } from "../config.js";
|
||||
|
||||
interface ModelDefinition {
|
||||
id: string;
|
||||
name?: string;
|
||||
api?: string;
|
||||
baseUrl?: string;
|
||||
reasoning?: boolean;
|
||||
input?: ("text" | "image")[];
|
||||
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
contextWindow?: number;
|
||||
maxTokens?: number;
|
||||
}
|
||||
|
||||
interface ProviderConfig {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: string;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
models?: ModelDefinition[];
|
||||
modelOverrides?: Record<string, Record<string, unknown>>;
|
||||
}
|
||||
|
||||
interface ModelsConfig {
|
||||
providers: Record<string, ProviderConfig>;
|
||||
}
|
||||
|
||||
export class ModelsJsonWriter {
|
||||
private modelsJsonPath: string;
|
||||
|
||||
constructor(modelsJsonPath?: string) {
|
||||
this.modelsJsonPath = modelsJsonPath ?? join(getAgentDir(), "models.json");
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a model to a provider. Creates the provider if it doesn't exist.
|
||||
*/
|
||||
addModel(provider: string, model: ModelDefinition, providerConfig?: Partial<ProviderConfig>): void {
|
||||
this.withLock((config) => {
|
||||
if (!config.providers[provider]) {
|
||||
config.providers[provider] = {
|
||||
...providerConfig,
|
||||
models: [],
|
||||
};
|
||||
}
|
||||
|
||||
const providerEntry = config.providers[provider];
|
||||
if (!providerEntry.models) {
|
||||
providerEntry.models = [];
|
||||
}
|
||||
|
||||
// Replace existing model with same id, or append
|
||||
const existingIndex = providerEntry.models.findIndex((m) => m.id === model.id);
|
||||
if (existingIndex >= 0) {
|
||||
providerEntry.models[existingIndex] = model;
|
||||
} else {
|
||||
providerEntry.models.push(model);
|
||||
}
|
||||
|
||||
return config;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a model from a provider. Removes the provider if no models remain.
|
||||
*/
|
||||
removeModel(provider: string, modelId: string): void {
|
||||
this.withLock((config) => {
|
||||
const providerEntry = config.providers[provider];
|
||||
if (!providerEntry?.models) return config;
|
||||
|
||||
providerEntry.models = providerEntry.models.filter((m) => m.id !== modelId);
|
||||
|
||||
// Clean up empty provider (no models and no overrides)
|
||||
if (providerEntry.models.length === 0 && !providerEntry.modelOverrides) {
|
||||
delete config.providers[provider];
|
||||
}
|
||||
|
||||
return config;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Set or update an entire provider configuration.
|
||||
*/
|
||||
setProvider(provider: string, providerConfig: ProviderConfig): void {
|
||||
this.withLock((config) => {
|
||||
config.providers[provider] = providerConfig;
|
||||
return config;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a provider and all its models.
|
||||
*/
|
||||
removeProvider(provider: string): void {
|
||||
this.withLock((config) => {
|
||||
delete config.providers[provider];
|
||||
return config;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers and their configurations.
|
||||
*/
|
||||
listProviders(): ModelsConfig {
|
||||
return this.readConfig();
|
||||
}
|
||||
|
||||
private readConfig(): ModelsConfig {
|
||||
if (!existsSync(this.modelsJsonPath)) {
|
||||
return { providers: {} };
|
||||
}
|
||||
try {
|
||||
const content = readFileSync(this.modelsJsonPath, "utf-8");
|
||||
return JSON.parse(content) as ModelsConfig;
|
||||
} catch {
|
||||
return { providers: {} };
|
||||
}
|
||||
}
|
||||
|
||||
private writeConfig(config: ModelsConfig): void {
|
||||
const dir = dirname(this.modelsJsonPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
writeFileSync(this.modelsJsonPath, JSON.stringify(config, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
private acquireLockWithRetry(): () => void {
|
||||
const maxAttempts = 10;
|
||||
const delayMs = 20;
|
||||
let lastError: unknown;
|
||||
|
||||
// Ensure file exists for locking
|
||||
const dir = dirname(this.modelsJsonPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true });
|
||||
}
|
||||
if (!existsSync(this.modelsJsonPath)) {
|
||||
writeFileSync(this.modelsJsonPath, JSON.stringify({ providers: {} }, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
try {
|
||||
return lockfile.lockSync(this.modelsJsonPath, { realpath: false });
|
||||
} catch (error) {
|
||||
const code =
|
||||
typeof error === "object" && error !== null && "code" in error
|
||||
? String((error as { code?: unknown }).code)
|
||||
: undefined;
|
||||
if (code !== "ELOCKED" || attempt === maxAttempts) {
|
||||
throw error;
|
||||
}
|
||||
lastError = error;
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < delayMs) {
|
||||
// Busy-wait (same pattern as auth-storage.ts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw (lastError as Error) ?? new Error("Failed to acquire models.json lock");
|
||||
}
|
||||
|
||||
private withLock(fn: (config: ModelsConfig) => ModelsConfig): void {
|
||||
let release: (() => void) | undefined;
|
||||
try {
|
||||
release = this.acquireLockWithRetry();
|
||||
const config = this.readConfig();
|
||||
const updated = fn(config);
|
||||
this.writeConfig(updated);
|
||||
} finally {
|
||||
if (release) {
|
||||
release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -79,6 +79,13 @@ export interface FallbackSettings {
|
|||
chains?: Record<string, FallbackChainEntry[]>; // keyed by chain name
|
||||
}
|
||||
|
||||
export interface ModelDiscoverySettings {
|
||||
enabled?: boolean; // default: false
|
||||
providers?: string[]; // limit discovery to specific providers
|
||||
ttlMinutes?: number; // override default TTLs (in minutes)
|
||||
autoRefreshOnModelSelect?: boolean; // default: false - refresh discovery when opening model selector
|
||||
}
|
||||
|
||||
export type TransportSetting = Transport;
|
||||
|
||||
/**
|
||||
|
|
@ -134,6 +141,7 @@ export interface Settings {
|
|||
bashInterceptor?: BashInterceptorSettings;
|
||||
taskIsolation?: TaskIsolationSettings;
|
||||
fallback?: FallbackSettings;
|
||||
modelDiscovery?: ModelDiscoverySettings;
|
||||
}
|
||||
|
||||
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
|
||||
|
|
@ -1076,4 +1084,17 @@ export class SettingsManager {
|
|||
chains: this.getFallbackChains(),
|
||||
};
|
||||
}
|
||||
|
||||
getModelDiscoverySettings(): ModelDiscoverySettings {
|
||||
return this.settings.modelDiscovery ?? {};
|
||||
}
|
||||
|
||||
setModelDiscoveryEnabled(enabled: boolean): void {
|
||||
if (!this.globalSettings.modelDiscovery) {
|
||||
this.globalSettings.modelDiscovery = {};
|
||||
}
|
||||
this.globalSettings.modelDiscovery.enabled = enabled;
|
||||
this.markModified("modelDiscovery", "enabled");
|
||||
this.save();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
|
|||
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
|
||||
{ name: "fork", description: "Create a new fork from a previous message" },
|
||||
{ name: "tree", description: "Navigate session tree (switch branches)" },
|
||||
{ name: "provider", description: "Manage provider configuration" },
|
||||
{ name: "login", description: "Login with OAuth provider" },
|
||||
{ name: "logout", description: "Logout from OAuth provider" },
|
||||
{ name: "new", description: "Start a new session" },
|
||||
|
|
|
|||
|
|
@ -143,7 +143,11 @@ export {
|
|||
// Footer data provider (git branch + extension statuses - data not otherwise available to extensions)
|
||||
export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js";
|
||||
export { convertToLlm } from "./core/messages.js";
|
||||
export { ModelDiscoveryCache } from "./core/discovery-cache.js";
|
||||
export type { DiscoveredModel, DiscoveryResult, ProviderDiscoveryAdapter } from "./core/model-discovery.js";
|
||||
export { getDiscoverableProviders, getDiscoveryAdapter } from "./core/model-discovery.js";
|
||||
export { ModelRegistry } from "./core/model-registry.js";
|
||||
export { ModelsJsonWriter } from "./core/models-json-writer.js";
|
||||
export type {
|
||||
PackageManager,
|
||||
PathMetadata,
|
||||
|
|
@ -307,6 +311,7 @@ export {
|
|||
LoginDialogComponent,
|
||||
ModelSelectorComponent,
|
||||
OAuthSelectorComponent,
|
||||
ProviderManagerComponent,
|
||||
type RenderDiffOptions,
|
||||
rawKeyHint,
|
||||
renderDiff,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import { createInterface } from "readline";
|
|||
import { type Args, parseArgs, printHelp } from "./cli/args.js";
|
||||
import { selectConfig } from "./cli/config-selector.js";
|
||||
import { processFileArguments } from "./cli/file-processor.js";
|
||||
import { listModels } from "./cli/list-models.js";
|
||||
import { discoverAndPrintModels, listModels } from "./cli/list-models.js";
|
||||
import { selectSession } from "./cli/session-picker.js";
|
||||
import { APP_NAME, getAgentDir, getModelsPath, VERSION } from "./config.js";
|
||||
import { AuthStorage } from "./core/auth-storage.js";
|
||||
|
|
@ -660,9 +660,26 @@ export async function main(args: string[]) {
|
|||
process.exit(0);
|
||||
}
|
||||
|
||||
if (parsed.addProvider) {
|
||||
const { ModelsJsonWriter } = await import("./core/models-json-writer.js");
|
||||
const writer = new ModelsJsonWriter();
|
||||
writer.setProvider(parsed.addProvider, {
|
||||
baseUrl: parsed.addProviderBaseUrl,
|
||||
apiKey: parsed.apiKey,
|
||||
});
|
||||
console.log(`Provider "${parsed.addProvider}" added to models.json`);
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
if (parsed.discoverModels !== undefined) {
|
||||
const provider = typeof parsed.discoverModels === "string" ? parsed.discoverModels : undefined;
|
||||
await discoverAndPrintModels(modelRegistry, provider);
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
if (parsed.listModels !== undefined) {
|
||||
const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined;
|
||||
await listModels(modelRegistry, searchPattern);
|
||||
await listModels(modelRegistry, { searchPattern, discover: parsed.discover });
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ export { appKey, appKeyHint, editorKey, keyHint, rawKeyHint } from "./keybinding
|
|||
export { LoginDialogComponent } from "./login-dialog.js";
|
||||
export { ModelSelectorComponent } from "./model-selector.js";
|
||||
export { OAuthSelectorComponent } from "./oauth-selector.js";
|
||||
export { ProviderManagerComponent } from "./provider-manager.js";
|
||||
export { type ModelsCallbacks, type ModelsConfig, ScopedModelsSelectorComponent } from "./scoped-models-selector.js";
|
||||
export { SessionSelectorComponent } from "./session-selector.js";
|
||||
export { type SettingsCallbacks, type SettingsConfig, SettingsSelectorComponent } from "./settings-selector.js";
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ export class ModelSelectorComponent extends Container implements Focusable {
|
|||
|
||||
// Load available models (built-in models still work even if models.json failed)
|
||||
try {
|
||||
const availableModels = await this.modelRegistry.getAvailable();
|
||||
const availableModels = this.modelRegistry.getAvailable();
|
||||
models = availableModels.map((model: Model<any>) => ({
|
||||
provider: model.provider,
|
||||
id: model.id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* TUI component for managing provider configurations.
|
||||
* Shows providers with auth status, discovery support, and model counts.
|
||||
*/
|
||||
|
||||
import {
|
||||
Container,
|
||||
type Focusable,
|
||||
getEditorKeybindings,
|
||||
Spacer,
|
||||
Text,
|
||||
type TUI,
|
||||
} from "@gsd/pi-tui";
|
||||
import type { AuthStorage } from "../../../core/auth-storage.js";
|
||||
import { getDiscoverableProviders } from "../../../core/model-discovery.js";
|
||||
import type { ModelRegistry } from "../../../core/model-registry.js";
|
||||
import { theme } from "../theme/theme.js";
|
||||
import { rawKeyHint } from "./keybinding-hints.js";
|
||||
|
||||
interface ProviderInfo {
|
||||
name: string;
|
||||
hasAuth: boolean;
|
||||
supportsDiscovery: boolean;
|
||||
modelCount: number;
|
||||
}
|
||||
|
||||
export class ProviderManagerComponent extends Container implements Focusable {
|
||||
private _focused = false;
|
||||
get focused(): boolean {
|
||||
return this._focused;
|
||||
}
|
||||
set focused(value: boolean) {
|
||||
this._focused = value;
|
||||
}
|
||||
|
||||
private providers: ProviderInfo[] = [];
|
||||
private selectedIndex = 0;
|
||||
private listContainer: Container;
|
||||
private tui: TUI;
|
||||
private authStorage: AuthStorage;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private onDone: () => void;
|
||||
private onDiscover: (provider: string) => void;
|
||||
|
||||
constructor(
|
||||
tui: TUI,
|
||||
authStorage: AuthStorage,
|
||||
modelRegistry: ModelRegistry,
|
||||
onDone: () => void,
|
||||
onDiscover: (provider: string) => void,
|
||||
) {
|
||||
super();
|
||||
|
||||
this.tui = tui;
|
||||
this.authStorage = authStorage;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.onDone = onDone;
|
||||
this.onDiscover = onDiscover;
|
||||
|
||||
// Header
|
||||
this.addChild(new Text(theme.fg("accent", "Provider Manager"), 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// Hints
|
||||
const hints = [
|
||||
rawKeyHint("d", "discover"),
|
||||
rawKeyHint("r", "remove auth"),
|
||||
rawKeyHint("esc", "close"),
|
||||
].join(" ");
|
||||
this.addChild(new Text(hints, 0, 0));
|
||||
this.addChild(new Spacer(1));
|
||||
|
||||
// List
|
||||
this.listContainer = new Container();
|
||||
this.addChild(this.listContainer);
|
||||
|
||||
this.loadProviders();
|
||||
this.updateList();
|
||||
}
|
||||
|
||||
private loadProviders(): void {
|
||||
const discoverableSet = new Set(getDiscoverableProviders());
|
||||
const allModels = this.modelRegistry.getAll();
|
||||
|
||||
// Group models by provider
|
||||
const providerModelCounts = new Map<string, number>();
|
||||
for (const model of allModels) {
|
||||
providerModelCounts.set(model.provider, (providerModelCounts.get(model.provider) ?? 0) + 1);
|
||||
}
|
||||
|
||||
// Build provider list from all known providers
|
||||
const providerNames = new Set([
|
||||
...providerModelCounts.keys(),
|
||||
...discoverableSet,
|
||||
]);
|
||||
|
||||
this.providers = Array.from(providerNames)
|
||||
.sort()
|
||||
.map((name) => ({
|
||||
name,
|
||||
hasAuth: this.authStorage.hasAuth(name),
|
||||
supportsDiscovery: discoverableSet.has(name),
|
||||
modelCount: providerModelCounts.get(name) ?? 0,
|
||||
}));
|
||||
}
|
||||
|
||||
private updateList(): void {
|
||||
this.listContainer.clear();
|
||||
|
||||
for (let i = 0; i < this.providers.length; i++) {
|
||||
const p = this.providers[i];
|
||||
const isSelected = i === this.selectedIndex;
|
||||
|
||||
const authBadge = p.hasAuth ? theme.fg("success", "[auth]") : theme.fg("muted", "[no auth]");
|
||||
const discoveryBadge = p.supportsDiscovery ? theme.fg("accent", "[discovery]") : "";
|
||||
const countBadge = theme.fg("muted", `(${p.modelCount} models)`);
|
||||
|
||||
const prefix = isSelected ? theme.fg("accent", "> ") : " ";
|
||||
const nameText = isSelected ? theme.fg("accent", p.name) : p.name;
|
||||
|
||||
const parts = [prefix, nameText, " ", authBadge];
|
||||
if (discoveryBadge) parts.push(" ", discoveryBadge);
|
||||
parts.push(" ", countBadge);
|
||||
|
||||
this.listContainer.addChild(new Text(parts.join(""), 0, 0));
|
||||
}
|
||||
|
||||
if (this.providers.length === 0) {
|
||||
this.listContainer.addChild(new Text(theme.fg("muted", " No providers configured"), 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(keyData: string): void {
|
||||
const kb = getEditorKeybindings();
|
||||
|
||||
if (kb.matches(keyData, "selectUp")) {
|
||||
if (this.providers.length === 0) return;
|
||||
this.selectedIndex = this.selectedIndex === 0 ? this.providers.length - 1 : this.selectedIndex - 1;
|
||||
this.updateList();
|
||||
this.tui.requestRender();
|
||||
} else if (kb.matches(keyData, "selectDown")) {
|
||||
if (this.providers.length === 0) return;
|
||||
this.selectedIndex = this.selectedIndex === this.providers.length - 1 ? 0 : this.selectedIndex + 1;
|
||||
this.updateList();
|
||||
this.tui.requestRender();
|
||||
} else if (kb.matches(keyData, "selectCancel")) {
|
||||
this.onDone();
|
||||
} else if (keyData === "d" || keyData === "D") {
|
||||
const provider = this.providers[this.selectedIndex];
|
||||
if (provider?.supportsDiscovery) {
|
||||
this.onDiscover(provider.name);
|
||||
}
|
||||
} else if (keyData === "r" || keyData === "R") {
|
||||
const provider = this.providers[this.selectedIndex];
|
||||
if (provider?.hasAuth) {
|
||||
this.authStorage.remove(provider.name);
|
||||
this.loadProviders();
|
||||
this.updateList();
|
||||
this.tui.requestRender();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -83,6 +83,7 @@ import { appKey, appKeyHint, editorKey, formatKeyForDisplay, keyHint, rawKeyHint
|
|||
import { LoginDialogComponent } from "./components/login-dialog.js";
|
||||
import { ModelSelectorComponent } from "./components/model-selector.js";
|
||||
import { OAuthSelectorComponent } from "./components/oauth-selector.js";
|
||||
import { ProviderManagerComponent } from "./components/provider-manager.js";
|
||||
import { ScopedModelsSelectorComponent } from "./components/scoped-models-selector.js";
|
||||
import { SessionSelectorComponent } from "./components/session-selector.js";
|
||||
import { SelectSubmenu, SettingsSelectorComponent, THINKING_DESCRIPTIONS } from "./components/settings-selector.js";
|
||||
|
|
@ -1997,6 +1998,11 @@ export class InteractiveMode {
|
|||
this.editor.setText("");
|
||||
return;
|
||||
}
|
||||
if (text === "/provider") {
|
||||
this.showProviderManager();
|
||||
this.editor.setText("");
|
||||
return;
|
||||
}
|
||||
if (text === "/login") {
|
||||
this.showOAuthSelector("login");
|
||||
this.editor.setText("");
|
||||
|
|
@ -3746,6 +3752,37 @@ export class InteractiveMode {
|
|||
this.showStatus("Resumed session");
|
||||
}
|
||||
|
||||
private showProviderManager(): void {
|
||||
this.showSelector((done) => {
|
||||
const component = new ProviderManagerComponent(
|
||||
this.ui,
|
||||
this.session.modelRegistry.authStorage,
|
||||
this.session.modelRegistry,
|
||||
() => {
|
||||
done();
|
||||
this.ui.requestRender();
|
||||
},
|
||||
async (provider: string) => {
|
||||
this.showStatus(`Discovering models for ${provider}...`);
|
||||
try {
|
||||
const results = await this.session.modelRegistry.discoverModels([provider]);
|
||||
const result = results[0];
|
||||
if (result?.error) {
|
||||
this.showError(`Discovery failed: ${result.error}`);
|
||||
} else {
|
||||
this.showStatus(`Discovered ${result?.models.length ?? 0} models from ${provider}`);
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error instanceof Error ? error.message : String(error));
|
||||
}
|
||||
done();
|
||||
this.ui.requestRender();
|
||||
},
|
||||
);
|
||||
return { component, focus: component };
|
||||
});
|
||||
}
|
||||
|
||||
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
|
||||
if (mode === "logout") {
|
||||
const providers = this.session.modelRegistry.authStorage.list();
|
||||
|
|
|
|||
|
|
@ -511,8 +511,10 @@ async function handlePrefsWizard(
|
|||
prefs.auto_supervisor = autoSup;
|
||||
}
|
||||
|
||||
// ─── Git main branch ────────────────────────────────────────────────────
|
||||
// ─── Git settings ───────────────────────────────────────────────────────
|
||||
const git: Record<string, unknown> = (prefs.git as Record<string, unknown>) ?? {};
|
||||
|
||||
// main_branch
|
||||
const currentBranch = git.main_branch ? String(git.main_branch) : "";
|
||||
const branchInput = await ctx.ui.input(
|
||||
`Git main branch${currentBranch ? ` (current: ${currentBranch})` : ""}:`,
|
||||
|
|
@ -526,6 +528,90 @@ async function handlePrefsWizard(
|
|||
delete git.main_branch;
|
||||
}
|
||||
}
|
||||
|
||||
// Boolean git toggles
|
||||
const gitBooleanFields = [
|
||||
{ key: "auto_push", label: "Auto-push commits after committing", defaultVal: false },
|
||||
{ key: "push_branches", label: "Push milestone branches to remote", defaultVal: false },
|
||||
{ key: "snapshots", label: "Create WIP snapshot commits during long tasks", defaultVal: false },
|
||||
] as const;
|
||||
|
||||
for (const field of gitBooleanFields) {
|
||||
const current = git[field.key];
|
||||
const currentStr = current !== undefined ? String(current) : "";
|
||||
const choice = await ctx.ui.select(
|
||||
`${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`,
|
||||
["true", "false", "(keep current)"],
|
||||
);
|
||||
if (choice && choice !== "(keep current)") {
|
||||
git[field.key] = choice === "true";
|
||||
}
|
||||
}
|
||||
|
||||
// remote
|
||||
const currentRemote = git.remote ? String(git.remote) : "";
|
||||
const remoteInput = await ctx.ui.input(
|
||||
`Git remote name${currentRemote ? ` (current: ${currentRemote})` : " (default: origin)"}:`,
|
||||
currentRemote || "origin",
|
||||
);
|
||||
if (remoteInput !== null && remoteInput !== undefined) {
|
||||
const val = remoteInput.trim();
|
||||
if (val && val !== "origin") {
|
||||
git.remote = val;
|
||||
} else if (!val && currentRemote) {
|
||||
delete git.remote;
|
||||
}
|
||||
}
|
||||
|
||||
// pre_merge_check
|
||||
const currentPreMerge = git.pre_merge_check !== undefined ? String(git.pre_merge_check) : "";
|
||||
const preMergeChoice = await ctx.ui.select(
|
||||
`Pre-merge check${currentPreMerge ? ` (current: ${currentPreMerge})` : " (default: false)"}:`,
|
||||
["true", "false", "auto", "(keep current)"],
|
||||
);
|
||||
if (preMergeChoice && preMergeChoice !== "(keep current)") {
|
||||
if (preMergeChoice === "auto") {
|
||||
git.pre_merge_check = "auto";
|
||||
} else {
|
||||
git.pre_merge_check = preMergeChoice === "true";
|
||||
}
|
||||
}
|
||||
|
||||
// commit_type
|
||||
const currentCommitType = git.commit_type ? String(git.commit_type) : "";
|
||||
const commitTypes = ["feat", "fix", "refactor", "docs", "test", "chore", "perf", "ci", "build", "style", "(inferred — default)", "(keep current)"];
|
||||
const commitChoice = await ctx.ui.select(
|
||||
`Default commit type${currentCommitType ? ` (current: ${currentCommitType})` : ""}:`,
|
||||
commitTypes,
|
||||
);
|
||||
if (commitChoice && typeof commitChoice === "string" && commitChoice !== "(keep current)") {
|
||||
if ((commitChoice as string).startsWith("(inferred")) {
|
||||
delete git.commit_type;
|
||||
} else {
|
||||
git.commit_type = commitChoice;
|
||||
}
|
||||
}
|
||||
|
||||
// merge_strategy
|
||||
const currentMerge = git.merge_strategy ? String(git.merge_strategy) : "";
|
||||
const mergeChoice = await ctx.ui.select(
|
||||
`Merge strategy${currentMerge ? ` (current: ${currentMerge})` : ""}:`,
|
||||
["squash", "merge", "(keep current)"],
|
||||
);
|
||||
if (mergeChoice && mergeChoice !== "(keep current)") {
|
||||
git.merge_strategy = mergeChoice;
|
||||
}
|
||||
|
||||
// isolation
|
||||
const currentIsolation = git.isolation ? String(git.isolation) : "";
|
||||
const isolationChoice = await ctx.ui.select(
|
||||
`Git isolation strategy${currentIsolation ? ` (current: ${currentIsolation})` : " (default: worktree)"}:`,
|
||||
["worktree", "branch", "(keep current)"],
|
||||
);
|
||||
if (isolationChoice && isolationChoice !== "(keep current)") {
|
||||
git.isolation = isolationChoice;
|
||||
}
|
||||
|
||||
// ─── Git commit_docs ────────────────────────────────────────────────────
|
||||
const currentCommitDocs = git.commit_docs;
|
||||
const commitDocsChoice = await ctx.ui.select(
|
||||
|
|
@ -560,6 +646,89 @@ async function handlePrefsWizard(
|
|||
prefs.unique_milestone_ids = uniqueChoice === "true";
|
||||
}
|
||||
|
||||
// ─── Budget & cost control ────────────────────────────────────────────
|
||||
const currentCeiling = prefs.budget_ceiling;
|
||||
const ceilingStr = currentCeiling !== undefined ? String(currentCeiling) : "";
|
||||
const ceilingInput = await ctx.ui.input(
|
||||
`Budget ceiling (USD)${ceilingStr ? ` (current: $${ceilingStr})` : " (default: no limit)"}:`,
|
||||
ceilingStr || "",
|
||||
);
|
||||
if (ceilingInput !== null && ceilingInput !== undefined) {
|
||||
const val = ceilingInput.trim().replace(/^\$/, "");
|
||||
if (val && !isNaN(Number(val)) && isFinite(Number(val))) {
|
||||
prefs.budget_ceiling = Number(val);
|
||||
} else if (val && (isNaN(Number(val)) || !isFinite(Number(val)))) {
|
||||
ctx.ui.notify(`Invalid budget ceiling "${val}" — must be a number. Keeping previous value.`, "warning");
|
||||
} else if (!val && ceilingStr) {
|
||||
delete prefs.budget_ceiling;
|
||||
}
|
||||
}
|
||||
|
||||
const currentEnforcement = (prefs.budget_enforcement as string) ?? "";
|
||||
const enforcementChoice = await ctx.ui.select(
|
||||
`Budget enforcement${currentEnforcement ? ` (current: ${currentEnforcement})` : " (default: pause)"}:`,
|
||||
["warn", "pause", "halt", "(keep current)"],
|
||||
);
|
||||
if (enforcementChoice && enforcementChoice !== "(keep current)") {
|
||||
prefs.budget_enforcement = enforcementChoice;
|
||||
}
|
||||
|
||||
const currentContextPause = prefs.context_pause_threshold;
|
||||
const contextPauseStr = currentContextPause !== undefined ? String(currentContextPause) : "";
|
||||
const contextPauseInput = await ctx.ui.input(
|
||||
`Context pause threshold (0-100%, 0=disabled)${contextPauseStr ? ` (current: ${contextPauseStr}%)` : " (default: 0)"}:`,
|
||||
contextPauseStr || "0",
|
||||
);
|
||||
if (contextPauseInput !== null && contextPauseInput !== undefined) {
|
||||
const val = contextPauseInput.trim().replace(/%$/, "");
|
||||
if (val && !isNaN(Number(val)) && Number(val) >= 0 && Number(val) <= 100) {
|
||||
const num = Number(val);
|
||||
if (num === 0) {
|
||||
delete prefs.context_pause_threshold;
|
||||
} else {
|
||||
prefs.context_pause_threshold = num;
|
||||
}
|
||||
} else if (val && (isNaN(Number(val)) || Number(val) < 0 || Number(val) > 100)) {
|
||||
ctx.ui.notify(`Invalid context pause threshold "${val}" — must be 0-100. Keeping previous value.`, "warning");
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Notifications ────────────────────────────────────────────────────
|
||||
const notif: Record<string, boolean> = (prefs.notifications as Record<string, boolean>) ?? {};
|
||||
const notifFields = [
|
||||
{ key: "enabled", label: "Notifications enabled (master toggle)", defaultVal: true },
|
||||
{ key: "on_complete", label: "Notify on unit completion", defaultVal: true },
|
||||
{ key: "on_error", label: "Notify on errors", defaultVal: true },
|
||||
{ key: "on_budget", label: "Notify on budget thresholds", defaultVal: true },
|
||||
{ key: "on_milestone", label: "Notify on milestone completion", defaultVal: true },
|
||||
{ key: "on_attention", label: "Notify when manual attention needed", defaultVal: true },
|
||||
] as const;
|
||||
|
||||
for (const field of notifFields) {
|
||||
const current = notif[field.key];
|
||||
const currentStr = current !== undefined ? String(current) : "";
|
||||
const choice = await ctx.ui.select(
|
||||
`${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`,
|
||||
["true", "false", "(keep current)"],
|
||||
);
|
||||
if (choice && choice !== "(keep current)") {
|
||||
notif[field.key] = choice === "true";
|
||||
}
|
||||
}
|
||||
if (Object.keys(notif).length > 0) {
|
||||
prefs.notifications = notif;
|
||||
}
|
||||
|
||||
// ─── UAT dispatch ─────────────────────────────────────────────────────
|
||||
const currentUat = prefs.uat_dispatch;
|
||||
const uatChoice = await ctx.ui.select(
|
||||
`UAT dispatch mode${currentUat !== undefined ? ` (current: ${currentUat})` : " (default: false)"}:`,
|
||||
["true", "false", "(keep current)"],
|
||||
);
|
||||
if (uatChoice && uatChoice !== "(keep current)") {
|
||||
prefs.uat_dispatch = uatChoice === "true";
|
||||
}
|
||||
|
||||
// ─── Serialize to frontmatter ───────────────────────────────────────────
|
||||
prefs.version = prefs.version || 1;
|
||||
const frontmatter = serializePreferencesToFrontmatter(prefs);
|
||||
|
|
@ -650,7 +819,10 @@ function serializePreferencesToFrontmatter(prefs: Record<string, unknown>): stri
|
|||
const orderedKeys = [
|
||||
"version", "always_use_skills", "prefer_skills", "avoid_skills",
|
||||
"skill_rules", "custom_instructions", "models", "skill_discovery",
|
||||
"auto_supervisor", "uat_dispatch", "unique_milestone_ids", "budget_ceiling", "remote_questions", "git",
|
||||
"auto_supervisor", "uat_dispatch", "unique_milestone_ids",
|
||||
"budget_ceiling", "budget_enforcement", "context_pause_threshold",
|
||||
"notifications", "remote_questions", "git",
|
||||
"post_unit_hooks", "pre_dispatch_hooks",
|
||||
];
|
||||
|
||||
const seen = new Set<string>();
|
||||
|
|
|
|||
|
|
@ -108,10 +108,51 @@ Setting `prefer_skills: []` does **not** disable skill discovery — it just mea
|
|||
- `pre_merge_check`: boolean or `"auto"` — run pre-merge checks before merging a worktree back to the integration branch. `true` always runs, `false` never runs, `"auto"` runs when CI is detected. Default: `false`.
|
||||
- `commit_type`: string — override the conventional commit type prefix. Must be one of: `feat`, `fix`, `refactor`, `docs`, `test`, `chore`, `perf`, `ci`, `build`, `style`. Default: inferred from diff content.
|
||||
- `main_branch`: string — the primary branch name for new git repos (e.g., `"main"`, `"master"`, `"trunk"`). Also used by `getMainBranch()` as the preferred branch when auto-detection is ambiguous. Default: `"main"`.
|
||||
- `merge_strategy`: `"squash"` or `"merge"` — controls how worktree branches are merged back. `"squash"` combines all commits into one; `"merge"` preserves individual commits. Default: `"squash"`.
|
||||
- `isolation`: `"worktree"` or `"branch"` — controls auto-mode git isolation strategy. `"worktree"` creates a milestone worktree for isolated work; `"branch"` works directly in the project root (useful for submodule-heavy repos). Default: `"worktree"`.
|
||||
- `commit_docs`: boolean — when `false`, prevents GSD from committing `.gsd/` planning artifacts to git. The `.gsd/` folder is added to `.gitignore` and kept local-only. Useful for teams where only some members use GSD, or when company policy requires a clean repository. Default: `true`.
|
||||
|
||||
- `unique_milestone_ids`: boolean — when `true`, generates milestone IDs in `M{seq}-{rand6}` format (e.g. `M001-eh88as`) instead of plain sequential `M001`. Prevents ID collisions in team workflows where multiple contributors create milestones concurrently. Both formats coexist — existing `M001`-style milestones remain valid. Default: `false`.
|
||||
|
||||
- `budget_ceiling`: number — maximum dollar amount to spend on auto-mode. When reached, behavior is controlled by `budget_enforcement`. Default: no limit.
|
||||
|
||||
- `budget_enforcement`: `"warn"`, `"pause"`, or `"halt"` — action taken when `budget_ceiling` is reached.
|
||||
- `warn` — log a warning but continue execution.
|
||||
- `pause` — pause auto-mode and wait for user confirmation.
|
||||
- `halt` — stop auto-mode immediately.
|
||||
- Default: `"pause"`.
|
||||
|
||||
- `context_pause_threshold`: number (0-100) — context window usage percentage at which auto-mode should pause to suggest checkpointing. Set to `0` to disable. Default: `0` (disabled).
|
||||
|
||||
- `notifications`: configures desktop notification behavior during auto-mode. Keys:
|
||||
- `enabled`: boolean — master toggle for all notifications. Default: `true`.
|
||||
- `on_complete`: boolean — notify when a unit completes. Default: `true`.
|
||||
- `on_error`: boolean — notify on errors. Default: `true`.
|
||||
- `on_budget`: boolean — notify when budget thresholds are reached. Default: `true`.
|
||||
- `on_milestone`: boolean — notify when a milestone finishes. Default: `true`.
|
||||
- `on_attention`: boolean — notify when manual attention is needed. Default: `true`.
|
||||
|
||||
- `uat_dispatch`: boolean — when `true`, enables UAT (User Acceptance Testing) dispatch mode. Default: `false`.
|
||||
|
||||
- `post_unit_hooks`: array — hooks that fire after a unit completes. Each entry has:
|
||||
- `name`: string — unique hook identifier.
|
||||
- `after`: string[] — unit types that trigger this hook (e.g., `["execute-task"]`).
|
||||
- `prompt`: string — prompt sent to the LLM. Supports `{milestoneId}`, `{sliceId}`, `{taskId}` substitutions.
|
||||
- `max_cycles`: number — max times this hook fires per trigger (default: 1, max: 10).
|
||||
- `model`: string — optional model override.
|
||||
- `artifact`: string — expected output file (skip if exists).
|
||||
- `retry_on`: string — file that triggers re-run of the trigger unit.
|
||||
- `enabled`: boolean — toggle without removing (default: `true`).
|
||||
|
||||
- `pre_dispatch_hooks`: array — hooks that fire before a unit is dispatched. Each entry has:
|
||||
- `name`: string — unique hook identifier.
|
||||
- `before`: string[] — unit types to intercept.
|
||||
- `action`: `"modify"`, `"skip"`, or `"replace"` — what to do with the unit.
|
||||
- `prepend`: string — text prepended to unit prompt (for `"modify"` action).
|
||||
- `append`: string — text appended to unit prompt (for `"modify"` action).
|
||||
- `prompt`: string — replacement prompt (for `"replace"` action).
|
||||
- `enabled`: boolean — toggle without removing (default: `true`).
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
|
@ -277,3 +318,56 @@ git:
|
|||
```
|
||||
|
||||
All git fields are optional. Omit any field to use the default behavior. Project-level preferences override global preferences on a per-field basis.
|
||||
|
||||
---
|
||||
|
||||
## Budget & Cost Control Example
|
||||
|
||||
```yaml
|
||||
---
|
||||
version: 1
|
||||
budget_ceiling: 10.00
|
||||
budget_enforcement: pause
|
||||
context_pause_threshold: 80
|
||||
---
|
||||
```
|
||||
|
||||
Sets a $10 budget ceiling. Auto-mode pauses when the ceiling is reached. Context window pauses at 80% usage for checkpointing.
|
||||
|
||||
---
|
||||
|
||||
## Notifications Example
|
||||
|
||||
```yaml
|
||||
---
|
||||
version: 1
|
||||
notifications:
|
||||
enabled: true
|
||||
on_complete: false
|
||||
on_error: true
|
||||
on_budget: true
|
||||
on_milestone: true
|
||||
on_attention: true
|
||||
---
|
||||
```
|
||||
|
||||
Disables per-unit completion notifications (noisy in long runs) while keeping error, budget, milestone, and attention notifications enabled.
|
||||
|
||||
---
|
||||
|
||||
## Post-Unit Hooks Example
|
||||
|
||||
```yaml
|
||||
---
|
||||
version: 1
|
||||
post_unit_hooks:
|
||||
- name: code-review
|
||||
after:
|
||||
- execute-task
|
||||
prompt: "Review the code changes in {sliceId}/{taskId} for quality, security, and test coverage."
|
||||
max_cycles: 1
|
||||
artifact: REVIEW.md
|
||||
---
|
||||
```
|
||||
|
||||
Runs an automated code review after each task execution. Skips if `REVIEW.md` already exists (idempotent).
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { existsSync, readdirSync, readFileSync, statSync } from "node:fs";
|
||||
import { existsSync, readdirSync, readFileSync, statSync, writeFileSync } from "node:fs";
|
||||
import { homedir } from "node:os";
|
||||
import { isAbsolute, join } from "node:path";
|
||||
import { getAgentDir } from "@gsd/pi-coding-agent";
|
||||
|
|
@ -1252,3 +1252,61 @@ export function resolvePreDispatchHooks(): PreDispatchHookConfig[] {
|
|||
return (prefs?.preferences.pre_dispatch_hooks ?? [])
|
||||
.filter(h => h.enabled !== false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate a model ID string.
|
||||
* Returns true if the ID looks like a valid model identifier.
|
||||
*/
|
||||
export function validateModelId(modelId: string): boolean {
|
||||
if (!modelId || typeof modelId !== "string") return false;
|
||||
const trimmed = modelId.trim();
|
||||
if (trimmed.length === 0 || trimmed.length > 256) return false;
|
||||
// Allow alphanumeric, hyphens, underscores, dots, slashes, colons
|
||||
return /^[a-zA-Z0-9\-_./:]+$/.test(trimmed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the models section of the global GSD preferences file.
|
||||
* Performs a safe read-modify-write: reads current content, updates the models
|
||||
* YAML block, and writes back. Creates the file if it doesn't exist.
|
||||
*/
|
||||
export function updatePreferencesModels(models: GSDModelConfigV2): void {
|
||||
const prefsPath = getGlobalGSDPreferencesPath();
|
||||
|
||||
let content = "";
|
||||
if (existsSync(prefsPath)) {
|
||||
content = readFileSync(prefsPath, "utf-8");
|
||||
}
|
||||
|
||||
// Build the new models block
|
||||
const lines: string[] = ["models:"];
|
||||
for (const [phase, value] of Object.entries(models)) {
|
||||
if (typeof value === "string") {
|
||||
lines.push(` ${phase}: ${value}`);
|
||||
} else if (value && typeof value === "object") {
|
||||
const config = value as GSDPhaseModelConfig;
|
||||
lines.push(` ${phase}:`);
|
||||
lines.push(` model: ${config.model}`);
|
||||
if (config.provider) {
|
||||
lines.push(` provider: ${config.provider}`);
|
||||
}
|
||||
if (config.fallbacks && config.fallbacks.length > 0) {
|
||||
lines.push(` fallbacks:`);
|
||||
for (const fb of config.fallbacks) {
|
||||
lines.push(` - ${fb}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const modelsBlock = lines.join("\n");
|
||||
|
||||
// Replace existing models block or append
|
||||
const modelsRegex = /^models:[\s\S]*?(?=\n[a-z_]|\n*$)/m;
|
||||
if (modelsRegex.test(content)) {
|
||||
content = content.replace(modelsRegex, modelsBlock);
|
||||
} else {
|
||||
content = content.trimEnd() + "\n\n" + modelsBlock + "\n";
|
||||
}
|
||||
|
||||
writeFileSync(prefsPath, content, "utf-8");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,21 @@ git:
|
|||
snapshots:
|
||||
pre_merge_check:
|
||||
commit_type:
|
||||
main_branch:
|
||||
merge_strategy:
|
||||
isolation:
|
||||
unique_milestone_ids:
|
||||
budget_ceiling:
|
||||
budget_enforcement:
|
||||
context_pause_threshold:
|
||||
notifications:
|
||||
enabled:
|
||||
on_complete:
|
||||
on_error:
|
||||
on_budget:
|
||||
on_milestone:
|
||||
on_attention:
|
||||
uat_dispatch:
|
||||
---
|
||||
|
||||
# GSD Skill Preferences
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* preferences-wizard-fields.test.ts — Validates that all wizard-configurable
|
||||
* preference fields are properly validated and round-trip through the schema.
|
||||
*/
|
||||
|
||||
import { createTestContext } from "./test-helpers.ts";
|
||||
import { validatePreferences } from "../preferences.ts";
|
||||
import type { GSDPreferences } from "../preferences.ts";
|
||||
|
||||
const { assertEq, assertTrue, report } = createTestContext();
|
||||
|
||||
async function main(): Promise<void> {
|
||||
console.log("\n=== budget fields validate correctly ===");
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({
|
||||
budget_ceiling: 25.50,
|
||||
budget_enforcement: "warn",
|
||||
context_pause_threshold: 80,
|
||||
});
|
||||
assertEq(errors.length, 0, "valid budget fields produce no errors");
|
||||
assertEq(preferences.budget_ceiling, 25.50, "budget_ceiling passes through");
|
||||
assertEq(preferences.budget_enforcement, "warn", "budget_enforcement passes through");
|
||||
assertEq(preferences.context_pause_threshold, 80, "context_pause_threshold passes through");
|
||||
}
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({
|
||||
budget_enforcement: "pause",
|
||||
});
|
||||
assertEq(errors.length, 0, "budget_enforcement 'pause' is valid");
|
||||
assertEq(preferences.budget_enforcement, "pause", "pause passes through");
|
||||
}
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({
|
||||
budget_enforcement: "halt",
|
||||
});
|
||||
assertEq(errors.length, 0, "budget_enforcement 'halt' is valid");
|
||||
assertEq(preferences.budget_enforcement, "halt", "halt passes through");
|
||||
}
|
||||
|
||||
{
|
||||
const { errors } = validatePreferences({
|
||||
budget_enforcement: "invalid",
|
||||
} as unknown as GSDPreferences);
|
||||
assertTrue(errors.some(e => e.includes("budget_enforcement")), "invalid budget_enforcement rejected");
|
||||
}
|
||||
|
||||
console.log("\n=== notification fields validate correctly ===");
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({
|
||||
notifications: {
|
||||
enabled: true,
|
||||
on_complete: false,
|
||||
on_error: true,
|
||||
on_budget: true,
|
||||
on_milestone: false,
|
||||
on_attention: true,
|
||||
},
|
||||
});
|
||||
assertEq(errors.length, 0, "valid notifications produce no errors");
|
||||
assertEq(preferences.notifications?.enabled, true, "notifications.enabled passes through");
|
||||
assertEq(preferences.notifications?.on_complete, false, "notifications.on_complete passes through");
|
||||
assertEq(preferences.notifications?.on_milestone, false, "notifications.on_milestone passes through");
|
||||
}
|
||||
|
||||
{
|
||||
const { errors } = validatePreferences({
|
||||
notifications: "invalid",
|
||||
} as unknown as GSDPreferences);
|
||||
assertTrue(errors.some(e => e.includes("notifications")), "invalid notifications rejected");
|
||||
}
|
||||
|
||||
console.log("\n=== git fields validate correctly ===");
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({
|
||||
git: {
|
||||
auto_push: true,
|
||||
push_branches: false,
|
||||
remote: "upstream",
|
||||
snapshots: true,
|
||||
pre_merge_check: "auto",
|
||||
commit_type: "feat",
|
||||
main_branch: "develop",
|
||||
merge_strategy: "squash",
|
||||
isolation: "branch",
|
||||
},
|
||||
});
|
||||
assertEq(errors.length, 0, "valid git fields produce no errors");
|
||||
assertEq(preferences.git?.auto_push, true, "git.auto_push passes through");
|
||||
assertEq(preferences.git?.push_branches, false, "git.push_branches passes through");
|
||||
assertEq(preferences.git?.remote, "upstream", "git.remote passes through");
|
||||
assertEq(preferences.git?.snapshots, true, "git.snapshots passes through");
|
||||
assertEq(preferences.git?.pre_merge_check, "auto", "git.pre_merge_check passes through");
|
||||
assertEq(preferences.git?.commit_type, "feat", "git.commit_type passes through");
|
||||
assertEq(preferences.git?.main_branch, "develop", "git.main_branch passes through");
|
||||
assertEq(preferences.git?.merge_strategy, "squash", "git.merge_strategy passes through");
|
||||
assertEq(preferences.git?.isolation, "branch", "git.isolation passes through");
|
||||
}
|
||||
|
||||
console.log("\n=== uat_dispatch validates correctly ===");
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({ uat_dispatch: true });
|
||||
assertEq(errors.length, 0, "valid uat_dispatch produces no errors");
|
||||
assertEq(preferences.uat_dispatch, true, "uat_dispatch true passes through");
|
||||
}
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({ uat_dispatch: false });
|
||||
assertEq(errors.length, 0, "valid uat_dispatch false produces no errors");
|
||||
assertEq(preferences.uat_dispatch, false, "uat_dispatch false passes through");
|
||||
}
|
||||
|
||||
console.log("\n=== unique_milestone_ids validates correctly ===");
|
||||
|
||||
{
|
||||
const { preferences, errors } = validatePreferences({ unique_milestone_ids: true });
|
||||
assertEq(errors.length, 0, "valid unique_milestone_ids produces no errors");
|
||||
assertEq(preferences.unique_milestone_ids, true, "unique_milestone_ids passes through");
|
||||
}
|
||||
|
||||
console.log("\n=== all wizard fields together produce no errors ===");
|
||||
|
||||
{
|
||||
const fullPrefs: GSDPreferences = {
|
||||
version: 1,
|
||||
models: { research: "claude-opus-4-6", planning: "claude-sonnet-4-6" },
|
||||
auto_supervisor: { soft_timeout_minutes: 15, idle_timeout_minutes: 5, hard_timeout_minutes: 25 },
|
||||
git: {
|
||||
main_branch: "main",
|
||||
auto_push: true,
|
||||
push_branches: false,
|
||||
remote: "origin",
|
||||
snapshots: true,
|
||||
pre_merge_check: "auto",
|
||||
commit_type: "feat",
|
||||
merge_strategy: "squash",
|
||||
isolation: "worktree",
|
||||
},
|
||||
skill_discovery: "suggest",
|
||||
unique_milestone_ids: false,
|
||||
budget_ceiling: 50,
|
||||
budget_enforcement: "pause",
|
||||
context_pause_threshold: 75,
|
||||
notifications: {
|
||||
enabled: true,
|
||||
on_complete: true,
|
||||
on_error: true,
|
||||
on_budget: true,
|
||||
on_milestone: true,
|
||||
on_attention: true,
|
||||
},
|
||||
uat_dispatch: false,
|
||||
};
|
||||
const { errors, warnings } = validatePreferences(fullPrefs);
|
||||
const unknownWarnings = warnings.filter(w => w.includes("unknown"));
|
||||
assertEq(errors.length, 0, "full wizard prefs produce no errors");
|
||||
assertEq(unknownWarnings.length, 0, "full wizard prefs produce no unknown-key warnings");
|
||||
}
|
||||
|
||||
report();
|
||||
}
|
||||
|
||||
main();
|
||||
Loading…
Add table
Reference in a new issue