Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const StyledEuiModal = styled(EuiModal)`
`;

/**
* Modal container for Security Assistant conversations, receiving the page contents as context, plus whatever
* Modal container for Elastic AI Assistant conversations, receiving the page contents as context, plus whatever
* component currently has focus and any specific context it may provide through the SAssInterface.
*/
export const AssistantOverlay: React.FC = React.memo(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
*/

import React from 'react';
import { fireEvent, render, screen } from '@testing-library/react';
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
import userEvent from '@testing-library/user-event';

import { TestProviders } from '../../mock/test_providers/test_providers';
import type { PromptContext } from '../prompt_context/types';
import type { PromptContext, SelectedPromptContext } from '../prompt_context/types';
import { ContextPills } from '.';

const mockPromptContexts: Record<string, PromptContext> = {
Expand All @@ -30,16 +30,22 @@ const mockPromptContexts: Record<string, PromptContext> = {
},
};

const defaultProps = {
defaultAllow: [],
defaultAllowReplacement: [],
promptContexts: mockPromptContexts,
};

describe('ContextPills', () => {
beforeEach(() => jest.clearAllMocks());

it('renders the context pill descriptions', () => {
render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[]}
setSelectedPromptContextIds={jest.fn()}
{...defaultProps}
selectedPromptContexts={{}}
setSelectedPromptContexts={jest.fn()}
/>
</TestProviders>
);
Expand All @@ -49,54 +55,74 @@ describe('ContextPills', () => {
});
});

it('invokes setSelectedPromptContextIds() when the prompt is NOT already selected', () => {
it('invokes setSelectedPromptContexts() when the prompt is NOT already selected', async () => {
const context = mockPromptContexts.context1;
const setSelectedPromptContextIds = jest.fn();
const setSelectedPromptContexts = jest.fn();

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[]} // <-- the prompt is NOT selected
setSelectedPromptContextIds={setSelectedPromptContextIds}
{...defaultProps}
selectedPromptContexts={{}} // <-- the prompt is NOT selected
setSelectedPromptContexts={setSelectedPromptContexts}
/>
</TestProviders>
);

userEvent.click(screen.getByTestId(`pillButton-${context.id}`));

expect(setSelectedPromptContextIds).toBeCalled();
await waitFor(() => {
expect(setSelectedPromptContexts).toBeCalled();
});
});

it('it does NOT invoke setSelectedPromptContextIds() when the prompt is already selected', () => {
it('it does NOT invoke setSelectedPromptContexts() when the prompt is already selected', async () => {
const context = mockPromptContexts.context1;
const setSelectedPromptContextIds = jest.fn();
const mockSelectedPromptContext: SelectedPromptContext = {
allow: [],
allowReplacement: [],
promptContextId: context.id,
rawData: 'test-raw-data',
};
const setSelectedPromptContexts = jest.fn();

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[context.id]} // <-- the context is already selected
setSelectedPromptContextIds={setSelectedPromptContextIds}
{...defaultProps}
selectedPromptContexts={{
[context.id]: mockSelectedPromptContext,
}} // <-- the context is already selected
setSelectedPromptContexts={setSelectedPromptContexts}
/>
</TestProviders>
);

// NOTE: this test uses `fireEvent` instead of `userEvent` to bypass the disabled button:
fireEvent.click(screen.getByTestId(`pillButton-${context.id}`));

expect(setSelectedPromptContextIds).not.toBeCalled();
await waitFor(() => {
expect(setSelectedPromptContexts).not.toBeCalled();
});
});

it('disables selected context pills', () => {
const context = mockPromptContexts.context1;
const mockSelectedPromptContext: SelectedPromptContext = {
allow: [],
allowReplacement: [],
promptContextId: context.id,
rawData: 'test-raw-data',
};

render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={[context.id]} // <-- context1 is selected
setSelectedPromptContextIds={jest.fn()}
{...defaultProps}
selectedPromptContexts={{
[context.id]: mockSelectedPromptContext,
}} // <-- the context is selected
setSelectedPromptContexts={jest.fn()}
/>
</TestProviders>
);
Expand All @@ -110,9 +136,9 @@ describe('ContextPills', () => {
render(
<TestProviders>
<ContextPills
promptContexts={mockPromptContexts}
selectedPromptContextIds={['context2']} // context1 is NOT selected
setSelectedPromptContextIds={jest.fn()}
{...defaultProps}
selectedPromptContexts={{}} // context1 is NOT selected
setSelectedPromptContexts={jest.fn()}
/>
</TestProviders>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,57 @@ import React, { useCallback, useMemo } from 'react';
// eslint-disable-next-line @kbn/eslint/module_migration
import styled from 'styled-components';

import type { PromptContext } from '../prompt_context/types';
import { getNewSelectedPromptContext } from '../../data_anonymization/get_new_selected_prompt_context';
import type { PromptContext, SelectedPromptContext } from '../prompt_context/types';

const PillButton = styled(EuiButton)`
margin-right: ${({ theme }) => theme.eui.euiSizeXS};
`;

interface Props {
defaultAllow: string[];
defaultAllowReplacement: string[];
promptContexts: Record<string, PromptContext>;
selectedPromptContextIds: string[];
setSelectedPromptContextIds: React.Dispatch<React.SetStateAction<string[]>>;
selectedPromptContexts: Record<string, SelectedPromptContext>;
setSelectedPromptContexts: React.Dispatch<
React.SetStateAction<Record<string, SelectedPromptContext>>
>;
}

const ContextPillsComponent: React.FC<Props> = ({
defaultAllow,
defaultAllowReplacement,
promptContexts,
selectedPromptContextIds,
setSelectedPromptContextIds,
selectedPromptContexts,
setSelectedPromptContexts,
}) => {
const sortedPromptContexts = useMemo(
() => sortBy('description', Object.values(promptContexts)),
[promptContexts]
);

const selectPromptContext = useCallback(
(id: string) => {
if (!selectedPromptContextIds.includes(id)) {
setSelectedPromptContextIds((prev) => [...prev, id]);
async (id: string) => {
if (selectedPromptContexts[id] == null && promptContexts[id] != null) {
const newSelectedPromptContext = await getNewSelectedPromptContext({
defaultAllow,
defaultAllowReplacement,
promptContext: promptContexts[id],
});

setSelectedPromptContexts((prev) => ({
...prev,
[id]: newSelectedPromptContext,
}));
}
},
[selectedPromptContextIds, setSelectedPromptContextIds]
[
defaultAllow,
defaultAllowReplacement,
promptContexts,
selectedPromptContexts,
setSelectedPromptContexts,
]
);

return (
Expand All @@ -49,7 +71,7 @@ const ContextPillsComponent: React.FC<Props> = ({
<EuiToolTip content={tooltip}>
<PillButton
data-test-subj={`pillButton-${id}`}
disabled={selectedPromptContextIds.includes(id)}
disabled={selectedPromptContexts[id] != null}
iconSide="left"
iconType="plus"
onClick={() => selectPromptContext(id)}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { invert } from 'lodash/fp';

import { getAnonymizedValue } from '.';

jest.mock('uuid', () => ({
v4: () => 'test-uuid',
}));

describe('getAnonymizedValue', () => {
beforeEach(() => jest.clearAllMocks());

it('returns a new UUID when currentReplacements is not provided', () => {
const currentReplacements = undefined;
const rawValue = 'test';

const result = getAnonymizedValue({ currentReplacements, rawValue });

expect(result).toBe('test-uuid');
});

it('returns an existing anonymized value when currentReplacements contains an entry for it', () => {
const rawValue = 'test';
const currentReplacements = { anonymized: 'test' };
const rawValueToReplacement = invert(currentReplacements);

const result = getAnonymizedValue({ currentReplacements, rawValue });
expect(result).toBe(rawValueToReplacement[rawValue]);
});

it('returns a new UUID with currentReplacements if no existing match', () => {
const rawValue = 'test';
const currentReplacements = { anonymized: 'other' };

const result = getAnonymizedValue({ currentReplacements, rawValue });

expect(result).toBe('test-uuid');
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { invert } from 'lodash/fp';
import { v4 } from 'uuid';

export const getAnonymizedValue = ({
currentReplacements,
rawValue,
}: {
currentReplacements: Record<string, string> | undefined;
rawValue: string;
}): string => {
if (currentReplacements != null) {
const rawValueToReplacement: Record<string, string> = invert(currentReplacements);
const existingReplacement: string | undefined = rawValueToReplacement[rawValue];

return existingReplacement != null ? existingReplacement : v4();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank-you for the explanation here while pairing -- appreciate the context! 😅

}

return v4();
};
Loading