京公网安备 11010802034615号
经营许可证编号:京B2-20210330
作者 | Francois Chollet
编译 | CDA数据分析师
A ten-minute introduction to sequence-to-sequence learning in Keras
什么是顺序学习?
序列到序列学习(Seq2Seq)是关于将模型从一个域(例如英语中的句子)转换为另一域(例如将相同句子翻译为法语的序列)的训练模型。
“猫坐在垫子上” -> [ Seq2Seq 模型] -> “在小吃中聊天
这可用于机器翻译或免费问答(在给定自然语言问题的情况下生成自然语言答案)-通常,它可在需要生成文本的任何时间使用。
有多种处理此任务的方法,可以使用RNN或使用一维卷积网络。在这里,我们将重点介绍RNN。
普通情况:输入和输出序列的长度相同
当输入序列和输出序列的长度相同时,您可以简单地使用Keras LSTM或GRU层(或其堆栈)来实现此类模型。在此示例脚本 中就是这种情况, 该脚本显示了如何教RNN学习加编码为字符串的数字:
该方法的一个警告是,它假定可以生成target[...t]给定input[...t]。在某些情况下(例如添加数字字符串),该方法有效,但在大多数用例中,则无效。在一般情况下,有关整个输入序列的信息是必需的,以便开始生成目标序列。
一般情况:规范序列间
在一般情况下,输入序列和输出序列具有不同的长度(例如,机器翻译),并且需要整个输入序列才能开始预测目标。这需要更高级的设置,这是人们在没有其他上下文的情况下提到“序列模型的序列”时通常所指的东西。运作方式如下:
RNN层(或其堆栈)充当“编码器”:它处理输入序列并返回其自己的内部状态。请注意,我们放弃了编码器RNN的输出,仅恢复 了状态。在下一步中,此状态将用作解码器的“上下文”或“条件”。
另一个RNN层(或其堆栈)充当“解码器”:在给定目标序列的先前字符的情况下,对其进行训练以预测目标序列的下一个字符。具体而言,它经过训练以将目标序列变成相同序列,但在将来会偏移一个时间步,在这种情况下,该训练过程称为“教师强迫”。重要的是,编码器使用来自编码器的状态向量作为初始状态,这就是解码器如何获取有关应该生成的信息的方式。有效地,解码器学会产生targets[t+1...] 给定的targets[...t],调节所述输入序列。
在推断模式下,即当我们想解码未知的输入序列时,我们会经历一个略有不同的过程:
同样的过程也可以用于训练Seq2Seq网络,而无需 “教师强制”,即通过将解码器的预测重新注入到解码器中。
一个Keras例子
因为训练过程和推理过程(解码句子)有很大的不同,所以我们对两者使用不同的模型,尽管它们都利用相同的内部层。
这是我们的训练模型。它利用Keras RNN的三个关键功能:
return_state构造器参数,配置RNN层返回一个列表,其中,第一项是输出与下一个条目是内部RNN状态。这用于恢复编码器的状态。
inital_state呼叫参数,指定一个RNN的初始状态(S)。这用于将编码器状态作为初始状态传递给解码器。
return_sequences构造函数的参数,配置RNN返回其输出全序列(而不只是最后的输出,其默认行为)。在解码器中使用。
from keras.models import Model from keras.layers import Input, LSTM, Dense # Define an input sequence and process it. encoder_inputs = Input(shape=(None, num_encoder_tokens)) encoder = LSTM(latent_dim, return_state=True) encoder_outputs, state_h, state_c = encoder(encoder_inputs) # We discard `encoder_outputs` and only keep the states. encoder_states = [state_h, state_c] # Set up the decoder, using `encoder_states` as initial state. decoder_inputs = Input(shape=(None, num_decoder_tokens)) # We set up our decoder to return full output sequences, # and to return internal states as well. We don't use the # return states in the training model, but we will use them in inference. decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states) decoder_dense = Dense(num_decoder_tokens, activation='softmax') decoder_outputs = decoder_dense(decoder_outputs) # Define the model that will turn # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
我们分两行训练我们的模型,同时监视20%的保留样本中的损失。
# Run training model.compile(optimizer='rmsprop', loss='categorical_crossentropy') model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=epochs, validation_split=0.2)
在MacBook CPU上运行大约一个小时后,我们就可以进行推断了。为了解码测试语句,我们将反复:
这是我们的推理设置:
encoder_model = Model(encoder_inputs, encoder_states) decoder_state_input_h = Input(shape=(latent_dim,)) decoder_state_input_c = Input(shape=(latent_dim,)) decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] decoder_outputs, state_h, state_c = decoder_lstm( decoder_inputs, initial_state=decoder_states_inputs) decoder_states = [state_h, state_c] decoder_outputs = decoder_dense(decoder_outputs) decoder_model = Model( [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)
我们使用它来实现上述推理循环:
def decode_sequence(input_seq): # Encode the input as state vectors. states_value = encoder_model.predict(input_seq) # Generate empty target sequence of length 1. target_seq = np.zeros((1, 1, num_decoder_tokens)) # Populate the first character of target sequence with the start character. target_seq[0, 0, target_token_index['\t']] = 1. # Sampling loop for a batch of sequences # (to simplify, here we assume a batch of size 1). stop_condition = False decoded_sentence = '' while not stop_condition: output_tokens, h, c = decoder_model.predict( [target_seq] + states_value) # Sample a token sampled_token_index = np.argmax(output_tokens[0, -1, :]) sampled_char = reverse_target_char_index[sampled_token_index] decoded_sentence += sampled_char # Exit condition: either hit max length # or find stop character. if (sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length): stop_condition = True # Update the target sequence (of length 1). target_seq = np.zeros((1, 1, num_decoder_tokens)) target_seq[0, 0, sampled_token_index] = 1. # Update states states_value = [h, c] return decoded_sentence
我们得到了一些不错的结果-毫不奇怪,因为我们正在解码从训练测试中提取的样本
Input sentence: Be nice. Decoded sentence: Soyez gentil ! - Input sentence: Drop it! Decoded sentence: Laissez tomber ! - Input sentence: Get out! Decoded sentence: Sortez !
到此,我们结束了对Keras中序列到序列模型的十分钟介绍。提醒:此脚本的完整代码可以在GitHub上找到。
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
数据分析的核心价值在于用数据驱动决策,而指标作为数据的“载体”,其选取的合理性直接决定分析结果的有效性。选对指标能精准定 ...
2026-01-23在MySQL查询编写中,我们习惯按“SELECT → FROM → WHERE → ORDER BY”的语法顺序组织语句,直觉上认为代码顺序即执行顺序。但 ...
2026-01-23数字化转型已从企业“可选项”升级为“必答题”,其核心本质是通过数据驱动业务重构、流程优化与模式创新,实现从传统运营向智能 ...
2026-01-23CDA持证人已遍布在世界范围各行各业,包括世界500强企业、顶尖科技独角兽、大型金融机构、国企事业单位、国家行政机关等等,“CDA数据分析师”人才队伍遵守着CDA职业道德准则,发挥着专业技能,已成为支撑科技发展的核心力量。 ...
2026-01-22在数字化时代,企业积累的海量数据如同散落的珍珠,而数据模型就是串联这些珍珠的线——它并非简单的数据集合,而是对现实业务场 ...
2026-01-22在数字化运营场景中,用户每一次点击、浏览、交互都构成了行为轨迹,这些轨迹交织成海量的用户行为路径。但并非所有路径都具备业 ...
2026-01-22在数字化时代,企业数据资产的价值持续攀升,数据安全已从“合规底线”升级为“生存红线”。企业数据安全管理方法论以“战略引领 ...
2026-01-22在SQL数据分析与业务查询中,日期数据是高频处理对象——订单创建时间、用户注册日期、数据统计周期等场景,都需对日期进行格式 ...
2026-01-21在实际业务数据分析中,单一数据表往往无法满足需求——用户信息存储在用户表、消费记录在订单表、商品详情在商品表,想要挖掘“ ...
2026-01-21在数字化转型浪潮中,企业数据已从“辅助资源”升级为“核心资产”,而高效的数据管理则是释放数据价值的前提。企业数据管理方法 ...
2026-01-21在数字化商业环境中,数据已成为企业优化运营、抢占市场、规避风险的核心资产。但商业数据分析绝非“堆砌数据、生成报表”的简单 ...
2026-01-20定量报告的核心价值是传递数据洞察,但密密麻麻的表格、复杂的计算公式、晦涩的数值罗列,往往让读者望而却步,导致核心信息被淹 ...
2026-01-20在CDA(Certified Data Analyst)数据分析师的工作场景中,“精准分类与回归预测”是高频核心需求——比如预测用户是否流失、判 ...
2026-01-20在建筑工程造价工作中,清单汇总分类是核心环节之一,尤其是针对楼梯、楼梯间这类包含多个分项工程(如混凝土浇筑、钢筋制作、扶 ...
2026-01-19数据清洗是数据分析的“前置必修课”,其核心目标是剔除无效信息、修正错误数据,让原始数据具备准确性、一致性与可用性。在实际 ...
2026-01-19在CDA(Certified Data Analyst)数据分析师的日常工作中,常面临“无标签高维数据难以归类、群体规律模糊”的痛点——比如海量 ...
2026-01-19在数据仓库与数据分析体系中,维度表与事实表是构建结构化数据模型的核心组件,二者如同“骨架”与“血肉”,协同支撑起各类业务 ...
2026-01-16在游戏行业“存量竞争”的当下,玩家留存率直接决定游戏的生命周期与商业价值。一款游戏即便拥有出色的画面与玩法,若无法精准识 ...
2026-01-16为配合CDA考试中心的 2025 版 CDA Level III 认证新大纲落地,CDA 网校正式推出新大纲更新后的第一套官方模拟题。该模拟题严格遵 ...
2026-01-16在数据驱动决策的时代,数据分析已成为企业运营、产品优化、业务增长的核心工具。但实际工作中,很多数据分析项目看似流程完整, ...
2026-01-15