Przeglądaj źródła

feat: Add Rust trait inheritance and impl block extraction with method receiver type support

Addresses Rust's impl block syntax where trait implementations (`impl Trait for Type`) and trait supertraits (`trait Sub: Super`) create inheritance relationships. Adds getReceiverType to extract method receiver types from impl blocks, enabling proper method-to-struct relationships and qualified name resolution. Verified against Deno codebase and moved from "Needs Verification" to completed language support.
Colby McHenry 2 miesięcy temu
rodzic
commit
2d14503258

+ 72 - 0
__tests__/extraction.test.ts

@@ -650,6 +650,78 @@ pub trait Repository {
     expect(traitNode).toBeDefined();
     expect(traitNode?.name).toBe('Repository');
   });
+
+  it('should extract impl Trait for Type as implements edges', () => {
+    const code = `
+pub struct MyCache {}
+
+pub trait Cache {
+    fn get(&self, key: &str) -> Option<String>;
+}
+
+impl Cache for MyCache {
+    fn get(&self, key: &str) -> Option<String> {
+        None
+    }
+}
+`;
+    const result = extractFromSource('cache.rs', code);
+
+    // Should have an unresolved reference for implements
+    const implRef = result.unresolvedReferences.find(
+      (r) => r.referenceKind === 'implements' && r.referenceName === 'Cache'
+    );
+    expect(implRef).toBeDefined();
+
+    // The struct MyCache should be the source
+    const myCacheNode = result.nodes.find((n) => n.name === 'MyCache' && n.kind === 'struct');
+    expect(myCacheNode).toBeDefined();
+    expect(implRef?.fromNodeId).toBe(myCacheNode?.id);
+  });
+
+  it('should extract trait supertraits as extends references', () => {
+    const code = `
+pub trait Display {}
+
+pub trait Error: Display {
+    fn description(&self) -> &str;
+}
+`;
+    const result = extractFromSource('error.rs', code);
+
+    const extendsRef = result.unresolvedReferences.find(
+      (r) => r.referenceKind === 'extends' && r.referenceName === 'Display'
+    );
+    expect(extendsRef).toBeDefined();
+
+    const errorTrait = result.nodes.find((n) => n.name === 'Error' && n.kind === 'trait');
+    expect(errorTrait).toBeDefined();
+    expect(extendsRef?.fromNodeId).toBe(errorTrait?.id);
+  });
+
+  it('should not create implements edges for plain impl blocks', () => {
+    const code = `
+pub struct Counter {
+    count: u32,
+}
+
+impl Counter {
+    pub fn new() -> Counter {
+        Counter { count: 0 }
+    }
+    pub fn increment(&mut self) {
+        self.count += 1;
+    }
+}
+`;
+    const result = extractFromSource('counter.rs', code);
+
+    // Should have no implements references (no trait involved)
+    const implRefs = result.unresolvedReferences.filter(
+      (r) => r.referenceKind === 'implements'
+    );
+    expect(implRefs).toHaveLength(0);
+  });
 });
 
 describe('Java Extraction', () => {

+ 1 - 1
docs/SEARCH_QUALITY_LOOP.md

@@ -523,12 +523,12 @@ if (receiverType) {
 - [x] **Swift** — NOT needed. Tree-sitter nests methods inside class/extension bodies
 - [x] **Java** — NOT needed. Methods nested in class body. Verified against Guava
 - [x] **Python** — NOT needed. Methods nested in class body. Verified against Flask
+- [x] **Rust** — `getReceiverType` walks up to parent `impl_item` to extract type name. Also adds `contains` edges from struct to impl methods. Verified against Deno
 
 ### Needs Verification
 
 Check these — may need `getReceiverType` if methods are top-level in the AST:
 
-- [ ] Rust — methods in `impl Type { }` blocks
 - [ ] C++ — out-of-class method definitions `Type::method()`
 - [ ] Kotlin — extension functions `fun Type.method()`
 

+ 37 - 0
src/extraction/languages/rust.ts

@@ -45,6 +45,43 @@ export const rustExtractor: LanguageExtractor = {
     }
     return 'private'; // Rust defaults to private
   },
+  getReceiverType: (node, source) => {
+    // Walk up the tree-sitter AST to find a parent impl_item
+    let parent = node.parent;
+    while (parent) {
+      if (parent.type === 'impl_item') {
+        // For `impl Type { ... }` — the type is a direct type_identifier child
+        // For `impl Trait for Type { ... }` — the type is the LAST type_identifier
+        // (the first is part of the trait path)
+        const children = parent.namedChildren;
+        // Find all direct type_identifier children (not nested in scoped paths)
+        const typeIdents = children.filter(
+          (c: SyntaxNode) => c.type === 'type_identifier'
+        );
+        if (typeIdents.length > 0) {
+          // Last type_identifier is always the implementing type
+          const typeNode = typeIdents[typeIdents.length - 1]!;
+          return source.substring(typeNode.startIndex, typeNode.endIndex);
+        }
+        // Handle generic types: impl<T> MyStruct<T> { ... }
+        const genericType = children.find(
+          (c: SyntaxNode) => c.type === 'generic_type'
+        );
+        if (genericType) {
+          const innerType = genericType.namedChildren.find(
+            (c: SyntaxNode) => c.type === 'type_identifier'
+          );
+          if (innerType) {
+            return source.substring(innerType.startIndex, innerType.endIndex);
+          }
+        }
+        return undefined;
+      }
+      parent = parent.parent;
+    }
+    return undefined;
+  },
+
   extractImport: (node, source) => {
     const importText = source.substring(node.startIndex, node.endIndex).trim();
 

+ 133 - 6
src/extraction/tree-sitter.ts

@@ -303,6 +303,10 @@ export class TreeSitterExtractor {
     else if (this.extractor.callTypes.includes(nodeType)) {
       this.extractCall(node);
     }
+    // Rust: `impl Trait for Type { ... }` — creates implements edge from Type to Trait
+    else if (nodeType === 'impl_item') {
+      this.extractRustImplItem(node);
+    }
 
     // Visit children (unless the extract method already visited them)
     if (!skipChildren) {
@@ -406,6 +410,13 @@ export class TreeSitterExtractor {
   private extractFunction(node: SyntaxNode): void {
     if (!this.extractor) return;
 
+    // If the language provides getReceiverType and this function has a receiver
+    // (e.g., Rust function_item inside an impl block), extract as method instead
+    if (this.extractor.getReceiverType?.(node, this.source)) {
+      this.extractMethod(node);
+      return;
+    }
+
     let name = extractName(node, this.source, this.extractor);
     // For arrow functions and function expressions assigned to variables,
     // resolve the name from the parent variable_declarator.
@@ -498,10 +509,15 @@ export class TreeSitterExtractor {
   private extractMethod(node: SyntaxNode): void {
     if (!this.extractor) return;
 
+    // For languages with receiver types (Go, Rust), include receiver in qualified name
+    // so FTS can match "scrapeLoop.run" → qualified_name "...::scrapeLoop::run"
+    const receiverType = this.extractor.getReceiverType?.(node, this.source);
+
     // For most languages, only extract as method if inside a class-like node
     // Languages with methodsAreTopLevel (e.g. Go) always treat them as methods
-    if (!this.isInsideClassLikeNode() && !this.extractor.methodsAreTopLevel) {
-      // Not inside a class-like node and not Go, treat as function
+    // Languages with getReceiverType (e.g. Rust) extract as method when receiver is found
+    if (!this.isInsideClassLikeNode() && !this.extractor.methodsAreTopLevel && !receiverType) {
+      // Not inside a class-like node and no receiver type, treat as function
       this.extractFunction(node);
       return;
     }
@@ -512,10 +528,6 @@ export class TreeSitterExtractor {
     const visibility = this.extractor.getVisibility?.(node);
     const isAsync = this.extractor.isAsync?.(node);
     const isStatic = this.extractor.isStatic?.(node);
-
-    // For languages with receiver types (Go), include receiver in qualified name
-    // so FTS can match "scrapeLoop.run" → qualified_name "...::scrapeLoop::run"
-    const receiverType = this.extractor.getReceiverType?.(node, this.source);
     const extraProps: Partial<Node> = {
       docstring,
       signature,
@@ -530,6 +542,24 @@ export class TreeSitterExtractor {
     const methodNode = this.createNode('method', name, node, extraProps);
     if (!methodNode) return;
 
+    // For methods with a receiver type but no class-like parent on the stack
+    // (e.g., Rust impl blocks), add a contains edge from the owning struct/trait
+    if (receiverType && !this.isInsideClassLikeNode()) {
+      const ownerNode = this.nodes.find(
+        (n) =>
+          n.name === receiverType &&
+          n.filePath === this.filePath &&
+          (n.kind === 'struct' || n.kind === 'class' || n.kind === 'enum' || n.kind === 'trait')
+      );
+      if (ownerNode) {
+        this.edges.push({
+          source: ownerNode.id,
+          target: methodNode.id,
+          kind: 'contains',
+        });
+      }
+    }
+
     // Extract type annotations (parameter types and return type)
     this.extractTypeAnnotations(node, methodNode.id);
 
@@ -1311,6 +1341,40 @@ export class TreeSitterExtractor {
         }
       }
 
+      // Rust trait supertraits: `trait SubTrait: SuperTrait + Display { ... }`
+      // trait_bounds contains type_identifier, generic_type, or higher_ranked_trait_bound children
+      if (child.type === 'trait_bounds') {
+        for (const bound of child.namedChildren) {
+          let typeName: string | undefined;
+          let posNode: SyntaxNode | undefined;
+
+          if (bound.type === 'type_identifier') {
+            typeName = getNodeText(bound, this.source);
+            posNode = bound;
+          } else if (bound.type === 'generic_type') {
+            // e.g. `Deserialize<'de>`
+            const inner = bound.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier');
+            if (inner) { typeName = getNodeText(inner, this.source); posNode = inner; }
+          } else if (bound.type === 'higher_ranked_trait_bound') {
+            // e.g. `for<'de> Deserialize<'de>`
+            const generic = bound.namedChildren.find((c: SyntaxNode) => c.type === 'generic_type');
+            const typeId = generic?.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier')
+              ?? bound.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier');
+            if (typeId) { typeName = getNodeText(typeId, this.source); posNode = typeId; }
+          }
+
+          if (typeName && posNode) {
+            this.unresolvedReferences.push({
+              fromNodeId: classId,
+              referenceName: typeName,
+              referenceKind: 'extends',
+              line: posNode.startPosition.row + 1,
+              column: posNode.startPosition.column,
+            });
+          }
+        }
+      }
+
       // Swift: inheritance_specifier > user_type > type_identifier
       // Used for class inheritance, protocol conformance, and protocol inheritance
       if (child.type === 'inheritance_specifier') {
@@ -1336,6 +1400,69 @@ export class TreeSitterExtractor {
     }
   }
 
+  /**
+   * Rust `impl Trait for Type` — creates an implements edge from Type to Trait.
+   * For plain `impl Type { ... }` (no trait), no inheritance edge is needed.
+   */
+  private extractRustImplItem(node: SyntaxNode): void {
+    // Check if this is `impl Trait for Type` by looking for a `for` keyword
+    const hasFor = node.children.some(
+      (c: SyntaxNode) => c.type === 'for' && !c.isNamed
+    );
+    if (!hasFor) return;
+
+    // In `impl Trait for Type`, the type_identifiers are:
+    // first = Trait name, last = implementing Type name
+    // Also handle generic types like `impl<T> Trait for MyStruct<T>`
+    const typeIdents = node.namedChildren.filter(
+      (c: SyntaxNode) => c.type === 'type_identifier' || c.type === 'generic_type' || c.type === 'scoped_type_identifier'
+    );
+    if (typeIdents.length < 2) return;
+
+    const traitNode = typeIdents[0]!;
+    const typeNode = typeIdents[typeIdents.length - 1]!;
+
+    // Get the trait name (handle scoped paths like std::fmt::Display)
+    const traitName = traitNode.type === 'scoped_type_identifier'
+      ? this.source.substring(traitNode.startIndex, traitNode.endIndex)
+      : getNodeText(traitNode, this.source);
+
+    // Get the implementing type name (extract inner type_identifier for generics)
+    let typeName: string;
+    if (typeNode.type === 'generic_type') {
+      const inner = typeNode.namedChildren.find(
+        (c: SyntaxNode) => c.type === 'type_identifier'
+      );
+      typeName = inner ? getNodeText(inner, this.source) : getNodeText(typeNode, this.source);
+    } else {
+      typeName = getNodeText(typeNode, this.source);
+    }
+
+    // Find the struct/type node for the implementing type
+    const typeNodeId = this.findNodeByName(typeName);
+    if (typeNodeId) {
+      this.unresolvedReferences.push({
+        fromNodeId: typeNodeId,
+        referenceName: traitName,
+        referenceKind: 'implements',
+        line: traitNode.startPosition.row + 1,
+        column: traitNode.startPosition.column,
+      });
+    }
+  }
+
+  /**
+   * Find a previously-extracted node by name (used for back-references like impl blocks)
+   */
+  private findNodeByName(name: string): string | undefined {
+    for (const node of this.nodes) {
+      if (node.name === name && (node.kind === 'struct' || node.kind === 'enum' || node.kind === 'class')) {
+        return node.id;
+      }
+    }
+    return undefined;
+  }
+
   /**
    * Languages that support type annotations (TypeScript, etc.)
    */