import mlflow
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
ChatAgentMessage,
ChatAgentResponse,
ChatAgentChunk,
)
from databricks.sdk import WorkspaceClient
from typing import Generator
class CustomerServiceAgent(ChatAgent):
"""Databricks Agent Framework 기반 고객 서비스 Agent"""
def __init__(self):
# Databricks Foundation Model API 사용
self.client = WorkspaceClient()
self.model_endpoint = "databricks-meta-llama-3-3-70b-instruct"
def predict(
self,
messages: list[ChatAgentMessage],
context=None,
custom_inputs=None,
) -> ChatAgentResponse:
"""동기 응답 -- 전체 결과를 한 번에 반환"""
# Unity Catalog 함수를 도구로 사용
response = self.client.serving_endpoints.query(
name=self.model_endpoint,
messages=[m.to_dict() for m in messages],
tools=[
{
"type": "uc_function",
"function": {
"name": "main.default.search_knowledge_base"
}
},
{
"type": "uc_function",
"function": {
"name": "main.default.get_order_status"
}
},
],
)
return ChatAgentResponse(
messages=[
ChatAgentMessage(
role="assistant",
content=response.choices[0].message.content,
)
]
)
def predict_stream(
self,
messages: list[ChatAgentMessage],
context=None,
custom_inputs=None,
) -> Generator[ChatAgentChunk, None, None]:
"""스트리밍 응답 -- 토큰 단위로 반환"""
# 스트리밍 구현 (생략)
pass
# MLflow에 Agent 로깅
mlflow.set_experiment("/Users/user@company.com/customer-service-agent")
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
artifact_path="agent",
python_model=CustomerServiceAgent(),
pip_requirements=[
"mlflow>=2.21",
"databricks-sdk>=0.40",
],
)
# Unity Catalog에 모델 등록
mlflow.set_registry_uri("databricks-uc")
mlflow.register_model(
model_uri=model_info.model_uri,
name="main.default.customer_service_agent",
)