branch.cjs 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. "use strict";
  2. Object.defineProperty(exports, "__esModule", { value: true });
  3. exports.RunnableBranch = void 0;
  4. const base_js_1 = require("./base.cjs");
  5. const config_js_1 = require("./config.cjs");
  6. const stream_js_1 = require("../utils/stream.cjs");
  7. /**
  8. * Class that represents a runnable branch. The RunnableBranch is
  9. * initialized with an array of branches and a default branch. When invoked,
  10. * it evaluates the condition of each branch in order and executes the
  11. * corresponding branch if the condition is true. If none of the conditions
  12. * are true, it executes the default branch.
  13. * @example
  14. * ```typescript
  15. * const branch = RunnableBranch.from([
  16. * [
  17. * (x: { topic: string; question: string }) =>
  18. * x.topic.toLowerCase().includes("anthropic"),
  19. * anthropicChain,
  20. * ],
  21. * [
  22. * (x: { topic: string; question: string }) =>
  23. * x.topic.toLowerCase().includes("langchain"),
  24. * langChainChain,
  25. * ],
  26. * generalChain,
  27. * ]);
  28. *
  29. * const fullChain = RunnableSequence.from([
  30. * {
  31. * topic: classificationChain,
  32. * question: (input: { question: string }) => input.question,
  33. * },
  34. * branch,
  35. * ]);
  36. *
  37. * const result = await fullChain.invoke({
  38. * question: "how do I use LangChain?",
  39. * });
  40. * ```
  41. */
  42. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  43. class RunnableBranch extends base_js_1.Runnable {
  44. static lc_name() {
  45. return "RunnableBranch";
  46. }
  47. constructor(fields) {
  48. super(fields);
  49. Object.defineProperty(this, "lc_namespace", {
  50. enumerable: true,
  51. configurable: true,
  52. writable: true,
  53. value: ["langchain_core", "runnables"]
  54. });
  55. Object.defineProperty(this, "lc_serializable", {
  56. enumerable: true,
  57. configurable: true,
  58. writable: true,
  59. value: true
  60. });
  61. Object.defineProperty(this, "default", {
  62. enumerable: true,
  63. configurable: true,
  64. writable: true,
  65. value: void 0
  66. });
  67. Object.defineProperty(this, "branches", {
  68. enumerable: true,
  69. configurable: true,
  70. writable: true,
  71. value: void 0
  72. });
  73. this.branches = fields.branches;
  74. this.default = fields.default;
  75. }
  76. /**
  77. * Convenience method for instantiating a RunnableBranch from
  78. * RunnableLikes (objects, functions, or Runnables).
  79. *
  80. * Each item in the input except for the last one should be a
  81. * tuple with two items. The first is a "condition" RunnableLike that
  82. * returns "true" if the second RunnableLike in the tuple should run.
  83. *
  84. * The final item in the input should be a RunnableLike that acts as a
  85. * default branch if no other branches match.
  86. *
  87. * @example
  88. * ```ts
  89. * import { RunnableBranch } from "@langchain/core/runnables";
  90. *
  91. * const branch = RunnableBranch.from([
  92. * [(x: number) => x > 0, (x: number) => x + 1],
  93. * [(x: number) => x < 0, (x: number) => x - 1],
  94. * (x: number) => x
  95. * ]);
  96. * ```
  97. * @param branches An array where the every item except the last is a tuple of [condition, runnable]
  98. * pairs. The last item is a default runnable which is invoked if no other condition matches.
  99. * @returns A new RunnableBranch.
  100. */
  101. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  102. static from(branches) {
  103. if (branches.length < 1) {
  104. throw new Error("RunnableBranch requires at least one branch");
  105. }
  106. const branchLikes = branches.slice(0, -1);
  107. const coercedBranches = branchLikes.map(([condition, runnable]) => [
  108. (0, base_js_1._coerceToRunnable)(condition),
  109. (0, base_js_1._coerceToRunnable)(runnable),
  110. ]);
  111. const defaultBranch = (0, base_js_1._coerceToRunnable)(branches[branches.length - 1]);
  112. return new this({
  113. branches: coercedBranches,
  114. default: defaultBranch,
  115. });
  116. }
  117. async _invoke(input, config, runManager) {
  118. let result;
  119. for (let i = 0; i < this.branches.length; i += 1) {
  120. const [condition, branchRunnable] = this.branches[i];
  121. const conditionValue = await condition.invoke(input, (0, config_js_1.patchConfig)(config, {
  122. callbacks: runManager?.getChild(`condition:${i + 1}`),
  123. }));
  124. if (conditionValue) {
  125. result = await branchRunnable.invoke(input, (0, config_js_1.patchConfig)(config, {
  126. callbacks: runManager?.getChild(`branch:${i + 1}`),
  127. }));
  128. break;
  129. }
  130. }
  131. if (!result) {
  132. result = await this.default.invoke(input, (0, config_js_1.patchConfig)(config, {
  133. callbacks: runManager?.getChild("branch:default"),
  134. }));
  135. }
  136. return result;
  137. }
  138. async invoke(input, config = {}) {
  139. return this._callWithConfig(this._invoke, input, config);
  140. }
  141. async *_streamIterator(input, config) {
  142. const callbackManager_ = await (0, config_js_1.getCallbackManagerForConfig)(config);
  143. const runManager = await callbackManager_?.handleChainStart(this.toJSON(), (0, base_js_1._coerceToDict)(input, "input"), config?.runId, undefined, undefined, undefined, config?.runName);
  144. let finalOutput;
  145. let finalOutputSupported = true;
  146. let stream;
  147. try {
  148. for (let i = 0; i < this.branches.length; i += 1) {
  149. const [condition, branchRunnable] = this.branches[i];
  150. const conditionValue = await condition.invoke(input, (0, config_js_1.patchConfig)(config, {
  151. callbacks: runManager?.getChild(`condition:${i + 1}`),
  152. }));
  153. if (conditionValue) {
  154. stream = await branchRunnable.stream(input, (0, config_js_1.patchConfig)(config, {
  155. callbacks: runManager?.getChild(`branch:${i + 1}`),
  156. }));
  157. for await (const chunk of stream) {
  158. yield chunk;
  159. if (finalOutputSupported) {
  160. if (finalOutput === undefined) {
  161. finalOutput = chunk;
  162. }
  163. else {
  164. try {
  165. finalOutput = (0, stream_js_1.concat)(finalOutput, chunk);
  166. }
  167. catch (e) {
  168. finalOutput = undefined;
  169. finalOutputSupported = false;
  170. }
  171. }
  172. }
  173. }
  174. break;
  175. }
  176. }
  177. if (stream === undefined) {
  178. stream = await this.default.stream(input, (0, config_js_1.patchConfig)(config, {
  179. callbacks: runManager?.getChild("branch:default"),
  180. }));
  181. for await (const chunk of stream) {
  182. yield chunk;
  183. if (finalOutputSupported) {
  184. if (finalOutput === undefined) {
  185. finalOutput = chunk;
  186. }
  187. else {
  188. try {
  189. finalOutput = (0, stream_js_1.concat)(finalOutput, chunk);
  190. }
  191. catch (e) {
  192. finalOutput = undefined;
  193. finalOutputSupported = false;
  194. }
  195. }
  196. }
  197. }
  198. }
  199. }
  200. catch (e) {
  201. await runManager?.handleChainError(e);
  202. throw e;
  203. }
  204. await runManager?.handleChainEnd(finalOutput ?? {});
  205. }
  206. }
  207. exports.RunnableBranch = RunnableBranch;