Просмотр исходного кода

perf(db): batch node lookups, fix insertNode cache, run maintenance after writes (#108)

Batch getNodesByIds to collapse N+1 reads in graph traversal, invalidate the
insertNode LRU cache so INSERT OR REPLACE doesn't serve a stale row, and run
incremental PRAGMA optimize + passive WAL checkpoint after bulk writes.

Closes #108
andreinknv 1 месяц назад
Родитель
Сommit
b13f2f1ba1
5 измененных файлов с 330 добавлено и 47 удалено
  1. 161 0
      __tests__/db-perf.test.ts
  2. 30 0
      src/db/index.ts
  3. 59 0
      src/db/queries.ts
  4. 69 47
      src/graph/traversal.ts
  5. 11 0
      src/index.ts

+ 161 - 0
__tests__/db-perf.test.ts

@@ -0,0 +1,161 @@
+/**
+ * DB Performance / Correctness Tests
+ *
+ * Regression tests for three changes:
+ *   1. Batch `getNodesByIds` collapses graph-traversal N+1 reads.
+ *   2. `insertNode` invalidates the LRU cache so INSERT OR REPLACE
+ *      doesn't serve a stale cached row on next `getNodeById`.
+ *   3. `runMaintenance` runs `PRAGMA optimize` + `wal_checkpoint(PASSIVE)`
+ *      after indexAll/sync without throwing.
+ */
+
+import { describe, it, expect, beforeEach, afterEach } from 'vitest';
+import * as fs from 'fs';
+import * as path from 'path';
+import * as os from 'os';
+import { DatabaseConnection } from '../src/db';
+import { QueryBuilder } from '../src/db/queries';
+import { Node } from '../src/types';
+
+function makeNode(id: string, name = id): Node {
+  return {
+    id,
+    kind: 'function',
+    name,
+    qualifiedName: name,
+    filePath: 'a.ts',
+    language: 'typescript',
+    startLine: 1,
+    endLine: 1,
+    startColumn: 0,
+    endColumn: 0,
+    updatedAt: Date.now(),
+  };
+}
+
+describe('getNodesByIds (batch lookup)', () => {
+  let dir: string;
+  let db: DatabaseConnection;
+  let q: QueryBuilder;
+
+  beforeEach(() => {
+    dir = fs.mkdtempSync(path.join(os.tmpdir(), 'db-perf-batch-'));
+    db = DatabaseConnection.initialize(path.join(dir, 'test.db'));
+    q = new QueryBuilder(db.getDb());
+  });
+
+  afterEach(() => {
+    db.close();
+    if (fs.existsSync(dir)) fs.rmSync(dir, { recursive: true, force: true });
+  });
+
+  it('returns a Map keyed by id, with one entry per existing node', () => {
+    q.insertNodes([makeNode('n1'), makeNode('n2'), makeNode('n3')]);
+    const out = q.getNodesByIds(['n1', 'n2', 'n3']);
+    expect(out.size).toBe(3);
+    expect(out.get('n1')!.name).toBe('n1');
+    expect(out.get('n3')!.name).toBe('n3');
+  });
+
+  it('omits missing IDs from the result map (no nulls, no exceptions)', () => {
+    q.insertNodes([makeNode('n1'), makeNode('n2')]);
+    const out = q.getNodesByIds(['n1', 'missing', 'n2']);
+    expect(out.size).toBe(2);
+    expect(out.has('missing')).toBe(false);
+    expect(out.has('n1')).toBe(true);
+    expect(out.has('n2')).toBe(true);
+  });
+
+  it('handles an empty input array', () => {
+    expect(q.getNodesByIds([]).size).toBe(0);
+  });
+
+  it('handles batches over the SQLite parameter limit (chunking)', () => {
+    // Insert 1500 nodes; the helper chunks at 500 internally.
+    const nodes = Array.from({ length: 1500 }, (_, i) => makeNode(`n${i}`));
+    q.insertNodes(nodes);
+    const ids = nodes.map((n) => n.id);
+    const out = q.getNodesByIds(ids);
+    expect(out.size).toBe(1500);
+    // Spot-check a few from the first / middle / last chunk.
+    expect(out.has('n0')).toBe(true);
+    expect(out.has('n750')).toBe(true);
+    expect(out.has('n1499')).toBe(true);
+  });
+
+  it('serves cache hits from memory and queries only the misses', () => {
+    q.insertNodes([makeNode('n1'), makeNode('n2'), makeNode('n3')]);
+    // Warm the cache for n1 only.
+    q.getNodeById('n1');
+    // Replace the underlying row to make a miss-vs-cache-hit detectable.
+    db.getDb().prepare('UPDATE nodes SET name = ? WHERE id = ?').run('changed', 'n1');
+    const out = q.getNodesByIds(['n1', 'n2']);
+    // The cached n1 (still 'n1', not 'changed') must be returned.
+    expect(out.get('n1')!.name).toBe('n1');
+    expect(out.get('n2')!.name).toBe('n2');
+  });
+});
+
+describe('insertNode cache invalidation', () => {
+  let dir: string;
+  let db: DatabaseConnection;
+  let q: QueryBuilder;
+
+  beforeEach(() => {
+    dir = fs.mkdtempSync(path.join(os.tmpdir(), 'db-perf-cache-'));
+    db = DatabaseConnection.initialize(path.join(dir, 'test.db'));
+    q = new QueryBuilder(db.getDb());
+  });
+
+  afterEach(() => {
+    db.close();
+    if (fs.existsSync(dir)) fs.rmSync(dir, { recursive: true, force: true });
+  });
+
+  it('does not serve a stale cached node after INSERT OR REPLACE', () => {
+    // Regression: insertNode (which uses INSERT OR REPLACE) used to skip
+    // cache invalidation, so the next getNodeById returned the pre-replace
+    // version until LRU eviction.
+    const original = makeNode('n1', 'oldName');
+    q.insertNode(original);
+    const beforeReplace = q.getNodeById('n1');
+    expect(beforeReplace!.name).toBe('oldName');
+
+    // Replace via insertNode (the bug path).
+    q.insertNode({ ...original, name: 'newName', updatedAt: Date.now() });
+    const afterReplace = q.getNodeById('n1');
+    expect(afterReplace!.name).toBe('newName');
+  });
+});
+
+describe('runMaintenance', () => {
+  let dir: string;
+  let db: DatabaseConnection;
+
+  beforeEach(() => {
+    dir = fs.mkdtempSync(path.join(os.tmpdir(), 'db-perf-maint-'));
+    db = DatabaseConnection.initialize(path.join(dir, 'test.db'));
+  });
+
+  afterEach(() => {
+    db.close();
+    if (fs.existsSync(dir)) fs.rmSync(dir, { recursive: true, force: true });
+  });
+
+  it('runs without throwing on a fresh database', () => {
+    expect(() => db.runMaintenance()).not.toThrow();
+  });
+
+  it('runs without throwing after writes', () => {
+    const q = new QueryBuilder(db.getDb());
+    q.insertNodes([makeNode('n1'), makeNode('n2')]);
+    expect(() => db.runMaintenance()).not.toThrow();
+  });
+
+  it('swallows failures rather than propagating (best-effort)', () => {
+    // Close the DB so the underlying handle would normally throw on any
+    // exec(). runMaintenance must still not propagate.
+    db.close();
+    expect(() => db.runMaintenance()).not.toThrow();
+  });
+});

+ 30 - 0
src/db/index.ts

@@ -186,6 +186,36 @@ export class DatabaseConnection {
     this.db.exec('ANALYZE');
   }
 
+  /**
+   * Lightweight, non-blocking maintenance to run after bulk writes
+   * (indexAll, sync). Two operations:
+   *
+   *   - `PRAGMA optimize` — incremental ANALYZE; SQLite only re-analyzes
+   *     tables whose row counts changed materially since the last
+   *     ANALYZE. Without it, the query planner has no statistics on the
+   *     freshly-bulk-loaded tables and can pick suboptimal indexes.
+   *
+   *   - `PRAGMA wal_checkpoint(PASSIVE)` — fold pending WAL pages back
+   *     into the main database file so the WAL file doesn't grow
+   *     unboundedly between automatic checkpoints (auto-fires at 1000
+   *     pages by default; large indexAll runs blow past that).
+   *
+   * Both operations are silently swallowed on failure — they're a
+   * best-effort optimization, never load-bearing for correctness.
+   */
+  runMaintenance(): void {
+    try {
+      this.db.exec('PRAGMA optimize');
+    } catch {
+      // ignore
+    }
+    try {
+      this.db.exec('PRAGMA wal_checkpoint(PASSIVE)');
+    } catch {
+      // ignore (e.g., not in WAL mode)
+    }
+  }
+
   /**
    * Close the database connection
    */

+ 59 - 0
src/db/queries.ts

@@ -224,6 +224,12 @@ export class QueryBuilder {
       return;
     }
 
+    // INSERT OR REPLACE may overwrite a node we have cached. Drop the
+    // stale entry so the next getNodeById sees the new row, not the old
+    // one (matches the cache-invalidation pattern used by updateNode and
+    // deleteNode below).
+    this.nodeCache.delete(node.id);
+
     try {
       this.stmts.insertNode.run({
         id: node.id,
@@ -380,6 +386,59 @@ export class QueryBuilder {
     return node;
   }
 
+  /**
+   * Batch lookup: fetch many nodes by ID in a single SQL round-trip.
+   *
+   * Replaces the N+1 pattern in graph traversal where every edge would
+   * trigger its own `getNodeById` call. For a function with 50 callers
+   * this collapses 50 point reads into one IN-list query (~10-50x
+   * faster end-to-end).
+   *
+   * Returns a Map keyed by id so callers can preserve their own ordering
+   * (typically the order edges were returned from the graph). Missing IDs
+   * are simply absent from the map.
+   *
+   * Cache-aware: ids already in the LRU cache are served from memory and
+   * the SQL query only touches the misses.
+   */
+  getNodesByIds(ids: readonly string[]): Map<string, Node> {
+    const out = new Map<string, Node>();
+    if (ids.length === 0) return out;
+
+    // Serve cache hits first; build the miss list for SQL.
+    const misses: string[] = [];
+    for (const id of ids) {
+      const cached = this.nodeCache.get(id);
+      if (cached !== undefined) {
+        // LRU touch
+        this.nodeCache.delete(id);
+        this.nodeCache.set(id, cached);
+        out.set(id, cached);
+      } else {
+        misses.push(id);
+      }
+    }
+    if (misses.length === 0) return out;
+
+    // Chunk under SQLite's parameter limit (default 999, raised to 32766
+    // in better-sqlite3 builds — chunk at 500 for safety across both
+    // backends and to keep the query plan simple).
+    const CHUNK = 500;
+    for (let i = 0; i < misses.length; i += CHUNK) {
+      const chunk = misses.slice(i, i + CHUNK);
+      const placeholders = chunk.map(() => '?').join(',');
+      const rows = this.db
+        .prepare(`SELECT * FROM nodes WHERE id IN (${placeholders})`)
+        .all(...chunk) as NodeRow[];
+      for (const row of rows) {
+        const node = rowToNode(row);
+        out.set(node.id, node);
+        this.cacheNode(node);
+      }
+    }
+    return out;
+  }
+
   /**
    * Add a node to the cache, evicting oldest if needed
    */

+ 69 - 47
src/graph/traversal.ts

@@ -90,29 +90,24 @@ export class GraphTraverser {
         return priority(a) - priority(b);
       });
 
+      // Batch-fetch the unvisited neighbors in one query (was N+1 per BFS step).
+      const wantIds = adjacentEdges
+        .map((e) => (e.source === node.id ? e.target : e.source))
+        .filter((id) => !visited.has(id));
+      const neighborNodes = wantIds.length > 0 ? this.queries.getNodesByIds(wantIds) : new Map();
+
       for (const adjEdge of adjacentEdges) {
-        // Determine next node: for 'both' direction, edges can be either
-        // incoming or outgoing, so pick whichever end is not the current node
         const nextNodeId = adjEdge.source === node.id ? adjEdge.target : adjEdge.source;
+        if (visited.has(nextNodeId)) continue;
 
-        if (visited.has(nextNodeId)) {
-          continue;
-        }
-
-        const nextNode = this.queries.getNodeById(nextNodeId);
-        if (!nextNode) {
-          continue;
-        }
+        const nextNode = neighborNodes.get(nextNodeId);
+        if (!nextNode) continue;
 
-        // Apply node kind filter
         if (opts.nodeKinds && opts.nodeKinds.length > 0 && !opts.nodeKinds.includes(nextNode.kind)) {
           continue;
         }
 
-        // Add node to result
         nodes.set(nextNode.id, nextNode);
-
-        // Queue for further traversal
         queue.push({ node: nextNode, edge: adjEdge, depth: depth + 1 });
       }
     }
@@ -176,19 +171,18 @@ export class GraphTraverser {
     // Get adjacent edges
     const adjacentEdges = this.getAdjacentEdges(node.id, opts.direction, opts.edgeKinds);
 
+    // Batch-fetch unvisited neighbors (was N+1 per DFS step).
+    const wantIds = adjacentEdges
+      .map((e) => (e.source === node.id ? e.target : e.source))
+      .filter((id) => !visited.has(id));
+    const neighborNodes = wantIds.length > 0 ? this.queries.getNodesByIds(wantIds) : new Map();
+
     for (const edge of adjacentEdges) {
-      // Determine next node: for 'both' direction, edges can be either
-      // incoming or outgoing, so pick whichever end is not the current node
       const nextNodeId = edge.source === node.id ? edge.target : edge.source;
+      if (visited.has(nextNodeId)) continue;
 
-      if (visited.has(nextNodeId)) {
-        continue;
-      }
-
-      const nextNode = this.queries.getNodeById(nextNodeId);
-      if (!nextNode) {
-        continue;
-      }
+      const nextNode = neighborNodes.get(nextNodeId);
+      if (!nextNode) continue;
 
       // Apply node kind filter
       if (opts.nodeKinds && opts.nodeKinds.length > 0 && !opts.nodeKinds.includes(nextNode.kind)) {
@@ -255,9 +249,15 @@ export class GraphTraverser {
     visited.add(nodeId);
 
     const incomingEdges = this.queries.getIncomingEdges(nodeId, ['calls', 'references', 'imports']);
+    if (incomingEdges.length === 0) return;
+
+    // Batch-fetch all caller nodes in one round-trip instead of one
+    // getNodeById per edge (was N+1 — meaningful on functions with many callers).
+    const sourceIds = incomingEdges.map((e) => e.source);
+    const callerNodes = this.queries.getNodesByIds(sourceIds);
 
     for (const edge of incomingEdges) {
-      const callerNode = this.queries.getNodeById(edge.source);
+      const callerNode = callerNodes.get(edge.source);
       if (callerNode && !visited.has(callerNode.id)) {
         result.push({ node: callerNode, edge });
         this.getCallersRecursive(callerNode.id, maxDepth, currentDepth + 1, result, visited);
@@ -294,9 +294,14 @@ export class GraphTraverser {
     visited.add(nodeId);
 
     const outgoingEdges = this.queries.getOutgoingEdges(nodeId, ['calls', 'references', 'imports']);
+    if (outgoingEdges.length === 0) return;
+
+    // Batch-fetch callee nodes (was N+1 — see getCallersRecursive note).
+    const targetIds = outgoingEdges.map((e) => e.target);
+    const calleeNodes = this.queries.getNodesByIds(targetIds);
 
     for (const edge of outgoingEdges) {
-      const calleeNode = this.queries.getNodeById(edge.target);
+      const calleeNode = calleeNodes.get(edge.target);
       if (calleeNode && !visited.has(calleeNode.id)) {
         result.push({ node: calleeNode, edge });
         this.getCalleesRecursive(calleeNode.id, maxDepth, currentDepth + 1, result, visited);
@@ -388,9 +393,11 @@ export class GraphTraverser {
     visited.add(nodeId);
 
     const outgoingEdges = this.queries.getOutgoingEdges(nodeId, ['extends', 'implements']);
+    if (outgoingEdges.length === 0) return;
+    const parents = this.queries.getNodesByIds(outgoingEdges.map((e) => e.target));
 
     for (const edge of outgoingEdges) {
-      const parentNode = this.queries.getNodeById(edge.target);
+      const parentNode = parents.get(edge.target);
       if (parentNode && !nodes.has(parentNode.id)) {
         nodes.set(parentNode.id, parentNode);
         edges.push(edge);
@@ -411,9 +418,11 @@ export class GraphTraverser {
     visited.add(nodeId);
 
     const incomingEdges = this.queries.getIncomingEdges(nodeId, ['extends', 'implements']);
+    if (incomingEdges.length === 0) return;
+    const children = this.queries.getNodesByIds(incomingEdges.map((e) => e.source));
 
     for (const edge of incomingEdges) {
-      const childNode = this.queries.getNodeById(edge.source);
+      const childNode = children.get(edge.source);
       if (childNode && !nodes.has(childNode.id)) {
         nodes.set(childNode.id, childNode);
         edges.push(edge);
@@ -433,12 +442,13 @@ export class GraphTraverser {
 
     // Get all incoming edges (references, calls, type_of, etc.)
     const incomingEdges = this.queries.getIncomingEdges(nodeId);
+    if (incomingEdges.length === 0) return result;
 
+    // Batch-fetch source nodes (was N+1).
+    const sources = this.queries.getNodesByIds(incomingEdges.map((e) => e.source));
     for (const edge of incomingEdges) {
-      const sourceNode = this.queries.getNodeById(edge.source);
-      if (sourceNode) {
-        result.push({ node: sourceNode, edge });
-      }
+      const sourceNode = sources.get(edge.source);
+      if (sourceNode) result.push({ node: sourceNode, edge });
     }
 
     return result;
@@ -496,13 +506,16 @@ export class GraphTraverser {
       const containerKinds = new Set(['class', 'interface', 'struct', 'trait', 'protocol', 'module', 'enum']);
       if (containerKinds.has(focalNode.kind)) {
         const containsEdges = this.queries.getOutgoingEdges(nodeId, ['contains']);
-        for (const edge of containsEdges) {
-          const childNode = this.queries.getNodeById(edge.target);
-          if (childNode && !visited.has(childNode.id)) {
-            nodes.set(childNode.id, childNode);
-            edges.push(edge);
-            // Recurse into children at the same depth (they're part of the same symbol)
-            this.getImpactRecursive(childNode.id, maxDepth, currentDepth, nodes, edges, visited);
+        if (containsEdges.length > 0) {
+          const children = this.queries.getNodesByIds(containsEdges.map((e) => e.target));
+          for (const edge of containsEdges) {
+            const childNode = children.get(edge.target);
+            if (childNode && !visited.has(childNode.id)) {
+              nodes.set(childNode.id, childNode);
+              edges.push(edge);
+              // Recurse into children at the same depth (they're part of the same symbol)
+              this.getImpactRecursive(childNode.id, maxDepth, currentDepth, nodes, edges, visited);
+            }
           }
         }
       }
@@ -510,9 +523,11 @@ export class GraphTraverser {
 
     // Get all incoming edges (things that depend on this node)
     const incomingEdges = this.queries.getIncomingEdges(nodeId);
+    if (incomingEdges.length === 0) return;
+    const sources = this.queries.getNodesByIds(incomingEdges.map((e) => e.source));
 
     for (const edge of incomingEdges) {
-      const sourceNode = this.queries.getNodeById(edge.source);
+      const sourceNode = sources.get(edge.source);
       if (sourceNode && !nodes.has(sourceNode.id)) {
         nodes.set(sourceNode.id, sourceNode);
         edges.push(edge);
@@ -564,10 +579,17 @@ export class GraphTraverser {
         nodeId,
         edgeKinds.length > 0 ? edgeKinds : undefined
       );
+      if (outgoingEdges.length === 0) continue;
+
+      // Batch-fetch only the unvisited targets (was N+1 per BFS frontier).
+      const wantIds = outgoingEdges
+        .map((e) => e.target)
+        .filter((id) => !visited.has(id));
+      const nextNodes = wantIds.length > 0 ? this.queries.getNodesByIds(wantIds) : new Map();
 
       for (const edge of outgoingEdges) {
         if (!visited.has(edge.target)) {
-          const nextNode = this.queries.getNodeById(edge.target);
+          const nextNode = nextNodes.get(edge.target);
           if (nextNode) {
             queue.push({
               nodeId: edge.target,
@@ -627,15 +649,15 @@ export class GraphTraverser {
    */
   getChildren(nodeId: string): Node[] {
     const containsEdges = this.queries.getOutgoingEdges(nodeId, ['contains']);
-    const children: Node[] = [];
+    if (containsEdges.length === 0) return [];
 
+    // Batch-fetch (was N+1).
+    const childNodes = this.queries.getNodesByIds(containsEdges.map((e) => e.target));
+    const children: Node[] = [];
     for (const edge of containsEdges) {
-      const childNode = this.queries.getNodeById(edge.target);
-      if (childNode) {
-        children.push(childNode);
-      }
+      const childNode = childNodes.get(edge.target);
+      if (childNode) children.push(childNode);
     }
-
     return children;
   }
 }

+ 11 - 0
src/index.ts

@@ -347,6 +347,12 @@ export class CodeGraph {
           });
         }
 
+        // Refresh planner stats + checkpoint the WAL after bulk writes.
+        // Cheap and non-blocking; never load-bearing for correctness.
+        if (result.success && result.filesIndexed > 0) {
+          this.db.runMaintenance();
+        }
+
         return result;
       } finally {
         this.fileLock.release();
@@ -428,6 +434,11 @@ export class CodeGraph {
           }
         }
 
+        // Refresh planner stats + checkpoint the WAL after bulk writes.
+        if (result.filesAdded > 0 || result.filesModified > 0 || result.filesRemoved > 0) {
+          this.db.runMaintenance();
+        }
+
         return result;
       } finally {
         this.fileLock.release();