Skip to content

Commit 146b004

Browse files
author
harry_squater
committed
add full input chains validation
1 parent c6825ba commit 146b004

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

langchain/src/chains/simple_sequential_chain.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ export class SimpleSequentialChain
6464
return [this.inputKey];
6565
}
6666

67+
get outputKeys(): string[] {
68+
return [this.outputKey];
69+
}
70+
6771
constructor(fields: SimpleSequentialChainInput) {
6872
super(fields.memory, fields.verbose, fields.callbackManager);
6973
this.chains = fields.chains;
@@ -80,6 +84,13 @@ export class SimpleSequentialChain
8084
} for ${chain._chainType()}.`
8185
);
8286
}
87+
if (chain.outputKeys.length !== 1) {
88+
throw new Error(
89+
`Chains used in SimpleSequentialChain should all have one output, got ${
90+
chain.outputKeys.length
91+
} for ${chain._chainType()}.`
92+
);
93+
}
8394
}
8495
}
8596

langchain/src/chains/tests/simple_sequential_chain.test.ts

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import { LLMResult } from "../../schema/index.js";
44
import { LLMChain } from "../llm_chain.js";
55
import { PromptTemplate } from "../../prompts/index.js";
66
import { SimpleSequentialChain } from "../simple_sequential_chain.js";
7+
import { AnalyzeDocumentChain } from "../analyze_documents_chain.js";
8+
import { ConversationalRetrievalQAChain } from "../conversational_retrieval_chain.js";
9+
import { VectorStoreRetriever } from "../../vectorstores/base.js";
10+
import { FakeEmbeddings } from "../../embeddings/fake.js";
11+
import { MemoryVectorStore } from "../../vectorstores/memory.js";
712

813
class FakeLLM1 extends BaseLLM {
914
nrMapCalls = 0;
@@ -64,3 +69,50 @@ test("Test SimpleSequentialChain", async () => {
6469
const response = await combinedChain.run("initial question");
6570
expect(response).toEqual("final answer");
6671
});
72+
73+
test("Test SimpleSequentialChain input chains' single input validation", async () => {
74+
const model1 = new FakeLLM1({});
75+
const model2 = new FakeLLM2({});
76+
const template = "Some arbitrary template with fake {input1} and {input2}.";
77+
const prompt = new PromptTemplate({
78+
template,
79+
inputVariables: ["input1", "input2"],
80+
});
81+
const chain1 = new LLMChain({ llm: model1, prompt });
82+
const chain2 = new LLMChain({ llm: model2, prompt });
83+
expect(() => {
84+
/* eslint-disable no-new */
85+
new SimpleSequentialChain({ chains: [chain1, chain2] });
86+
}).toThrowErrorMatchingInlineSnapshot(
87+
`"Chains used in SimpleSequentialChain should all have one input, got 2 for llm_chain."`
88+
);
89+
});
90+
91+
test("Test SimpleSequentialChain input chains' single ouput validation", async () => {
92+
const model1 = new FakeLLM1({});
93+
const fakeEmbeddings = new FakeEmbeddings();
94+
const anyStore = new MemoryVectorStore(fakeEmbeddings);
95+
const retriever = new VectorStoreRetriever({
96+
vectorStore: anyStore,
97+
});
98+
const template = "Some arbitrary template with fake {input}.";
99+
const prompt = new PromptTemplate({ template, inputVariables: ["input"] });
100+
const chain1 = new LLMChain({ llm: model1, prompt });
101+
const chain2 = new ConversationalRetrievalQAChain({
102+
retriever,
103+
combineDocumentsChain: chain1,
104+
questionGeneratorChain: chain1,
105+
returnSourceDocuments: true,
106+
});
107+
// Chain below is is not meant to work in a real-life scenario.
108+
// It's only combined this way to get one input/multiple outputs chain.
109+
const multipleOutputChain = new AnalyzeDocumentChain({
110+
combineDocumentsChain: chain2,
111+
});
112+
expect(() => {
113+
/* eslint-disable no-new */
114+
new SimpleSequentialChain({ chains: [chain1, multipleOutputChain] });
115+
}).toThrowErrorMatchingInlineSnapshot(
116+
`"Chains used in SimpleSequentialChain should all have one output, got 2 for analyze_document_chain."`
117+
);
118+
});

0 commit comments

Comments
 (0)