ai.js 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import { parsePartialJson } from "../utils/json.js";
  2. import { BaseMessage, BaseMessageChunk, mergeContent, _mergeDicts, _mergeLists, } from "./base.js";
  3. import { defaultToolCallParser, } from "./tool.js";
  4. /**
  5. * Represents an AI message in a conversation.
  6. */
  7. export class AIMessage extends BaseMessage {
  8. get lc_aliases() {
  9. // exclude snake case conversion to pascal case
  10. return {
  11. ...super.lc_aliases,
  12. tool_calls: "tool_calls",
  13. invalid_tool_calls: "invalid_tool_calls",
  14. };
  15. }
  16. constructor(fields,
  17. /** @deprecated */
  18. kwargs) {
  19. let initParams;
  20. if (typeof fields === "string") {
  21. initParams = {
  22. content: fields,
  23. tool_calls: [],
  24. invalid_tool_calls: [],
  25. additional_kwargs: kwargs ?? {},
  26. };
  27. }
  28. else {
  29. initParams = fields;
  30. const rawToolCalls = initParams.additional_kwargs?.tool_calls;
  31. const toolCalls = initParams.tool_calls;
  32. if (!(rawToolCalls == null) &&
  33. rawToolCalls.length > 0 &&
  34. (toolCalls === undefined || toolCalls.length === 0)) {
  35. console.warn([
  36. "New LangChain packages are available that more efficiently handle",
  37. "tool calling.\n\nPlease upgrade your packages to versions that set",
  38. "message tool calls. e.g., `yarn add @langchain/anthropic`,",
  39. "yarn add @langchain/openai`, etc.",
  40. ].join(" "));
  41. }
  42. try {
  43. if (!(rawToolCalls == null) && toolCalls === undefined) {
  44. const [toolCalls, invalidToolCalls] = defaultToolCallParser(rawToolCalls);
  45. initParams.tool_calls = toolCalls ?? [];
  46. initParams.invalid_tool_calls = invalidToolCalls ?? [];
  47. }
  48. else {
  49. initParams.tool_calls = initParams.tool_calls ?? [];
  50. initParams.invalid_tool_calls = initParams.invalid_tool_calls ?? [];
  51. }
  52. }
  53. catch (e) {
  54. // Do nothing if parsing fails
  55. initParams.tool_calls = [];
  56. initParams.invalid_tool_calls = [];
  57. }
  58. }
  59. // Sadly, TypeScript only allows super() calls at root if the class has
  60. // properties with initializers, so we have to check types twice.
  61. super(initParams);
  62. // These are typed as optional to avoid breaking changes and allow for casting
  63. // from BaseMessage.
  64. Object.defineProperty(this, "tool_calls", {
  65. enumerable: true,
  66. configurable: true,
  67. writable: true,
  68. value: []
  69. });
  70. Object.defineProperty(this, "invalid_tool_calls", {
  71. enumerable: true,
  72. configurable: true,
  73. writable: true,
  74. value: []
  75. });
  76. /**
  77. * If provided, token usage information associated with the message.
  78. */
  79. Object.defineProperty(this, "usage_metadata", {
  80. enumerable: true,
  81. configurable: true,
  82. writable: true,
  83. value: void 0
  84. });
  85. if (typeof initParams !== "string") {
  86. this.tool_calls = initParams.tool_calls ?? this.tool_calls;
  87. this.invalid_tool_calls =
  88. initParams.invalid_tool_calls ?? this.invalid_tool_calls;
  89. }
  90. this.usage_metadata = initParams.usage_metadata;
  91. }
  92. static lc_name() {
  93. return "AIMessage";
  94. }
  95. _getType() {
  96. return "ai";
  97. }
  98. get _printableFields() {
  99. return {
  100. ...super._printableFields,
  101. tool_calls: this.tool_calls,
  102. invalid_tool_calls: this.invalid_tool_calls,
  103. usage_metadata: this.usage_metadata,
  104. };
  105. }
  106. }
  107. export function isAIMessage(x) {
  108. return x._getType() === "ai";
  109. }
  110. export function isAIMessageChunk(x) {
  111. return x._getType() === "ai";
  112. }
  113. /**
  114. * Represents a chunk of an AI message, which can be concatenated with
  115. * other AI message chunks.
  116. */
  117. export class AIMessageChunk extends BaseMessageChunk {
  118. constructor(fields) {
  119. let initParams;
  120. if (typeof fields === "string") {
  121. initParams = {
  122. content: fields,
  123. tool_calls: [],
  124. invalid_tool_calls: [],
  125. tool_call_chunks: [],
  126. };
  127. }
  128. else if (fields.tool_call_chunks === undefined) {
  129. initParams = {
  130. ...fields,
  131. tool_calls: fields.tool_calls ?? [],
  132. invalid_tool_calls: [],
  133. tool_call_chunks: [],
  134. usage_metadata: fields.usage_metadata !== undefined
  135. ? fields.usage_metadata
  136. : undefined,
  137. };
  138. }
  139. else {
  140. const toolCalls = [];
  141. const invalidToolCalls = [];
  142. for (const toolCallChunk of fields.tool_call_chunks) {
  143. let parsedArgs = {};
  144. try {
  145. parsedArgs = parsePartialJson(toolCallChunk.args || "{}");
  146. if (parsedArgs === null ||
  147. typeof parsedArgs !== "object" ||
  148. Array.isArray(parsedArgs)) {
  149. throw new Error("Malformed tool call chunk args.");
  150. }
  151. toolCalls.push({
  152. name: toolCallChunk.name ?? "",
  153. args: parsedArgs,
  154. id: toolCallChunk.id,
  155. type: "tool_call",
  156. });
  157. }
  158. catch (e) {
  159. invalidToolCalls.push({
  160. name: toolCallChunk.name,
  161. args: toolCallChunk.args,
  162. id: toolCallChunk.id,
  163. error: "Malformed args.",
  164. type: "invalid_tool_call",
  165. });
  166. }
  167. }
  168. initParams = {
  169. ...fields,
  170. tool_calls: toolCalls,
  171. invalid_tool_calls: invalidToolCalls,
  172. usage_metadata: fields.usage_metadata !== undefined
  173. ? fields.usage_metadata
  174. : undefined,
  175. };
  176. }
  177. // Sadly, TypeScript only allows super() calls at root if the class has
  178. // properties with initializers, so we have to check types twice.
  179. super(initParams);
  180. // Must redeclare tool call fields since there is no multiple inheritance in JS.
  181. // These are typed as optional to avoid breaking changes and allow for casting
  182. // from BaseMessage.
  183. Object.defineProperty(this, "tool_calls", {
  184. enumerable: true,
  185. configurable: true,
  186. writable: true,
  187. value: []
  188. });
  189. Object.defineProperty(this, "invalid_tool_calls", {
  190. enumerable: true,
  191. configurable: true,
  192. writable: true,
  193. value: []
  194. });
  195. Object.defineProperty(this, "tool_call_chunks", {
  196. enumerable: true,
  197. configurable: true,
  198. writable: true,
  199. value: []
  200. });
  201. /**
  202. * If provided, token usage information associated with the message.
  203. */
  204. Object.defineProperty(this, "usage_metadata", {
  205. enumerable: true,
  206. configurable: true,
  207. writable: true,
  208. value: void 0
  209. });
  210. this.tool_call_chunks =
  211. initParams.tool_call_chunks ?? this.tool_call_chunks;
  212. this.tool_calls = initParams.tool_calls ?? this.tool_calls;
  213. this.invalid_tool_calls =
  214. initParams.invalid_tool_calls ?? this.invalid_tool_calls;
  215. this.usage_metadata = initParams.usage_metadata;
  216. }
  217. get lc_aliases() {
  218. // exclude snake case conversion to pascal case
  219. return {
  220. ...super.lc_aliases,
  221. tool_calls: "tool_calls",
  222. invalid_tool_calls: "invalid_tool_calls",
  223. tool_call_chunks: "tool_call_chunks",
  224. };
  225. }
  226. static lc_name() {
  227. return "AIMessageChunk";
  228. }
  229. _getType() {
  230. return "ai";
  231. }
  232. get _printableFields() {
  233. return {
  234. ...super._printableFields,
  235. tool_calls: this.tool_calls,
  236. tool_call_chunks: this.tool_call_chunks,
  237. invalid_tool_calls: this.invalid_tool_calls,
  238. usage_metadata: this.usage_metadata,
  239. };
  240. }
  241. concat(chunk) {
  242. const combinedFields = {
  243. content: mergeContent(this.content, chunk.content),
  244. additional_kwargs: _mergeDicts(this.additional_kwargs, chunk.additional_kwargs),
  245. response_metadata: _mergeDicts(this.response_metadata, chunk.response_metadata),
  246. tool_call_chunks: [],
  247. id: this.id ?? chunk.id,
  248. };
  249. if (this.tool_call_chunks !== undefined ||
  250. chunk.tool_call_chunks !== undefined) {
  251. const rawToolCalls = _mergeLists(this.tool_call_chunks, chunk.tool_call_chunks);
  252. if (rawToolCalls !== undefined && rawToolCalls.length > 0) {
  253. combinedFields.tool_call_chunks = rawToolCalls;
  254. }
  255. }
  256. if (this.usage_metadata !== undefined ||
  257. chunk.usage_metadata !== undefined) {
  258. const inputTokenDetails = {
  259. ...((this.usage_metadata?.input_token_details?.audio !== undefined ||
  260. chunk.usage_metadata?.input_token_details?.audio !== undefined) && {
  261. audio: (this.usage_metadata?.input_token_details?.audio ?? 0) +
  262. (chunk.usage_metadata?.input_token_details?.audio ?? 0),
  263. }),
  264. ...((this.usage_metadata?.input_token_details?.cache_read !==
  265. undefined ||
  266. chunk.usage_metadata?.input_token_details?.cache_read !==
  267. undefined) && {
  268. cache_read: (this.usage_metadata?.input_token_details?.cache_read ?? 0) +
  269. (chunk.usage_metadata?.input_token_details?.cache_read ?? 0),
  270. }),
  271. ...((this.usage_metadata?.input_token_details?.cache_creation !==
  272. undefined ||
  273. chunk.usage_metadata?.input_token_details?.cache_creation !==
  274. undefined) && {
  275. cache_creation: (this.usage_metadata?.input_token_details?.cache_creation ?? 0) +
  276. (chunk.usage_metadata?.input_token_details?.cache_creation ?? 0),
  277. }),
  278. };
  279. const outputTokenDetails = {
  280. ...((this.usage_metadata?.output_token_details?.audio !== undefined ||
  281. chunk.usage_metadata?.output_token_details?.audio !== undefined) && {
  282. audio: (this.usage_metadata?.output_token_details?.audio ?? 0) +
  283. (chunk.usage_metadata?.output_token_details?.audio ?? 0),
  284. }),
  285. ...((this.usage_metadata?.output_token_details?.reasoning !==
  286. undefined ||
  287. chunk.usage_metadata?.output_token_details?.reasoning !==
  288. undefined) && {
  289. reasoning: (this.usage_metadata?.output_token_details?.reasoning ?? 0) +
  290. (chunk.usage_metadata?.output_token_details?.reasoning ?? 0),
  291. }),
  292. };
  293. const left = this.usage_metadata ?? {
  294. input_tokens: 0,
  295. output_tokens: 0,
  296. total_tokens: 0,
  297. };
  298. const right = chunk.usage_metadata ?? {
  299. input_tokens: 0,
  300. output_tokens: 0,
  301. total_tokens: 0,
  302. };
  303. const usage_metadata = {
  304. input_tokens: left.input_tokens + right.input_tokens,
  305. output_tokens: left.output_tokens + right.output_tokens,
  306. total_tokens: left.total_tokens + right.total_tokens,
  307. // Do not include `input_token_details` / `output_token_details` keys in combined fields
  308. // unless their values are defined.
  309. ...(Object.keys(inputTokenDetails).length > 0 && {
  310. input_token_details: inputTokenDetails,
  311. }),
  312. ...(Object.keys(outputTokenDetails).length > 0 && {
  313. output_token_details: outputTokenDetails,
  314. }),
  315. };
  316. combinedFields.usage_metadata = usage_metadata;
  317. }
  318. return new AIMessageChunk(combinedFields);
  319. }
  320. }