branch.js 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import { Runnable, _coerceToDict, _coerceToRunnable, } from "./base.js";
  2. import { getCallbackManagerForConfig, patchConfig, } from "./config.js";
  3. import { concat } from "../utils/stream.js";
  4. /**
  5. * Class that represents a runnable branch. The RunnableBranch is
  6. * initialized with an array of branches and a default branch. When invoked,
  7. * it evaluates the condition of each branch in order and executes the
  8. * corresponding branch if the condition is true. If none of the conditions
  9. * are true, it executes the default branch.
  10. * @example
  11. * ```typescript
  12. * const branch = RunnableBranch.from([
  13. * [
  14. * (x: { topic: string; question: string }) =>
  15. * x.topic.toLowerCase().includes("anthropic"),
  16. * anthropicChain,
  17. * ],
  18. * [
  19. * (x: { topic: string; question: string }) =>
  20. * x.topic.toLowerCase().includes("langchain"),
  21. * langChainChain,
  22. * ],
  23. * generalChain,
  24. * ]);
  25. *
  26. * const fullChain = RunnableSequence.from([
  27. * {
  28. * topic: classificationChain,
  29. * question: (input: { question: string }) => input.question,
  30. * },
  31. * branch,
  32. * ]);
  33. *
  34. * const result = await fullChain.invoke({
  35. * question: "how do I use LangChain?",
  36. * });
  37. * ```
  38. */
  39. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  40. export class RunnableBranch extends Runnable {
  41. static lc_name() {
  42. return "RunnableBranch";
  43. }
  44. constructor(fields) {
  45. super(fields);
  46. Object.defineProperty(this, "lc_namespace", {
  47. enumerable: true,
  48. configurable: true,
  49. writable: true,
  50. value: ["langchain_core", "runnables"]
  51. });
  52. Object.defineProperty(this, "lc_serializable", {
  53. enumerable: true,
  54. configurable: true,
  55. writable: true,
  56. value: true
  57. });
  58. Object.defineProperty(this, "default", {
  59. enumerable: true,
  60. configurable: true,
  61. writable: true,
  62. value: void 0
  63. });
  64. Object.defineProperty(this, "branches", {
  65. enumerable: true,
  66. configurable: true,
  67. writable: true,
  68. value: void 0
  69. });
  70. this.branches = fields.branches;
  71. this.default = fields.default;
  72. }
  73. /**
  74. * Convenience method for instantiating a RunnableBranch from
  75. * RunnableLikes (objects, functions, or Runnables).
  76. *
  77. * Each item in the input except for the last one should be a
  78. * tuple with two items. The first is a "condition" RunnableLike that
  79. * returns "true" if the second RunnableLike in the tuple should run.
  80. *
  81. * The final item in the input should be a RunnableLike that acts as a
  82. * default branch if no other branches match.
  83. *
  84. * @example
  85. * ```ts
  86. * import { RunnableBranch } from "@langchain/core/runnables";
  87. *
  88. * const branch = RunnableBranch.from([
  89. * [(x: number) => x > 0, (x: number) => x + 1],
  90. * [(x: number) => x < 0, (x: number) => x - 1],
  91. * (x: number) => x
  92. * ]);
  93. * ```
  94. * @param branches An array where the every item except the last is a tuple of [condition, runnable]
  95. * pairs. The last item is a default runnable which is invoked if no other condition matches.
  96. * @returns A new RunnableBranch.
  97. */
  98. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  99. static from(branches) {
  100. if (branches.length < 1) {
  101. throw new Error("RunnableBranch requires at least one branch");
  102. }
  103. const branchLikes = branches.slice(0, -1);
  104. const coercedBranches = branchLikes.map(([condition, runnable]) => [
  105. _coerceToRunnable(condition),
  106. _coerceToRunnable(runnable),
  107. ]);
  108. const defaultBranch = _coerceToRunnable(branches[branches.length - 1]);
  109. return new this({
  110. branches: coercedBranches,
  111. default: defaultBranch,
  112. });
  113. }
  114. async _invoke(input, config, runManager) {
  115. let result;
  116. for (let i = 0; i < this.branches.length; i += 1) {
  117. const [condition, branchRunnable] = this.branches[i];
  118. const conditionValue = await condition.invoke(input, patchConfig(config, {
  119. callbacks: runManager?.getChild(`condition:${i + 1}`),
  120. }));
  121. if (conditionValue) {
  122. result = await branchRunnable.invoke(input, patchConfig(config, {
  123. callbacks: runManager?.getChild(`branch:${i + 1}`),
  124. }));
  125. break;
  126. }
  127. }
  128. if (!result) {
  129. result = await this.default.invoke(input, patchConfig(config, {
  130. callbacks: runManager?.getChild("branch:default"),
  131. }));
  132. }
  133. return result;
  134. }
  135. async invoke(input, config = {}) {
  136. return this._callWithConfig(this._invoke, input, config);
  137. }
  138. async *_streamIterator(input, config) {
  139. const callbackManager_ = await getCallbackManagerForConfig(config);
  140. const runManager = await callbackManager_?.handleChainStart(this.toJSON(), _coerceToDict(input, "input"), config?.runId, undefined, undefined, undefined, config?.runName);
  141. let finalOutput;
  142. let finalOutputSupported = true;
  143. let stream;
  144. try {
  145. for (let i = 0; i < this.branches.length; i += 1) {
  146. const [condition, branchRunnable] = this.branches[i];
  147. const conditionValue = await condition.invoke(input, patchConfig(config, {
  148. callbacks: runManager?.getChild(`condition:${i + 1}`),
  149. }));
  150. if (conditionValue) {
  151. stream = await branchRunnable.stream(input, patchConfig(config, {
  152. callbacks: runManager?.getChild(`branch:${i + 1}`),
  153. }));
  154. for await (const chunk of stream) {
  155. yield chunk;
  156. if (finalOutputSupported) {
  157. if (finalOutput === undefined) {
  158. finalOutput = chunk;
  159. }
  160. else {
  161. try {
  162. finalOutput = concat(finalOutput, chunk);
  163. }
  164. catch (e) {
  165. finalOutput = undefined;
  166. finalOutputSupported = false;
  167. }
  168. }
  169. }
  170. }
  171. break;
  172. }
  173. }
  174. if (stream === undefined) {
  175. stream = await this.default.stream(input, patchConfig(config, {
  176. callbacks: runManager?.getChild("branch:default"),
  177. }));
  178. for await (const chunk of stream) {
  179. yield chunk;
  180. if (finalOutputSupported) {
  181. if (finalOutput === undefined) {
  182. finalOutput = chunk;
  183. }
  184. else {
  185. try {
  186. finalOutput = concat(finalOutput, chunk);
  187. }
  188. catch (e) {
  189. finalOutput = undefined;
  190. finalOutputSupported = false;
  191. }
  192. }
  193. }
  194. }
  195. }
  196. }
  197. catch (e) {
  198. await runManager?.handleChainError(e);
  199. throw e;
  200. }
  201. await runManager?.handleChainEnd(finalOutput ?? {});
  202. }
  203. }