Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ export const useAssistantOverlay = (
*/
replacements?: Replacements | null
): UseAssistantOverlay => {
const { http } = useAssistantContext();
const { http, inferenceEnabled } = useAssistantContext();
const { data: connectors } = useLoadConnectors({
http,
inferenceEnabled,
});

const defaultConnector = useMemo(() => getDefaultConnector(connectors), [connectors]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ export const ConnectorSelector: React.FC<Props> = React.memo(
setIsOpen,
stats = null,
}) => {
const { actionTypeRegistry, http, assistantAvailability } = useAssistantContext();
const { actionTypeRegistry, http, assistantAvailability, inferenceEnabled } =
useAssistantContext();
// Connector Modal State
const [isConnectorModalVisible, setIsConnectorModalVisible] = useState<boolean>(false);
const { data: actionTypes } = useLoadActionTypes({ http });

const [selectedActionType, setSelectedActionType] = useState<ActionType | null>(null);

const { data: aiConnectors, refetch: refetchConnectors } = useLoadConnectors({ http });
const { data: aiConnectors, refetch: refetchConnectors } = useLoadConnectors({
http,
inferenceEnabled,
});

const localIsDisabled = isDisabled || !assistantAvailability.hasConnectorsReadPrivilege;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ export const ConnectorSetup = ({
);
const { setApiConfig } = useConversation();
// Access all conversations so we can add connector to all on initial setup
const { actionTypeRegistry, http } = useAssistantContext();
const { actionTypeRegistry, http, inferenceEnabled } = useAssistantContext();

const { refetch: refetchConnectors } = useLoadConnectors({ http });
const { refetch: refetchConnectors } = useLoadConnectors({ http, inferenceEnabled });

const { data: actionTypes } = useLoadActionTypes({ http });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import { waitFor, renderHook } from '@testing-library/react';
import { useLoadConnectors, Props } from '.';
import { mockConnectors } from '../../mock/connectors';
import { TestProviders } from '../../mock/test_providers/test_providers';
import React, { ReactNode } from 'react';

const mockConnectorsAndExtras = [
...mockConnectors,
Expand Down Expand Up @@ -55,13 +54,6 @@ const toasts = {
};
const defaultProps = { http, toasts } as unknown as Props;

const createWrapper = (inferenceEnabled = false) => {
// eslint-disable-next-line react/display-name
return ({ children }: { children: ReactNode }) => (
<TestProviders providerContext={{ inferenceEnabled }}>{children}</TestProviders>
);
};

describe('useLoadConnectors', () => {
beforeEach(() => {
jest.clearAllMocks();
Expand Down Expand Up @@ -91,9 +83,12 @@ describe('useLoadConnectors', () => {
});

it('includes preconfigured .inference results when inferenceEnabled is true', async () => {
const { result } = renderHook(() => useLoadConnectors(defaultProps), {
wrapper: createWrapper(true),
});
const { result } = renderHook(
() => useLoadConnectors({ ...defaultProps, inferenceEnabled: true }),
{
wrapper: TestProviders,
}
);
await waitFor(() => {
expect(result.current.data).toStrictEqual(
mockConnectors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import type { IHttpFetchError } from '@kbn/core-http-browser';
import { HttpSetup } from '@kbn/core-http-browser';
import { IToasts } from '@kbn/core-notifications-browser';
import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants';
import { useAssistantContext } from '../../assistant_context';
import { AIConnector } from '../connector_selector';
import * as i18n from '../translations';

Expand All @@ -26,15 +25,16 @@ const QUERY_KEY = ['elastic-assistant, load-connectors'];
export interface Props {
http: HttpSetup;
toasts?: IToasts;
inferenceEnabled?: boolean;
}

const actionTypes = ['.bedrock', '.gen-ai', '.gemini'];

export const useLoadConnectors = ({
http,
toasts,
inferenceEnabled = false,
}: Props): UseQueryResult<AIConnector[], IHttpFetchError> => {
const { inferenceEnabled } = useAssistantContext();
if (inferenceEnabled) {
actionTypes.push('.inference');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ const actionType = { id: '.bedrock', name: 'Bedrock', iconClass: 'logoBedrock' }
mockServices.triggersActionsUi.actionTypeRegistry.register(
actionType as unknown as ActionTypeModel
);

const inferenceActionType = { id: '.inference', name: 'Inference', iconClass: 'logoInference' };
mockServices.triggersActionsUi.actionTypeRegistry.register(
inferenceActionType as unknown as ActionTypeModel
);

jest.mock('@kbn/elastic-assistant/impl/connectorland/use_load_action_types', () => ({
useLoadActionTypes: jest.fn(() => ({ data: [actionType] })),
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,23 @@ interface ConnectorStepProps {
}
export const ConnectorStep = React.memo<ConnectorStepProps>(({ connector }) => {
const { euiTheme } = useEuiTheme();
const { http, notifications } = useKibana().services;
const { http, notifications, triggersActionsUi } = useKibana().services;
const { setConnector, completeStep } = useActions();

const [connectors, setConnectors] = useState<AIConnector[]>();
let inferenceEnabled: boolean = false;

if (triggersActionsUi.actionTypeRegistry.has('.inference')) {
inferenceEnabled = triggersActionsUi.actionTypeRegistry.get('.inference') as unknown as boolean;
}
if (inferenceEnabled) {
AllowedActionTypeIds.push('.inference');
}
const {
isLoading,
data: aiConnectors,
refetch: refetchConnectors,
} = useLoadConnectors({ http, toasts: notifications.toasts });
} = useLoadConnectors({ http, toasts: notifications.toasts, inferenceEnabled });

useEffect(() => {
if (aiConnectors != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ export const getLLMType = (actionTypeId: string): string | undefined => {
[`.gen-ai`]: `openai`,
[`.bedrock`]: `bedrock`,
[`.gemini`]: `gemini`,
[`.inference`]: `inference`,
};
return llmTypeDictionary[actionTypeId];
};

export const getLLMClass = (llmType?: string) =>
llmType === 'openai'
llmType === 'openai' || llmType === 'inference'
? ActionsClientChatOpenAI
: llmType === 'bedrock'
? ActionsClientBedrockChatModel
Expand Down