82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
from typing import Optional, Any
|
|
from app.nodes.base import NodeExecutor
|
|
from app.context import FlowContext
|
|
|
|
|
|
class TriggerExecutor(NodeExecutor):
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
return "default"
|
|
|
|
|
|
class MessageExecutor(NodeExecutor):
|
|
def __init__(self, send_message_fn):
|
|
self.send_message = send_message_fn
|
|
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
text = context.interpolate(config.get("text", ""))
|
|
await self.send_message(session.conversation_id, text)
|
|
return "default"
|
|
|
|
|
|
class ButtonsExecutor(NodeExecutor):
|
|
def __init__(self, send_message_fn):
|
|
self.send_message = send_message_fn
|
|
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
text = context.interpolate(config.get("text", ""))
|
|
buttons = config.get("buttons", [])
|
|
button_text = "\n".join([f"• {b['label']}" for b in buttons])
|
|
await self.send_message(session.conversation_id, f"{text}\n\n{button_text}")
|
|
return "wait"
|
|
|
|
|
|
class WaitInputExecutor(NodeExecutor):
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
variable = config.get("variable", "user_input")
|
|
context.set(variable, context.message.get("content", ""))
|
|
return "default"
|
|
|
|
|
|
class SetVariableExecutor(NodeExecutor):
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
var_name = config.get("variable", "")
|
|
var_value = context.interpolate(config.get("value", ""))
|
|
context.set(var_name, var_value)
|
|
return "default"
|
|
|
|
|
|
class ConditionExecutor(NodeExecutor):
|
|
async def execute(self, config: dict, context: FlowContext, session: Any) -> Optional[str]:
|
|
conditions = config.get("conditions", [])
|
|
|
|
for cond in conditions:
|
|
field = cond.get("field", "")
|
|
operator = cond.get("operator", "equals")
|
|
value = cond.get("value", "")
|
|
branch = cond.get("branch", "default")
|
|
|
|
actual = context.get(field) or context.message.get("content", "")
|
|
|
|
if operator == "equals" and str(actual).lower() == str(value).lower():
|
|
return branch
|
|
elif operator == "contains" and str(value).lower() in str(actual).lower():
|
|
return branch
|
|
elif operator == "starts_with" and str(actual).lower().startswith(str(value).lower()):
|
|
return branch
|
|
elif operator == "not_equals" and str(actual).lower() != str(value).lower():
|
|
return branch
|
|
elif operator == "greater_than":
|
|
try:
|
|
if float(actual) > float(value):
|
|
return branch
|
|
except ValueError:
|
|
pass
|
|
elif operator == "less_than":
|
|
try:
|
|
if float(actual) < float(value):
|
|
return branch
|
|
except ValueError:
|
|
pass
|
|
|
|
return "default"
|