Skip to content

Add GRU classifier and GRU time series model#70

Merged
typhoonzero merged 3 commits intosql-machine-learning:developfrom
Derek-Wds:develop
Jun 9, 2020
Merged

Add GRU classifier and GRU time series model#70
typhoonzero merged 3 commits intosql-machine-learning:developfrom
Derek-Wds:develop

Conversation

@Derek-Wds
Copy link
Contributor

Hi, I have added two GRU models in this PR 😄 . Since, the architectures of the two LSTM models used in this repo are reasonable, I just wrote the GRU models based on them with simple modifications. I have also provided the unit test codes, and all tests passed on my machine.

If there are any problems or additional things to do, please let me know. Hope it helps!

Derek

Copy link
Collaborator

@sneaxiy sneaxiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using WITH statements to choose LSTM or GRU model, instead of copying most of the model codes from LSTM-based models?

@Derek-Wds
Copy link
Contributor Author

Yes, that is definitely a better solution. If this solution makes sense, we could even combine the RNN family models (RNN, LSTM, GRU) into one class. At the same time, we could give the user freedom to choose whether using the Bidirectional option in the model.

If this sounds good, I can update my PR to make an initial version of it.

@typhoonzero
Copy link
Collaborator

Yes, that is definitely a better solution. If this solution makes sense, we could even combine the RNN family models (RNN, LSTM, GRU) into one class. At the same time, we could give the user freedom to choose whether using the Bidirectional option in the model.

Sounds cool.

@Derek-Wds
Copy link
Contributor Author

Hi, I have just updated the codes for RNN classifier and RNN TS model, where I combine vanilla RNN, LSTM and GRU into one model respectively. Hope this is readable and more efficient.

Copy link
Collaborator

@typhoonzero typhoonzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM generally. Please be aware that the model name LSTMBasedTimeSeriesModel is already used in SQLFlow's unit test: https://github.com/sql-machine-learning/sqlflow/blob/c58f0a4f6b08984c667d3ace1aa21d09aa6112de/cmd/sqlflowserver/e2e_mysql_test.go#L194 and the tutorial https://github.com/sql-machine-learning/sqlflow/blob/develop/doc/tutorial/energe_lstmbasedtimeseries.md. Can you pleas also update those two places after this PR was merged?

:param n_features: number of features in every time window.
type n_features: int
type n_features: int.
:param model_type: Specific RNN model to be used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a description about possible values, e.g. "rnn", "lstm" and "gru"?

:type stack_units: vector of ints.
:param n_classes: Target number of classes.
:type n_classes: int.
:param model_type: Specific RNN model to be used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@Derek-Wds
Copy link
Contributor Author

Hi, I have just added one sentence for the description of the model type. Please take a look.

LGTM generally. Please be aware that the model name LSTMBasedTimeSeriesModel is already used in SQLFlow's unit test: https://github.com/sql-machine-learning/sqlflow/blob/c58f0a4f6b08984c667d3ace1aa21d09aa6112de/cmd/sqlflowserver/e2e_mysql_test.go#L194 and the tutorial https://github.com/sql-machine-learning/sqlflow/blob/develop/doc/tutorial/energe_lstmbasedtimeseries.md. Can you pleas also update those two places after this PR was merged?

Yeah, sure! Once this PR is being merged, I will try my best to modify the unit tests and tutorials in two days.

Copy link
Collaborator

@typhoonzero typhoonzero left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@typhoonzero typhoonzero merged commit 3ce1221 into sql-machine-learning:develop Jun 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants