1
0
mirror of synced 2026-05-22 22:53:35 +00:00

customer-support-assistant v1 (#103)

* feat(customer-support): updated code

* Delete 02-use-cases/customer-support-assistant/Dockerfile

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>

* Update .gitignore

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>

---------

Signed-off-by: Eashan Kaushik <50113394+EashanKaushik@users.noreply.github.com>
This commit is contained in:
Eashan Kaushik
2025-07-21 11:34:00 -04:00
committed by GitHub
parent 176ef7bd91
commit 88e19eddc9
48 changed files with 2227 additions and 2240 deletions
@@ -0,0 +1,69 @@
# Build artifacts
build/
dist/
*.egg-info/
*.egg
# Python cache
__pycache__/
__pycache__*
*.py[cod]
*$py.class
*.so
.Python
# Virtual environments
.venv/
.env
venv/
env/
ENV/
# Testing
.pytest_cache/
.coverage
.coverage*
htmlcov/
.tox/
*.cover
.hypothesis/
.mypy_cache/
.ruff_cache/
# Development
*.log
*.bak
*.swp
*.swo
*~
.DS_Store
# IDEs
.vscode/
.idea/
# Version control
.git/
.gitignore
.gitattributes
# Documentation
docs/
*.md
!README.md
# CI/CD
.github/
.gitlab-ci.yml
.travis.yml
# Project specific
tests/
# Bedrock AgentCore specific - keep config but exclude runtime files
.bedrock_agentcore.yaml
Dockerfile
.dockerignore
# Keep wheelhouse for offline installations
# wheelhouse/
@@ -0,0 +1,220 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
*.whl
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
#poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
#pdm.lock
#pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
#pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Cursor
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml
wheelhouse/
*.zip
credentials.json
hooks.ipynb
gateway.config
.agentcore.yaml
.bedrock_agentcore.yaml
model
Dockerfile
@@ -7,6 +7,35 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
![architecture](./images/architecture.png)
## Table of Contents
- [Customer Support Agent](#customer-support-agent)
- [Table of Contents](#table-of-contents)
- [Prerequisites](#prerequisites)
- [AWS Account Setup](#aws-account-setup)
- [Deploy](#deploy)
- [Sample Queries](#sample-queries)
- [Scripts](#scripts)
- [Amazon Bedrock AgentCore Gateway](#amazon-bedrock-agentcore-gateway)
- [Create Amazon Bedrock AgentCore Gateway](#create-amazon-bedrock-agentcore-gateway)
- [Delete Amazon Bedrock AgentCore Gateway](#delete-amazon-bedrock-agentcore-gateway)
- [Amazon Bedrock AgentCore Memory](#amazon-bedrock-agentcore-memory)
- [Create Amazon Bedrock AgentCore Memory](#create-amazon-bedrock-agentcore-memory)
- [Delete Amazon Bedrock AgentCore Memory](#delete-amazon-bedrock-agentcore-memory)
- [Cognito Credentials Provider](#cognito-credentials-provider)
- [Create Cognito Credentials Provider](#create-cognito-credentials-provider)
- [Delete Cognito Credentials Provider](#delete-cognito-credentials-provider)
- [Google Credentials Provider](#google-credentials-provider)
- [Create Credentials Provider](#create-credentials-provider)
- [Delete Credentials Provider](#delete-credentials-provider)
- [Agent Runtime](#agent-runtime)
- [Delete Agent Runtime](#delete-agent-runtime)
- [Cleanup](#cleanup)
- [🤝 Contributing](#-contributing)
- [📄 License](#-license)
- [🆘 Support](#-support)
- [🔄 Updates](#-updates)
## Prerequisites
### AWS Account Setup
@@ -26,24 +55,24 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
3. **Bedrock Model Access**: Enable access to Amazon Bedrock Anthropic Claude 4.0 models in your AWS region
- Navigate to [Amazon Bedrock Console](https://console.aws.amazon.com/bedrock/)
- Go to "Model access" and request access to:
- Anthropic Claude Sonnet models
- Anthropic Claude 4.0 Sonnet model
- Anthropic Claude 3.5 Haiku model
- [Bedrock Model Access Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html)
4. **Python 3.8+**: Required for running the application
4. **Python 3.10+**: Required for running the application
- [Python Downloads](https://www.python.org/downloads/)
5. **Create OAuth 2.0 credentials for calendar access** : For Google Calendar integration
- Follow [Google OAuth Setup](./prerequisite/google_oauth_setup.md)
6. **Install [uv](https://docs.astral.sh/uv/getting-started/installation/) package manager**.
## Deploy
1. Create infrastructure
1. **Create infrastructure**
```bash
python -m venv .venv
source .venv/bin/activate
pip install -r dev-requirements.txt
chmod +x scripts/prereq.sh
./scripts/prereq.sh
@@ -52,17 +81,18 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
./scripts/list_ssm_parameters.sh
```
2. Create Agentcore Gateway
> [!CAUTION]
> Please prefix all the resource name with `customersupport`.
2. **Create Agentcore Gateway**
```bash
python scripts/agentcore_gateway.py create --name customersupgateway
python scripts/agentcore_gateway.py create --name customersupport-gw
```
This create `gateway.config` file.
3. **Setup Agentcore Identity**
3. Setup Agentcore Identity
- Setup Cognito Credential Provider
- **Setup Cognito Credential Provider**
```bash
python scripts/cognito_credentials_provider.py create --name customersupport-gateways
@@ -70,7 +100,7 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
python test/test_gateway.py --prompt "Check warranty with serial number MNO33333333"
```
- Setup Google Credential Provider
- **Setup Google Credential Provider**
Follow instructions to setup [Google Credentials](./prerequisite/google_oauth_setup.md).
@@ -80,24 +110,26 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
python test/test_google_tool.py
```
4. Create Memory
4. **Create Memory**
```bash
python scripts/agentcore_memory.py create --name customersupport
python test/test_memory.py load-conversation
python test/test_memory.py load-prompt "My preference of gaming console is V5 Pro"
python test/test_memory.py list-memory
```
5. Setup Agent Runtime
5. **Setup Agent Runtime**
```bash
agentcore configure --entrypoint main.py -er arn:aws:iam::<Account-Id>:role/<Role> --name customersupport<AgentName>
```
Use `./scripts/list_ssm_parameters.sh` to fill:
- `Role = ValueOf(/app/customersupport/agentcore/agentcore_iam_role)`
- `Oath Discovery URL = ValueOf(/app/customersupport/agentcore/cognito_discovery_url)`
- `Oath client id = ValueOf(/app/customersupport/agentcore/web_client_id)`.
- `OAuth Discovery URL = ValueOf(/app/customersupport/agentcore/cognito_discovery_url)`
- `OAuth client id = ValueOf(/app/customersupport/agentcore/web_client_id)`.
![configure](./images/runtime_configure.png)
@@ -107,13 +139,29 @@ This is a customer support agent implementation using AWS Bedrock AgentCore fram
python test/test_agent.py customersupport<AgentName> -p "Hi"
```
6. Local Host Streamlit UI
![code](./images/code.png)
6. **Local Host Streamlit UI**
> [!CAUTION]
> Streamlit app should only run on port `8501`.
```bash
pip install streamlit
streamlit run app.py -- --agent=customersupport<AgentName>
streamlit run app.py --server.port 8501 -- --agent=customersupport<AgentName>
```
## Sample Queries
1. I have a Gaming Console Pro device , I want to check my warranty status, warranty serial number is MNO33333333.
2. What are the warranty support guidelines ?
3. Whats my agenda for today?
4. Can you create an event to setup call to renew warranty?
5. I have overheating issues with my device, help me debug.
## Scripts
### Amazon Bedrock AgentCore Gateway
@@ -197,6 +245,21 @@ python scripts/google_credentials_provider.py delete --name customersupport-goog
python scripts/google_credentials_provider.py delete --confirm
```
### Agent Runtime
#### Delete Agent Runtime
```bash
# Delete specific agent runtime by name
python scripts/agentcore_agent_runtime.py customersupport
# Preview what would be deleted without actually deleting
python scripts/agentcore_agent_runtime.py --dry-run customersupport
# Delete any agent runtime by name
python scripts/agentcore_agent_runtime.py <agent-name>
```
## Cleanup
```bash
@@ -207,7 +270,10 @@ python scripts/google_credentials_provider.py delete
python scripts/cognito_credentials_provider.py delete
python scripts/agentcore_memory.py delete
python scripts/agentcore_gateway.py delete
python scripts/agencore_agent_runtime.py delete
python scripts/agentcore_agent_runtime.py customersupport<AgentName>
rm .agentcore.yaml
rm .bedrock_agentcore.yaml
```
## 🤝 Contributing
@@ -0,0 +1,11 @@
from .utils import get_ssm_parameter
from bedrock_agentcore.identity.auth import requires_access_token
@requires_access_token(
provider_name=get_ssm_parameter("/app/customersupport/agentcore/cognito_provider"),
scopes=[], # Optional unless required
auth_flow="M2M",
)
async def get_gateway_access_token(access_token: str):
return access_token
@@ -1,11 +1,11 @@
from typing import List
from strands.tools.mcp import MCPClient
from strands import Agent
from strands.models import BedrockModel
from .utils import get_ssm_parameter
from agent_config.memory_hook_provider import MemoryHook
from mcp.client.streamable_http import streamablehttp_client
from strands import Agent
from strands_tools import current_time, retrieve
from memory_hook_provider import MemoryHook
from scripts.utils import read_config
from strands.models import BedrockModel
from strands.tools.mcp import MCPClient
from typing import List
class CustomerSupport:
@@ -41,15 +41,13 @@ class CustomerSupport:
"""
)
self.gateway_config = read_config("gateway.config")
print(
f"Gateway Endpoint - MCP URL: {self.gateway_config['gateway']['gateway_url']}mcp"
)
gateway_url = get_ssm_parameter("/app/customersupport/agentcore/gateway_url")
print(f"Gateway Endpoint - MCP URL: {gateway_url}")
try:
self.gateway_client = MCPClient(
lambda: streamablehttp_client(
f"{self.gateway_config['gateway']['gateway_url']}",
gateway_url,
headers={"Authorization": f"Bearer {bearer_token}"},
)
)
@@ -0,0 +1,53 @@
from .context import (
get_agent_ctx,
get_gateway_token_ctx,
get_response_queue_ctx,
set_agent_ctx,
)
from .memory_hook_provider import MemoryHook
from .utils import get_ssm_parameter
from agent_config.agent import CustomerSupport # Your custom agent class
from agent_config.tools.google import get_calendar_events_today, create_calendar_event
from bedrock_agentcore.memory import MemoryClient
import logging
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
memory_client = MemoryClient()
async def agent_task(user_message: str, session_id: str, actor_id: str):
agent = get_agent_ctx()
response_queue = get_response_queue_ctx()
gateway_access_token = get_gateway_token_ctx()
if not gateway_access_token:
raise RuntimeError("Gateway Access token is none")
try:
if agent is None:
memory_hook = MemoryHook(
memory_client=memory_client,
memory_id=get_ssm_parameter("/app/customersupport/agentcore/memory_id"),
actor_id=actor_id,
session_id=session_id,
)
agent = CustomerSupport(
bearer_token=gateway_access_token,
memory_hook=memory_hook,
tools=[get_calendar_events_today, create_calendar_event],
)
set_agent_ctx(agent)
async for chunk in agent.stream(user_query=user_message, session_id=session_id):
await response_queue.put(chunk)
except Exception as e:
logger.exception("Agent execution failed.")
await response_queue.put(f"Error: {str(e)}")
finally:
await response_queue.finish()
@@ -0,0 +1,45 @@
from .agent import CustomerSupport
from contextvars import ContextVar
from typing import Optional
import asyncio
# Context variables for application state
google_token_ctx: ContextVar[Optional[str]] = ContextVar("google_token", default=None)
gateway_token_ctx: ContextVar[Optional[str]] = ContextVar("gateway_token", default=None)
response_queue_ctx: ContextVar[Optional[asyncio.Queue]] = ContextVar(
"response_queue", default=None
)
agent_ctx: ContextVar[Optional[CustomerSupport]] = ContextVar("agent", default=None)
# Helper functions
def get_google_token_ctx() -> Optional[str]:
return google_token_ctx.get()
def set_google_token_ctx(token: str) -> None:
google_token_ctx.set(token)
def get_response_queue_ctx() -> Optional[asyncio.Queue]:
return response_queue_ctx.get()
def set_response_queue_ctx(queue: asyncio.Queue) -> None:
response_queue_ctx.set(queue)
def get_gateway_token_ctx() -> Optional[str]:
return gateway_token_ctx.get()
def set_gateway_token_ctx(token: str) -> None:
gateway_token_ctx.set(token)
def get_agent_ctx() -> Optional[CustomerSupport]:
return agent_ctx.get()
def set_agent_ctx(agent: CustomerSupport) -> None:
agent_ctx.set(agent)
@@ -0,0 +1,107 @@
from bedrock_agentcore.memory import MemoryClient
from strands.hooks.events import AgentInitializedEvent, MessageAddedEvent
from strands.hooks.registry import HookProvider, HookRegistry
import copy
class MemoryHook(HookProvider):
def __init__(
self,
memory_client: MemoryClient,
memory_id: str,
actor_id: str,
session_id: str,
):
self.memory_client = memory_client
self.memory_id = memory_id
self.actor_id = actor_id
self.session_id = session_id
def on_agent_initialized(self, event: AgentInitializedEvent):
"""Load recent conversation history when agent starts"""
try:
# Load the last 5 conversation turns from memory
recent_turns = self.memory_client.get_last_k_turns(
memory_id=self.memory_id,
actor_id=self.actor_id,
session_id=self.session_id,
k=5,
)
if recent_turns:
# Format conversation history for context
context_messages = []
for turn in recent_turns:
for message in turn:
role = "assistant" if message["role"] == "ASSISTANT" else "user"
content = message["content"]["text"]
context_messages.append(
{"role": role, "content": [{"text": content}]}
)
# context = "\n".join(context_messages)
# Add context to agent's system prompt.
event.agent.system_prompt += """
Do not respond to user preferences or user facts.
Strictly use user preferences and user facts to know more about the user.
"""
event.agent.messages = context_messages
except Exception as e:
print(f"Memory load error: {e}")
def _add_context_user_query(
self, namespace: str, query: str, init_content: str, event: MessageAddedEvent
):
content = None
memories = self.memory_client.retrieve_memories(
memory_id=self.memory_id, namespace=namespace, query=query, top_k=3
)
for memory in memories:
if not content:
content = "\n\n" + init_content + "\n\n"
content += memory["content"]["text"]
if content:
event.agent.messages[-1]["content"][0]["text"] += content + "\n\n"
def on_message_added(self, event: MessageAddedEvent):
"""Store messages in memory"""
messages = copy.deepcopy(event.agent.messages)
try:
if messages[-1]["role"] == "user" or messages[-1]["role"] == "assistant":
if "text" not in messages[-1]["content"][0]:
return
if messages[-1]["role"] == "user":
self._add_context_user_query(
namespace=f"support/user/{self.actor_id}/preferences",
query=messages[-1]["content"][0]["text"],
init_content="These are user preferences:",
event=event,
)
self._add_context_user_query(
namespace=f"support/user/{self.actor_id}/facts",
query=messages[-1]["content"][0]["text"],
init_content="These are user facts:",
event=event,
)
self.memory_client.save_conversation(
memory_id=self.memory_id,
actor_id=self.actor_id,
session_id=self.session_id,
messages=[
(messages[-1]["content"][0]["text"], messages[-1]["role"])
],
)
except Exception as e:
print(messages[-1])
raise RuntimeError(f"Memory save error: {e}")
def register_hooks(self, registry: HookRegistry):
registry.add_callback(MessageAddedEvent, self.on_message_added)
registry.add_callback(AgentInitializedEvent, self.on_agent_initialized)
@@ -0,0 +1,22 @@
import asyncio
# Queue for streaming responses
class StreamingQueue:
def __init__(self):
self.finished = False
self.queue = asyncio.Queue()
async def put(self, item):
await self.queue.put(item)
async def finish(self):
self.finished = True
await self.queue.put(None)
async def stream(self):
while True:
item = await self.queue.get()
if item is None and self.finished:
break
yield item
@@ -0,0 +1,140 @@
from ..context import get_google_token_ctx, get_response_queue_ctx, set_google_token_ctx
from bedrock_agentcore.identity.auth import requires_access_token
from datetime import datetime, timedelta
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from scripts.utils import get_ssm_parameter
from strands import tool
import json
SCOPES = ["https://www.googleapis.com/auth/calendar"]
async def on_auth_url(url: str):
response_queue = get_response_queue_ctx()
print(f"Authorization url: {url}")
await response_queue.put(f"Authorization url: {url}")
# This annotation helps agent developer to obtain access tokens from external applications
@requires_access_token(
provider_name=get_ssm_parameter("/app/customersupport/agentcore/google_provider"),
scopes=SCOPES, # Google OAuth2 scopes
auth_flow="USER_FEDERATION", # On-behalf-of user (3LO) flow
on_auth_url=on_auth_url, # prints authorization URL to console
force_authentication=True,
into="access_token",
)
async def get_google_access_token(access_token: str):
return access_token
@tool(
name="Create_calendar_event",
description="Creates a new event on your Google Calendar",
)
async def create_calendar_event() -> str:
google_access_token = get_google_token_ctx() # Get from context instead of global
if not google_access_token:
try:
google_access_token = await get_google_access_token(
access_token=google_access_token
)
set_google_token_ctx(token=google_access_token)
except Exception as e:
return "Error Authentication with Google: " + str(e)
creds = Credentials(token=google_access_token, scopes=SCOPES)
try:
service = build("calendar", "v3", credentials=creds)
# Define event details
start_time = datetime.now() + timedelta(hours=1)
end_time = start_time + timedelta(hours=1)
event = {
"summary": "Customer Support Call",
"location": "Virtual",
"description": "This event was created by Customer Support Assistant.",
"start": {
"dateTime": start_time.isoformat() + "Z", # UTC time
"timeZone": "UTC",
},
"end": {
"dateTime": end_time.isoformat() + "Z",
"timeZone": "UTC",
},
}
created_event = (
service.events().insert(calendarId="primary", body=event).execute()
)
return json.dumps(
{
"event_created": True,
"event_id": created_event.get("id"),
"htmlLink": created_event.get("htmlLink"),
}
)
except HttpError as error:
return json.dumps({"error": str(error), "event_created": False})
except Exception as e:
return json.dumps({"error": str(e), "event_created": False})
@tool(
name="Get_calendar_events_today",
description="Retrieves the calendar events for the day from your Google Calendar",
)
async def get_calendar_events_today() -> str:
google_access_token = get_google_token_ctx() # Get from context instead of global
if not google_access_token:
try:
google_access_token = await get_google_access_token(
access_token=google_access_token
)
set_google_token_ctx(token=google_access_token)
except Exception as e:
return "Error Authentication with Google: " + str(e)
# Create credentials from the provided access token
creds = Credentials(token=google_access_token, scopes=SCOPES)
try:
service = build("calendar", "v3", credentials=creds)
# Call the Calendar API
today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
today_end = today_start.replace(hour=23, minute=59, second=59)
# Format with CDT timezone (-05:00)
timeMin = today_start.strftime("%Y-%m-%dT00:00:00-05:00")
timeMax = today_end.strftime("%Y-%m-%dT23:59:59-05:00")
events_result = (
service.events()
.list(
calendarId="primary",
timeMin=timeMin,
timeMax=timeMax,
singleEvents=True,
orderBy="startTime",
)
.execute()
)
events = events_result.get("items", [])
print(events)
if not events:
return json.dumps({"events": []}) # Return empty events array as JSON
return json.dumps({"events": events}) # Return events wrapped in an object
except HttpError as error:
error_message = str(error)
return json.dumps({"error": error_message, "events": []})
except Exception as e:
error_message = str(e)
return json.dumps({"error": error_message, "events": []})
@@ -0,0 +1,9 @@
import boto3
def get_ssm_parameter(name: str, with_decryption: bool = True) -> str:
ssm = boto3.client("ssm")
response = ssm.get_parameter(Name=name, WithDecryption=with_decryption)
return response["Parameter"]["Value"]
+3 -589
View File
@@ -1,590 +1,4 @@
# import ast
import sys
from typing import Any, Optional
import streamlit as st
import requests
import base64
import hashlib
import os
import uuid
from urllib.parse import urlencode
from app_modules.main import main
# from streamlit_cookies_manager import EncryptedCookieManager
import json
import jwt
import time
import re
import urllib
from scripts.utils import read_config, get_aws_region, get_ssm_parameter
from streamlit_cookies_controller import CookieController
# ==== Configuration ====
AGENT_NAME = "default"
# crude way to parse args
if len(sys.argv) > 1:
for arg in sys.argv:
if arg.startswith("--agent="):
AGENT_NAME = arg.split("=")[1]
COGNITO_DOMAIN = get_ssm_parameter(
"/app/customersupport/agentcore/cognito_domain"
).replace("https://", "")
CLIENT_ID = get_ssm_parameter("/app/customersupport/agentcore/web_client_id")
REDIRECT_URI = "http://localhost:8501/"
SCOPES = "email openid profile"
# ==== Initialize cookies manager ====
cookies = CookieController()
st.set_page_config(layout="wide")
# if not cookies.ready():
# st.stop() # Wait for cookies to load
# ==== PKCE Helpers ====
def generate_pkce_pair():
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8").rstrip("=")
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode("utf-8")
.rstrip("=")
)
return code_verifier, code_challenge
# ==== Clickable URL Helpers ====
def make_urls_clickable(text):
"""Convert URLs in text to clickable HTML links."""
# Comprehensive URL regex pattern
url_pattern = r"https?://(?:[-\w.])+(?:\:[0-9]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:\#(?:[\w.])*)?)?"
def replace_url(match):
url = match.group(0)
# Clean URL and create clickable link with styling to match theme
return f'<a href="{url}" target="_blank" style="color:#4fc3f7;text-decoration:underline;">{url}</a>'
return re.sub(url_pattern, replace_url, text)
def create_safe_markdown_text(text, message_placeholder):
safe_text = text.encode("utf-16", "surrogatepass").decode("utf-16")
message_placeholder.markdown(safe_text, unsafe_allow_html=True)
# ==== Logout function ====
def logout():
cookies.remove("tokens")
# Clear cookies on logout as well (in case)
# cookies.remove("code_verifier")
# cookies.remove("code_challenge")
# cookies.remove("oauth_state")
# cookies.save()
del st.session_state["session_id"]
del st.session_state["messages"]
del st.session_state["agent_arn"]
del st.session_state["pending_assistant"]
del st.session_state["region"]
logout_url = f"https://{COGNITO_DOMAIN}/logout?" + urlencode(
{"client_id": CLIENT_ID, "logout_uri": REDIRECT_URI}
)
create_safe_markdown_text(
f'<meta http-equiv="refresh" content="0;url={logout_url}">', st
)
st.rerun()
# ==== Styles ====
st.markdown(
"""
<style>
body {
background: #181c24 !important;
}
.stApp {
background: #181c24 !important;
}
.css-1v0mbdj, .css-1dp5vir {
border-radius: 14px !important;
padding: 0.5rem 1rem !important;
}
.user-bubble {
background: #23272f;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: inline-block;
border: 1px solid #3a3f4b;
}
.assistant-bubble {
background: #0b2545;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: block;
border: 1px solid #298dff;
animation: fadeInUp 0.3s ease-out;
white-space: pre-wrap;
word-wrap: break-word;
max-width: 100%;
}
.assistant-bubble.streaming {
border: 1px solid #4fc3f7;
box-shadow: 0 0 10px rgba(79, 195, 247, 0.3);
animation: pulse-border 2s infinite, fadeInUp 0.3s ease-out;
}
.thinking-bubble {
background: #0b2545;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: inline-block;
border: 1px solid #298dff;
animation: thinking-pulse 1.5s infinite, fadeInUp 0.3s ease-out;
}
.typing-cursor::after {
content: '';
color: #4fc3f7;
animation: cursor-blink 1s infinite;
margin-left: 2px;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes pulse-border {
0%, 100% {
border-color: #298dff;
box-shadow: 0 0 5px rgba(41, 141, 255, 0.3);
}
50% {
border-color: #4fc3f7;
box-shadow: 0 0 15px rgba(79, 195, 247, 0.6);
}
}
@keyframes thinking-pulse {
0%, 100% {
opacity: 1;
transform: scale(1);
}
50% {
opacity: 0.8;
transform: scale(1.02);
}
}
@keyframes cursor-blink {
0%, 50% {
opacity: 1;
}
51%, 100% {
opacity: 0;
}
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-30px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
.sidebar .sidebar-content {
background: #181c24 !important;
}
h1, h2, h3, h4, h5, h6, p, label, .css-10trblm, .css-1cpxqw2 {
color: #e6e6e6 !important;
}
hr {
border: 1px solid #298dff !important;
}
</style>
""",
unsafe_allow_html=True,
)
# ==== Handle OAuth callback ====
query_params = st.query_params
if query_params.get("code") and query_params.get("state") and not cookies.get("tokens"):
auth_code = query_params.get("code")
returned_state = query_params.get("state")
code_verifier = cookies.get("code_verifier")
state = cookies.get("oauth_state")
print(f"Check state {cookies.get('oauth_state')} against {returned_state}")
if not state:
st.stop()
else:
if returned_state != state:
st.error("State mismatch - potential CSRF detected")
st.stop()
# Exchange authorization code for tokens
token_url = f"https://{COGNITO_DOMAIN}/oauth2/token"
data = {
"grant_type": "authorization_code",
"client_id": CLIENT_ID,
"code": auth_code,
"redirect_uri": REDIRECT_URI,
"code_verifier": code_verifier,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.ok:
tokens = response.json()
# st.success("Logged in successfully!")
# Clear the cookies after login to avoid reuse of old code_verifier and state
cookies.set("tokens", json.dumps(tokens))
cookies.remove("code_verifier")
cookies.remove("code_challenge")
cookies.remove("oauth_state")
# cookies.save()
st.query_params.clear()
# st.rerun()
else:
st.error(f"Failed to exchange token: {response.status_code} - {response.text}")
# ==== Sidebar with welcome, tokens, and logout ====
st.sidebar.title("Access Tokens")
def invoke_endpoint(
agent_arn: str,
payload,
session_id: str,
bearer_token: Optional[str], # noqa: F821
endpoint_name: str = "DEFAULT",
) -> Any:
"""Invoke agent endpoint using HTTP request with bearer token.
Args:
agent_arn: Agent ARN to invoke
payload: Payload to send (dict or string)
session_id: Session ID for the request
bearer_token: Bearer token for authentication
endpoint_name: Endpoint name, defaults to "DEFAULT"
Returns:
Response from the agent endpoint
"""
# Escape agent ARN for URL
escaped_arn = urllib.parse.quote(agent_arn, safe="")
# Build URL
url = f"https://bedrock-agentcore.{st.session_state['region']}.amazonaws.com/runtimes/{escaped_arn}/invocations"
# Headers
headers = {
"Authorization": f"Bearer {bearer_token}",
"Content-Type": "application/json",
"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id,
}
# Parse the payload string back to JSON object to send properly
# This ensures consistent payload structure between boto3 and HTTP clients
try:
body = json.loads(payload) if isinstance(payload, str) else payload
except json.JSONDecodeError:
# Fallback for non-JSON strings - wrap in payload object
body = {"payload": payload}
try:
# Make request with timeout
response = requests.post(
url,
params={"qualifier": endpoint_name},
headers=headers,
json=body,
timeout=100,
stream=True,
)
last_data = False
for line in response.iter_lines(chunk_size=1):
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
last_data = True
line = line[6:]
yield line
elif line:
if last_data:
yield "\n" + line
last_data = False
except requests.exceptions.RequestException as e:
print("Failed to invoke agent endpoint: %s", str(e))
raise
# ==== Main app ====
if cookies.get("tokens"):
st.sidebar.code(cookies.get("tokens"))
if st.sidebar.button("Logout"):
logout()
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4())
if "agent_arn" not in st.session_state:
runtime_config = read_config(".bedrock_agentcore.yaml")
st.session_state["agent_arn"] = runtime_config["agents"][AGENT_NAME][
"bedrock_agentcore"
]["agent_arn"]
if "region" not in st.session_state:
st.session_state["region"] = get_aws_region()
st.sidebar.write("Agent Arn")
st.sidebar.code(st.session_state["agent_arn"])
st.sidebar.write("Session Id")
st.sidebar.code(st.session_state["session_id"])
token = json.loads(cookies.get("tokens"))
claims = jwt.decode(token["id_token"], options={"verify_signature": False})
st.title("Customer Support Assistant")
st.markdown(
"""
<hr style='border:1px solid #298dff;'>
""",
unsafe_allow_html=True,
)
# Initialize chat history
if "messages" not in st.session_state:
default_prompt = (
f"Hi my name is Maira Ladeira Tanke and my email is {claims.get('email')}"
)
st.session_state.messages = [
{
"role": "user",
"content": default_prompt,
}
]
with st.chat_message("user"):
create_safe_markdown_text(
f'<span class="user-bubble">🧑‍💻 {default_prompt}</span>', st
)
st.session_state["pending_assistant"] = True
start_time = int()
with st.chat_message("assistant"):
message_placeholder = st.empty()
start_time = time.time()
create_safe_markdown_text(
'<span class="thinking-bubble">🤖 💭 Customer Support Assistant is thinking...</span>',
message_placeholder,
)
# Stream the response with animations
chunk_count = 0
formatted_response = ""
accumulated_response = ""
for chunk in invoke_endpoint(
agent_arn=st.session_state["agent_arn"],
payload=json.dumps(
{
"prompt": default_prompt,
"actor_id": claims.get("cognito:username"),
}
),
bearer_token=token["access_token"],
session_id=st.session_state["session_id"],
):
chunk = str(chunk)
if chunk.strip(): # Only process non-empty chunks
accumulated_response += chunk
chunk_count += 1
if chunk_count % 3 == 0: # Add cursor every few chunks for effect
accumulated_response += ""
# Update display with streaming animation (make URLs clickable)
clickable_streaming_text = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble streaming typing-cursor">🤖 {clickable_streaming_text}</div>',
message_placeholder,
)
# Small delay to make streaming visible and smooth
time.sleep(0.02)
elapsed = time.time() - start_time
clickable_answer = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble">🤖 {clickable_answer}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {elapsed:.2f} seconds</span></div>',
message_placeholder,
)
# Add user message to chat history
st.session_state.messages.append(
{"role": "assistant", "content": accumulated_response, "elapsed": elapsed}
)
st.session_state["pending_assistant"] = False
st.rerun()
else:
# Display chat messages from history on app rerun
messages_to_show = st.session_state.messages[:]
# If waiting for assistant, don't show the last user message here (it will be shown in pending section)
if (
st.session_state.get("pending_assistant", False)
and messages_to_show
and messages_to_show[-1]["role"] == "user"
):
messages_to_show = messages_to_show[:-1]
for message in messages_to_show:
bubble_class = (
"user-bubble" if message["role"] == "user" else "assistant-bubble"
)
emoji = "🧑‍💻" if message["role"] == "user" else "🤖"
with st.chat_message(message["role"]):
if message["role"] == "assistant" and "elapsed" in message:
clickable_content = make_urls_clickable(message["content"])
create_safe_markdown_text(
f'<div class="{bubble_class}">{emoji} {clickable_content}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {message["elapsed"]:.2f} seconds</span></div>',
st,
)
else:
if message["role"] == "assistant":
clickable_content = make_urls_clickable(message["content"])
create_safe_markdown_text(
f'<div class="{bubble_class}">{emoji} {clickable_content}</div>',
st,
)
else:
create_safe_markdown_text(
f'<span class="{bubble_class}">{emoji} {message["content"]}</span>',
st,
)
if prompt := st.chat_input("Ask customer support assistant questions!"):
# Display user message in chat message container
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
create_safe_markdown_text(
f'<span class="user-bubble">🧑‍💻 {prompt}</span>', st
)
st.session_state["pending_assistant"] = True
start_time = int()
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
start_time = time.time()
message_placeholder.markdown(
'<span class="thinking-bubble">🤖 💭 Customer Support Assistant is thinking...</span>',
unsafe_allow_html=True,
)
# Stream the response with animations
chunk_count = 0
formatted_response = ""
accumulated_response = ""
for chunk in invoke_endpoint(
agent_arn=st.session_state["agent_arn"],
payload=json.dumps(
{"prompt": prompt, "actor_id": claims.get("cognito:username")}
),
bearer_token=token["access_token"],
session_id=st.session_state["session_id"],
):
chunk = str(chunk)
if chunk.strip(): # Only process non-empty chunks
if ".prod.agent-credential-provider.cognito.aws.dev" in chunk:
accumulated_response = f"Please use {chunk}"
else:
accumulated_response += chunk
chunk_count += 1
if chunk_count % 3 == 0: # Add cursor every few chunks for effect
accumulated_response += ""
# Update display with streaming animation (make URLs clickable)
clickable_streaming_text = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble streaming typing-cursor">🤖 {clickable_streaming_text}</div>',
message_placeholder,
)
if (
".prod.agent-credential-provider.cognito.aws.dev"
in accumulated_response
):
accumulated_response = str()
# Small delay to make streaming visible and smooth
time.sleep(0.02)
elapsed = time.time() - start_time
clickable_streaming_text = make_urls_clickable(accumulated_response)
# clickable_answer = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble">🤖 {clickable_streaming_text}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {elapsed:.2f} seconds</span></div>',
message_placeholder,
)
# Add user message to chat history
st.session_state.messages.append(
{"role": "assistant", "content": accumulated_response, "elapsed": elapsed}
)
st.session_state["pending_assistant"] = False
else:
code_verifier, code_challenge = generate_pkce_pair()
cookies.set("code_verifier", code_verifier)
cookies.set("code_challenge", code_challenge)
state = str(uuid.uuid4())
cookies.set("oauth_state", state)
# cookies.save()
# Show login link
login_params = {
"response_type": "code",
"client_id": CLIENT_ID,
"redirect_uri": REDIRECT_URI,
"scope": SCOPES,
"code_challenge_method": "S256",
"code_challenge": cookies.get("code_challenge"),
"state": cookies.get("oauth_state"),
}
login_url = f"https://{COGNITO_DOMAIN}/oauth2/authorize?{urlencode(login_params)}"
print(f"Login signed with state: {cookies.get('oauth_state')}")
# st.markdown(f"[Login with Cognito]({login_url})")
create_safe_markdown_text(
f'<meta http-equiv="refresh" content="0;url={login_url}">', st
)
if __name__ == "__main__":
main()
@@ -0,0 +1 @@
# Streamlit module for customer support assistant
@@ -0,0 +1,146 @@
import base64
import hashlib
import os
import uuid
import json
import jwt
from urllib.parse import urlencode
import requests
import streamlit as st
from streamlit_cookies_controller import CookieController
from scripts.utils import get_ssm_parameter
class AuthManager:
def __init__(self):
self.cognito_domain = get_ssm_parameter(
"/app/customersupport/agentcore/cognito_domain"
).replace("https://", "")
self.client_id = get_ssm_parameter(
"/app/customersupport/agentcore/web_client_id"
)
self.redirect_uri = "http://localhost:8501/"
self.scopes = "email openid profile"
self.cookies = CookieController()
def generate_pkce_pair(self):
code_verifier = (
base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8").rstrip("=")
)
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode("utf-8")
.rstrip("=")
)
return code_verifier, code_challenge
def logout(self):
self.cookies.remove("tokens")
# Clear session state
if "session_id" in st.session_state:
del st.session_state["session_id"]
if "messages" in st.session_state:
del st.session_state["messages"]
if "agent_arn" in st.session_state:
del st.session_state["agent_arn"]
if "pending_assistant" in st.session_state:
del st.session_state["pending_assistant"]
if "region" in st.session_state:
del st.session_state["region"]
logout_url = f"https://{self.cognito_domain}/logout?" + urlencode(
{"client_id": self.client_id, "logout_uri": self.redirect_uri}
)
st.markdown(
f'<meta http-equiv="refresh" content="0;url={logout_url}">',
unsafe_allow_html=True,
)
st.rerun()
def handle_oauth_callback(self):
query_params = st.query_params
if (
query_params.get("code")
and query_params.get("state")
and not self.cookies.get("tokens")
):
auth_code = query_params.get("code")
returned_state = query_params.get("state")
code_verifier = self.cookies.get("code_verifier")
state = self.cookies.get("oauth_state")
if not state:
st.stop()
if returned_state != state:
st.error("State mismatch - potential CSRF detected")
st.stop()
# Exchange authorization code for tokens
token_url = f"https://{self.cognito_domain}/oauth2/token"
data = {
"grant_type": "authorization_code",
"client_id": self.client_id,
"code": auth_code,
"redirect_uri": self.redirect_uri,
"code_verifier": code_verifier,
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(token_url, data=data, headers=headers)
if response.ok:
tokens = response.json()
self.cookies.set("tokens", json.dumps(tokens))
self.cookies.remove("code_verifier")
self.cookies.remove("code_challenge")
self.cookies.remove("oauth_state")
st.query_params.clear()
else:
st.error(
f"Failed to exchange token: {response.status_code} - {response.text}"
)
def get_login_url(self):
code_verifier, code_challenge = self.generate_pkce_pair()
self.cookies.set("code_verifier", code_verifier)
self.cookies.set("code_challenge", code_challenge)
state = str(uuid.uuid4())
self.cookies.set("oauth_state", state)
login_params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": self.scopes,
"code_challenge_method": "S256",
"code_challenge": self.cookies.get("code_challenge"),
"state": self.cookies.get("oauth_state"),
}
return (
f"https://{self.cognito_domain}/oauth2/authorize?{urlencode(login_params)}"
)
def is_authenticated(self):
return bool(self.cookies.get("tokens"))
def get_tokens(self):
tokens_data = self.cookies.get("tokens")
if not tokens_data:
return None
# Handle both string and dict cases
if isinstance(tokens_data, str):
return json.loads(tokens_data)
elif isinstance(tokens_data, dict):
return tokens_data
else:
return None
def get_user_claims(self):
tokens = self.get_tokens()
if tokens:
return jwt.decode(tokens["id_token"], options={"verify_signature": False})
return None
@@ -0,0 +1,266 @@
import json
import time
import uuid
import urllib.parse
from typing import Any, Optional
import requests
import streamlit as st
from scripts.utils import read_config, get_aws_region
from .utils import make_urls_clickable, create_safe_markdown_text
class ChatManager:
def __init__(self, agent_name: str = "default"):
self.auth_url_matching = ".amazonaws.com/identities/oauth2/authorize"
self.agent_name = agent_name
self._init_session_state()
def _init_session_state(self):
"""Initialize session state variables"""
if "session_id" not in st.session_state:
st.session_state["session_id"] = str(uuid.uuid4())
if "agent_arn" not in st.session_state:
runtime_config = read_config(".bedrock_agentcore.yaml")
st.session_state["agent_arn"] = runtime_config["agents"][self.agent_name][
"bedrock_agentcore"
]["agent_arn"]
if "region" not in st.session_state:
st.session_state["region"] = get_aws_region()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "pending_assistant" not in st.session_state:
st.session_state["pending_assistant"] = False
def invoke_endpoint(
self,
agent_arn: str,
payload,
session_id: str,
bearer_token: Optional[str],
endpoint_name: str = "DEFAULT",
) -> Any:
"""Invoke agent endpoint using HTTP request with bearer token."""
escaped_arn = urllib.parse.quote(agent_arn, safe="")
url = f"https://bedrock-agentcore.{st.session_state['region']}.amazonaws.com/runtimes/{escaped_arn}/invocations"
headers = {
"Authorization": f"Bearer {bearer_token}",
"Content-Type": "application/json",
"X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": session_id,
}
try:
body = json.loads(payload) if isinstance(payload, str) else payload
except json.JSONDecodeError:
body = {"payload": payload}
try:
response = requests.post(
url,
params={"qualifier": endpoint_name},
headers=headers,
json=body,
timeout=100,
stream=True,
)
last_data = False
for line in response.iter_lines(chunk_size=1):
if line:
line = line.decode("utf-8")
if line.startswith("data: "):
last_data = True
line = line[6:]
line = line.replace('"', "")
yield line
elif line:
line = line.replace('"', "")
if last_data:
yield "\n" + line
last_data = False
except requests.exceptions.RequestException as e:
print("Failed to invoke agent endpoint: %s", str(e))
raise
def display_chat_history(self):
"""Display chat messages from history"""
messages_to_show = st.session_state.messages[:]
if (
st.session_state.get("pending_assistant", False)
and messages_to_show
and messages_to_show[-1]["role"] == "user"
):
messages_to_show = messages_to_show[:-1]
for message in messages_to_show:
bubble_class = (
"user-bubble" if message["role"] == "user" else "assistant-bubble"
)
emoji = "🧑‍💻" if message["role"] == "user" else "🤖"
with st.chat_message(message["role"]):
if message["role"] == "assistant" and "elapsed" in message:
clickable_content = make_urls_clickable(message["content"])
create_safe_markdown_text(
f'<div class="{bubble_class}">{emoji} {clickable_content}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {message["elapsed"]:.2f} seconds</span></div>',
st,
)
else:
if message["role"] == "assistant":
clickable_content = make_urls_clickable(message["content"])
create_safe_markdown_text(
f'<div class="{bubble_class}">{emoji} {clickable_content}</div>',
st,
)
else:
create_safe_markdown_text(
f'<span class="{bubble_class}">{emoji} {message["content"]}</span>',
st,
)
def process_user_message(self, prompt: str, user_claims: dict, bearer_token: str):
"""Process a user message and get assistant response"""
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
create_safe_markdown_text(
f'<span class="user-bubble">🧑‍💻 {prompt}</span>', st
)
st.session_state["pending_assistant"] = True
with st.chat_message("assistant"):
message_placeholder = st.empty()
start_time = time.time()
create_safe_markdown_text(
'<span class="thinking-bubble">🤖 💭 Customer Support Assistant is thinking...</span>',
message_placeholder,
)
chunk_count = 0
accumulated_response = ""
for chunk in self.invoke_endpoint(
agent_arn=st.session_state["agent_arn"],
payload=json.dumps(
{"prompt": prompt, "actor_id": user_claims.get("cognito:username")}
),
bearer_token=bearer_token,
session_id=st.session_state["session_id"],
):
chunk = str(chunk)
if chunk.strip():
if self.auth_url_matching in chunk:
accumulated_response = f"Please use {chunk}"
else:
accumulated_response += chunk
chunk_count += 1
if chunk_count % 3 == 0:
accumulated_response += ""
clickable_streaming_text = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble streaming typing-cursor">🤖 {clickable_streaming_text}</div>',
message_placeholder,
)
if self.auth_url_matching in accumulated_response:
accumulated_response = str()
time.sleep(0.02)
elapsed = time.time() - start_time
clickable_streaming_text = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble">🤖 {clickable_streaming_text}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {elapsed:.2f} seconds</span></div>',
message_placeholder,
)
st.session_state.messages.append(
{
"role": "assistant",
"content": accumulated_response,
"elapsed": elapsed,
}
)
st.session_state["pending_assistant"] = False
def initialize_default_conversation(self, user_claims: dict, bearer_token: str):
"""Initialize the conversation with a default message"""
if not st.session_state.messages:
default_prompt = f"Hi my email is {user_claims.get('email')}"
st.session_state.messages = [{"role": "user", "content": default_prompt}]
with st.chat_message("user"):
create_safe_markdown_text(
f'<span class="user-bubble">🧑‍💻 {default_prompt}</span>', st
)
st.session_state["pending_assistant"] = True
with st.chat_message("assistant"):
message_placeholder = st.empty()
start_time = time.time()
create_safe_markdown_text(
'<span class="thinking-bubble">🤖 💭 Customer Support Assistant is thinking...</span>',
message_placeholder,
)
chunk_count = 0
accumulated_response = ""
for chunk in self.invoke_endpoint(
agent_arn=st.session_state["agent_arn"],
payload=json.dumps(
{
"prompt": default_prompt,
"actor_id": user_claims.get("cognito:username"),
}
),
bearer_token=bearer_token,
session_id=st.session_state["session_id"],
):
chunk = str(chunk)
if chunk.strip():
accumulated_response += chunk
chunk_count += 1
if chunk_count % 3 == 0:
accumulated_response += ""
clickable_streaming_text = make_urls_clickable(
accumulated_response
)
create_safe_markdown_text(
f'<div class="assistant-bubble streaming typing-cursor">🤖 {clickable_streaming_text}</div>',
message_placeholder,
)
time.sleep(0.02)
elapsed = time.time() - start_time
clickable_answer = make_urls_clickable(accumulated_response)
create_safe_markdown_text(
f'<div class="assistant-bubble">🤖 {clickable_answer}<br><span style="font-size:0.9em;color:#888;">⏱️ Response time: {elapsed:.2f} seconds</span></div>',
message_placeholder,
)
st.session_state.messages.append(
{
"role": "assistant",
"content": accumulated_response,
"elapsed": elapsed,
}
)
st.session_state["pending_assistant"] = False
st.rerun()
@@ -0,0 +1,93 @@
import sys
import streamlit as st
from .auth import AuthManager
from .chat import ChatManager
from .styles import apply_custom_styles
def main():
"""Main application entry point"""
# Parse command line arguments
agent_name = "default"
if len(sys.argv) > 1:
for arg in sys.argv:
if arg.startswith("--agent="):
agent_name = arg.split("=")[1]
# Configure page
st.set_page_config(layout="wide")
# Apply custom styles
apply_custom_styles()
# Initialize managers
auth_manager = AuthManager()
chat_manager = ChatManager(agent_name)
# Handle OAuth callback
auth_manager.handle_oauth_callback()
# Check authentication status
if auth_manager.is_authenticated():
# Authenticated user interface
render_authenticated_interface(auth_manager, chat_manager)
else:
# Login interface
render_login_interface(auth_manager)
def render_authenticated_interface(
auth_manager: AuthManager, chat_manager: ChatManager
):
"""Render the interface for authenticated users"""
# Sidebar
st.sidebar.title("Access Tokens")
st.sidebar.code(auth_manager.cookies.get("tokens"))
if st.sidebar.button("Logout"):
auth_manager.logout()
st.sidebar.write("Agent Arn")
st.sidebar.code(st.session_state["agent_arn"])
st.sidebar.write("Session Id")
st.sidebar.code(st.session_state["session_id"])
# Main content
st.title("Customer Support Assistant")
st.markdown(
"""
<hr style='border:1px solid #298dff;'>
""",
unsafe_allow_html=True,
)
# Get user info and tokens
tokens = auth_manager.get_tokens()
user_claims = auth_manager.get_user_claims()
# Initialize conversation if needed
if not st.session_state.get("messages"):
chat_manager.initialize_default_conversation(
user_claims, tokens["access_token"]
)
else:
# Display chat history
chat_manager.display_chat_history()
# Chat input
if prompt := st.chat_input("Ask customer support assistant questions!"):
chat_manager.process_user_message(prompt, user_claims, tokens["access_token"])
def render_login_interface(auth_manager: AuthManager):
"""Render the login interface"""
login_url = auth_manager.get_login_url()
st.markdown(
f'<meta http-equiv="refresh" content="0;url={login_url}">',
unsafe_allow_html=True,
)
if __name__ == "__main__":
main()
@@ -0,0 +1,122 @@
import streamlit as st
def apply_custom_styles():
"""Apply custom CSS styles to the Streamlit app"""
st.markdown(
"""
<style>
body {
background: #181c24 !important;
}
.stApp {
background: #181c24 !important;
}
.css-1v0mbdj, .css-1dp5vir {
border-radius: 14px !important;
padding: 0.5rem 1rem !important;
}
.user-bubble {
background: #23272f;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: inline-block;
border: 1px solid #3a3f4b;
}
.assistant-bubble {
background: #0b2545;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: block;
border: 1px solid #298dff;
animation: fadeInUp 0.3s ease-out;
white-space: pre-wrap;
word-wrap: break-word;
max-width: 100%;
}
.assistant-bubble.streaming {
border: 1px solid #4fc3f7;
box-shadow: 0 0 10px rgba(79, 195, 247, 0.3);
animation: pulse-border 2s infinite, fadeInUp 0.3s ease-out;
}
.thinking-bubble {
background: #0b2545;
color: #e6e6e6;
border-radius: 16px;
padding: 0.7rem 1.2rem;
margin-bottom: 0.5rem;
display: inline-block;
border: 1px solid #298dff;
animation: thinking-pulse 1.5s infinite, fadeInUp 0.3s ease-out;
}
.typing-cursor::after {
content: '';
color: #4fc3f7;
animation: cursor-blink 1s infinite;
margin-left: 2px;
}
@keyframes fadeInUp {
from {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes pulse-border {
0%, 100% {
border-color: #298dff;
box-shadow: 0 0 5px rgba(41, 141, 255, 0.3);
}
50% {
border-color: #4fc3f7;
box-shadow: 0 0 15px rgba(79, 195, 247, 0.6);
}
}
@keyframes thinking-pulse {
0%, 100% {
opacity: 1;
transform: scale(1);
}
50% {
opacity: 0.8;
transform: scale(1.02);
}
}
@keyframes cursor-blink {
0%, 50% {
opacity: 1;
}
51%, 100% {
opacity: 0;
}
}
@keyframes slideInLeft {
from {
opacity: 0;
transform: translateX(-30px);
}
to {
opacity: 1;
transform: translateX(0);
}
}
.sidebar .sidebar-content {
background: #181c24 !important;
}
h1, h2, h3, h4, h5, h6, p, label, .css-10trblm, .css-1cpxqw2 {
color: #e6e6e6 !important;
}
hr {
border: 1px solid #298dff !important;
}
</style>
""",
unsafe_allow_html=True,
)
@@ -0,0 +1,18 @@
import re
def make_urls_clickable(text):
"""Convert URLs in text to clickable HTML links."""
url_pattern = r"https?://(?:[-\w.])+(?:\:[0-9]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:\#(?:[\w.])*)?)?"
def replace_url(match):
url = match.group(0)
return f'<a href="{url}" target="_blank" style="color:#4fc3f7;text-decoration:underline;">{url}</a>'
return re.sub(url_pattern, replace_url, text)
def create_safe_markdown_text(text, message_placeholder):
"""Create safe markdown text with proper encoding"""
safe_text = text.encode("utf-16", "surrogatepass").decode("utf-16")
message_placeholder.markdown(safe_text, unsafe_allow_html=True)
@@ -1,13 +1,48 @@
opensearch-py
requests-aws4auth
pyyaml
retrying
pandas
streamlit
streamlit-cookies-controller
boto3
click
bedrock-agentcore
bedrock-agentcore-starter-toolkit
botocore
boto3
# Development requirements for Customer Support Assistant
# Generated from pyproject.toml dependencies
# Core AgentCore dependencies
bedrock-agentcore>=0.1.0
bedrock-agentcore-starter-toolkit>=0.1.0
# AWS SDK
boto3>=1.39.7
botocore>=1.39.7
# CLI and utilities
click>=8.2.1
# Google APIs
google-api-python-client>=2.176.0
google-auth>=2.40.3
# Search and database
opensearch-py>=3.0.0
# Data processing
pandas>=2.3.1
# Configuration and serialization
pyyaml>=6.0.2
# AWS authentication
requests-aws4auth>=1.3.1
# Utilities
retrying>=1.4.0
# AI agents and tools
strands-agents>=1.0.0
strands-agents-tools>=0.2.0
# Web interface
streamlit>=1.47.0
streamlit-cookies-controller>=0.0.4
# Additional development dependencies
pytest>=7.0.0
pytest-cov>=4.0.0
black>=23.0.0
flake8>=6.0.0
mypy>=1.0.0
pre-commit>=3.0.0
Binary file not shown.

Before

Width:  |  Height:  |  Size: 159 KiB

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

+16 -254
View File
@@ -1,22 +1,17 @@
import os
import uuid
from agent_config.context import (
get_response_queue_ctx,
set_gateway_token_ctx,
set_response_queue_ctx,
)
from agent_config.access_token import get_gateway_access_token
from agent_config.agent_task import agent_task
from agent_config.streaming_queue import StreamingQueue
from bedrock_agentcore.runtime import BedrockAgentCoreApp
from scripts.utils import get_ssm_parameter
import asyncio
import logging
from bedrock_agentcore.identity.auth import requires_access_token
from agent import CustomerSupport # Your custom agent class
from datetime import datetime, timedelta
import json
from strands import tool
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from scripts.utils import get_ssm_parameter
from tools.agent_core_memory import AgentCoreMemoryToolProvider
from memory_hook_provider import MemoryHook
from bedrock_agentcore.memory import MemoryClient
from bedrock_agentcore.runtime import BedrockAgentCoreApp
import os
import uuid
# Environment flags
os.environ["STRANDS_OTEL_ENABLE_CONSOLE_EXPORT"] = "true"
@@ -33,256 +28,23 @@ logger = logging.getLogger(__name__)
# Bedrock app and global agent instance
app = BedrockAgentCoreApp()
agent = None # Will be initialized with access token
gateway_access_token = None
google_access_token = None
memory_client = MemoryClient()
# Queue for streaming responses
class StreamingQueue:
def __init__(self):
self.finished = False
self.queue = asyncio.Queue()
async def put(self, item):
await self.queue.put(item)
async def finish(self):
self.finished = True
await self.queue.put(None)
async def stream(self):
while True:
item = await self.queue.get()
if item is None and self.finished:
break
yield item
response_queue = StreamingQueue()
@tool(
name="Create_calendar_event",
description="Creates a new event on your Google Calendar",
)
def create_calendar_event() -> str:
global google_access_token
print("create_calendar_event invoked")
print(f"google_access_token: {google_access_token}")
if not google_access_token:
return "Google Calendar authentication is required."
creds = Credentials(token=google_access_token, scopes=SCOPES)
try:
service = build("calendar", "v3", credentials=creds)
# Define event details
start_time = datetime.now() + timedelta(hours=1)
end_time = start_time + timedelta(hours=1)
event = {
"summary": "Customer Support Call - Maira Ladeira Tanke",
"location": "Virtual",
"description": "This event was created by Customer Support Assistant.",
"start": {
"dateTime": start_time.isoformat() + "Z", # UTC time
"timeZone": "UTC",
},
"end": {
"dateTime": end_time.isoformat() + "Z",
"timeZone": "UTC",
},
}
created_event = (
service.events().insert(calendarId="primary", body=event).execute()
)
return json.dumps(
{
"event_created": True,
"event_id": created_event.get("id"),
"htmlLink": created_event.get("htmlLink"),
}
)
except HttpError as error:
return json.dumps({"error": str(error), "event_created": False})
except Exception as e:
return json.dumps({"error": str(e), "event_created": False})
@tool(
name="Get_calendar_events_today",
description="Retrieves the calendar events for the day from your Google Calendar",
)
def get_calendar_events_today() -> str:
global google_access_token
print("get_calendar_events_today invoked")
print(f"google_access_token: {google_access_token}")
# Check if we already have a token
if not google_access_token:
return "Google Calendar authentication is required."
# Create credentials from the provided access token
creds = Credentials(token=google_access_token, scopes=SCOPES)
try:
service = build("calendar", "v3", credentials=creds)
# Call the Calendar API
today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
today_end = today_start.replace(hour=23, minute=59, second=59)
# Format with CDT timezone (-05:00)
timeMin = today_start.strftime("%Y-%m-%dT00:00:00-05:00")
timeMax = today_end.strftime("%Y-%m-%dT23:59:59-05:00")
events_result = (
service.events()
.list(
calendarId="primary",
timeMin=timeMin,
timeMax=timeMax,
singleEvents=True,
orderBy="startTime",
)
.execute()
)
events = events_result.get("items", [])
print(events)
if not events:
return json.dumps({"events": []}) # Return empty events array as JSON
return json.dumps({"events": events}) # Return events wrapped in an object
except HttpError as error:
error_message = str(error)
return json.dumps({"error": error_message, "events": []})
except Exception as e:
error_message = str(e)
return json.dumps({"error": error_message, "events": []})
@requires_access_token(
provider_name=get_ssm_parameter("/app/customersupport/agentcore/cognito_provider"),
scopes=[], # Optional unless required
auth_flow="M2M",
)
async def _get_access_token_manually(*, access_token: str):
global gateway_access_token
gateway_access_token = access_token
return access_token # Update the global access token
async def on_auth_url(url: str):
print(f"Authorization url: {url}")
await response_queue.put(f"Authorization url: {url}")
SCOPES = ["https://www.googleapis.com/auth/calendar"]
# This annotation helps agent developer to obtain access tokens from external applications
@requires_access_token(
provider_name=get_ssm_parameter("/app/customersupport/agentcore/google_provider"),
scopes=SCOPES, # Google OAuth2 scopes
auth_flow="USER_FEDERATION", # On-behalf-of user (3LO) flow
on_auth_url=on_auth_url, # prints authorization URL to console
force_authentication=True,
)
async def need_token_3LO_async(*, access_token: str):
global google_access_token
google_access_token = access_token
print(f"google_access_token set: {google_access_token}")
return access_token
async def agent_task(
user_message: str, session_id: str, actor_id: str, access_token: str
):
global agent
global google_access_token
if not access_token:
raise RuntimeError("access_token is none")
try:
if agent is None:
provider = AgentCoreMemoryToolProvider(
memory_id=get_ssm_parameter("/app/customersupport/agentcore/memory_id"),
actor_id=actor_id,
session_id=session_id,
namespace=f"summaries/{actor_id}/{session_id}",
)
memory_hook = MemoryHook(
memory_client=memory_client,
memory_id=get_ssm_parameter("/app/customersupport/agentcore/memory_id"),
actor_id=actor_id,
session_id=session_id,
)
agent = CustomerSupport(
bearer_token=access_token,
memory_hook=memory_hook,
tools=[get_calendar_events_today, create_calendar_event]
+ provider.tools,
)
auth_keywords = ["authentication"]
needs_auth = False
async for chunk in agent.stream(user_query=user_message, session_id=session_id):
needs_auth = any(
keyword.lower() in chunk.lower() for keyword in auth_keywords
)
if needs_auth:
break
else:
await response_queue.put(chunk)
if needs_auth:
# Trigger the 3LO authentication flow
try:
google_access_token = await need_token_3LO_async(access_token="")
# Retry the agent call now that we have authentication
async for chunk in agent.stream(
user_query=user_message, session_id=session_id
):
await response_queue.put(chunk)
except Exception as auth_error:
# print("Exception occurred:")
# traceback.print_exc()
print("auth_error:", auth_error)
except Exception as e:
logger.exception("Agent execution failed.")
await response_queue.put(f"Error: {str(e)}")
finally:
await response_queue.finish()
set_response_queue_ctx(StreamingQueue())
@app.entrypoint
async def invoke(payload, context):
response_queue = get_response_queue_ctx()
set_gateway_token_ctx(await get_gateway_access_token())
user_message = payload["prompt"]
actor_id = payload["actor_id"]
session_id = context.session_id or str(uuid.uuid4())
access_token = await _get_access_token_manually()
task = asyncio.create_task(
agent_task(
user_message=user_message,
session_id=session_id,
access_token=access_token,
actor_id=actor_id,
)
)
@@ -1,62 +0,0 @@
from strands import Agent, tool
from strands.hooks.events import AgentInitializedEvent, MessageAddedEvent
from strands.hooks.registry import HookProvider, HookRegistry
from bedrock_agentcore.memory import MemoryClient
class MemoryHook(HookProvider):
def __init__(
self,
memory_client: MemoryClient,
memory_id: str,
actor_id: str,
session_id: str,
):
self.memory_client = memory_client
self.memory_id = memory_id
self.actor_id = actor_id
self.session_id = session_id
def on_agent_initialized(self, event: AgentInitializedEvent):
"""Load recent conversation history when agent starts"""
try:
# Load the last 5 conversation turns from memory
recent_turns = self.memory_client.get_last_k_turns(
memory_id=self.memory_id,
actor_id=self.actor_id,
session_id=self.session_id,
k=5,
)
if recent_turns:
# Format conversation history for context
context_messages = []
for turn in recent_turns:
for message in turn:
role = message["role"]
content = message["content"]["text"]
context_messages.append(f"{role}: {content}")
context = "\n".join(context_messages)
# Add context to agent's system prompt.
event.agent.system_prompt += f"\n\nRecent conversation:\n{context}"
except Exception as e:
print(f"Memory load error: {e}")
def on_message_added(self, event: MessageAddedEvent):
"""Store messages in memory"""
messages = event.agent.messages
try:
self.memory_client.save_conversation(
memory_id=self.memory_id,
actor_id=self.actor_id,
session_id=self.session_id,
messages=[(messages[-1]["content"][0]["text"], messages[-1]["role"])],
)
except Exception as e:
print(f"Memory save error: {e}")
def register_hooks(self, registry: HookRegistry):
registry.add_callback(MessageAddedEvent, self.on_message_added)
registry.add_callback(AgentInitializedEvent, self.on_agent_initialized)
@@ -36,8 +36,7 @@ Your new project will appear in the project list.
3. Choose Web application as the application type.
4. Enter a name for the credentials.
5. Under Authorized redirect URIs, add your following redirect URI:
- `https://us-west-2.prod.agent-credential-provider.cognito.aws.dev/identities/oauth2/callback`
- `http://localhost:64161/`
- `https://bedrock-agentcore.us-east-1.amazonaws.com/identities/oauth2/callback`
6. Click Create.
## 🔑 5. Obtain Client ID and Client Secret
@@ -98,22 +98,42 @@ Resources:
- bedrock:InvokeModel
- bedrock:InvokeModelWithResponseStream
Resource:
- "arn:aws:bedrock:*::foundation-model/*"
- !Sub arn:aws:bedrock:${AWS::Region}:${AWS::AccountId}:*
- Sid: AgentcoreAllowRuntime
Effect: Allow
Action:
- bedrock-agentcore:*
- iam:PassRole
Resource:
- "*"
- Sid: SSMGetparam
Effect: Allow
Action:
- ssm:GetParameter
Resource:
- !Sub arn:aws:ssm:${AWS::Region}:${AWS::AccountId}:parameter/app/customersupport/*
ManagedPolicyArns:
- arn:aws:iam::aws:policy/AdministratorAccess
- Sid: Identity
Effect: Allow
Action:
- bedrock-agentcore:GetResourceOauth2Token
Resource:
- !Sub arn:aws:bedrock-agentcore:${AWS::Region}:${AWS::AccountId}:token-vault/default/oauth2credentialprovider/customersupport*
- !Sub arn:aws:bedrock-agentcore:${AWS::Region}:${AWS::AccountId}:workload-identity-directory/default/workload-identity/customersupport*
- !Sub arn:aws:bedrock-agentcore:${AWS::Region}:${AWS::AccountId}:workload-identity-directory/default
- !Sub arn:aws:bedrock-agentcore:${AWS::Region}:${AWS::AccountId}:token-vault/default
- Sid: SecretManager
Effect: Allow
Action:
- secretsmanager:GetSecretValue
Resource:
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:bedrock-agentcore-identity!default/oauth2/customersupport*
- Sid: AgentCoreMemory
Effect: Allow
Action:
- bedrock-agentcore:ListMemories
- bedrock-agentcore:ListMemoryRecords
- bedrock-agentcore:RetrieveMemoryRecords
- bedrock-agentcore:GetMemory
- bedrock-agentcore:GetMemoryRecord
- bedrock-agentcore:CreateEvent
- bedrock-agentcore:GetEvent
Resource:
- !Sub arn:aws:bedrock-agentcore:${AWS::Region}:${AWS::AccountId}:memory/customersupport*
GatewayAgentCoreRole:
Type: AWS::IAM::Role
@@ -132,18 +152,12 @@ Resources:
PolicyDocument:
Version: '2012-10-17'
Statement:
- Sid: AgentcoreAllow
Effect: Allow
Action:
- bedrock-agentcore:*
Resource:
- "*"
- Sid: InvokeFunction
Effect: Allow
Action:
- lambda:InvokeFunction
Resource:
- "*"
- !GetAtt CustomerSupportLambda.Arn
# DynamoDB Table for Warranty Information
WarrantyTable:
@@ -602,7 +616,7 @@ Resources:
WarrantyTableNameParameter:
Type: AWS::SSM::Parameter
Properties:
Name: /app/customersupport/dynamodb/warranty-table-name
Name: /app/customersupport/dynamodb/warranty_table_name
Type: String
Value: !Ref WarrantyTable
Description: DynamoDB table name for warranty information
@@ -613,7 +627,7 @@ Resources:
CustomerProfileTableNameParameter:
Type: AWS::SSM::Parameter
Properties:
Name: /app/customersupport/dynamodb/customer-profile-table-name
Name: /app/customersupport/dynamodb/customer_profile_table_name
Type: String
Value: !Ref CustomerProfileTable
Description: DynamoDB table name for customer profiles
@@ -5,7 +5,7 @@
This module contains a helper class for building and using Knowledge Bases for Amazon Bedrock.
The KnowledgeBasesForAmazonBedrock class provides a convenient interface for working with Knowledge Bases.
It includes methods for creating, updating, and invoking Knowledge Bases, as well as managing
IAM roles and OpenSearch Serverless.
IAM roles and S3 Vectors.
"""
import json
@@ -13,15 +13,8 @@ import boto3
import time
import uuid
from botocore.exceptions import ClientError
from opensearchpy import (
OpenSearch,
RequestsHttpConnection,
AWSV4SignerAuth,
RequestError,
)
import pprint
from retrying import retry
import random
import yaml
import os
import argparse
@@ -66,7 +59,7 @@ class KnowledgeBasesForAmazonBedrock:
"""
Support class that allows for:
- creation (or retrieval) of a Knowledge Base for Amazon Bedrock with all its pre-requisites
(including OSS, IAM roles and Permissions and S3 bucket)
(including S3 Vectors, IAM roles and Permissions and S3 bucket)
- Ingestion of data into the Knowledge Base
- Deletion of all resources created
"""
@@ -90,19 +83,15 @@ class KnowledgeBasesForAmazonBedrock:
self.identity = boto3.client(
"sts", region_name=self.region_name
).get_caller_identity()["Arn"]
self.aoss_client = boto3_session.client(
"opensearchserverless", region_name=self.region_name
self.s3_vectors_client = boto3_session.client(
"s3vectors", region_name=self.region_name
)
self.s3_client = boto3.client("s3", region_name=self.region_name)
self.bedrock_agent_client = boto3.client(
"bedrock-agent", region_name=self.region_name
)
self.bedrock_agent_client = boto3.client(
"bedrock-agent", region_name=self.region_name
)
credentials = boto3.Session().get_credentials()
self.awsauth = AWSV4SignerAuth(credentials, self.region_name, "aoss")
self.oss_client = None
self.vector_bucket_name = None
self.index_name = None
self.data_bucket_name = None
def create_or_retrieve_knowledge_base(
@@ -163,10 +152,6 @@ class KnowledgeBasesForAmazonBedrock:
raise ValueError(
f"Invalid embedding model. Your embedding model should be one of {valid_embeddings_str}"
)
# self.embedding_model = embedding_model
encryption_policy_name = f"{kb_name}-sp-{self.suffix}"
network_policy_name = f"{kb_name}-np-{self.suffix}"
access_policy_name = f"{kb_name}-ap-{self.suffix}"
kb_execution_role_name = (
f"AmazonBedrockExecutionRoleForKnowledgeBase_{self.suffix}"
)
@@ -174,8 +159,10 @@ class KnowledgeBasesForAmazonBedrock:
f"AmazonBedrockFoundationModelPolicyForKnowledgeBase_{self.suffix}"
)
s3_policy_name = f"AmazonBedrockS3PolicyForKnowledgeBase_{self.suffix}"
oss_policy_name = f"AmazonBedrockOSSPolicyForKnowledgeBase_{self.suffix}"
vector_store_name = f"{kb_name}-{self.suffix}"
s3_vectors_policy_name = (
f"AmazonBedrockS3VectorsPolicyForKnowledgeBase_{self.suffix}"
)
vector_bucket_name = f"{kb_name}-vectors-{self.suffix}"
index_name = f"{kb_name}-index-{self.suffix}"
print(
"========================================================================================"
@@ -197,49 +184,28 @@ class KnowledgeBasesForAmazonBedrock:
s3_policy_name,
kb_execution_role_name,
)
print(time.sleep(10))
print(
"========================================================================================"
)
print(f"Step 3 - Creating OSS encryption, network and data access policies")
encryption_policy, network_policy, access_policy = (
self.create_policies_in_oss(
encryption_policy_name,
vector_store_name,
network_policy_name,
bedrock_kb_execution_role,
access_policy_name,
)
print("Step 3 - Creating S3 Vectors Bucket and Index")
vector_bucket_arn, index_arn = self.create_s3_vectors_bucket_and_index(
vector_bucket_name, index_name, bedrock_kb_execution_role
)
print(
"========================================================================================"
)
print(
f"Step 4 - Creating OSS Collection (this step takes a couple of minutes to complete)"
print("Step 4 - Creating S3 Vectors Policy")
self.create_s3_vectors_policy(
s3_vectors_policy_name, vector_bucket_arn, bedrock_kb_execution_role
)
host, collection, collection_id, collection_arn = self.create_oss(
vector_store_name, oss_policy_name, bedrock_kb_execution_role
)
# Build the OpenSearch client
self.oss_client = OpenSearch(
hosts=[{"host": host, "port": 443}],
http_auth=self.awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
timeout=300,
)
print(
"========================================================================================"
)
print(f"Step 5 - Creating OSS Vector Index")
self.create_vector_index(index_name)
print(
"========================================================================================"
)
print(f"Step 6 - Creating Knowledge Base")
print("Step 5 - Creating Knowledge Base")
knowledge_base, data_source = self.create_knowledge_base(
collection_arn,
vector_bucket_arn,
index_arn,
index_name,
data_bucket_name,
embedding_model,
@@ -265,7 +231,7 @@ class KnowledgeBasesForAmazonBedrock:
try:
self.s3_client.head_bucket(Bucket=bucket_name)
print(f"Bucket {bucket_name} already exists - retrieving it!")
except ClientError as e:
except ClientError:
print(f"Creating bucket {bucket_name}")
if self.region_name == "us-east-1":
self.s3_client.create_bucket(Bucket=bucket_name)
@@ -380,6 +346,12 @@ class KnowledgeBasesForAmazonBedrock:
"Effect": "Allow",
"Principal": {"Service": "bedrock.amazonaws.com"},
"Action": "sts:AssumeRole",
"Condition": {
"StringEquals": {"aws:SourceAccount": f"{self.account_number}"},
"ArnLike": {
"aws:SourceArn": f"arn:aws:bedrock:{self.region_name}:{self.account_number}:knowledge-base/*"
},
},
}
],
}
@@ -436,276 +408,140 @@ class KnowledgeBasesForAmazonBedrock:
)
return bedrock_kb_execution_role
def create_oss_policy_attach_bedrock_execution_role(
self, collection_id: str, oss_policy_name: str, bedrock_kb_execution_role: str
def create_s3_vectors_bucket_and_index(
self,
vector_bucket_name: str,
index_name: str,
bedrock_kb_execution_role: str,
):
"""
Create OpenSearch Serverless policy and attach it to the Knowledge Base Execution role.
If policy already exists, attaches it
Create S3 Vectors bucket and index.
Args:
collection_id: collection id
oss_policy_name: opensearch serverless policy name
vector_bucket_name: name of the S3 vectors bucket
index_name: name of the vector index
bedrock_kb_execution_role: knowledge base execution role
Returns:
created: bool - boolean to indicate if role was created
vector_bucket_arn, index_arn
"""
# define oss policy document
oss_policy_document = {
self.vector_bucket_name = vector_bucket_name
self.index_name = index_name
# Create S3 Vectors bucket
try:
self.s3_vectors_client.create_vector_bucket(
vectorBucketName=vector_bucket_name,
encryptionConfiguration={"sseType": "AES256"},
)
get_response = self.s3_vectors_client.get_vector_bucket(
vectorBucketName=vector_bucket_name
)
vector_bucket_arn = get_response["vectorBucket"]["vectorBucketArn"]
print(f"Created S3 Vectors bucket: {vector_bucket_name}")
except self.s3_vectors_client.exceptions.ConflictException:
print(f"S3 Vectors bucket {vector_bucket_name} already exists")
# Get the bucket ARN
vector_bucket_arn = f"arn:aws:s3vectors:{self.region_name}:{self.account_number}:vector-bucket/{vector_bucket_name}"
except Exception as e:
print(f"Error creating S3 vectors bucket: {e}")
raise
# Create vector index
try:
self.s3_vectors_client.create_index(
vectorBucketName=vector_bucket_name,
indexName=index_name,
dataType="float32",
dimension=1024, # Matching the OpenSearch configuration
distanceMetric="cosine",
metadataConfiguration={
"nonFilterableMetadataKeys": [
"AMAZON_BEDROCK_TEXT",
]
},
)
get_index_response = self.s3_vectors_client.get_index(
vectorBucketName=vector_bucket_name,
indexName=index_name,
)
time.sleep(10)
index_arn = get_index_response["index"]["indexArn"]
print(f"Created S3 Vectors index: {index_name}")
except self.s3_vectors_client.exceptions.ConflictException:
print(f"S3 Vectors index {index_name} already exists")
# Get the index ARN
index_arn = f"arn:aws:s3vectors:{self.region_name}:{self.account_number}:index/{vector_bucket_name}/{index_name}"
except Exception as e:
print(f"Error creating S3 vectors index: {e}")
raise
return vector_bucket_arn, index_arn
def create_s3_vectors_policy(
self,
s3_vectors_policy_name: str,
vector_bucket_arn: str,
bedrock_kb_execution_role: str,
):
"""
Create S3 Vectors policy and attach it to the Knowledge Base Execution role.
Args:
s3_vectors_policy_name: name of the S3 vectors policy
vector_bucket_arn: ARN of the S3 vectors bucket
bedrock_kb_execution_role: knowledge base execution role
"""
# Define S3 Vectors policy document
s3_vectors_policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "S3VectorsPermissions",
"Effect": "Allow",
"Action": ["aoss:APIAccessAll"],
"Resource": [
f"arn:aws:aoss:{self.region_name}:{self.account_number}:collection/{collection_id}"
"Action": [
"s3vectors:GetIndex",
"s3vectors:QueryVectors",
"s3vectors:PutVectors",
"s3vectors:GetVectors",
"s3vectors:DeleteVectors",
],
"Resource": f"{vector_bucket_arn}/index/*",
"Condition": {
"StringEquals": {
"aws:ResourceAccount": f"{self.account_number}"
}
},
}
],
}
oss_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{oss_policy_name}"
created = False
try:
self.iam_client.create_policy(
PolicyName=oss_policy_name,
PolicyDocument=json.dumps(oss_policy_document),
Description="Policy for accessing opensearch serverless",
s3_vectors_policy = self.iam_client.create_policy(
PolicyName=s3_vectors_policy_name,
PolicyDocument=json.dumps(s3_vectors_policy_document),
Description="Policy for accessing S3 vectors",
)
created = True
print(f"Created S3 Vectors policy: {s3_vectors_policy_name}")
except self.iam_client.exceptions.EntityAlreadyExistsException:
print(f"Policy {oss_policy_arn} already exists, updating it")
print("Opensearch serverless arn: ", oss_policy_arn)
print(f"S3 Vectors policy {s3_vectors_policy_name} already exists")
s3_vectors_policy = self.iam_client.get_policy(
PolicyArn=f"arn:aws:iam::{self.account_number}:policy/{s3_vectors_policy_name}"
)
# Attach policy to Bedrock execution role
s3_vectors_policy_arn = s3_vectors_policy["Policy"]["Arn"]
self.iam_client.attach_role_policy(
RoleName=bedrock_kb_execution_role["Role"]["RoleName"],
PolicyArn=oss_policy_arn,
PolicyArn=s3_vectors_policy_arn,
)
print(
f"Attached S3 Vectors policy to role: {bedrock_kb_execution_role['Role']['RoleName']}"
)
return created
def create_policies_in_oss(
self,
encryption_policy_name: str,
vector_store_name: str,
network_policy_name: str,
bedrock_kb_execution_role: str,
access_policy_name: str,
):
"""
Create OpenSearch Serverless encryption, network and data access policies.
If policies already exist, retrieve them
Args:
encryption_policy_name: name of the data encryption policy
vector_store_name: name of the vector store
network_policy_name: name of the network policy
bedrock_kb_execution_role: name of the knowledge base execution role
access_policy_name: name of the data access policy
Returns:
encryption_policy, network_policy, access_policy
"""
try:
encryption_policy = self.aoss_client.create_security_policy(
name=encryption_policy_name,
policy=json.dumps(
{
"Rules": [
{
"Resource": ["collection/" + vector_store_name],
"ResourceType": "collection",
}
],
"AWSOwnedKey": True,
}
),
type="encryption",
)
except self.aoss_client.exceptions.ConflictException:
print(f"{encryption_policy_name} already exists, retrieving it!")
encryption_policy = self.aoss_client.get_security_policy(
name=encryption_policy_name, type="encryption"
)
try:
network_policy = self.aoss_client.create_security_policy(
name=network_policy_name,
policy=json.dumps(
[
{
"Rules": [
{
"Resource": ["collection/" + vector_store_name],
"ResourceType": "collection",
}
],
"AllowFromPublic": True,
}
]
),
type="network",
)
except self.aoss_client.exceptions.ConflictException:
print(f"{network_policy_name} already exists, retrieving it!")
network_policy = self.aoss_client.get_security_policy(
name=network_policy_name, type="network"
)
try:
access_policy = self.aoss_client.create_access_policy(
name=access_policy_name,
policy=json.dumps(
[
{
"Rules": [
{
"Resource": ["collection/" + vector_store_name],
"Permission": [
"aoss:CreateCollectionItems",
"aoss:DeleteCollectionItems",
"aoss:UpdateCollectionItems",
"aoss:DescribeCollectionItems",
],
"ResourceType": "collection",
},
{
"Resource": ["index/" + vector_store_name + "/*"],
"Permission": [
"aoss:CreateIndex",
"aoss:DeleteIndex",
"aoss:UpdateIndex",
"aoss:DescribeIndex",
"aoss:ReadDocument",
"aoss:WriteDocument",
],
"ResourceType": "index",
},
],
"Principal": [
self.identity,
bedrock_kb_execution_role["Role"]["Arn"],
],
"Description": "Easy data policy",
}
]
),
type="data",
)
except self.aoss_client.exceptions.ConflictException:
print(f"{access_policy_name} already exists, retrieving it!")
access_policy = self.aoss_client.get_access_policy(
name=access_policy_name, type="data"
)
return encryption_policy, network_policy, access_policy
def create_oss(
self,
vector_store_name: str,
oss_policy_name: str,
bedrock_kb_execution_role: str,
):
"""
Create OpenSearch Serverless Collection. If already existent, retrieve
Args:
vector_store_name: name of the vector store
oss_policy_name: name of the opensearch serverless access policy
bedrock_kb_execution_role: name of the knowledge base execution role
"""
try:
collection = self.aoss_client.create_collection(
name=vector_store_name, type="VECTORSEARCH"
)
collection_id = collection["createCollectionDetail"]["id"]
collection_arn = collection["createCollectionDetail"]["arn"]
except self.aoss_client.exceptions.ConflictException:
collection = self.aoss_client.batch_get_collection(
names=[vector_store_name]
)["collectionDetails"][0]
pp.pprint(collection)
collection_id = collection["id"]
collection_arn = collection["arn"]
pp.pprint(collection)
# Get the OpenSearch serverless collection URL
host = collection_id + "." + self.region_name + ".aoss.amazonaws.com"
print(host)
# wait for collection creation
# This can take couple of minutes to finish
response = self.aoss_client.batch_get_collection(names=[vector_store_name])
# Periodically check collection status
while (response["collectionDetails"][0]["status"]) == "CREATING":
print("Creating collection...")
interactive_sleep(30)
response = self.aoss_client.batch_get_collection(names=[vector_store_name])
print("\nCollection successfully created:")
pp.pprint(response["collectionDetails"])
# create opensearch serverless access policy and attach it to Bedrock execution role
try:
created = self.create_oss_policy_attach_bedrock_execution_role(
collection_id, oss_policy_name, bedrock_kb_execution_role
)
if created:
# It can take up to a minute for data access rules to be enforced
print(
"Sleeping for a minute to ensure data access rules have been enforced"
)
interactive_sleep(60)
return host, collection, collection_id, collection_arn
except Exception as e:
print("Policy already exists")
pp.pprint(e)
def create_vector_index(self, index_name: str):
"""
Create OpenSearch Serverless vector index. If existent, ignore
Args:
index_name: name of the vector index
"""
body_json = {
"settings": {
"index.knn": "true",
"number_of_shards": 1,
"knn.algo_param.ef_search": 512,
"number_of_replicas": 0,
},
"mappings": {
"properties": {
"vector": {
"type": "knn_vector",
"dimension": 1024,
"method": {
"name": "hnsw",
"engine": "faiss",
"space_type": "l2",
},
},
"text": {"type": "text"},
"text-metadata": {"type": "text"},
}
},
}
# Create index
try:
response = self.oss_client.indices.create(
index=index_name, body=json.dumps(body_json)
)
print("\nCreating index:")
pp.pprint(response)
# index creation can take up to a minute
interactive_sleep(60)
except RequestError as e:
# you can delete the index if its already exists
# oss_client.indices.delete(index=index_name)
print(
f"Error while trying to create the index, with error {e.error}\nyou may unmark the delete above to "
f"delete, and recreate the index"
)
@retry(wait_random_min=1000, wait_random_max=2000, stop_max_attempt_number=7)
def create_knowledge_base(
self,
collection_arn: str,
vector_bucket_arn: str,
index_arn: str,
index_name: str,
bucket_name: str,
embedding_model: str,
@@ -716,8 +552,9 @@ class KnowledgeBasesForAmazonBedrock:
"""
Create Knowledge Base and its Data Source. If existent, retrieve
Args:
collection_arn: ARN of the opensearch serverless collection
index_name: name of the opensearch serverless index
vector_bucket_arn: ARN of the S3 vectors bucket
index_arn: ARN of the S3 vectors index
index_name: name of the S3 vectors index
bucket_name: name of the s3 bucket containing the knowledge base data
embedding_model: id of the embedding model used
kb_name: knowledge base name
@@ -728,14 +565,12 @@ class KnowledgeBasesForAmazonBedrock:
knowledge base object,
data source object
"""
opensearch_serverless_configuration = {
"collectionArn": collection_arn,
"vectorIndexName": index_name,
"fieldMapping": {
"vectorField": "vector",
"textField": "text",
"metadataField": "text-metadata",
},
print(vector_bucket_arn)
print(index_name)
s3_vectors_configuration = {
"vectorBucketArn": vector_bucket_arn,
# "indexName": index_name,
"indexArn": index_arn,
}
# Ingest strategy - How to ingest data from the data source
@@ -770,6 +605,7 @@ class KnowledgeBasesForAmazonBedrock:
)
)
try:
print(bedrock_kb_execution_role["Role"]["Arn"])
create_kb_response = self.bedrock_agent_client.create_knowledge_base(
name=kb_name,
description=kb_description,
@@ -781,23 +617,24 @@ class KnowledgeBasesForAmazonBedrock:
},
},
storageConfiguration={
"type": "OPENSEARCH_SERVERLESS",
"opensearchServerlessConfiguration": opensearch_serverless_configuration,
"type": "S3_VECTORS",
"s3VectorsConfiguration": s3_vectors_configuration,
},
)
kb = create_kb_response["knowledgeBase"]
pp.pprint(kb)
except self.bedrock_agent_client.exceptions.ConflictException:
kbs = self.bedrock_agent_client.list_knowledge_bases(maxResults=100)
kb_id = None
for kb in kbs["knowledgeBaseSummaries"]:
if kb["name"] == kb_name:
kb_id = kb["knowledgeBaseId"]
response = self.bedrock_agent_client.get_knowledge_base(
knowledgeBaseId=kb_id
)
kb = response["knowledgeBase"]
pp.pprint(kb)
except Exception as e:
# kbs = self.bedrock_agent_client.list_knowledge_bases(maxResults=100)
# kb_id = None
# for kb in kbs["knowledgeBaseSummaries"]:
# if kb["name"] == kb_name:
# kb_id = kb["knowledgeBaseId"]
# response = self.bedrock_agent_client.get_knowledge_base(
# knowledgeBaseId=kb_id
# )
# kb = response["knowledgeBase"]
# pp.pprint(kb)
print(e)
# Create a DataSource in KnowledgeBase
try:
@@ -878,7 +715,7 @@ class KnowledgeBasesForAmazonBedrock:
kb_name: str,
delete_s3_bucket: bool = True,
delete_iam_roles_and_policies: bool = True,
delete_aoss: bool = True,
delete_s3_vector: bool = True,
):
"""
Delete the Knowledge Base resources
@@ -886,7 +723,7 @@ class KnowledgeBasesForAmazonBedrock:
kb_name: name of the knowledge base to delete
delete_s3_bucket (bool): boolean to indicate if s3 bucket should also be deleted
delete_iam_roles_and_policies (bool): boolean to indicate if IAM roles and Policies should also be deleted
delete_aoss: boolean to indicate if amazon opensearch serverless resources should also be deleted
delete_s3_vector: boolean to indicate if amazon Amazon S3 Vector
"""
kbs_available = self.bedrock_agent_client.list_knowledge_bases(
maxResults=100,
@@ -898,36 +735,16 @@ class KnowledgeBasesForAmazonBedrock:
kb_id = kb["knowledgeBaseId"]
kb_details = self.bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id)
kb_role = kb_details["knowledgeBase"]["roleArn"].split("/")[1]
collection_id = kb_details["knowledgeBase"]["storageConfiguration"][
"opensearchServerlessConfiguration"
]["collectionArn"].split("/")[1]
index_name = kb_details["knowledgeBase"]["storageConfiguration"][
"opensearchServerlessConfiguration"
]["vectorIndexName"]
encryption_policies = self.aoss_client.list_security_policies(
maxResults=100, type="encryption"
)
encryption_policy_name = None
for ep in encryption_policies["securityPolicySummaries"]:
if ep["name"].startswith(kb_name):
encryption_policy_name = ep["name"]
network_policies = self.aoss_client.list_security_policies(
maxResults=100, type="network"
)
network_policy_name = None
for np in network_policies["securityPolicySummaries"]:
if np["name"].startswith(kb_name):
network_policy_name = np["name"]
data_policies = self.aoss_client.list_access_policies(
maxResults=100, type="data"
)
access_policy_name = None
for dp in data_policies["accessPolicySummaries"]:
if dp["name"].startswith(kb_name):
access_policy_name = dp["name"]
vector_bucket_arn = kb_details["knowledgeBase"]["storageConfiguration"][
"s3VectorsConfiguration"
]["vectorBucketArn"]
# index_name = kb_details["knowledgeBase"]["storageConfiguration"][
# "s3VectorsConfiguration"
# ]["indexName"]
index_arn = kb_details["knowledgeBase"]["storageConfiguration"][
"s3VectorsConfiguration"
]["indexArn"]
ds_available = self.bedrock_agent_client.list_data_sources(
knowledgeBaseId=kb_id,
@@ -936,71 +753,41 @@ class KnowledgeBasesForAmazonBedrock:
for ds in ds_available["dataSourceSummaries"]:
if kb_id == ds["knowledgeBaseId"]:
ds_id = ds["dataSourceId"]
ds_details = self.bedrock_agent_client.get_data_source(
self.bedrock_agent_client.get_data_source(
dataSourceId=ds_id,
knowledgeBaseId=kb_id,
)
bucket_name = ds_details["dataSource"]["dataSourceConfiguration"][
"s3Configuration"
]["bucketArn"].replace("arn:aws:s3:::", "")
try:
self.bedrock_agent_client.delete_data_source(
dataSourceId=ds_id, knowledgeBaseId=kb_id
if (
delete_s3_vector
): # Renamed for backward compatibility, but now handles S3 vectors
self.s3_vectors_client.delete_index(
# vectorBucketName=vector_bucket_name,
# vectorBucketArn=vector_bucket_arn,
# indexName=index_name,
indexArn=index_arn,
)
print("Data Source deleted successfully!")
except Exception as e:
print(e)
try:
self.bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=kb_id)
print("Knowledge Base deleted successfully!")
except Exception as e:
print(e)
if delete_aoss:
try:
self.oss_client.indices.delete(index=index_name)
print("OpenSource Serveless Index deleted successfully!")
except Exception as e:
print(e)
try:
self.aoss_client.delete_collection(id=collection_id)
print("OpenSource Collection Index deleted successfully!")
except Exception as e:
print(e)
try:
self.aoss_client.delete_access_policy(
type="data", name=access_policy_name
)
print("OpenSource Serveless access policy deleted successfully!")
except Exception as e:
print(e)
try:
self.aoss_client.delete_security_policy(
type="network", name=network_policy_name
)
print("OpenSource Serveless network policy deleted successfully!")
except Exception as e:
print(e)
try:
self.aoss_client.delete_security_policy(
type="encryption", name=encryption_policy_name
)
print("OpenSource Serveless encryption policy deleted successfully!")
except Exception as e:
print(e)
if delete_s3_bucket:
try:
self.delete_s3(bucket_name)
print("Knowledge Base S3 bucket deleted successfully!")
except Exception as e:
print(e)
print("S3 Vectors index deleted successfully!")
self.s3_vectors_client.delete_vector_bucket(
vectorBucketArn=vector_bucket_arn,
)
print("S3 Vectors bucket deleted successfully!")
if delete_iam_roles_and_policies:
try:
self.delete_iam_roles_and_policies(kb_role)
print("Knowledge Base Roles and Policies deleted successfully!")
except Exception as e:
print(e)
self.delete_iam_roles_and_policies(kb_role)
print("Knowledge Base Roles and Policies deleted successfully!")
print("Resources deleted successfully!")
self.bedrock_agent_client.delete_data_source(
dataSourceId=ds_id, knowledgeBaseId=kb_id
)
print("Data Source deleted successfully!")
self.bedrock_agent_client.delete_knowledge_base(knowledgeBaseId=kb_id)
print("Knowledge Base deleted successfully!")
def delete_iam_roles_and_policies(self, kb_execution_role_name: str):
"""
Delete IAM Roles and policies used by the Knowledge Base
@@ -18,7 +18,7 @@ smm_client = boto3.client("ssm")
# Get warranty table name from Parameter Store
warranty_table = smm_client.get_parameter(
Name="/app/customersupport/dynamodb/warranty-table-name", WithDecryption=False
Name="/app/customersupport/dynamodb/warranty_table_name", WithDecryption=False
)
warranty_table_name = warranty_table["Parameter"]["Value"]
@@ -18,7 +18,7 @@ smm_client = boto3.client("ssm")
# Get customer profile table name from Parameter Store
customer_table = smm_client.get_parameter(
Name="/app/customersupport/dynamodb/customer-profile-table-name",
Name="/app/customersupport/dynamodb/customer_profile_table_name",
WithDecryption=False,
)
customer_table_name = customer_table["Parameter"]["Value"]
@@ -1,6 +1,5 @@
from check_warranty import check_warranty_status
from get_customer_profile import get_customer_profile
import json
def get_named_parameter(event, name):
@@ -1,10 +0,0 @@
import boto3
from utils import get_aws_region
agentcore_control_client = boto3.client(
"bedrock-agentcore-control", region_name=get_aws_region()
)
# print(agentcore_control_client.list_agent_runtimes())
runtime_delete_response = agentcore_control_client.delete_agent_runtime()
@@ -0,0 +1,76 @@
#!/usr/bin/env python3
import sys
import boto3
import click
from utils import get_aws_region
@click.command()
@click.argument('agent_name', type=str)
@click.option('--dry-run', is_flag=True, help='Show what would be deleted without actually deleting')
def delete_agent_runtime(agent_name: str, dry_run: bool):
"""Delete an agent runtime by name from AWS Bedrock AgentCore.
AGENT_NAME: Name of the agent runtime to delete
"""
try:
agentcore_control_client = boto3.client(
"bedrock-agentcore-control", region_name=get_aws_region()
)
except Exception as e:
click.echo(f"Error creating AWS client: {e}", err=True)
sys.exit(1)
agent_id = None
found = False
next_token = None
click.echo(f"Searching for agent runtime: {agent_name}")
try:
while True:
kwargs = {"maxResults": 20}
if next_token:
kwargs["nextToken"] = next_token
agent_runtimes = agentcore_control_client.list_agent_runtimes(**kwargs)
for agent_runtime in agent_runtimes.get("agentRuntimes", []):
if agent_runtime["agentRuntimeName"] == agent_name:
agent_id = agent_runtime["agentRuntimeId"]
found = True
break
if found:
break
next_token = agent_runtimes.get("nextToken")
if not next_token:
break
except Exception as e:
click.echo(f"Error listing agent runtimes: {e}", err=True)
sys.exit(1)
if found:
click.echo(f"Found agent runtime '{agent_name}' with ID: {agent_id}")
if dry_run:
click.echo(f"[DRY RUN] Would delete agent runtime: {agent_name}")
return
try:
agentcore_control_client.delete_agent_runtime(agentRuntimeId=agent_id)
click.echo(f"Successfully deleted agent runtime: {agent_name}")
except Exception as e:
click.echo(f"Error deleting agent runtime: {e}", err=True)
sys.exit(1)
else:
click.echo(f"Agent runtime '{agent_name}' not found", err=True)
sys.exit(1)
if __name__ == "__main__":
delete_agent_runtime()
@@ -1,5 +1,5 @@
#!/usr/bin/python
from typing import List
import json
import os
import sys
import boto3
@@ -8,9 +8,10 @@ import click
from utils import (
get_aws_region,
get_ssm_parameter,
put_ssm_parameter,
delete_ssm_parameter,
load_api_spec,
save_config,
read_config,
get_cognito_client_secret,
)
@@ -40,7 +41,9 @@ def create_gateway(gateway_name: str, api_spec: List) -> dict:
auth_config = {
"customJWTAuthorizer": {
"allowedClients": [
get_ssm_parameter("/app/customersupport/agentcore/machine_client_id")
get_ssm_parameter(
"/app/customersupport/agentcore/machine_client_id"
)
],
"discoveryUrl": get_ssm_parameter(
"/app/customersupport/agentcore/cognito_discovery_url"
@@ -69,7 +72,7 @@ def create_gateway(gateway_name: str, api_spec: List) -> dict:
# Create gateway target
credential_config = [{"credentialProviderType": "GATEWAY_IAM_ROLE"}]
gateway_id = create_response["gatewayId"]
create_target_response = gateway_client.create_gateway_target(
gatewayIdentifier=gateway_id,
name="LambdaUsingSDK",
@@ -87,9 +90,23 @@ def create_gateway(gateway_name: str, api_spec: List) -> dict:
"gateway_arn": create_response["gatewayArn"],
}
save_config(gateway, "gateway.config")
click.echo("✅ Gateway configuration saved to gateway.config")
# Save gateway details to SSM parameters
put_ssm_parameter("/app/customersupport/agentcore/gateway_id", gateway_id)
put_ssm_parameter("/app/customersupport/agentcore/gateway_name", gateway_name)
put_ssm_parameter(
"/app/customersupport/agentcore/gateway_arn", create_response["gatewayArn"]
)
put_ssm_parameter(
"/app/customersupport/agentcore/gateway_url", create_response["gatewayUrl"]
)
put_ssm_parameter(
"/app/customersupport/agentcore/cognito_secret",
get_cognito_client_secret(),
with_encryption=True,
)
click.echo("✅ Gateway configuration saved to SSM parameters")
return gateway
except Exception as e:
@@ -101,12 +118,12 @@ def delete_gateway(gateway_id: str) -> bool:
"""Delete a gateway and all its targets."""
try:
click.echo(f"🗑️ Deleting all targets for gateway: {gateway_id}")
# List and delete all targets
list_response = gateway_client.list_gateway_targets(
gatewayIdentifier=gateway_id, maxResults=100
)
for item in list_response["items"]:
target_id = item["targetId"]
click.echo(f" Deleting target: {target_id}")
@@ -119,7 +136,7 @@ def delete_gateway(gateway_id: str) -> bool:
click.echo(f"🗑️ Deleting gateway: {gateway_id}")
gateway_client.delete_gateway(gatewayIdentifier=gateway_id)
click.echo(f"✅ Gateway {gateway_id} deleted successfully")
return True
except Exception as e:
@@ -128,12 +145,11 @@ def delete_gateway(gateway_id: str) -> bool:
def get_gateway_id_from_config() -> str:
"""Get gateway ID from config file."""
"""Get gateway ID from SSM parameter."""
try:
config = read_config("gateway.config")
return config["gateway"]["id"]
return get_ssm_parameter("/app/customersupport/agentcore/gateway_id")
except Exception as e:
click.echo(f"❌ Error reading gateway config: {str(e)}", err=True)
click.echo(f"❌ Error reading gateway ID from SSM: {str(e)}", err=True)
return None
@@ -141,38 +157,34 @@ def get_gateway_id_from_config() -> str:
@click.pass_context
def cli(ctx):
"""AgentCore Gateway Management CLI.
Create and delete AgentCore gateways for the customer support application.
"""
ctx.ensure_object(dict)
@cli.command()
@click.option(
"--name",
required=True,
help="Name for the gateway"
)
@click.option("--name", required=True, help="Name for the gateway")
@click.option(
"--api-spec-file",
default="lambda/api_spec.json",
help="Path to the API specification file (default: lambda/api_spec.json)"
default="prerequisite/lambda/api_spec.json",
help="Path to the API specification file (default: prerequisite/lambda/api_spec.json)",
)
def create(name, api_spec_file):
"""Create a new AgentCore gateway."""
click.echo(f"🚀 Creating AgentCore gateway: {name}")
click.echo(f"📍 Region: {REGION}")
# Validate API spec file exists
if not os.path.exists(api_spec_file):
click.echo(f"❌ API specification file not found: {api_spec_file}", err=True)
sys.exit(1)
try:
api_spec = load_api_spec(api_spec_file)
gateway = create_gateway(gateway_name=name, api_spec=api_spec)
click.echo(f"🎉 Gateway created successfully with ID: {gateway['id']}")
except Exception as e:
click.echo(f"❌ Failed to create gateway: {str(e)}", err=True)
sys.exit(1)
@@ -181,40 +193,49 @@ def create(name, api_spec_file):
@cli.command()
@click.option(
"--gateway-id",
help="Gateway ID to delete (if not provided, will read from gateway.config)"
)
@click.option(
"--confirm",
is_flag=True,
help="Skip confirmation prompt"
help="Gateway ID to delete (if not provided, will read from gateway.config)",
)
@click.option("--confirm", is_flag=True, help="Skip confirmation prompt")
def delete(gateway_id, confirm):
"""Delete an AgentCore gateway and all its targets."""
# If no gateway ID provided, try to read from config
if not gateway_id:
gateway_id = get_gateway_id_from_config()
if not gateway_id:
click.echo("❌ No gateway ID provided and couldn't read from gateway.config", err=True)
click.echo(
"❌ No gateway ID provided and couldn't read from SSM parameters",
err=True,
)
sys.exit(1)
click.echo(f"📖 Using gateway ID from config: {gateway_id}")
click.echo(f"📖 Using gateway ID from SSM: {gateway_id}")
# Confirmation prompt
if not confirm:
if not click.confirm(f"⚠️ Are you sure you want to delete gateway {gateway_id}? This action cannot be undone."):
if not click.confirm(
f"⚠️ Are you sure you want to delete gateway {gateway_id}? This action cannot be undone."
):
click.echo("❌ Operation cancelled")
sys.exit(0)
click.echo(f"🗑️ Deleting gateway: {gateway_id}")
if delete_gateway(gateway_id):
click.echo("✅ Gateway deleted successfully")
# Always clean up config file if it exists
# Clean up SSM parameters
delete_ssm_parameter("/app/customersupport/agentcore/gateway_id")
delete_ssm_parameter("/app/customersupport/agentcore/gateway_name")
delete_ssm_parameter("/app/customersupport/agentcore/gateway_arn")
delete_ssm_parameter("/app/customersupport/agentcore/gateway_url")
delete_ssm_parameter("/app/customersupport/agentcore/cognito_secret")
click.echo("🧹 Removed gateway SSM parameters")
# Clean up config file if it exists (backward compatibility)
if os.path.exists("gateway.config"):
os.remove("gateway.config")
click.echo("🧹 Removed gateway.config file")
click.echo("🎉 Gateway and configuration deleted successfully")
else:
click.echo("❌ Failed to delete gateway", err=True)
@@ -222,4 +243,4 @@ def delete(gateway_id, confirm):
if __name__ == "__main__":
cli()
cli()
@@ -1,3 +1,4 @@
#!/usr/bin/python
import click
import boto3
import sys
@@ -37,7 +38,7 @@ def delete_ssm_param(param_name: str):
@click.pass_context
def cli(ctx):
"""AgentCore Memory Management CLI.
Create and delete AgentCore memory resources for the customer support application.
"""
ctx.ensure_object(dict)
@@ -45,35 +46,47 @@ def cli(ctx):
@cli.command()
@click.option(
"--name",
default="CustomerSupportMemory",
help="Name of the memory resource"
"--name", default="CustomerSupportMemory", help="Name of the memory resource"
)
@click.option(
"--ssm-param",
default="/app/customersupport/agentcore/memory_id",
help="SSM parameter to store memory_id"
help="SSM parameter to store memory_id",
)
@click.option(
"--event-expiry-days",
default=30,
type=int,
help="Number of days before events expire (default: 30)"
help="Number of days before events expire (default: 30)",
)
def create(name, ssm_param, event_expiry_days):
"""Create a new AgentCore memory resource."""
click.echo(f"🚀 Creating AgentCore memory: {name}")
click.echo(f"📍 Region: {REGION}")
click.echo(f"⏱️ Event expiry: {event_expiry_days} days")
strategies = [
{
StrategyType.SEMANTIC.value: {
"name": "fact_extractor",
"description": "Extracts and stores factual information",
"namespaces": ["support/user/{actorId}/facts"],
},
},
{
StrategyType.SUMMARY.value: {
"name": "conversation_summary",
"description": "Captures summaries of conversations",
"namespaces": ["summaries/{actorId}/{sessionId}"],
}
}
"namespaces": ["support/user/{actorId}/{sessionId}"],
},
},
{
StrategyType.USER_PREFERENCE.value: {
"name": "user_preferences",
"description": "Captures user preferences and settings",
"namespaces": ["support/user/{actorId}/preferences"],
},
},
]
try:
@@ -86,7 +99,7 @@ def create(name, ssm_param, event_expiry_days):
)
memory_id = memory["memoryId"]
click.echo(f"✅ Memory created successfully: {memory_id}")
except Exception as e:
if "already exists" in str(e):
click.echo("📋 Memory already exists, finding existing resource...")
@@ -105,10 +118,10 @@ def create(name, ssm_param, event_expiry_days):
try:
store_memory_id_in_ssm(ssm_param, memory_id)
click.echo(f"🎉 Memory setup completed successfully!")
click.echo("🎉 Memory setup completed successfully!")
click.echo(f" Memory ID: {memory_id}")
click.echo(f" SSM Parameter: {ssm_param}")
except Exception as e:
click.echo(f"⚠️ Memory created but failed to store in SSM: {str(e)}", err=True)
@@ -116,33 +129,34 @@ def create(name, ssm_param, event_expiry_days):
@cli.command()
@click.option(
"--memory-id",
help="Memory ID to delete (if not provided, will read from SSM parameter)"
help="Memory ID to delete (if not provided, will read from SSM parameter)",
)
@click.option(
"--ssm-param",
default="/app/customersupport/agentcore/memory_id",
help="SSM parameter to retrieve memory_id from"
)
@click.option(
"--confirm",
is_flag=True,
help="Skip confirmation prompt"
help="SSM parameter to retrieve memory_id from",
)
@click.option("--confirm", is_flag=True, help="Skip confirmation prompt")
def delete(memory_id, ssm_param, confirm):
"""Delete an AgentCore memory resource."""
# If no memory ID provided, try to read from SSM
if not memory_id:
try:
memory_id = get_memory_id_from_ssm(ssm_param)
click.echo(f"📖 Using memory ID from SSM: {memory_id}")
except Exception:
click.echo("❌ No memory ID provided and couldn't read from SSM parameter", err=True)
click.echo(
"❌ No memory ID provided and couldn't read from SSM parameter",
err=True,
)
sys.exit(1)
# Confirmation prompt
if not confirm:
if not click.confirm(f"⚠️ Are you sure you want to delete memory {memory_id}? This action cannot be undone."):
if not click.confirm(
f"⚠️ Are you sure you want to delete memory {memory_id}? This action cannot be undone."
):
click.echo("❌ Operation cancelled")
sys.exit(0)
@@ -161,4 +175,4 @@ def delete(memory_id, ssm_param, confirm):
if __name__ == "__main__":
cli()
cli()
@@ -51,7 +51,9 @@ fi
echo "🗑️ Removing local file $ZIP_FILE..."
rm -f "$ZIP_FILE"
# ----- 5. Delete Knowledge Base -----
echo "🗑️ Deleting Knowledgebase"
python prerequisite/knowledge_base.py --mode delete
echo "✅ Cleanup complete."
echo "✅ Deployment complete."
@@ -1,8 +1,9 @@
#!/usr/bin/python
import boto3
import click
import sys
from botocore.exceptions import ClientError
from utils import get_ssm_parameter, read_config, get_aws_region
from utils import get_ssm_parameter, get_aws_region
REGION = get_aws_region()
@@ -48,17 +49,15 @@ def delete_ssm_param():
def create_cognito_provider(provider_name: str) -> dict:
"""Create a Cognito OAuth2 credential provider."""
try:
click.echo("🔧 Reading gateway configuration...")
gateway_config = read_config("gateway.config")
click.echo("✅ Gateway configuration loaded")
click.echo("📥 Fetching Cognito configuration from SSM...")
client_id = get_ssm_parameter(
"/app/customersupport/agentcore/machine_client_id"
)
click.echo(f"✅ Retrieved client ID: {client_id}")
client_secret = gateway_config["cognito"]["secret"]
client_secret = get_ssm_parameter(
"/app/customersupport/agentcore/cognito_secret"
)
click.echo(f"✅ Retrieved client secret: {client_secret[:4]}***")
issuer = get_ssm_parameter(
@@ -113,11 +112,9 @@ def delete_cognito_provider(provider_name: str) -> bool:
try:
click.echo(f"🗑️ Deleting OAuth2 credential provider: {provider_name}")
identity_client.delete_oauth2_credential_provider(
name=provider_name
)
identity_client.delete_oauth2_credential_provider(name=provider_name)
click.echo(f"✅ OAuth2 credential provider deleted successfully")
click.echo("✅ OAuth2 credential provider deleted successfully")
return True
except Exception as e:
@@ -1,3 +1,4 @@
#!/usr/bin/python
import json
import os
import sys
@@ -121,7 +122,7 @@ def delete_google_provider(provider_name: str) -> bool:
identity_client.delete_oauth2_credential_provider(name=provider_name)
click.echo(f"✅ Google OAuth2 credential provider deleted successfully")
click.echo("✅ Google OAuth2 credential provider deleted successfully")
return True
except Exception as e:
@@ -13,7 +13,7 @@ REGION=$(aws configure get region)
ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
FULL_BUCKET_NAME="${BUCKET_NAME}-${ACCOUNT_ID}"
ZIP_FILE="lambda.zip"
LAMBDA_SRC="lambda/python"
LAMBDA_SRC="prerequisite/lambda/python"
S3_KEY="${ZIP_FILE}"
# ----- 1. Create S3 bucket -----
@@ -33,7 +33,7 @@ fi
# ----- 2. Zip Lambda code -----
echo "📦 Zipping contents of $LAMBDA_SRC into $ZIP_FILE..."
cd "$LAMBDA_SRC"
zip -r "../../$ZIP_FILE" . > /dev/null
zip -r "../../../$ZIP_FILE" . > /dev/null
cd - > /dev/null
# ----- 3. Upload to S3 -----
@@ -1,7 +1,8 @@
import boto3
import json
import yaml
import os
from typing import Dict, Any
def get_ssm_parameter(name: str, with_decryption: bool = True) -> str:
@@ -12,6 +13,32 @@ def get_ssm_parameter(name: str, with_decryption: bool = True) -> str:
return response["Parameter"]["Value"]
def put_ssm_parameter(
name: str, value: str, parameter_type: str = "String", with_encryption: bool = False
) -> None:
ssm = boto3.client("ssm")
put_params = {
"Name": name,
"Value": value,
"Type": parameter_type,
"Overwrite": True,
}
if with_encryption:
put_params["Type"] = "SecureString"
ssm.put_parameter(**put_params)
def delete_ssm_parameter(name: str) -> None:
ssm = boto3.client("ssm")
try:
ssm.delete_parameter(Name=name)
except ssm.exceptions.ParameterNotFound:
pass
def load_api_spec(file_path: str) -> list:
with open(file_path, "r") as f:
data = json.load(f)
@@ -30,26 +57,6 @@ def get_aws_account_id() -> str:
return sts.get_caller_identity()["Account"]
class Gateway:
pass
def save_config(gateway: Gateway, filepath: str):
# Extract relevant data as a dict
config_data = {
"gateway": {
"id": gateway["id"],
"name": gateway["name"],
"gateway_url": gateway["gateway_url"],
"gateway_arn": gateway["gateway_arn"],
},
"cognito": {"secret": get_cognito_client_secret()},
}
# Write YAML file
with open(filepath, "w") as f:
yaml.dump(config_data, f, sort_keys=False)
def get_cognito_client_secret() -> str:
client = boto3.client("cognito-idp")
response = client.describe_user_pool_client(
@@ -59,7 +66,55 @@ def get_cognito_client_secret() -> str:
return response["UserPoolClient"]["ClientSecret"]
def read_config(file_path: str) -> dict:
with open(file_path, "r") as f:
config = yaml.safe_load(f)
return config
def read_config(file_path: str) -> Dict[str, Any]:
"""
Read configuration from a file path. Supports JSON, YAML, and YML formats.
Args:
file_path (str): Path to the configuration file
Returns:
Dict[str, Any]: Configuration data as a dictionary
Raises:
FileNotFoundError: If the file doesn't exist
ValueError: If the file format is not supported or invalid
yaml.YAMLError: If YAML parsing fails
json.JSONDecodeError: If JSON parsing fails
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"Configuration file not found: {file_path}")
# Get file extension to determine format
_, ext = os.path.splitext(file_path.lower())
try:
with open(file_path, "r", encoding="utf-8") as file:
if ext == ".json":
return json.load(file)
elif ext in [".yaml", ".yml"]:
return yaml.safe_load(file)
else:
# Try to auto-detect format by attempting JSON first, then YAML
content = file.read()
file.seek(0)
# Try JSON first
try:
return json.loads(content)
except json.JSONDecodeError:
# Try YAML
try:
return yaml.safe_load(content)
except yaml.YAMLError:
raise ValueError(
f"Unsupported configuration file format: {ext}. "
f"Supported formats: .json, .yaml, .yml"
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in configuration file {file_path}: {e}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in configuration file {file_path}: {e}")
except Exception as e:
raise ValueError(f"Error reading configuration file {file_path}: {e}")
@@ -1,6 +1,5 @@
#!/usr/bin/env python3
#!/usr/bin/python
import ast
import base64
import hashlib
from typing import Any, Optional
@@ -62,7 +61,6 @@ def invoke_endpoint(
)
logger = logging.getLogger("bedrock_agentcore.stream")
logger.setLevel(logging.INFO)
content = []
last_data = False
# for line in response.text.splitlines():
@@ -82,11 +80,11 @@ def invoke_endpoint(
# print(line)
if line.startswith("data: "):
last_data = True
line = line[6:]
line = line[6:].replace('"', "")
print(line, end="")
elif line:
if last_data:
print("\n" + line, end="")
print("\n" + line.replace('"', ""), end="")
last_data = False
# print({"response": "\n".join(content)})
@@ -103,6 +101,8 @@ def main(agent_name: str, prompt: str):
"""CLI tool to invoke a Bedrock agent by name."""
runtime_config = read_config(".bedrock_agentcore.yaml")
print(runtime_config)
if agent_name not in runtime_config["agents"]:
print(f"❌ Agent '{agent_name}' not found in config.")
sys.exit(1)
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
#!/usr/bin/python
import asyncio
import click
@@ -10,7 +10,7 @@ import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from scripts.utils import read_config, get_ssm_parameter
from scripts.utils import get_ssm_parameter
gateway_access_token = None
@@ -34,14 +34,19 @@ def main(prompt: str):
# Fetch access token
asyncio.run(_get_access_token_manually(access_token=""))
# Load config
gateway_config = read_config("gateway.config")
print(f"Gateway Endpoint - MCP URL: {gateway_config['gateway']['gateway_url']}mcp")
# Load gateway configuration from SSM parameters
try:
gateway_url = get_ssm_parameter("/app/customersupport/agentcore/gateway_url")
except Exception as e:
print(f"❌ Error reading gateway URL from SSM: {str(e)}")
sys.exit(1)
print(f"Gateway Endpoint - MCP URL: {gateway_url}")
# Set up MCP client
client = MCPClient(
lambda: streamablehttp_client(
f"{gateway_config['gateway']['gateway_url']}",
gateway_url,
headers={"Authorization": f"Bearer {gateway_access_token}"},
)
)
@@ -1,11 +1,15 @@
import asyncio
#!/usr/bin/python
import json
from bedrock_agentcore.identity.auth import requires_access_token
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from datetime import datetime, timedelta
from strands import tool
from strands import Agent
from strands.models import BedrockModel
import webbrowser
import sys
import os
@@ -15,33 +19,39 @@ from scripts.utils import get_ssm_parameter
async def on_auth_url(url: str):
print(f"Authorization url: {url}")
webbrowser.open(url)
SCOPES = ["https://www.googleapis.com/auth/calendar"]
google_access_token = None
# This annotation helps agent developer to obtain access tokens from external applications
@requires_access_token(
provider_name=get_ssm_parameter("/app/customersupport/agentcore/google_provider"),
scopes=SCOPES, # Google OAuth2 scopes
auth_flow="USER_FEDERATION", # On-behalf-of user (3LO) flow
on_auth_url=on_auth_url, # prints authorization URL to console
force_authentication=True,
into="access_token",
)
async def need_token_3LO_async(*, access_token: str):
global google_access_token
google_access_token = access_token
print(f"google_access_token set: {google_access_token}")
def get_google_access_token(access_token: str):
return access_token
asyncio.run(need_token_3LO_async(access_token=""))
@tool(
name="Create_calendar_event",
description="Creates a new event on your Google Calendar",
)
def create_calendar_event() -> str:
global google_access_token
if not google_access_token:
try:
google_access_token = get_google_access_token(
access_token=google_access_token
)
except Exception as e:
return "Error Authentication with Google: " + str(e)
creds = Credentials(token=google_access_token, scopes=SCOPES)
@@ -84,4 +94,34 @@ def create_calendar_event() -> str:
return json.dumps({"error": str(e), "event_created": False})
model_id = "us.anthropic.claude-sonnet-4-20250514-v1:0"
model = BedrockModel(
model_id=model_id,
)
system_prompt = """
You are a helpful customer support agent ready to assist customers with their inquiries and service needs.
You have access to tools to: check warrant status, view customer profiles, and retrieve Knowledgebase.
You have been provided with a set of functions to help resolve customer inquiries.
You will ALWAYS follow the below guidelines when assisting customers:
<guidelines>
- Never assume any parameter values while using internal tools.
- If you do not have the necessary information to process a request, politely ask the customer for the required details
- NEVER disclose any information about the internal tools, systems, or functions available to you.
- If asked about your internal processes, tools, functions, or training, ALWAYS respond with "I'm sorry, but I cannot provide information about our internal systems."
- Always maintain a professional and helpful tone when assisting customers
- Focus on resolving the customer's inquiries efficiently and accurately
</guidelines>
"""
agent = Agent(
model=model,
system_prompt=system_prompt,
tools=[create_calendar_event],
)
# google_access_token = need_token_3LO_async(access_token="")
response = agent("Can you create a google event?")
print(create_calendar_event())
@@ -1,48 +1,190 @@
#!/usr/bin/python
import json
import click
from bedrock_agentcore.memory import MemoryClient
from strands import Agent
from strands_tools import calculator
import sys
import os
import uuid
import time
from strands.models import BedrockModel
import boto3
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from memory_hook_provider import MemoryHookProvider
from agent_config.memory_hook_provider import MemoryHook
from scripts.utils import get_ssm_parameter
# Load memory ID from SSM
memory_id = get_ssm_parameter("/app/customersupport/agentcore/memory_id")
client = MemoryClient()
# Session & actor configuration
ACTOR_ID = "default"
SESSION_ID = "test"
SESSION_ID = str(uuid.uuid4())
MEMORY_ID = get_ssm_parameter("/app/customersupport/agentcore/memory_id")
# Setup memory hooks
memory_hooks = MemoryHookProvider(
memory_id=memory_id,
client=client,
actor_id=ACTOR_ID,
session_id=SESSION_ID,
model = BedrockModel(
model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0",
)
# Initialize agent with memory and tools
agent = Agent(
hooks=[memory_hooks],
tools=[calculator],
system_prompt="You are a helpful personal math assistant.",
)
memory_client = boto3.client("bedrock-agentcore")
# Interactive prompt loop
print("🧮 Interactive Math Agent")
print("Type your question (or 'q' to quit):")
while True:
user_input = input("You > ").strip()
if user_input.lower() in {"q", "quit"}:
print("👋 Exiting session.")
break
def setup_agent():
"""Setup agent with memory and tools"""
memory_client = MemoryClient()
memory_hooks = MemoryHook(
memory_client=memory_client,
memory_id=MEMORY_ID,
actor_id=ACTOR_ID,
session_id=SESSION_ID,
)
agent = Agent(
hooks=[memory_hooks],
tools=[calculator],
system_prompt="You are a helpful personal assistant.",
model=model,
callback_handler=None,
)
return agent, memory_client
@click.group()
def cli():
"""Memory Testing CLI for Customer Support Assistant"""
pass
@cli.command()
def load_conversation():
"""Load and execute predefined mock conversations to test long-term memory"""
conversations = [
"Hi, how are you doing?",
"My name is John Smith and I'm having trouble with my account login. Can you help me?",
"I'm trying to reset my password but I'm not receiving the verification email.",
"My email address is john.smith@email.com and my account was created about 6 months ago.",
"Actually, let me also mention that I have a premium subscription plan.",
"Can you calculate what 15% of 240 would be? I need to figure out my discount.",
"Great! Now back to my login issue - I remember my username is johnsmith123.",
"I also want to update my billing address to 123 Main Street, New York, NY 10001.",
"By the way, do you remember what my subscription plan type is?",
"Perfect! Can you summarize all the information we discussed about my account today?",
]
click.echo("=== Testing Long-term Memory with Mock Conversations ===")
click.echo(f"Session ID: {SESSION_ID}")
click.echo(f"Actor ID: {ACTOR_ID}")
click.echo("=" * 60)
for i, conversation in enumerate(conversations, 1):
agent, _ = setup_agent()
click.echo(f"\n[{i}/10] You > {conversation}")
try:
response = str(agent(conversation))
click.echo(f"Agent > {response}")
except Exception as e:
click.echo(f"❌ Error: {e}")
# Add a small delay between conversations to simulate real interaction
time.sleep(1)
click.echo("\n" + "=" * 60)
click.echo("=== Memory Test Complete ===")
@cli.command()
@click.argument("prompt", type=str)
def load_prompt(prompt):
"""Load a custom prompt from user input and execute it with memory"""
click.echo("=== Processing Custom Prompt ===")
click.echo(f"Session ID: {SESSION_ID}")
click.echo(f"Actor ID: {ACTOR_ID}")
click.echo("=" * 40)
agent, _ = setup_agent()
click.echo(f"You > {prompt}")
try:
print(f"Agent >")
response = agent(user_input)
response = agent(prompt)
click.echo(f"Agent > {response}")
except Exception as e:
print(f"❌ Error: {e}")
click.echo(f"❌ Error: {e}")
click.echo("=" * 40)
click.echo("✓ Custom prompt processed successfully")
@cli.command()
def list_memory():
"""List all memory entries (not implemented yet)"""
click.echo("=== Memory List Command ===")
click.echo(f"Session ID: {SESSION_ID}")
click.echo(f"Actor ID: {ACTOR_ID}")
click.echo("=" * 30)
list_sessions = memory_client.list_sessions(
memoryId=MEMORY_ID, actorId=ACTOR_ID, maxResults=3
)
click.echo("All Sessions")
first_session = None
for list_session in list_sessions["sessionSummaries"]:
click.echo(f"Session ID: {list_session['sessionId']}")
if not first_session:
first_session = list_session["sessionId"]
click.echo("=" * 30)
click.echo(f"Events for session: {first_session}")
list_events = memory_client.list_events(
memoryId=MEMORY_ID,
sessionId=first_session,
actorId=ACTOR_ID,
includePayloads=True,
)
click.echo(json.dumps(list_events["events"], indent=2, default=str))
# for list_event in list_events["events"]:
# click.echo(f"Session ID: {list_session['sessionId']}")
# if not first_session:
# first_session = list_session["sessionId"]
click.echo("=" * 30)
click.echo(f"Actor facts {ACTOR_ID}")
list_memory_records = memory_client.list_memory_records(
memoryId=MEMORY_ID,
namespace=f"support/user/{ACTOR_ID}/facts",
)
for list_memory_record in list_memory_records["memoryRecordSummaries"]:
click.echo(f"Content: {list_memory_record['content']['text']}")
click.echo("=" * 30)
click.echo(f"Conversation Summary for {first_session}")
list_memory_records = memory_client.list_memory_records(
memoryId=MEMORY_ID,
namespace=f"support/user/{ACTOR_ID}/{first_session}",
)
for list_memory_record in list_memory_records["memoryRecordSummaries"]:
click.echo(f"Content: {list_memory_record['content']['text'][:200]}...")
click.echo("=" * 30)
click.echo(f"User Preferences {ACTOR_ID}")
list_memory_records = memory_client.list_memory_records(
memoryId=MEMORY_ID,
namespace=f"support/user/{ACTOR_ID}/preferences",
)
for list_memory_record in list_memory_records["memoryRecordSummaries"]:
click.echo(f"Content: {list_memory_record['content']['text']}")
click.echo("=" * 30)
if __name__ == "__main__":
cli()
@@ -1,690 +0,0 @@
"""
Tool for managing memories in Bedrock Agent Core Memory Service.
This module provides Bedrock Agent Core Memory capabilities with memory record
creation and retrieval.
Key Features:
------------
1. Event Management:
• create_event: Store events in memory sessions
2. Memory Record Operations:
• retrieve_memory_records: Semantic search for extracted memories
• list_memory_records: List all memory records
• get_memory_record: Get specific memory record
• delete_memory_record: Delete memory records
Usage Examples:
--------------
```python
from strands import Agent
from strands_tools.agent_core_memory import AgentCoreMemoryToolProvider
# Initialize with required parameters
provider = AgentCoreMemoryToolProvider(
memory_id="memory-123abc", # Required
actor_id="user-456", # Required
session_id="session-789", # Required
namespace="default", # Required
)
agent = Agent(tools=provider.tools)
# Create a memory using the default IDs from initialization
agent.tool.agent_core_memory(
action="RecordMemory",
payload=[{
"conversational": {
"content": {"text": "Hello, how are you?"},
"role": "USER"
}
}]
)
# Search memory records using the default namespace from initialization
agent.tool.agent_core_memory(
action="RetrieveMemory",
query="user preferences"
)
```
"""
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Literal, Optional, Union
# Use typing_extensions.TypedDict instead of typing.TypedDict for compatibility
try:
from typing_extensions import TypedDict
except ImportError:
from typing import TypedDict # Fallback for Python 3.12+
import boto3
from botocore.config import Config as BotocoreConfig
from strands import tool
from strands.types.tools import AgentTool
# Define payload structure types for conversational messages
class ConversationalContent(TypedDict):
"""Content structure for conversational messages."""
text: str
class ConversationalPayload(TypedDict):
"""Structure for conversational messages."""
content: ConversationalContent
role: Literal["USER", "ASSISTANT", "TOOL", "OTHER"]
class ConversationalPayloadItem(TypedDict):
"""Conversational payload item wrapper."""
conversational: ConversationalPayload
# Blob payload can be any type
class BlobPayloadItem(TypedDict):
"""Blob payload item wrapper."""
blob: Any
# Union type for payload items
PayloadItem = Union[ConversationalPayloadItem, BlobPayloadItem, Dict]
# Event payload is a list of payload items
EventPayload = List[PayloadItem]
# Set up logging
logger = logging.getLogger(__name__)
# Default region if not specified
DEFAULT_REGION = "us-west-2"
class AgentCoreMemoryToolProvider:
"""Provider for Agent Core Memory Service tools."""
def __init__(
self,
memory_id: str,
actor_id: str,
session_id: str,
namespace: str,
region: Optional[str] = None,
boto_client_config: Optional[BotocoreConfig] = None,
):
"""
Initialize the Agent Core Memory tool provider.
Args:
memory_id: Memory ID to use for operations (required)
actor_id: Actor ID to use for operations (required)
session_id: Session ID to use for operations (required)
namespace: Namespace for memory record operations (required)
region: AWS region for the service
boto_client_config: Optional boto client configuration
Raises:
ValueError: If any of the required parameters are missing or empty
"""
# Validate required parameters
if not memory_id:
raise ValueError("memory_id is required")
if not actor_id:
raise ValueError("actor_id is required")
if not session_id:
raise ValueError("session_id is required")
if not namespace:
raise ValueError("namespace is required")
self.memory_id = memory_id
self.actor_id = actor_id
self.session_id = session_id
self.namespace = namespace
# Set up client configuration with user agent
if boto_client_config:
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
# Append 'strands-agents-memory' to existing user_agent_extra or set it if not present
if existing_user_agent:
new_user_agent = f"{existing_user_agent} strands-agents-memory"
else:
new_user_agent = "strands-agents-memory"
self.client_config = boto_client_config.merge(
BotocoreConfig(user_agent_extra=new_user_agent)
)
else:
self.client_config = BotocoreConfig(
user_agent_extra="strands-agents-memory"
)
# Resolve region from parameters, environment, or default
self.region = region or DEFAULT_REGION
# Initialize clients with None - they'll be created on first use
self._data_plane_client = None
self._control_plane_client = None
def _init_clients(self, region=None):
"""
Initialize the service clients.
Args:
region: Optional region override. If provided, reinitializes clients with this region.
"""
# Update region if provided
if region:
self.region = region
# Construct endpoint URLs based on the region
data_plane_endpoint = f"https://bedrock-agentcore.{self.region}.amazonaws.com"
control_plane_endpoint = (
f"https://bedrock-agentcore-control.{self.region}.amazonaws.com"
)
# Initialize clients with the appropriate region and endpoints
self._data_plane_client = boto3.client(
"bedrock-agentcore",
region_name=self.region,
endpoint_url=data_plane_endpoint,
config=self.client_config,
)
self._control_plane_client = boto3.client(
"bedrock-agentcore-control", # Agent Core Memory Control Plane
region_name=self.region,
endpoint_url=control_plane_endpoint,
config=self.client_config,
)
@property
def tools(self) -> list[AgentTool]:
"""Extract all @tool decorated methods from this instance."""
tools = []
for attr_name in dir(self):
if attr_name == "tools":
continue
attr = getattr(self, attr_name)
# Also check the original way for regular AgentTool instances
if isinstance(attr, AgentTool):
tools.append(attr)
return tools
@property
def data_plane_client(self):
"""Get the data plane service client, initializing if needed."""
if not self._data_plane_client:
self._init_clients()
return self._data_plane_client
@property
def control_plane_client(self):
"""Get the control plane client, initializing if needed."""
if not self._control_plane_client:
self._init_clients()
return self._control_plane_client
@tool
def agent_core_memory(
self,
action: str,
payload: Optional[EventPayload] = None,
query: Optional[str] = None,
memory_record_id: Optional[str] = None,
max_results: Optional[int] = None,
next_token: Optional[str] = None,
region: Optional[str] = None,
) -> Dict:
"""
Work with agent memories - create, search, retrieve, list, and manage memory records.
This tool helps agents store and access memories, allowing them to remember important
information across conversations and interactions.
Key Capabilities:
- Store new memories (text conversations or structured data)
- Search for memories using semantic search
- Browse and list all stored memories
- Retrieve specific memories by ID
- Delete unwanted memories
Supported Actions:
-----------------
Memory Management:
- RecordMemory: Store a new memory (conversation or data)
Use this when you need to save information for later recall.
- RetrieveMemory: Find relevant memories using semantic search
Use this when searching for specific information in memories.
This is the best action for queries like "find memories about X" or "search for memories related to Y".
- ListMemories: Browse all stored memories
Use this to see all available memories without filtering.
This is useful for getting an overview of what's been stored.
- GetMemory: Fetch a specific memory by ID
Use this when you already know the exact memory ID.
- DeleteMemory: Remove a specific memory
Use this to delete memories that are no longer needed.
Args:
action: The memory operation to perform (see Supported Actions)
payload: Memory content (required for RecordMemory). Must be a list of objects with specific structure:
- For conversational memories: [{"conversational": {"content": {"text": "message"},
"role": "USER"}}]
- For data memories: [{"blob": {"any": "data"}}]
query: Search terms for finding relevant memories (required for RetrieveMemory)
memory_record_id: ID of a specific memory (required for GetMemory, DeleteMemory)
max_results: Maximum number of results to return (optional)
next_token: Pagination token (optional)
region: AWS region (defaults to us-west-2)
Returns:
Dict: Response containing the requested memory information or operation status
Examples:
--------
# Store a new conversational memory
result = agent_core_memory(
action="RecordMemory",
payload=[{
"conversational": {
"content": {"text": "User prefers vegetarian pizza with extra cheese"},
"role": "USER"
}
}]
)
# Store a structured data memory
result = agent_core_memory(
action="RecordMemory",
payload=[{
"blob": {
"preferences": {
"food": "pizza",
"toppings": ["cheese", "mushrooms"],
"crust": "thin"
}
}
}]
)
# Search for relevant memories (use this for finding specific information)
result = agent_core_memory(
action="RetrieveMemory",
query="what food preferences does the user have"
)
# Browse all stored memories (use this for getting an overview)
result = agent_core_memory(
action="ListMemories"
)
# Get a specific memory by ID
result = agent_core_memory(
action="GetMemory",
memory_record_id="mr-12345"
)
# Delete a specific memory
result = agent_core_memory(
action="DeleteMemory",
memory_record_id="mr-12345"
)
"""
try:
# Use values from initialization
memory_id = self.memory_id
actor_id = self.actor_id
session_id = self.session_id
namespace = self.namespace
# Use provided values or defaults for other parameters
memory_record_id = memory_record_id
max_results = max_results
# Handle region override - reinitialize clients if a new region is provided
if region and region != self.region:
self._init_clients(region=region)
else:
region = self.region
# Define required parameters for each action
required_params = {
# New agent-friendly action names
"RecordMemory": ["memory_id", "actor_id", "session_id", "payload"],
"RetrieveMemory": ["memory_id", "namespace", "query"],
"ListMemories": ["memory_id"],
"GetMemory": ["memory_id", "memory_record_id"],
"DeleteMemory": ["memory_id", "memory_record_id", "namespace"],
}
# Map new action names to original API actions (internal use only)
action_mapping = {
"RecordMemory": "create_event",
"RetrieveMemory": "retrieve_memory_records",
"ListMemories": "list_memory_records",
"GetMemory": "get_memory_record",
"DeleteMemory": "delete_memory_record",
}
# Map the action to the API action
api_action = action_mapping.get(action, action)
# Validate action
if action not in required_params:
return {
"status": "error",
"content": [
{
"text": f"Action '{action}' is not supported. "
f"Supported actions: {', '.join(action_mapping.keys())}"
}
],
}
# Validate required parameters
if action in required_params:
missing_params = []
for param in required_params[action]:
param_value = locals().get(param)
if not param_value:
missing_params.append(param)
if missing_params:
return {
"status": "error",
"content": [
{
"text": (
f"The following parameters are required for {action} action: "
f"{', '.join(missing_params)}"
)
}
],
}
# Execute the appropriate action
try:
# Handle action names by mapping to API methods
if action == "RecordMemory" or api_action == "create_event":
response = self.create_event(
memory_id=memory_id,
actor_id=actor_id,
session_id=session_id,
payload=payload,
)
# Extract only the relevant "event" field from the response
event_data = (
response.get("event", {}) if isinstance(response, dict) else {}
)
return {
"status": "success",
"content": [
{
"text": f"Memory created successfully: {json.dumps(event_data, default=str)}"
}
],
}
elif (
action == "RetrieveMemory"
or api_action == "retrieve_memory_records"
):
response = self.retrieve_memory_records(
memory_id=memory_id,
namespace=namespace,
search_query=query,
max_results=max_results,
next_token=next_token,
)
# Extract only the relevant fields from the response
relevant_data = {}
if isinstance(response, dict):
if "memoryRecordSummaries" in response:
relevant_data["memoryRecordSummaries"] = response[
"memoryRecordSummaries"
]
if "nextToken" in response:
relevant_data["nextToken"] = response["nextToken"]
return {
"status": "success",
"content": [
{
"text": f"Memories retrieved successfully: {json.dumps(relevant_data, default=str)}"
}
],
}
elif action == "ListMemories" or api_action == "list_memory_records":
response = self.list_memory_records(
memory_id=memory_id,
namespace=namespace,
max_results=max_results,
next_token=next_token,
)
# Extract only the relevant fields from the response
relevant_data = {}
if isinstance(response, dict):
if "memoryRecordSummaries" in response:
relevant_data["memoryRecordSummaries"] = response[
"memoryRecordSummaries"
]
if "nextToken" in response:
relevant_data["nextToken"] = response["nextToken"]
return {
"status": "success",
"content": [
{
"text": f"Memories listed successfully: {json.dumps(relevant_data, default=str)}"
}
],
}
elif action == "GetMemory" or api_action == "get_memory_record":
response = self.get_memory_record(
memory_id=memory_id,
memory_record_id=memory_record_id,
)
# Extract only the relevant "memoryRecord" field from the response
memory_record = (
response.get("memoryRecord", {})
if isinstance(response, dict)
else {}
)
return {
"status": "success",
"content": [
{
"text": f"Memory retrieved successfully: {json.dumps(memory_record, default=str)}"
}
],
}
elif action == "DeleteMemory" or api_action == "delete_memory_record":
response = self.delete_memory_record(
memory_id=memory_id,
memory_record_id=memory_record_id,
namespace=namespace,
)
# Extract only the relevant "memoryRecordId" field from the response
memory_record_id = (
response.get("memoryRecordId", "")
if isinstance(response, dict)
else ""
)
return {
"status": "success",
"content": [
{"text": f"Memory deleted successfully: {memory_record_id}"}
],
}
except Exception as e:
error_msg = f"API error: {str(e)}"
logger.error(error_msg)
return {"status": "error", "content": [{"text": error_msg}]}
except Exception as e:
logger.error(f"Unexpected error in agent_core_memory tool: {str(e)}")
return {"status": "error", "content": [{"text": str(e)}]}
def create_event(
self,
memory_id: str,
actor_id: str,
session_id: str,
payload: EventPayload,
event_timestamp: Optional[datetime] = None,
) -> Dict:
"""
Create an event in a memory session.
Creates a new event record in the specified memory session. Events are immutable
records that capture interactions or state changes in your application.
Args:
memory_id: ID of the memory store
actor_id: ID of the actor (user, agent, etc.) creating the event
session_id: ID of the session this event belongs to
payload: List of event payload items. Each item can be:
- Conversational message (with enforced structure):
{
"conversational": {
"content": {"text": "Message text"},
"role": "USER" | "ASSISTANT" | "TOOL" | "OTHER"
}
}
- Blob (any structure):
{
"blob": <any data>
}
event_timestamp: Optional timestamp for the event (defaults to current time)
Returns:
Dict: Response containing the created event details
Raises:
ValueError: If required parameters are invalid
RuntimeError: If the API call fails
Example:
```python
# Example with conversational payload
payload = [{
"conversational": {
"content": {"text": "Hello, how are you?"},
"role": "USER"
}
}]
# Example with blob payload
blob_payload = [{
"blob": {"custom_data": "any structure can go here"}
}]
result = create_event(
memory_id="memory-123abc",
actor_id="user-456",
session_id="session-789",
payload=payload
)
```
"""
# Set default timestamp if not provided
if event_timestamp is None:
event_timestamp = datetime.now(timezone.utc)
return self.data_plane_client.create_event(
memoryId=memory_id,
actorId=actor_id,
sessionId=session_id,
eventTimestamp=event_timestamp,
payload=payload,
)
def retrieve_memory_records(
self,
memory_id: str,
namespace: str,
search_query: str,
max_results: Optional[int] = None,
next_token: Optional[str] = None,
) -> Dict:
"""
Retrieve memory records using semantic search.
Performs a semantic search across memory records in the specified namespace,
returning records that semantically match the search query. Results are ranked
by relevance to the query.
Args:
memory_id: ID of the memory store to search in
namespace: Namespace to search within (e.g., "actor/user123/userId")
search_query: Natural language query to search for
max_results: Maximum number of results to return (default: service default)
next_token: Pagination token for retrieving additional results
Returns:
Dict: Response containing matching memory records and optional next_token
"""
# Prepare request parameters
params = {
"memoryId": memory_id,
"namespace": namespace,
"searchCriteria": {"searchQuery": search_query},
}
if max_results is not None:
params["maxResults"] = max_results
if next_token is not None:
params["nextToken"] = next_token
# Direct API call without redundant try/except block
return self.data_plane_client.retrieve_memory_records(**params)
def get_memory_record(
self,
memory_id: str,
memory_record_id: str,
) -> Dict:
"""Get a specific memory record."""
return self.data_plane_client.get_memory_record(
memoryId=memory_id,
memoryRecordId=memory_record_id,
)
def list_memory_records(
self,
memory_id: str,
namespace: str,
max_results: Optional[int] = None,
next_token: Optional[str] = None,
) -> Dict:
"""List memory records."""
params = {"memoryId": memory_id}
if namespace is not None:
params["namespace"] = namespace
if max_results is not None:
params["maxResults"] = max_results
if next_token is not None:
params["nextToken"] = next_token
return self.data_plane_client.list_memory_records(**params)
def delete_memory_record(
self,
memory_id: str,
memory_record_id: str,
namespace: str,
) -> Dict:
"""Delete a specific memory record."""
return self.data_plane_client.delete_memory_record(
memoryId=memory_id,
memoryRecordId=memory_record_id,
namespace=namespace,
)