111 lines
3.3 KiB
Python
111 lines
3.3 KiB
Python
import asyncio
|
|
import random as random_module
|
|
from typing import Any, Optional
|
|
|
|
from app.context import FlowContext
|
|
from app.nodes.base import NodeExecutor
|
|
|
|
|
|
class SwitchExecutor(NodeExecutor):
|
|
"""Multi-branch switch based on variable value"""
|
|
|
|
async def execute(
|
|
self, config: dict, context: FlowContext, session: Any
|
|
) -> Optional[str]:
|
|
variable = config.get("variable", "")
|
|
cases = config.get("cases", [])
|
|
|
|
actual = context.get(variable) or context.message.get("content", "")
|
|
actual_str = str(actual).lower().strip()
|
|
|
|
for case in cases:
|
|
case_value = str(case.get("value", "")).lower().strip()
|
|
if actual_str == case_value:
|
|
return case.get("branch", "default")
|
|
|
|
return config.get("default_branch", "default")
|
|
|
|
|
|
class DelayExecutor(NodeExecutor):
|
|
"""Wait for a specified duration before continuing"""
|
|
|
|
async def execute(
|
|
self, config: dict, context: FlowContext, session: Any
|
|
) -> Optional[str]:
|
|
delay_seconds = config.get("seconds", 0)
|
|
delay_type = config.get("type", "fixed")
|
|
|
|
if delay_type == "random":
|
|
min_delay = config.get("min_seconds", 1)
|
|
max_delay = config.get("max_seconds", 5)
|
|
delay_seconds = random_module.uniform(min_delay, max_delay)
|
|
|
|
if delay_seconds > 0:
|
|
await asyncio.sleep(min(delay_seconds, 30))
|
|
|
|
return "default"
|
|
|
|
|
|
class RandomExecutor(NodeExecutor):
|
|
"""Random branch selection for A/B testing"""
|
|
|
|
async def execute(
|
|
self, config: dict, context: FlowContext, session: Any
|
|
) -> Optional[str]:
|
|
branches = config.get("branches", [])
|
|
|
|
if not branches:
|
|
return "default"
|
|
|
|
total_weight = sum(b.get("weight", 1) for b in branches)
|
|
rand_value = random_module.uniform(0, total_weight)
|
|
|
|
cumulative = 0
|
|
for branch in branches:
|
|
cumulative += branch.get("weight", 1)
|
|
if rand_value <= cumulative:
|
|
test_name = config.get("test_name", "ab_test")
|
|
context.set(f"_ab_{test_name}", branch.get("branch", "default"))
|
|
return branch.get("branch", "default")
|
|
|
|
return branches[-1].get("branch", "default") if branches else "default"
|
|
|
|
|
|
class LoopExecutor(NodeExecutor):
|
|
"""Loop a certain number of times"""
|
|
|
|
async def execute(
|
|
self, config: dict, context: FlowContext, session: Any
|
|
) -> Optional[str]:
|
|
loop_var = config.get("counter_variable", "_loop_counter")
|
|
max_iterations = config.get("max_iterations", 10)
|
|
|
|
current = int(context.get(loop_var) or 0)
|
|
|
|
if current < max_iterations:
|
|
context.set(loop_var, current + 1)
|
|
return "continue"
|
|
else:
|
|
context.set(loop_var, 0)
|
|
return "done"
|
|
|
|
|
|
class GoToExecutor(NodeExecutor):
|
|
"""Jump to another node or flow"""
|
|
|
|
async def execute(
|
|
self, config: dict, context: FlowContext, session: Any
|
|
) -> Optional[str]:
|
|
target_node_id = config.get("target_node_id")
|
|
target_flow_id = config.get("target_flow_id")
|
|
|
|
if target_flow_id:
|
|
context.set("_goto_flow_id", target_flow_id)
|
|
return "sub_flow"
|
|
|
|
if target_node_id:
|
|
context.set("_goto_node_id", target_node_id)
|
|
return "goto"
|
|
|
|
return "default"
|