LLM fine-tuning#1350
Conversation
| fn insert_logs(project_id: i64, model_id: i64, logs: String) -> PyResult<String> { | ||
|
|
||
| let id_value = Spi::get_one_with_args::<i64>( | ||
| "INSERT INTO pgml.logs (project_id, model_id, logs) VALUES ($1, $2, $3::JSONB) RETURNING id;", |
There was a problem hiding this comment.
Did we include a migration for this table somewhere? We need to make sure it's created on all databases running PostgresML.
There was a problem hiding this comment.
Yes, need to add the following three to our migration once we freeze on the version number.
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';
CREATE TABLE IF NOT EXISTS pgml.logs (
id SERIAL PRIMARY KEY,
model_id BIGINT,
project_id BIGINT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
logs JSONB
);
| MarkupSafe==2.1.3 | ||
| marshmallow==3.20.1 | ||
| matplotlib==3.8.2 | ||
| maturin==1.4.0 |
There was a problem hiding this comment.
Don't think you need maturin inside PostgresML deployments. This may be a "leak" from the pypgrx extension venv.
| task: default!(Option<&str>, "NULL"), | ||
| relation_name: default!(Option<&str>, "NULL"), | ||
| y_column_name: default!(Option<&str>, "NULL"), | ||
| _y_column_name: default!(Option<&str>, "NULL"), |
There was a problem hiding this comment.
Why the underscore? Is it because it's not used?
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | ||
| from trl.trainer import ConstantLengthDataset | ||
| from peft import LoraConfig, get_peft_model | ||
| from pypgrx import print_info, insert_logs |
There was a problem hiding this comment.
Need to make sure we either import this conditionally (only for fine tuning) and we include this in requirements.linux.txt. I didn't see a Mac OS build for this and for the M1/M2 architecture, we've been doing releases manually from our Macs (Github actions doesn't have M1 builders).
This makes me thing we should start cross-compiling soon. Rust supports this pretty well, maturin may need a patch.
There was a problem hiding this comment.
I couldn't get fine tuning to work on Mac OS. It keeps crashing. How about I check for the operating system and bail out if it is mac?
requirements.linux.txt is updated with trl and peft.
| logs["step"] = state.global_step | ||
| logs["max_steps"] = state.max_steps | ||
| logs["timestamp"] = str(datetime.now()) | ||
| print_info(json.dumps(logs)) |
There was a problem hiding this comment.
If you use use print(), this will appear in Postgres logs. It won't be pretty, but we can add a function that formats it correctly.
There was a problem hiding this comment.
I will add indent in json.dumps() to pretty print.
| trainable_model_params += param.numel() | ||
|
|
||
| # Calculate and print the number and percentage of trainable parameters | ||
| print_info(f"Trainable model parameters: {trainable_model_params}") |
There was a problem hiding this comment.
@kczimm This will require us to use the main thread for ML workloads in our cloud.
There was a problem hiding this comment.
A PR with that is close. What's the reason we need main thread here?
There was a problem hiding this comment.
We need logging visibility during fine tuning.
There was a problem hiding this comment.
Thanks to a commit by @levkk, we should be able to log from any thread.
189e9f0 to
4bbca96
Compare
| ####################### | ||
|
|
||
|
|
||
| class PGMLCallback(TrainerCallback): |
There was a problem hiding this comment.
I wouldn't be opposed to this functionality living in it's own file like tune.py, since transformers is getting a bit beefy.
There was a problem hiding this comment.
transformers.py is hardcoded in several places. Needs some more refactoring and testing to accomplish moving finetuning code to tune.py. Will revisit this in the next iteration. #1378
| self.model_id = model_id | ||
|
|
||
| def on_log(self, args, state, control, logs=None, **kwargs): | ||
| _ = logs.pop("total_flos", None) |
There was a problem hiding this comment.
Why throw away total_flos?
| } | ||
|
|
||
| #[pyfunction] | ||
| fn print_info(info: String) -> PyResult<String> { |
There was a problem hiding this comment.
I think this would be more reusable as log(level, msg)
| else: | ||
| self.model_name = hyperparameters.pop("model_name") | ||
|
|
||
| if "token" in hyperparameters: |
There was a problem hiding this comment.
Isn't this a model init param, not a hyperparam, like many other things in this list? Maybe hyperparams covers everything?
There was a problem hiding this comment.
That's correct. Moved all the parameters to hyperparams.
| trainable_model_params += param.numel() | ||
|
|
||
| # Calculate and print the number and percentage of trainable parameters | ||
| print_info(f"Trainable model parameters: {trainable_model_params}") |
There was a problem hiding this comment.
We need logging visibility during fine tuning.
| y_train, | ||
| x_test, | ||
| y_test, | ||
| Ok::<std::option::Option<()>, i64>(Some(())) // this return type is nonsense |
|
|
||
| let text1_column_value = dataset_args | ||
| .0 | ||
| .get("text1_column") |
There was a problem hiding this comment.
do we require these column names?
There was a problem hiding this comment.
Yes, for text pair classification - (natural language inference, qnli etc.), we need three columns - text1, text2 and the class.
|
|
||
| let system_column_value = dataset_args | ||
| .0 | ||
| .get("system_column") |
There was a problem hiding this comment.
How standard are these names these days?
There was a problem hiding this comment.
For conversation task, system, user and assistant have become standard keys.
| Ok(info) | ||
| } | ||
| /// A Python module implemented in Rust. | ||
| #[pymodule] |
There was a problem hiding this comment.
Since this crate is interdependent, what if we moved this whole pymodule into the main pgml-extension crate, under bindings/python/mod.rs instead of publishing it as a separate crate?
|
|
||
|
|
||
| # venv | ||
| pgml-venv No newline at end of file |
| project_id BIGINT, | ||
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | ||
| logs JSONB | ||
| ); No newline at end of file |
Example: https://github.com/postgresml/postgresml/tree/santi-llm-fine-tuning?tab=readme-ov-file#llm-fine-tuning
Refactored TextDataSet to handle different NLP tasks
Three tasks: text classification, text pair classification, conversation
PEFT/LoRA for conversation task
Pypgrx for callbacks to print info statements and insert logs into pgml.logs table
New tasks have to be added to pgml.tasks:
ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'conversation';ALTER TYPE pgml.task ADD VALUE IF NOT EXISTS 'text_pair_classification';New
pgml.logstable has to be added:Note: Training is initialized using a previous run and model from HF Hub.