Skip to content

Commit 39bfc8a

Browse files
committed
Merge branch 'main' into release/3.6
2 parents 0c94385 + 70bace1 commit 39bfc8a

File tree

6 files changed

+12
-9
lines changed

6 files changed

+12
-9
lines changed

swift/llm/dataset/loader.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def _load_dataset_path(
209209
kwargs = {'split': 'train', 'streaming': streaming, 'num_proc': num_proc}
210210
if file_type == 'csv':
211211
kwargs['na_filter'] = False
212-
dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs)
212+
with safe_ddp_context(None, True):
213+
dataset = hf_load_dataset(file_type, data_files=dataset_path, **kwargs)
213214
if columns:
214215
dataset = RowPreprocessor.safe_rename_columns(dataset, columns)
215216
dataset = dataset_meta.preprocess_func(
@@ -315,7 +316,8 @@ def _select_subsets(subsets: List[str], dataset_meta: DatasetMeta) -> List[Subse
315316
@staticmethod
316317
def shuffle_dataset(dataset, seed: int, buffer_size: int = 1000):
317318
if isinstance(dataset, HfDataset):
318-
return dataset.shuffle(seed=seed)
319+
with safe_ddp_context(None, True):
320+
return dataset.shuffle(seed=seed)
319321
else:
320322
return dataset.shuffle(seed=seed, buffer_size=buffer_size)
321323

@@ -366,8 +368,9 @@ def post_process(
366368
val_sample = max(int(train_len * split_dataset_ratio), 1)
367369
train_sample = dataset_sample - val_sample
368370
assert train_sample > 0
369-
train_dataset, val_dataset = train_dataset.train_test_split(
370-
test_size=val_sample, shuffle=shuffle, seed=get_seed(random_state)).values()
371+
with safe_ddp_context(None, True):
372+
train_dataset, val_dataset = train_dataset.train_test_split(
373+
test_size=val_sample, shuffle=shuffle, seed=get_seed(random_state)).values()
371374
train_dataset = sample_dataset(train_dataset, train_sample, shuffle, random_state)
372375
return train_dataset, val_dataset
373376

swift/ui/llm_grpo/llm_grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
260260
LLMRollout.set_lang(cls.lang)
261261
LLMRollout.build_ui(LLMRollout)
262262
GRPOTuner.build_ui(base_tab)
263-
with gr.Accordion(elem_id='extra_params', open=True):
263+
with gr.Accordion(elem_id='extra_params', open=False):
264264
with gr.Tabs():
265265
GrpoAdvanced.build_ui(base_tab)
266266
GRPOAdvanced.build_ui(base_tab)

swift/ui/llm_grpo/rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class Rollout(BaseUI):
201201

202202
@classmethod
203203
def do_build_ui(cls, base_tab: Type['BaseUI']):
204-
with gr.Accordion(elem_id='rollout_param', open=True):
204+
with gr.Accordion(elem_id='rollout_param', open=False):
205205
with gr.Row():
206206
gr.Slider(elem_id='temperature', minimum=0.0, maximum=10, step=0.1, value=1.0)
207207
gr.Slider(elem_id='top_k', minimum=1, maximum=100, step=5, value=80)

swift/ui/llm_rlhf/llm_rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
272272
RLHFTuner.build_ui(base_tab)
273273
RLHFOptimizer.build_ui(base_tab)
274274
RLHF.build_ui(base_tab)
275-
with gr.Accordion(elem_id='extra_params', open=True):
275+
with gr.Accordion(elem_id='extra_params', open=False):
276276
with gr.Tabs():
277277
RLHFAdvanced.build_ui(base_tab)
278278
RLHFQuantization.build_ui(base_tab)

swift/ui/llm_rlhf/rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class RLHF(BaseUI):
147147

148148
@classmethod
149149
def do_build_ui(cls, base_tab: Type['BaseUI']):
150-
with gr.Accordion(elem_id='rlhf_tab', open=True):
150+
with gr.Accordion(elem_id='rlhf_tab', open=False):
151151
with gr.Blocks():
152152
with gr.Row():
153153
gr.Slider(elem_id='beta', minimum=0., maximum=5.0, step=0.1, value=0.1, scale=10)

swift/ui/llm_train/llm_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
281281
Tuner.build_ui(base_tab)
282282
Optimizer.build_ui(base_tab)
283283
Task.build_ui(base_tab)
284-
with gr.Accordion(elem_id='extra_params', open=True):
284+
with gr.Accordion(elem_id='extra_params', open=False):
285285
with gr.Tabs():
286286
Advanced.build_ui(base_tab)
287287
Quantization.build_ui(base_tab)

0 commit comments

Comments
 (0)