1、Spring AI Alibaba Graph 是社区核心实现之一,也是整个框架在设计理念上区别于 Spring AI 只做底层原子抽象的地方,Spring AI Alibaba 期望帮助开发者更容易的构建智能体应用。基于 Graph 开发者可以构建工作流、多智能体应用。Spring AI Alibaba Graph 在设计理念上借鉴 Langgraph,因此在一定程度上可以理解为是 Java 版的 Langgraph 实现,社区在此基础上增加了大量预置 Node、简化了 State 定义过程等,让开发者更容易编写对等低代码平台的工作流、多智能体等。
2、框架核心概念包括:StateGraph(状态图,用于定义节点和边)、Node(节点,封装具体操作或模型调用)、Edge(边,表示节点间的跳转关系)以及 OverAllState(全局状态,贯穿流程共享数据)。这些设计使开发者能够方便地管理工作流中的状态和逻辑流转。
3、人类反馈复原案例
在实际业务场景中,经常会遇到人类介入的场景,人类的不同操作将影响工作流不同的走向
以下实现一个简单案例:包含三个节点,扩展节点、人类节点、翻译节点
- 扩展节点:AI 模型流式对问题进行扩展输出
- 人类节点:通过对用户的反馈,决定是直接结束,还是接着执行翻译节点
- 翻译节点:将问题翻译为其他语种或者英文
实战代码可见:spring-ai-alibaba-examples 下的 graph 目录,本章代码为其 human-node 模块
4、pom文件
这里使用 1.0.0.3-SNAPSHOT。在定义 StateGraph 方面和 1.0.0.2 有些变动
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><parent><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-example</artifactId><version>${revision}</version><relativePath>../pom.xml</relativePath></parent><groupId>com.alibaba.cloud.ai.graph</groupId><artifactId>human-node</artifactId><properties><spring-ai-alibaba.version>1.0.0.3-SNAPSHOT</spring-ai-alibaba.version></properties><dependencies><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-starter-dashscope</artifactId></dependency><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-autoconfigure-model-chat-client</artifactId></dependency><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId><version>${spring-ai-alibaba.version}</version></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency></dependencies></project>
5、
OverAllState的 keyStrategies 中存储的字段
- query:用户的问题
- expandernumber:扩展的数量
- expandercontent:扩展的内容
- feedback:人类反馈的内容
- humannextnode:人类反馈后的下一个节点
- translatelanguage:翻译的目标语言,默认为英文
- translatecontent:翻译的内容
定义 ExpanderNode,边的连接为:
START -> expander -> humanfeedback
humanfeedback -> translate
humanfeedback -> END
translate -> END
代码如下
package com.alibaba.cloud.ai.graph.config;import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategy;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.dispatcher.HumanFeedbackDispatcher;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.node.ExpanderNode;
import com.alibaba.cloud.ai.graph.node.HumanFeedbackNode;
import com.alibaba.cloud.ai.graph.node.TranslateNode;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;import java.util.HashMap;
import java.util.Map;import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;/*** @author yingzi* @since 2025/6/13*/
@Configuration
public class GraphHumanConfiguration {private static final Logger logger = LoggerFactory.getLogger(GraphHumanConfiguration.class);@Beanpublic StateGraph humanGraph(ChatClient.Builder chatClientBuilder) throws GraphStateException {KeyStrategyFactory keyStrategyFactory = () -> {HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();// 用户输入keyStrategyHashMap.put("query", new ReplaceStrategy());keyStrategyHashMap.put("thread_id", new ReplaceStrategy());keyStrategyHashMap.put("expander_number", new ReplaceStrategy());keyStrategyHashMap.put("expander_content", new ReplaceStrategy());// 人类反馈keyStrategyHashMap.put("feed_back", new ReplaceStrategy());keyStrategyHashMap.put("human_next_node", new ReplaceStrategy());// 是否需要翻译keyStrategyHashMap.put("translate_language", new ReplaceStrategy());keyStrategyHashMap.put("translate_content", new ReplaceStrategy());return keyStrategyHashMap;};StateGraph stateGraph = new StateGraph(keyStrategyFactory).addNode("expander", node_async(new ExpanderNode(chatClientBuilder))).addNode("translate", node_async(new TranslateNode(chatClientBuilder))).addNode("human_feedback", node_async(new HumanFeedbackNode()))
// START -> expander.addEdge(StateGraph.START, "expander")
// expander -> humanfeedback.addEdge("expander", "human_feedback")
// 人类节点的下一个边是条件边,由 HumanFeedbackDispatcher 控制下一步跳转到哪一个节点
// humanfeedback -> translate
// humanfeedback -> END.addConditionalEdges("human_feedback", AsyncEdgeAction.edge_async((new HumanFeedbackDispatcher())), Map.of("translate", "translate", StateGraph.END, StateGraph.END))
// translate -> END.addEdge("translate", StateGraph.END);// 添加 PlantUML 打印GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,"human flow");logger.info("\n=== expander UML Flow ===");logger.info(representation.content());logger.info("==================================\n");return stateGraph;}
}
6、node代码如下
ExpanderNode
package com.alibaba.cloud.ai.graph.node;import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;import java.util.Arrays;
import java.util.List;
import java.util.Map;/*** @author yingzi* @since 2025/6/13*/public class ExpanderNode implements NodeAction {private static final Logger logger = LoggerFactory.getLogger(ExpanderNode.class);private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("You are an expert at information retrieval and search optimization.\nYour task is to generate {number} different versions of the given query.\n\nEach variant must cover different perspectives or aspects of the topic,\nwhile maintaining the core intent of the original query. The goal is to\nexpand the search space and improve the chances of finding relevant information.\n\nDo not explain your choices or add any other text.\nProvide the query variants separated by newlines.\n\nOriginal query: {query}\n\nQuery variants:\n");private final ChatClient chatClient;private final Integer NUMBER = 3;public ExpanderNode(ChatClient.Builder chatClientBuilder) {this.chatClient = chatClientBuilder.build();}@Overridepublic Map<String, Object> apply(OverAllState state) {// OverAllState 中存储的字段
//
// query:用户的问题
// expandernumber:扩展的数量
// expandercontent:扩展的内容
// feedback:人类反馈的内容
// humannextnode:人类反馈后的下一个节点
// translatelanguage:翻译的目标语言,默认为英文
// translatecontent:翻译的内容logger.info("expander node is running.");
// 扩展节点:AI 模型流式对问题进行扩展输出String query = state.value("query", "");Integer expanderNumber = state.value("expander_number", this.NUMBER);Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULT_PROMPT_TEMPLATE.getTemplate()).param("number", expanderNumber).param("query", query)).stream().chatResponse();AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder().startingNode("expander_llm_stream").startingState(state).mapResult(response -> {String text = response.getResult().getOutput().getText();List<String> queryVariants = Arrays.asList(text.split("\n"));return Map.of("expander_content", queryVariants);}).build(chatResponseFlux);return Map.of("expander_content", generator);}}
HumanFeedbackNode
package com.alibaba.cloud.ai.graph.node;import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;import java.util.HashMap;
import java.util.Map;/*** @author yingzi* @since 2025/6/19*/public class HumanFeedbackNode implements NodeAction {private static final Logger logger = LoggerFactory.getLogger(HumanFeedbackNode.class);@Overridepublic Map<String, Object> apply(OverAllState state) {
// 人类节点:通过对用户的反馈,决定是直接结束,还是接着执行翻译节点logger.info("human_feedback node is running.");HashMap<String, Object> resultMap = new HashMap<>();String nextStep = StateGraph.END;Map<String, Object> feedBackData = state.humanFeedback().data();boolean feedback = (boolean) feedBackData.getOrDefault("feed_back", true);if (feedback) {nextStep = "translate";}resultMap.put("human_next_node", nextStep);logger.info("human_feedback node -> {} node", nextStep);return resultMap;}
}
TranslateNode
package com.alibaba.cloud.ai.graph.node;import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingChatGenerator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import reactor.core.publisher.Flux;import java.util.Arrays;
import java.util.List;
import java.util.Map;/*** @author yingzi* @since 2025/6/13*/public class TranslateNode implements NodeAction {private static final Logger logger = LoggerFactory.getLogger(ExpanderNode.class);//用于提示模板的关键组件private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("Given a user query, translate it to {targetLanguage}.\nIf the query is already in {targetLanguage}, return it unchanged.\nIf you don't know the language of the query, return it unchanged.\nDo not add explanations nor any other text.\n\nOriginal query: {query}\n\nTranslated query:\n");private final ChatClient chatClient;private final String TARGET_LANGUAGE = "English";public TranslateNode(ChatClient.Builder chatClientBuilder) {this.chatClient = chatClientBuilder.build();}@Overridepublic Map<String, Object> apply(OverAllState state) {logger.info("translate node is running.");// 翻译节点:将问题翻译为其他英文String query = state.value("query", "");String targetLanguage = state.value("translate_language", TARGET_LANGUAGE);Flux<ChatResponse> chatResponseFlux = this.chatClient.prompt().user((user) -> user.text(DEFAULT_PROMPT_TEMPLATE.getTemplate()).param("targetLanguage", targetLanguage).param("query", query)).stream().chatResponse();AsyncGenerator<? extends NodeOutput> generator = StreamingChatGenerator.builder().startingNode("translate_llm_stream").startingState(state).mapResult(response -> {String text = response.getResult().getOutput().getText();List<String> queryVariants = Arrays.asList(text.split("\n"));return Map.of("translate_content", queryVariants);}).build(chatResponseFlux);return Map.of("translate_content", generator);}
}
7、edge代码
人类节点的下一个边是条件边,由 HumanFeedbackDispatcher 控制下一步跳转到哪一个节点
package com.alibaba.cloud.ai.graph.dispatcher;import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.EdgeAction;/*** @author yingzi* @since 2025/6/19*/public class HumanFeedbackDispatcher implements EdgeAction {@Overridepublic String apply(OverAllState state) throws Exception {
// 人类节点的下一个边是条件边,由 HumanFeedbackDispatcher 控制下一步跳转到哪一个节点// state的human_next_node的value 默认 endreturn (String) state.value("human_next_node", StateGraph.END);}
}
8、controller代码
GraphHumanController
- CompileConfig.builder().saverConfig(saverConfig).interruptBefore(“humanfeedback”):在人类反馈节点前断流
- Sinks.Many<ServerSentEvent> sink:接收 Stream 数据
package com.alibaba.cloud.ai.graph.controller;import com.alibaba.cloud.ai.graph.CompileConfig;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.checkpoint.config.SaverConfig;
import com.alibaba.cloud.ai.graph.checkpoint.constant.SaverConstant;
import com.alibaba.cloud.ai.graph.checkpoint.savers.MemorySaver;
import com.alibaba.cloud.ai.graph.controller.GraphProcess.GraphProcess;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.StateSnapshot;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;import java.util.HashMap;
import java.util.Map;/*** @author yingzi* @since 2025/6/13*/
@RestController
public class GraphHumanController {private static final Logger logger = LoggerFactory.getLogger(GraphHumanController.class);private final CompiledGraph compiledGraph;@Autowiredpublic GraphHumanController(@Qualifier("humanGraph") StateGraph stateGraph) throws GraphStateException {SaverConfig saverConfig = SaverConfig.builder().register(SaverConstant.MEMORY, new MemorySaver()).build();
// CompileConfig.builder().saverConfig(saverConfig).interruptBefore("human_feedback").build() 在人类反馈节点前断流this.compiledGraph = stateGraph.compile(CompileConfig.builder().saverConfig(saverConfig).interruptBefore("human_feedback").build()); }@GetMapping(value = "/graph/human/expand", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public Flux<ServerSentEvent<String>> expand(@RequestParam(value = "query", defaultValue = "你好,很高兴认识你,能简单介绍一下自己吗?", required = false) String query,@RequestParam(value = "expander_number", defaultValue = "3", required = false) Integer expanderNumber,@RequestParam(value = "thread_id", defaultValue = "yingzi", required = false) String threadId) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();Map<String, Object> objectMap = new HashMap<>();objectMap.put("query", query);objectMap.put("expander_number", expanderNumber);GraphProcess graphProcess = new GraphProcess(this.compiledGraph);
// Sinks.Many<ServerSentEvent> sink:接收 Stream 数据Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();AsyncGenerator<NodeOutput> resultFuture = compiledGraph.stream(objectMap, runnableConfig);graphProcess.processStream(resultFuture, sink);return sink.asFlux().doOnCancel(() -> logger.info("Client disconnected from stream")).doOnError(e -> logger.error("Error occurred during streaming", e));}@GetMapping(value = "/graph/human/resume", produces = MediaType.TEXT_EVENT_STREAM_VALUE)public Flux<ServerSentEvent<String>> resume(@RequestParam(value = "thread_id", defaultValue = "yingzi", required = false) String threadId,@RequestParam(value = "feed_back", defaultValue = "true", required = false) boolean feedBack) throws GraphRunnerException {RunnableConfig runnableConfig = RunnableConfig.builder().threadId(threadId).build();StateSnapshot stateSnapshot = this.compiledGraph.getState(runnableConfig);OverAllState state = stateSnapshot.state();state.withResume();Map<String, Object> objectMap = new HashMap<>();objectMap.put("feed_back", feedBack);state.withHumanFeedback(new OverAllState.HumanFeedback(objectMap, ""));// Create a unicast sink to emit ServerSentEventsSinks.Many<ServerSentEvent<String>> sink = Sinks.many().unicast().onBackpressureBuffer();GraphProcess graphProcess = new GraphProcess(this.compiledGraph);AsyncGenerator<NodeOutput> resultFuture = compiledGraph.streamFromInitialNode(state, runnableConfig);graphProcess.processStream(resultFuture, sink);return sink.asFlux().doOnCancel(() -> logger.info("Client disconnected from stream")).doOnError(e -> logger.error("Error occurred during streaming", e)); }
}
9、GraphProcess代码
- ExecutorService executor:配置线程池,获取 stream 流
将结果写入到 sink 中
package com.alibaba.cloud.ai.graph.controller.GraphProcess;import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.NodeOutput;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.codec.ServerSentEvent;
import reactor.core.publisher.Sinks;import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;/*** @author yingzi* @since 2025/6/13*/public class GraphProcess {private static final Logger logger = LoggerFactory.getLogger(GraphProcess.class);
// 配置线程池,获取 stream 流将结果写入到 sink 中private final ExecutorService executor = Executors.newSingleThreadExecutor();private CompiledGraph compiledGraph;public GraphProcess(CompiledGraph compiledGraph) {this.compiledGraph = compiledGraph;}public void processStream(AsyncGenerator<NodeOutput> generator, Sinks.Many<ServerSentEvent<String>> sink) {executor.submit(() -> {generator.forEachAsync(output -> {try {logger.info("output = {}", output);String nodeName = output.node();String content;if (output instanceof StreamingOutput streamingOutput) {content = JSON.toJSONString(Map.of(nodeName, streamingOutput.chunk()));} else {JSONObject nodeOutput = new JSONObject();nodeOutput.put("data", output.state().data());nodeOutput.put("node", nodeName);content = JSON.toJSONString(nodeOutput);}sink.tryEmitNext(ServerSentEvent.builder(content).build());} catch (Exception e) {throw new CompletionException(e);}}).thenAccept(v -> {// 正常完成sink.tryEmitComplete();}).exceptionally(e -> {sink.tryEmitError(e);return null;});});}
}
10 、效果
调用 http://127.0.0.1:8080/graph/human/expand
调用 expand 接口,流式输出 && 断流得到最终结果
再调用 resume 接口,状态恢复续上流,接着走后续逻辑
http://127.0.0.1:8080/graph/human/resume
http://127.0.0.1:8080/graph/human/resume?feed_back=false