diff --git a/README.md b/README.md index 9c77638..a4eabd6 100644 --- a/README.md +++ b/README.md @@ -98,9 +98,10 @@ See [Test assertions](https://promptfoo.dev/docs/configuration/expected-outputs) | `is-json` | output is valid json | | `contains-json` | output contains valid json | | `javascript` | provided Javascript function validates the output | -| `webhook` | provided webhook returns `{pass: true} | +| `webhook` | provided webhook returns `{pass: true}` | | `similar` | embeddings and cosine similarity are above a threshold | | `llm-rubric` | LLM output matches a given rubric, using a Language Model to grade output | +| `rouge-n` | Rouge-N score is above a given threshold | Every test type can be negated by prepending `not-`. For example, `not-equals` or `not-regex`. diff --git a/package-lock.json b/package-lock.json index d29f0cc..ccba545 100644 --- a/package-lock.json +++ b/package-lock.json @@ -27,6 +27,7 @@ "node-fetch": "^2.6.7", "nunjucks": "^3.2.4", "opener": "^1.5.2", + "rouge": "^1.0.3", "socket.io": "^4.6.1", "tiny-invariant": "^1.3.1", "winston": "^3.8.2" @@ -3973,6 +3974,12 @@ "signal-exit": "^3.0.2" } }, + "node_modules/lodash-node": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz", + "integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg==", + "deprecated": "This package is discontinued. Use lodash@^4.0.0." + }, "node_modules/lodash.clonedeep": { "version": "4.5.0", "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", @@ -4665,6 +4672,14 @@ "node": ">=10" } }, + "node_modules/rouge": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz", + "integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==", + "dependencies": { + "lodash-node": "^2.4.1" + } + }, "node_modules/safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -8647,6 +8662,11 @@ "signal-exit": "^3.0.2" } }, + "lodash-node": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz", + "integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg==" + }, "lodash.clonedeep": { "version": "4.5.0", "resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz", @@ -9130,6 +9150,14 @@ "integrity": "sha512-X2UW6Nw3n/aMgDVy+0rSqgHlv39WZAlZrXCdnbyEiKm17DSqHX4MmQMaST3FbeWR5FTuRcUwYAziZajji0Y7mg==", "dev": true }, + "rouge": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz", + "integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==", + "requires": { + "lodash-node": "^2.4.1" + } + }, "safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", diff --git a/package.json b/package.json index 58905ae..e377bd8 100644 --- a/package.json +++ b/package.json @@ -77,6 +77,7 @@ "node-fetch": "^2.6.7", "nunjucks": "^3.2.4", "opener": "^1.5.2", + "rouge": "^1.0.3", "socket.io": "^4.6.1", "tiny-invariant": "^1.3.1", "winston": "^3.8.2" diff --git a/src/assertions.ts b/src/assertions.ts index 3a5c07a..9bf338b 100644 --- a/src/assertions.ts +++ b/src/assertions.ts @@ -1,3 +1,4 @@ +import rouge from 'rouge'; import invariant from 'tiny-invariant'; import nunjucks from 'nunjucks'; @@ -18,6 +19,31 @@ const SIMILAR_REGEX = /similar(?::|\((\d+(\.\d+)?)\):)/; const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8; +function handleRougeScore( + baseType: 'rouge-n', + assertion: Assertion, + expected: string | string[], + output: string, + inverted: boolean, +): GradingResult { + const fnName = baseType[baseType.length - 1] as 'n' | 'l' | 's'; + const rougeMethod = rouge[fnName]; + const score = rougeMethod(output, expected); + console.log(output, expected, score); + const pass = score >= (assertion.threshold || 0.75) != inverted; + + return { + pass, + reason: pass + ? `${baseType.toUpperCase()} score ${score} is greater than or equal to threshold ${ + assertion.threshold || 0.75 + }` + : `${baseType.toUpperCase()} score ${score} is less than threshold ${ + assertion.threshold || 0.75 + }`, + }; +} + export async function runAssertions(test: AtomicTestCase, output: string): Promise { const tokensUsed = { total: 0, @@ -248,6 +274,11 @@ ${assertion.value}`, }; } + if (baseType === 'rouge-n') { + invariant(assertion.value, '"rouge" assertion type must a value (string or string array)'); + return handleRougeScore(baseType, assertion, assertion.value, output, inverse); + } + throw new Error('Unknown assertion type: ' + assertion.type); } diff --git a/src/types.ts b/src/types.ts index 5b13ac3..4b4b7d7 100644 --- a/src/types.ts +++ b/src/types.ts @@ -125,7 +125,10 @@ type BaseAssertionTypes = | 'javascript' | 'similar' | 'llm-rubric' - | 'webhook'; + | 'webhook' + | 'rouge-n' + | 'rouge-s' + | 'rouge-l'; type NotPrefixed = `not-${T}`; diff --git a/test/assertions.test.ts b/test/assertions.test.ts index f393840..42fc633 100644 --- a/test/assertions.test.ts +++ b/test/assertions.test.ts @@ -452,6 +452,29 @@ describe('runAssertion', () => { expect(result.pass).toBeFalsy(); expect(result.reason).toBe('Webhook error: Webhook response status: 500'); }); + + // Test for rouge-n assertion + const rougeNAssertion: Assertion = { + type: 'rouge-n', + value: 'This is the expected output.', + threshold: 0.75, + }; + + it('should pass when the rouge-n assertion passes', async () => { + const output = 'This is the expected output.'; + + const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output); + expect(result.pass).toBeTruthy(); + expect(result.reason).toBe('ROUGE-N score 1 is greater than or equal to threshold 0.75'); + }); + + it('should fail when the rouge-n assertion fails', async () => { + const output = 'some different output'; + + const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output); + expect(result.pass).toBeFalsy(); + expect(result.reason).toBe('ROUGE-N score 0.2 is less than threshold 0.75'); + }); }); describe('assertionFromString', () => { diff --git a/tsconfig.json b/tsconfig.json index 69812f1..e3d7bb3 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,6 +11,6 @@ "sourceMap": true, "strict": true }, - "include": ["src/"], + "include": ["src/", "typings/**/*"], "exclude": ["node_modules", "dist", "src/web/client/**/*"] } diff --git a/typings/rouge.d.ts b/typings/rouge.d.ts new file mode 100644 index 0000000..9a6292f --- /dev/null +++ b/typings/rouge.d.ts @@ -0,0 +1,10 @@ +declare module 'rouge' { + function n( + candidate: string, + reference: string | string[], + n?: number, + jackknife?: boolean, + ): number; + function l(candidate: string, reference: string | string[], jackknife?: boolean): number; + function s(candidate: string, reference: string | string[], jackknife?: boolean): number; +}