evaluate_comparative.js 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import { v4 as uuid4, validate } from "uuid";
  2. import { Client } from "../index.js";
  3. import { shuffle } from "../utils/shuffle.js";
  4. import { AsyncCaller } from "../utils/async_caller.js";
  5. import pRetry from "p-retry";
  6. import { getCurrentRunTree, traceable } from "../traceable.js";
  7. function isExperimentResultsList(value) {
  8. return value.some((x) => typeof x !== "string");
  9. }
  10. async function loadExperiment(client, experiment) {
  11. const value = typeof experiment === "string" ? experiment : experiment.experimentName;
  12. return client.readProject(validate(value) ? { projectId: value } : { projectName: value });
  13. }
  14. async function loadTraces(client, experiment, options) {
  15. const executionOrder = options.loadNested ? undefined : 1;
  16. const runs = await client.listRuns(validate(experiment)
  17. ? { projectId: experiment, executionOrder }
  18. : { projectName: experiment, executionOrder });
  19. const treeMap = {};
  20. const runIdMap = {};
  21. const results = [];
  22. for await (const run of runs) {
  23. if (run.parent_run_id != null) {
  24. treeMap[run.parent_run_id] ??= [];
  25. treeMap[run.parent_run_id].push(run);
  26. }
  27. else {
  28. results.push(run);
  29. }
  30. runIdMap[run.id] = run;
  31. }
  32. for (const [parentRunId, childRuns] of Object.entries(treeMap)) {
  33. const parentRun = runIdMap[parentRunId];
  34. parentRun.child_runs = childRuns.sort((a, b) => {
  35. if (a.dotted_order == null || b.dotted_order == null)
  36. return 0;
  37. return a.dotted_order.localeCompare(b.dotted_order);
  38. });
  39. }
  40. return results;
  41. }
  42. export async function evaluateComparative(experiments, options) {
  43. if (experiments.length < 2) {
  44. throw new Error("Comparative evaluation requires at least 2 experiments.");
  45. }
  46. if (!options.evaluators.length) {
  47. throw new Error("At least one evaluator is required for comparative evaluation.");
  48. }
  49. if (options.maxConcurrency && options.maxConcurrency < 0) {
  50. throw new Error("maxConcurrency must be a positive number.");
  51. }
  52. const client = options.client ?? new Client();
  53. const resolvedExperiments = await Promise.all(experiments);
  54. const projects = await (() => {
  55. if (!isExperimentResultsList(resolvedExperiments)) {
  56. return Promise.all(resolvedExperiments.map((experiment) => loadExperiment(client, experiment)));
  57. }
  58. // if we know the number of runs beforehand, check if the
  59. // number of runs in the project matches the expected number of runs
  60. return Promise.all(resolvedExperiments.map((experiment) => pRetry(async () => {
  61. const project = await loadExperiment(client, experiment);
  62. if (project.run_count !== experiment?.results.length) {
  63. throw new Error("Experiment is missing runs. Retrying.");
  64. }
  65. return project;
  66. }, { factor: 2, minTimeout: 1000, retries: 10 })));
  67. })();
  68. if (new Set(projects.map((p) => p.reference_dataset_id)).size > 1) {
  69. throw new Error("All experiments must have the same reference dataset.");
  70. }
  71. const referenceDatasetId = projects.at(0)?.reference_dataset_id;
  72. if (!referenceDatasetId) {
  73. throw new Error("Reference dataset is required for comparative evaluation.");
  74. }
  75. if (new Set(projects.map((p) => p.extra?.metadata?.dataset_version)).size > 1) {
  76. console.warn("Detected multiple dataset versions used by experiments, which may lead to inaccurate results.");
  77. }
  78. const datasetVersion = projects.at(0)?.extra?.metadata?.dataset_version;
  79. const id = uuid4();
  80. const experimentName = (() => {
  81. if (!options.experimentPrefix) {
  82. const names = projects
  83. .map((p) => p.name)
  84. .filter(Boolean)
  85. .join(" vs. ");
  86. return `${names}-${uuid4().slice(0, 4)}`;
  87. }
  88. return `${options.experimentPrefix}-${uuid4().slice(0, 4)}`;
  89. })();
  90. // TODO: add URL to the comparative experiment
  91. console.log(`Starting pairwise evaluation of: ${experimentName}`);
  92. const comparativeExperiment = await client.createComparativeExperiment({
  93. id,
  94. name: experimentName,
  95. experimentIds: projects.map((p) => p.id),
  96. description: options.description,
  97. metadata: options.metadata,
  98. referenceDatasetId: projects.at(0)?.reference_dataset_id,
  99. });
  100. const viewUrl = await (async () => {
  101. const projectId = projects.at(0)?.id ?? projects.at(1)?.id;
  102. const datasetId = comparativeExperiment?.reference_dataset_id;
  103. if (projectId && datasetId) {
  104. const hostUrl = (await client.getProjectUrl({ projectId }))
  105. .split("/projects/p/")
  106. .at(0);
  107. const result = new URL(`${hostUrl}/datasets/${datasetId}/compare`);
  108. result.searchParams.set("selectedSessions", projects.map((p) => p.id).join(","));
  109. result.searchParams.set("comparativeExperiment", comparativeExperiment.id);
  110. return result.toString();
  111. }
  112. return null;
  113. })();
  114. if (viewUrl != null) {
  115. console.log(`View results at: ${viewUrl}`);
  116. }
  117. const experimentRuns = await Promise.all(projects.map((p) => loadTraces(client, p.id, { loadNested: !!options.loadNested })));
  118. let exampleIdsIntersect;
  119. for (const runs of experimentRuns) {
  120. const exampleIdsSet = new Set(runs
  121. .map((r) => r.reference_example_id)
  122. .filter((x) => x != null));
  123. if (!exampleIdsIntersect) {
  124. exampleIdsIntersect = exampleIdsSet;
  125. }
  126. else {
  127. exampleIdsIntersect = new Set([...exampleIdsIntersect].filter((x) => exampleIdsSet.has(x)));
  128. }
  129. }
  130. const exampleIds = [...(exampleIdsIntersect ?? [])];
  131. if (!exampleIds.length) {
  132. throw new Error("No examples found in common between experiments.");
  133. }
  134. const exampleMap = {};
  135. for (let start = 0; start < exampleIds.length; start += 99) {
  136. const exampleIdsChunk = exampleIds.slice(start, start + 99);
  137. for await (const example of client.listExamples({
  138. datasetId: referenceDatasetId,
  139. exampleIds: exampleIdsChunk,
  140. asOf: datasetVersion,
  141. })) {
  142. exampleMap[example.id] = example;
  143. }
  144. }
  145. const runMapByExampleId = {};
  146. for (const runs of experimentRuns) {
  147. for (const run of runs) {
  148. if (run.reference_example_id == null ||
  149. !exampleIds.includes(run.reference_example_id)) {
  150. continue;
  151. }
  152. runMapByExampleId[run.reference_example_id] ??= [];
  153. runMapByExampleId[run.reference_example_id].push(run);
  154. }
  155. }
  156. const caller = new AsyncCaller({
  157. maxConcurrency: options.maxConcurrency,
  158. debug: client.debug,
  159. });
  160. async function evaluateAndSubmitFeedback(runs, example, evaluator) {
  161. const expectedRunIds = new Set(runs.map((r) => r.id));
  162. // Check if evaluator expects an object parameter
  163. const result = evaluator.length === 1
  164. ? await evaluator({
  165. runs: options.randomizeOrder ? shuffle(runs) : runs,
  166. example,
  167. inputs: example.inputs,
  168. outputs: runs.map((run) => run.outputs || {}),
  169. referenceOutputs: example.outputs || {},
  170. })
  171. : await evaluator(runs, example);
  172. for (const [runId, score] of Object.entries(result.scores)) {
  173. // validate if the run id
  174. if (!expectedRunIds.has(runId)) {
  175. throw new Error(`Returning an invalid run id ${runId} from evaluator.`);
  176. }
  177. await client.createFeedback(runId, result.key, {
  178. score,
  179. sourceRunId: result.source_run_id,
  180. comparativeExperimentId: comparativeExperiment.id,
  181. });
  182. }
  183. return result;
  184. }
  185. const tracedEvaluators = options.evaluators.map((evaluator) => traceable(async (runs, example) => {
  186. const evaluatorRun = getCurrentRunTree();
  187. const result = evaluator.length === 1
  188. ? await evaluator({
  189. runs: options.randomizeOrder ? shuffle(runs) : runs,
  190. example,
  191. inputs: example.inputs,
  192. outputs: runs.map((run) => run.outputs || {}),
  193. referenceOutputs: example.outputs || {},
  194. })
  195. : await evaluator(runs, example);
  196. // sanitise the payload before sending to LangSmith
  197. evaluatorRun.inputs = { runs: runs, example: example };
  198. evaluatorRun.outputs = result;
  199. return {
  200. ...result,
  201. source_run_id: result.source_run_id ?? evaluatorRun.id,
  202. };
  203. }, {
  204. project_name: "evaluators",
  205. name: evaluator.name || "evaluator",
  206. }));
  207. const promises = Object.entries(runMapByExampleId).flatMap(([exampleId, runs]) => {
  208. const example = exampleMap[exampleId];
  209. if (!example)
  210. throw new Error(`Example ${exampleId} not found.`);
  211. return tracedEvaluators.map((evaluator) => caller.call(evaluateAndSubmitFeedback, runs, exampleMap[exampleId], evaluator));
  212. });
  213. const results = await Promise.all(promises);
  214. return { experimentName, results };
  215. }