Add support for rouge NLP metric

This commit is contained in:
Ian Webster
2023-06-10 15:37:24 -07:00
parent d2b093ab1e
commit 0c2f5d4c8c
8 changed files with 100 additions and 3 deletions

View File

@@ -98,9 +98,10 @@ See [Test assertions](https://promptfoo.dev/docs/configuration/expected-outputs)
| `is-json` | output is valid json |
| `contains-json` | output contains valid json |
| `javascript` | provided Javascript function validates the output |
| `webhook` | provided webhook returns `{pass: true} |
| `webhook` | provided webhook returns `{pass: true}` |
| `similar` | embeddings and cosine similarity are above a threshold |
| `llm-rubric` | LLM output matches a given rubric, using a Language Model to grade output |
| `rouge-n` | Rouge-N score is above a given threshold |
Every test type can be negated by prepending `not-`. For example, `not-equals` or `not-regex`.

28
package-lock.json generated
View File

@@ -27,6 +27,7 @@
"node-fetch": "^2.6.7",
"nunjucks": "^3.2.4",
"opener": "^1.5.2",
"rouge": "^1.0.3",
"socket.io": "^4.6.1",
"tiny-invariant": "^1.3.1",
"winston": "^3.8.2"
@@ -3973,6 +3974,12 @@
"signal-exit": "^3.0.2"
}
},
"node_modules/lodash-node": {
"version": "2.4.1",
"resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz",
"integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg==",
"deprecated": "This package is discontinued. Use lodash@^4.0.0."
},
"node_modules/lodash.clonedeep": {
"version": "4.5.0",
"resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz",
@@ -4665,6 +4672,14 @@
"node": ">=10"
}
},
"node_modules/rouge": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz",
"integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==",
"dependencies": {
"lodash-node": "^2.4.1"
}
},
"node_modules/safe-buffer": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
@@ -8647,6 +8662,11 @@
"signal-exit": "^3.0.2"
}
},
"lodash-node": {
"version": "2.4.1",
"resolved": "https://registry.npmjs.org/lodash-node/-/lodash-node-2.4.1.tgz",
"integrity": "sha512-egEt8eNQp2kZWRmngahiqMoDCDCENv3uM188S7Ed5t4k3v6RrLELXC+FqLNMUnhCo7gvQX3G1V8opK/Lcslahg=="
},
"lodash.clonedeep": {
"version": "4.5.0",
"resolved": "https://registry.npmjs.org/lodash.clonedeep/-/lodash.clonedeep-4.5.0.tgz",
@@ -9130,6 +9150,14 @@
"integrity": "sha512-X2UW6Nw3n/aMgDVy+0rSqgHlv39WZAlZrXCdnbyEiKm17DSqHX4MmQMaST3FbeWR5FTuRcUwYAziZajji0Y7mg==",
"dev": true
},
"rouge": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/rouge/-/rouge-1.0.3.tgz",
"integrity": "sha512-YCt74Dxsi99E8/uh943FTa80EmGboaOu1ij4q8WD4EAGyvyWYaH7MRHorrDbGgLY7iFUwDwyW/g9KJZx7D5fUQ==",
"requires": {
"lodash-node": "^2.4.1"
}
},
"safe-buffer": {
"version": "5.2.1",
"resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",

View File

@@ -77,6 +77,7 @@
"node-fetch": "^2.6.7",
"nunjucks": "^3.2.4",
"opener": "^1.5.2",
"rouge": "^1.0.3",
"socket.io": "^4.6.1",
"tiny-invariant": "^1.3.1",
"winston": "^3.8.2"

View File

@@ -1,3 +1,4 @@
import rouge from 'rouge';
import invariant from 'tiny-invariant';
import nunjucks from 'nunjucks';
@@ -18,6 +19,31 @@ const SIMILAR_REGEX = /similar(?::|\((\d+(\.\d+)?)\):)/;
const DEFAULT_SEMANTIC_SIMILARITY_THRESHOLD = 0.8;
function handleRougeScore(
baseType: 'rouge-n',
assertion: Assertion,
expected: string | string[],
output: string,
inverted: boolean,
): GradingResult {
const fnName = baseType[baseType.length - 1] as 'n' | 'l' | 's';
const rougeMethod = rouge[fnName];
const score = rougeMethod(output, expected);
console.log(output, expected, score);
const pass = score >= (assertion.threshold || 0.75) != inverted;
return {
pass,
reason: pass
? `${baseType.toUpperCase()} score ${score} is greater than or equal to threshold ${
assertion.threshold || 0.75
}`
: `${baseType.toUpperCase()} score ${score} is less than threshold ${
assertion.threshold || 0.75
}`,
};
}
export async function runAssertions(test: AtomicTestCase, output: string): Promise<GradingResult> {
const tokensUsed = {
total: 0,
@@ -248,6 +274,11 @@ ${assertion.value}`,
};
}
if (baseType === 'rouge-n') {
invariant(assertion.value, '"rouge" assertion type must a value (string or string array)');
return handleRougeScore(baseType, assertion, assertion.value, output, inverse);
}
throw new Error('Unknown assertion type: ' + assertion.type);
}

View File

@@ -125,7 +125,10 @@ type BaseAssertionTypes =
| 'javascript'
| 'similar'
| 'llm-rubric'
| 'webhook';
| 'webhook'
| 'rouge-n'
| 'rouge-s'
| 'rouge-l';
type NotPrefixed<T extends string> = `not-${T}`;

View File

@@ -452,6 +452,29 @@ describe('runAssertion', () => {
expect(result.pass).toBeFalsy();
expect(result.reason).toBe('Webhook error: Webhook response status: 500');
});
// Test for rouge-n assertion
const rougeNAssertion: Assertion = {
type: 'rouge-n',
value: 'This is the expected output.',
threshold: 0.75,
};
it('should pass when the rouge-n assertion passes', async () => {
const output = 'This is the expected output.';
const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output);
expect(result.pass).toBeTruthy();
expect(result.reason).toBe('ROUGE-N score 1 is greater than or equal to threshold 0.75');
});
it('should fail when the rouge-n assertion fails', async () => {
const output = 'some different output';
const result: GradingResult = await runAssertion(rougeNAssertion, {} as AtomicTestCase, output);
expect(result.pass).toBeFalsy();
expect(result.reason).toBe('ROUGE-N score 0.2 is less than threshold 0.75');
});
});
describe('assertionFromString', () => {

View File

@@ -11,6 +11,6 @@
"sourceMap": true,
"strict": true
},
"include": ["src/"],
"include": ["src/", "typings/**/*"],
"exclude": ["node_modules", "dist", "src/web/client/**/*"]
}

10
typings/rouge.d.ts vendored Normal file
View File

@@ -0,0 +1,10 @@
declare module 'rouge' {
function n(
candidate: string,
reference: string | string[],
n?: number,
jackknife?: boolean,
): number;
function l(candidate: string, reference: string | string[], jackknife?: boolean): number;
function s(candidate: string, reference: string | string[], jackknife?: boolean): number;
}