evaluate_comparative.cjs 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. "use strict";
  2. var __importDefault = (this && this.__importDefault) || function (mod) {
  3. return (mod && mod.__esModule) ? mod : { "default": mod };
  4. };
  5. Object.defineProperty(exports, "__esModule", { value: true });
  6. exports.evaluateComparative = evaluateComparative;
  7. const uuid_1 = require("uuid");
  8. const index_js_1 = require("../index.cjs");
  9. const shuffle_js_1 = require("../utils/shuffle.cjs");
  10. const async_caller_js_1 = require("../utils/async_caller.cjs");
  11. const p_retry_1 = __importDefault(require("p-retry"));
  12. const traceable_js_1 = require("../traceable.cjs");
  13. function isExperimentResultsList(value) {
  14. return value.some((x) => typeof x !== "string");
  15. }
  16. async function loadExperiment(client, experiment) {
  17. const value = typeof experiment === "string" ? experiment : experiment.experimentName;
  18. return client.readProject((0, uuid_1.validate)(value) ? { projectId: value } : { projectName: value });
  19. }
  20. async function loadTraces(client, experiment, options) {
  21. const executionOrder = options.loadNested ? undefined : 1;
  22. const runs = await client.listRuns((0, uuid_1.validate)(experiment)
  23. ? { projectId: experiment, executionOrder }
  24. : { projectName: experiment, executionOrder });
  25. const treeMap = {};
  26. const runIdMap = {};
  27. const results = [];
  28. for await (const run of runs) {
  29. if (run.parent_run_id != null) {
  30. treeMap[run.parent_run_id] ??= [];
  31. treeMap[run.parent_run_id].push(run);
  32. }
  33. else {
  34. results.push(run);
  35. }
  36. runIdMap[run.id] = run;
  37. }
  38. for (const [parentRunId, childRuns] of Object.entries(treeMap)) {
  39. const parentRun = runIdMap[parentRunId];
  40. parentRun.child_runs = childRuns.sort((a, b) => {
  41. if (a.dotted_order == null || b.dotted_order == null)
  42. return 0;
  43. return a.dotted_order.localeCompare(b.dotted_order);
  44. });
  45. }
  46. return results;
  47. }
  48. async function evaluateComparative(experiments, options) {
  49. if (experiments.length < 2) {
  50. throw new Error("Comparative evaluation requires at least 2 experiments.");
  51. }
  52. if (!options.evaluators.length) {
  53. throw new Error("At least one evaluator is required for comparative evaluation.");
  54. }
  55. if (options.maxConcurrency && options.maxConcurrency < 0) {
  56. throw new Error("maxConcurrency must be a positive number.");
  57. }
  58. const client = options.client ?? new index_js_1.Client();
  59. const resolvedExperiments = await Promise.all(experiments);
  60. const projects = await (() => {
  61. if (!isExperimentResultsList(resolvedExperiments)) {
  62. return Promise.all(resolvedExperiments.map((experiment) => loadExperiment(client, experiment)));
  63. }
  64. // if we know the number of runs beforehand, check if the
  65. // number of runs in the project matches the expected number of runs
  66. return Promise.all(resolvedExperiments.map((experiment) => (0, p_retry_1.default)(async () => {
  67. const project = await loadExperiment(client, experiment);
  68. if (project.run_count !== experiment?.results.length) {
  69. throw new Error("Experiment is missing runs. Retrying.");
  70. }
  71. return project;
  72. }, { factor: 2, minTimeout: 1000, retries: 10 })));
  73. })();
  74. if (new Set(projects.map((p) => p.reference_dataset_id)).size > 1) {
  75. throw new Error("All experiments must have the same reference dataset.");
  76. }
  77. const referenceDatasetId = projects.at(0)?.reference_dataset_id;
  78. if (!referenceDatasetId) {
  79. throw new Error("Reference dataset is required for comparative evaluation.");
  80. }
  81. if (new Set(projects.map((p) => p.extra?.metadata?.dataset_version)).size > 1) {
  82. console.warn("Detected multiple dataset versions used by experiments, which may lead to inaccurate results.");
  83. }
  84. const datasetVersion = projects.at(0)?.extra?.metadata?.dataset_version;
  85. const id = (0, uuid_1.v4)();
  86. const experimentName = (() => {
  87. if (!options.experimentPrefix) {
  88. const names = projects
  89. .map((p) => p.name)
  90. .filter(Boolean)
  91. .join(" vs. ");
  92. return `${names}-${(0, uuid_1.v4)().slice(0, 4)}`;
  93. }
  94. return `${options.experimentPrefix}-${(0, uuid_1.v4)().slice(0, 4)}`;
  95. })();
  96. // TODO: add URL to the comparative experiment
  97. console.log(`Starting pairwise evaluation of: ${experimentName}`);
  98. const comparativeExperiment = await client.createComparativeExperiment({
  99. id,
  100. name: experimentName,
  101. experimentIds: projects.map((p) => p.id),
  102. description: options.description,
  103. metadata: options.metadata,
  104. referenceDatasetId: projects.at(0)?.reference_dataset_id,
  105. });
  106. const viewUrl = await (async () => {
  107. const projectId = projects.at(0)?.id ?? projects.at(1)?.id;
  108. const datasetId = comparativeExperiment?.reference_dataset_id;
  109. if (projectId && datasetId) {
  110. const hostUrl = (await client.getProjectUrl({ projectId }))
  111. .split("/projects/p/")
  112. .at(0);
  113. const result = new URL(`${hostUrl}/datasets/${datasetId}/compare`);
  114. result.searchParams.set("selectedSessions", projects.map((p) => p.id).join(","));
  115. result.searchParams.set("comparativeExperiment", comparativeExperiment.id);
  116. return result.toString();
  117. }
  118. return null;
  119. })();
  120. if (viewUrl != null) {
  121. console.log(`View results at: ${viewUrl}`);
  122. }
  123. const experimentRuns = await Promise.all(projects.map((p) => loadTraces(client, p.id, { loadNested: !!options.loadNested })));
  124. let exampleIdsIntersect;
  125. for (const runs of experimentRuns) {
  126. const exampleIdsSet = new Set(runs
  127. .map((r) => r.reference_example_id)
  128. .filter((x) => x != null));
  129. if (!exampleIdsIntersect) {
  130. exampleIdsIntersect = exampleIdsSet;
  131. }
  132. else {
  133. exampleIdsIntersect = new Set([...exampleIdsIntersect].filter((x) => exampleIdsSet.has(x)));
  134. }
  135. }
  136. const exampleIds = [...(exampleIdsIntersect ?? [])];
  137. if (!exampleIds.length) {
  138. throw new Error("No examples found in common between experiments.");
  139. }
  140. const exampleMap = {};
  141. for (let start = 0; start < exampleIds.length; start += 99) {
  142. const exampleIdsChunk = exampleIds.slice(start, start + 99);
  143. for await (const example of client.listExamples({
  144. datasetId: referenceDatasetId,
  145. exampleIds: exampleIdsChunk,
  146. asOf: datasetVersion,
  147. })) {
  148. exampleMap[example.id] = example;
  149. }
  150. }
  151. const runMapByExampleId = {};
  152. for (const runs of experimentRuns) {
  153. for (const run of runs) {
  154. if (run.reference_example_id == null ||
  155. !exampleIds.includes(run.reference_example_id)) {
  156. continue;
  157. }
  158. runMapByExampleId[run.reference_example_id] ??= [];
  159. runMapByExampleId[run.reference_example_id].push(run);
  160. }
  161. }
  162. const caller = new async_caller_js_1.AsyncCaller({
  163. maxConcurrency: options.maxConcurrency,
  164. debug: client.debug,
  165. });
  166. async function evaluateAndSubmitFeedback(runs, example, evaluator) {
  167. const expectedRunIds = new Set(runs.map((r) => r.id));
  168. // Check if evaluator expects an object parameter
  169. const result = evaluator.length === 1
  170. ? await evaluator({
  171. runs: options.randomizeOrder ? (0, shuffle_js_1.shuffle)(runs) : runs,
  172. example,
  173. inputs: example.inputs,
  174. outputs: runs.map((run) => run.outputs || {}),
  175. referenceOutputs: example.outputs || {},
  176. })
  177. : await evaluator(runs, example);
  178. for (const [runId, score] of Object.entries(result.scores)) {
  179. // validate if the run id
  180. if (!expectedRunIds.has(runId)) {
  181. throw new Error(`Returning an invalid run id ${runId} from evaluator.`);
  182. }
  183. await client.createFeedback(runId, result.key, {
  184. score,
  185. sourceRunId: result.source_run_id,
  186. comparativeExperimentId: comparativeExperiment.id,
  187. });
  188. }
  189. return result;
  190. }
  191. const tracedEvaluators = options.evaluators.map((evaluator) => (0, traceable_js_1.traceable)(async (runs, example) => {
  192. const evaluatorRun = (0, traceable_js_1.getCurrentRunTree)();
  193. const result = evaluator.length === 1
  194. ? await evaluator({
  195. runs: options.randomizeOrder ? (0, shuffle_js_1.shuffle)(runs) : runs,
  196. example,
  197. inputs: example.inputs,
  198. outputs: runs.map((run) => run.outputs || {}),
  199. referenceOutputs: example.outputs || {},
  200. })
  201. : await evaluator(runs, example);
  202. // sanitise the payload before sending to LangSmith
  203. evaluatorRun.inputs = { runs: runs, example: example };
  204. evaluatorRun.outputs = result;
  205. return {
  206. ...result,
  207. source_run_id: result.source_run_id ?? evaluatorRun.id,
  208. };
  209. }, {
  210. project_name: "evaluators",
  211. name: evaluator.name || "evaluator",
  212. }));
  213. const promises = Object.entries(runMapByExampleId).flatMap(([exampleId, runs]) => {
  214. const example = exampleMap[exampleId];
  215. if (!example)
  216. throw new Error(`Example ${exampleId} not found.`);
  217. return tracedEvaluators.map((evaluator) => caller.call(evaluateAndSubmitFeedback, runs, exampleMap[exampleId], evaluator));
  218. });
  219. const results = await Promise.all(promises);
  220. return { experimentName, results };
  221. }